1 import numpy as np 2 import tensorflow as tf 3 from tensorflow.keras.layers import Dense, SimpleRNN 4 import matplotlib.pyplot as plt 5 import os 6 7 8 input_word = 'abcde' 9 w_to_id = {'a':0, 'b':1, 'c':2, 'd':3, 'e':4} 10 id_to_onehot = {0:[1., 0., 0., 0., 0.], 1:[0., 1., 0., 0., 0.], 2:[0., 0., 1., 0., 0., ], 3:[0., 0., 0., 1., 0.], 11 4:[0., 0., 0., 0., 1.]} 12 13 14 x_train = [[id_to_onehot[w_to_id['a']], id_to_onehot[w_to_id['b']], id_to_onehot[w_to_id['c']], id_to_onehot[w_to_id['d']]], 15 [id_to_onehot[w_to_id['b']], id_to 16 [id_to_onehot[w_to_id['d']], id__onehot[w_to_id['c']], id_to_onehot[w_to_id['d']], id_to_onehot[w_to_id['e']]], 17 [id_to_onehot[w_to_id['c']], id_to_onehot[w_to_id['d']], id_to_onehot[w_to_id['e']], id_to_onehot[w_to_id['a']]],to_onehot[w_to_id['e']], id_to_onehot[w_to_id['a']], id_to_onehot[w_to_id['b']]], 18 [id_to_onehot[w_to_id['e']], id_to_onehot[w_to_id['a']], id_to_onehot[w_to_id['b']], id_to_onehot[w_to_id['c']]]] 19 y_train = [w_to_id['e'], w_to_id['a'], w_to_id['b'], w_to_id['c'], w_to_id['d']] 20 21 22 print(x_train) 23 print(y_train) 24 25 26 np.random.seed(7) 27 np.random.shuffle(x_train) 28 np.random.seed(7) 29 np.random.shuffle(y_train) 30 tf.random.set_seed(7) 31 32 33 # 使x_train符合SimpleRNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]。 34 # 此处整个数据集送入,送入样本数为len(x_train);输入4个字母出结果,循环核时间展开步数为4; 表示为独热码有5个输入特征,每个时间步输入特征个数为5 35 x_train = np.reshape(x_train, (len(x_train), 4, 5)) 36 y_train = np.array(y_train) 37 38 39 model = tf.keras.models.Sequential([ 40 SimpleRNN(3), 41 Dense(5, activation='softmax') 42 ]) 43 44 model.compile(optimizer=tf.keras.optimizers.Adam(0.01), 45 loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), 46 metrics=['sparse_categorical_accuracy']) 47 48 checkpoint_save_path = './checkpoint/rnn_onehot_4pre1.ckpt' 49 50 if os.path.exists(checkpoint_save_path + '.index'): 51 print('-----------load the model-------------------') 52 model.load_weigts(checkpoint_save_path) 53 54 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, 55 save_weights_only=True, 56 save_best_only=True, 57 monitor='loss') 58 59 history = model.fit(x_train, y_train, batch_size=32, epochs=100, callbacks=[cp_callback]) 60 61 model.summary() 62 63 64 with open('./weights.txt', 'w') as f: 65 for v in model.trainable_variables: 66 f.write(str(v.name) +'\n') 67 f.write(str(v.shape) + '\n') 68 f.write(str(v.numpy()) + '\n') 69 70 71 72 acc = history.history['sparse_categorical_accuracy'] 73 loss = history.history['loss'] 74 75 plt.subplot(1, 2, 1) 76 plt.plot(acc, label='Training Accuracy') 77 plt.title('Training Accuracy') 78 plt.legend() 79 80 plt.subplot(1, 2, 2) 81 plt.plot(loss, label='Training Loss') 82 plt.title('Training Loss') 83 plt.legend() 84 plt.show() 85 86 87 88 preNum = int(input("input the number of test alphabet:")) 89 for i in range(preNum): 90 alphabet1 = input("input test alphabet:") 91 alphabet = [id_to_onehot[w_to_id[a]] for a in alphabet1] 92 #使alphabet符合SimpleRNN输入要求:[送入样本数, 循环核时间展开步数, 93 #每个时间步输入特征个数]。此处验证效果送入了1个样本,送入样本数为1;输入4个字母出结果, 94 #所以循环核时间展开步数为4; 表示为独热码有5个输入特征,每个时间步输入特征个数为5 95 alphabet = np.reshape(alphabet, (1, 4, 5)) 96 result = model.predict([alphabet]) 97 pred = tf.argmax(result, axis=1) 98 pred = int(pred) 99 tf.print(alphabet1 + '->' + input_word[pred])
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:第六讲 循环神经网路——SImpleRNN_onehot_4pred1 - Python技术站