代码如下:

import numpy as np
import tflearn
from tflearn.layers.core import dropout
from tflearn.layers.normalization import batch_normalization
from tflearn.data_utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import sys


class EarlyStoppingCallback(tflearn.callbacks.Callback):
    def __init__(self, val_acc_thresh):
        """ Note: We are free to define our init function however we please. """
        # Store a validation accuracy threshold, which we can compare against
        # the current validation accuracy at, say, each epoch, each batch step, etc.
        self.val_acc_thresh = val_acc_thresh

    def on_epoch_end(self, training_state):
        """ 
        This is the final method called in trainer.py in the epoch loop. 
        We can stop training and leave without losing any information with a simple exception.  
        """
        #print dir(training_state)
        print("Terminating training at the end of epoch", training_state.epoch)
        if training_state.val_acc >= self.val_acc_thresh and training_state.acc_value >= self.val_acc_thresh:
            raise StopIteration

    def on_train_end(self, training_state):
        """
        Furthermore, tflearn will then immediately call this method after we terminate training, 
        (or when training ends regardless). This would be a good time to store any additional 
        information that tflearn doesn't store already.
        """
        print("Successfully left training! Final model accuracy:", training_state.acc_value)

cols
= ["label", "flow_cnt", "len(srcip_arr)", "len(dstip_arr)", "subdomain_num", "uniq_subdomain_ratio", "np.average(dns_request_len_arr)", "np.average(dns_reply_len_arr)", "np.average(subdomain_tag_num_arr)", "np.average(subdomain_len_arr)", "np.average(subdomain_weird_len_arr)", "np.average(subdomain_entropy_arr)", "A_rr_type_ratio", "incommon_rr_type_rato", "valid_ipv4_ratio", "uniq_valid_ipv4_ratio", "request_reply_ratio", "np.max(dns_request_len_arr)", "np.max(dns_reply_len_arr)", "np.max(subdomain_tag_num_arr)", "np.max(subdomain_len_arr)", "np.max(subdomain_weird_len_arr)", "np.max(subdomain_entropy_arr)", "avg_distance", "std_distance"] #unwanted_cols = set(["uniq_subdomain_ratio", "incommon_rr_type_rato"]) unwanted_cols = set(["uniq_subdomain_ratio", "incommon_rr_type_rato", "np.max(dns_reply_len_arr)", "request_reply_ratio", "uniq_valid_ipv4_ratio", "A_rr_type_ratio"]) wanted_cols = set(['label', 'flow_cnt', 'len(srcip_arr)', 'len(dstip_arr)', 'subdomain_num', 'np.average(dns_request_len_arr)', 'np.average(dns_reply_len_arr)', 'A_rr_type_ratio', 'valid_ipv4_ratio', 'request_reply_ratio', 'np.max(dns_request_len_arr)', 'np.max(dns_reply_len_arr)']) def parse_line(s): s = s.replace("(", "").replace(")", "").replace("[", "").replace("]", "") #dat = [float(_) for i,_ in enumerate(s.split(",")) if cols[i] not in unwanted_cols] dat = [float(_) for i,_ in enumerate(s.split(",")) if cols[i] in wanted_cols] return dat if __name__ == "__main__": training_data = [] with open("feature_with_dnn_todo.dat") as f: training_data = [parse_line(line) for line in f] #sys.exit(0) X = training_data org_labels = [1 if int(x[0])==2.0 else 0 for x in X] labels = to_categorical(org_labels, nb_classes=2) data = [x[1:] for x in X] input_dim = len(data[0]) X = data Y = labels print "X len:", len(X), "Y len:", len(Y) trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.2, random_state=42) print trainX[0] print trainY[0] print testX[-1] print testY[-1] # Build neural network net = tflearn.input_data(shape=[None, input_dim]) net = batch_normalization(net) net = tflearn.fully_connected(net, input_dim) net = tflearn.fully_connected(net, 128, activation='tanh') net = dropout(net, 0.5) net = tflearn.fully_connected(net, 2, activation='softmax') net = tflearn.regression(net, optimizer='adam', learning_rate=0.001, loss='categorical_crossentropy', name='target') # Define model model = tflearn.DNN(net) # Start training (apply gradient descent algorithm) # Initialize our callback with desired accuracy threshold. early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.998) try: model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, batch_size=8, show_metric=True, callbacks=early_stopping_cb) except StopIteration as e: print "pass" filename = 'tf_model/dns_tunnel2_998.tflearn' model.save(filename) model.load(filename) y_predict_list = model.predict(X) y_predict = [] for i in y_predict_list: #print i[0] if i[0] >= 0.5: y_predict.append(0) else: y_predict.append(1) print(classification_report(org_labels, y_predict)) print confusion_matrix(org_labels, y_predict)

