from tensorflow.python.keras.preprocessing.image import load_img,img_to_array
from tensorflow.python.keras.models import Sequential,Model
from tensorflow.python.keras.layers import Dense,Flatten,Input
import tensorflow as tf
from tensorflow.python.keras.losses import sparse_categorical_crossentropy
from tensorflow.python import keras
import os
import numpy as np

class SingleNN(object):

    #建立神经网络模型
    model = keras.Sequential([
        keras.layers.Flatten(input_shape=(28,28)),
        keras.layers.Dense(128,activation=tf.nn.relu),
        keras.layers.Dense(10,activation=tf.nn.softmax)
    ])

    def __init__(self):
        (self.x_train,self.y_train),(self.x_test,self.y_test) = keras.datasets.fashion_mnist.load_data()
        #归一化
        self.x_train = self.x_train/255.0
        self.x_test = self.x_test/255.0

    def singlenn_compile(self):
        '''
        编译模型优化器、损失、准确率
        :return:
        '''
        SingleNN.model.compile(
            optimizer=keras.optimizers.SGD(lr=0.01),
            loss=keras.losses.sparse_categorical_crossentropy,
            metrics=['accuracy']
        )

    def singlenn_fit(self):
        """
        进行fit训练
        :return: 
        """
        # modelcheck = keras.callbacks.ModelCheckpoint("./ckpt/singlenn_{epoch:02d}-{acc:.2f}.h5",
        #                                         # monitor="val_acc", #保存损失还是准确率
        #                                         # save_best_only=True,
        #                                         save_weights_only=True,
        #                                         mode = 'auto',
        #                                         period = 1
        #                                         )
        board = keras.callbacks.TensorBoard(log_dir="./graph",write_graph=True)
        SingleNN.model.fit(self.x_train,self.y_train,epochs=5,callbacks=[board])

    def single_evalute(self):
        '''
        模型评估
        :return: 
        '''
        test_loss,test_acc = SingleNN.model.evaluate(self.x_test,self.y_test)
        print(test_loss,test_acc)

    def single_predict(self):
        '''
        预测结果
        :return: 
        '''
        # if os.path.exists("./ckpt/checkpoink"):
        #     SingleNN.model.load_weights("./ckpt/SingleNN")

        if os.path.exists("./ckpt/SingleNN.h5"):
            SingleNN.model.load_weights("./ckpt/SingleNN.h5")

        predictions = SingleNN.model.predict(self.x_test)

        return predictions

if __name__ == '__main__':
    snn = SingleNN()
    snn.singlenn_compile()
    snn.singlenn_fit()
    snn.single_evalute()
    # # SingleNN.model.save_weights("./ckpt/SingleNN")
    # SingleNN.model.save_weights("./ckpt/SingleNN.h5")
    # predictions = snn.single_predict()
    # print(predictions)
    # result = np.argmax(predictions,axis=1)
    # print(result)