Tensorflow =1.8.0
# -*- coding: utf-8 -*-
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
import numpy as np
import pandas as pd
from keras.models import Sequential # 链式构建模型
from keras.layers import Dense # 全连接层
from keras.wrappers.scikit_learn import KerasClassifier
from keras.utils import np_utils
from sklearn.model_selection import cross_val_score # 交叉验证
from sklearn.model_selection import KFold # 数据分割,1个作为test,k-1个作为train
from sklearn.preprocessing import LabelEncoder
from keras.models import model_from_json # 模型保存
# reproducibility
seed = 13
np.random.seed(seed)
#load data
df = pd.read_csv('iris.csv')
X = df.values[:, 1:5].astype(float)
Y = df.values[:, 5]
encoder = LabelEncoder()
Y_encoder = encoder.fit_transform(Y) # 把文字标签变成数字标签
Y_onehot = np_utils.to_categorical(Y_encoder) # convert to one_hot label
# input=4,hidden=7,output=3
def baseline_model():
model=Sequential()
model.add(Dense(7, input_dim=4,activation='tanh'))
model.add(Dense(3, activation='softmax'))
model.compile(loss='mean_squared_error',optimizer='sgd',metrics=['accuracy'])
return model
estimator = KerasClassifier(build_fn=baseline_model, epochs=20, batch_size=1, verbose=1)
# evalute
kfold=KFold(n_splits=10,shuffle=True, random_state=seed)
result = cross_val_score(estimator, X, Y_onehot, cv=kfold)
print("Accuray of cross validation, mean %.2f, std %.2f" %(result.mean(),result.std()))
# save model
estimator.fit(X,Y_onehot)
model_json =estimator.model.to_json()
with open("model.json","w") as json_file:
json_file.write(model_json)
estimator.model.save_weights("model.h5")
print("save model to disk")
# load model and use it for prediction
json_file=open("model.json","r")
loaded_model_json=json_file.read()
json_file.close()
loaded_model=model_from_json(loaded_model_json)
loaded_model.load_weights("model.h5")
print("loaded model from disk")
predicted = loaded_model.predict(X)
print("predicted probability" + str(predicted))
predicted_label=loaded_model.predict_classes(X)
print("predicted label:" + str(predicted_label))
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:使用keras构建简单的网络分类鸢尾花 - Python技术站