结果:

('Terminating training at the end of epoch', 175)
Training Step: 309936  | total loss: 0.00695 | time: 4.371s
| Adam | epoch: 176 | loss: 0.00695 - acc: 0.9988 | val_loss: 0.00661 - val_acc: 0.9991 -- iter: 14084/14084
--
('Terminating training at the end of epoch', 176)
('Successfully left training! Final model accuracy:', 0.9987633228302002)
pass
             precision    recall  f1-score   support

          0       1.00      1.00      1.00     16529
          1       0.97      0.99      0.98      1076

avg / total       1.00      1.00      1.00     17605

从混淆矩阵看,还是非常不错的!
[[16497    32]
 [    8  1068]]

输入数据样例:

(2.0,[39.0,1.0,2.0,38.0,0.974358974359,85.0,86.6666666667,3.0,30.0,0.0,3.84923785837,1.0,0.0,0.512820512821,0.025641025641,0.00150829562594,85.0,169.0,3.0,30.0,0.0,3.98989809546,2.54054054054,1.15301237879])
(2.0,[4437.0,3.0,10.0,13.0,0.00292990759522,48.554428668,45.3955375254,1.92307692308,91.3846153846,0.0,3.69230769231,0.972954699121,0.0,0.0,0.0,2.32087487699e-05,138.0,138.0,2.0,100.0,0.0,4.0,15.25,30.5753849799])
(2.0,[115.0,4.0,8.0,11.0,0.095652173913,99.2260869565,47.0347826087,2.0,74.7272727273,0.0,4.24137616275,0.0,0.0,0.0,0.0,0.000438173692052,131.0,131.0,2.0,82.0,0.0,4.3128598958,7.9,14.1594491418])
(2.0,[177.0,2.0,8.0,11.0,0.0621468926554,88.3389830508,35.6327683616,2.0,66.0,0.0,4.17962650637,0.0,0.0,0.0,0.0,0.000319774878486,115.0,115.0,2.0,66.0,0.0,4.17962650637,2.0,0.0])
(2.0,[38.0,7.0,6.0,23.0,0.605263157895,59.0263157895,120.473684211,1.0,20.5652173913,0.0,3.55684374229,0.657894736842,0.0,0.0263157894737,0.0263157894737,0.00222915737851,65.0,267.0,1.0,26.0,0.0,3.97366068969,14.7727272727,3.20414246338])
(2.0,[232.0,4.0,8.0,18.0,0.0775862068966,94.5301724138,39.9224137931,2.0,71.3333333333,0.0,4.19859571366,0.0,0.0,0.0,0.0,0.000227987779855,131.0,131.0,2.0,82.0,0.0,4.28968752349,5.47058823529,11.241298057])
(2.0,[90.0,3.0,8.0,12.0,0.133333333333,97.6,63.7222222222,2.0,74.0,0.0,4.23623035806,0.0,0.0,0.0,0.0,0.000569216757741,131.0,131.0,2.0,82.0,0.0,4.3128598958,7.36363636364,13.6066342594])
(2.0,[419.0,1.0,2.0,355.0,0.847255369928,72.9403341289,88.2816229117,3.0,30.0,0.0,3.80441789011,1.0,0.0,0.980906921241,0.00238663484487,0.000163601858517,74.0,90.0,3.0,30.0,0.0,4.05656476213,1.86440677966,0.654172884041])
(2.0,[132.0,2.0,8.0,12.0,0.0909090909091,83.446969697,38.446969697,2.0,66.0,0.0,4.15523801434,0.0,0.0,0.0,0.0,0.000453926463913,115.0,115.0,2.0,66.0,0.0,4.15523801434,2.0,0.0])
(2.0,[12399.0,9.0,8.0,48.0,0.00387127994193,131.489636261,63.534236632,2.0,86.5416666667,0.0,4.29632333151,0.92402613114,0.0,0.0,0.0,3.06684495259e-06,143.0,143.0,2.0,94.0,0.0,4.37237921923,7.34042553191,13.9897783289])
(2.0,[13659.0,11.0,11.0,55.0,0.00402664909583,131.545574347,65.8218756864,2.0,88.3272727273,0.0,4.34545972513,0.933670107621,0.0,0.0,0.0,2.78275427e-06,145.0,145.0,2.0,96.0,0.0,4.48022025041,8.31481481481,15.5072552602])
(2.0,[187.0,2.0,5.0,94.0,0.502673796791,88.1229946524,139.229946524,1.98936170213,43.9042553191,0.0,4.27189155149,0.502673796791,0.0,0.0,0.0,0.000303416469446,111.0,701.0,2.0,67.0,0.0,4.56541251219,21.5161290323,7.83926277973])
(2.0,[13651.0,11.0,8.0,50.0,0.00366273533075,131.740458574,66.4286132884,1.98,76.26,0.0,4.30942940291,0.955461138378,0.0,0.0,0.0,2.78026611595e-06,145.0,145.0,2.0,96.0,0.0,4.43135478727,11.6734693878,19.406907833])
(2.0,[13867.0,6.0,8.0,48.0,0.00346145525348,131.98341386,66.6828441624,1.97916666667,83.8541666667,0.0,4.28707673609,0.946347443571,0.0,0.0,0.0,2.73192096662e-06,143.0,143.0,2.0,94.0,0.0,4.3688366088,5.53191489362,11.7361979849])
(2.0,[12882.0,10.0,8.0,58.0,0.00450240645862,130.423381463,63.3864306785,1.96551724138,76.7068965517,0.0,3.93103448276,0.938674118926,0.0,0.0,0.0,2.97598853411e-06,143.0,143.0,2.0,94.0,0.0,4.0,8.98245614035,16.8912841929])
(2.0,[258.0,3.0,2.0,76.0,0.294573643411,77.0,80.9263565891,3.0,29.0,0.0,3.75053197533,0.492248062016,0.0,0.259689922481,0.00387596899225,0.000251686298198,77.0,630.0,3.0,29.0,0.0,3.89246375375,2.74666666667,1.87682947784])
(2.0,[14147.0,12.0,8.0,52.0,0.00367569095921,131.023397187,64.6592210363,1.96153846154,79.8461538462,0.0,4.3284491183,0.922032939846,0.0,0.0,0.0,2.69747106693e-06,143.0,143.0,2.0,94.0,0.0,4.42489759102,11.0588235294,19.7974169089])
(2.0,[13970.0,9.0,8.0,70.0,0.0050107372942,131.400501074,66.2702219041,1.98571428571,82.3571428571,0.0,4.29402071493,0.919183965641,0.0,0.0,0.0,2.72380853805e-06,143.0,143.0,2.0,94.0,0.0,4.36589057319,7.26086956522,13.9778833316])
(2.0,[13431.0,8.0,8.0,48.0,0.00357382175564,131.08234681,65.0862184499,1.97916666667,73.8958333333,0.0,4.26146070383,0.951604497059,0.0,0.0,0.0,2.83999416097e-06,133.0,133.0,2.0,84.0,0.0,4.41617048851,7.82978723404,15.2752325593])
(2.0,[13196.0,7.0,8.0,50.0,0.00378902697787,131.38898151,65.7071082146,2.0,84.28,0.0,4.3113658841,0.921718702637,0.0,0.0,0.0,2.88382399676e-06,143.0,143.0,2.0,94.0,0.0,4.39200898112,7.0612244898,12.5004622988])