#基于IMDB数据集的简单文本分类任务
#一层embedding层+一层lstm层+一层全连接层
#基于Keras 2.1.1 Tensorflow 1.4.0
代码:
1 '''Trains an LSTM model on the IMDB sentiment classification task. 2 The dataset is actually too small for LSTM to be of any advantage 3 compared to simpler, much faster methods such as TF-IDF + LogReg. 4 # Notes 5 - RNNs are tricky. Choice of batch size is important, 6 choice of loss and optimizer is critical, etc. 7 Some configurations won't converge. 8 - LSTM loss decrease patterns during training can be quite different 9 from what you see with CNNs/MLPs/etc. 10 ''' 11 from __future__ import print_function 12 13 from keras.preprocessing import sequence 14 from keras.models import Sequential 15 from keras.layers import Dense, Embedding 16 from keras.layers import LSTM 17 from keras.datasets import imdb 18 19 max_features = 20000 20 maxlen = 80 # cut texts after this number of words (among top max_features most common words) 21 batch_size = 32 22 23 print('Loading data...') 24 (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features) 25 print(len(x_train), 'train sequences') 26 print(len(x_test), 'test sequences') 27 28 print('Pad sequences (samples x time)') 29 x_train = sequence.pad_sequences(x_train, maxlen=maxlen) 30 x_test = sequence.pad_sequences(x_test, maxlen=maxlen) 31 print('x_train shape:', x_train.shape) 32 print('x_test shape:', x_test.shape) 33 34 print('Build model...') 35 model = Sequential() 36 model.add(Embedding(max_features, 128)) 37 model.add(LSTM(128, dropout=0.2, recurrent_dropout=0.2)) 38 model.add(Dense(1, activation='sigmoid')) 39 model.summary() 40 41 # try using different optimizers and different optimizer configs 42 model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy']) 43 44 print('Train...') 45 model.fit(x_train, y_train,batch_size=batch_size,epochs=15,validation_data=(x_test, y_test)) 46 score, acc = model.evaluate(x_test, y_test,batch_size=batch_size) 47 print('Test score:', score) 48 print('Test accuracy:', acc)
结果:
Test accuracy: 0.81248
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Keras lstm 文本分类示例 - Python技术站