代码:
# -- coding: gbk -- from sklearn.datasets import load_breast_cancer from pylab import * from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression from sklearn.svm import LinearSVC from sklearn.datasets import make_blobs import mglearn def LogisticRegression二分类线性模型(): cancer = load_breast_cancer() X_train, X_test, y_train, y_test = train_test_split( cancer.data, cancer.target, stratify=cancer.target, random_state=42) '''构建模型''' logreg = LogisticRegression().fit(X_train, y_train) #print(logreg.predict()) '''评测''' print("Training set score: {:.3f}".format(logreg.score(X_train, y_train))) print("Test set score: {:.3f}".format(logreg.score(X_test, y_test))) '''增加C拟合灵活度——————更高训练集精度''' logreg100 = LogisticRegression(C=100).fit(X_train, y_train) print("Training set score: {:.3f}".format(logreg100.score(X_train, y_train))) print("Test set score: {:.3f}".format(logreg100.score(X_test, y_test))) def LinearSVC一对其余分类器(): X, y = make_blobs(random_state=42) linear_svm = LinearSVC().fit(X, y) ''' coef_的形状是(3, 2),说明coef_每行包含三个类别之一的系数向量, 每列包含某个特征(这个数据集有2个特征)对应的系数值。 现在intercept_是一维数组,保存每个类别的截距。 ''' print(linear_svm.coef_) # 特征 print(linear_svm.intercept_) # 截距 mglearn.discrete_scatter(X[:, 0], X[:, 1], y) line = np.linspace(-15, 15) print(line) for coef, intercept, color in zip(linear_svm.coef_, linear_svm.intercept_, ['b', 'r', 'g']): plt.plot(line, -(line * coef[0] + intercept) / coef[1], c=color) plt.ylim(-10, 15) plt.xlim(-10, 8) plt.xlabel("Feature 0") plt.ylabel("Feature 1") plt.legend(['Class 0', 'Class 1', 'Class 2', 'Line class 0', 'Line class 1', 'Line class 2'], loc=(1.01, 0.3)) plt.show() if __name__ =='__main__': cancer = load_breast_cancer() X_train, X_test, y_train, y_test = train_test_split( cancer.data, cancer.target, stratify=cancer.target, random_state=42) logreg = LogisticRegression().fit(X_train, y_train) y_pred=logreg.predict(X_test) print(np.mean(y_pred==y_test))
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:简单机器学习——最简单分类算法(LogisticRegression二分类线性模型、LinearSVC一对其余分类器) - Python技术站