from tensorflow.python.keras.applications.vgg16 import VGG16,preprocess_input,decode_predictions from tensorflow.python.keras.preprocessing.image import load_img,img_to_array def predict(): model = VGG16() print(model.summary()) #预测一张图片的类别 #加载图片并输入到模型当中 #(224,224)是VGG的输入要求 image = load_img("./tiger.png",target_size=(224,224)) image = img_to_array(image) #输入到卷积神经网络当中,需要四维结构 image = image.reshape((1,image.shape[0],image.shape[1],image.shape[2])) print(image.shape) #预测之前做图片的数据处理,归一化处理等 image = preprocess_input(image) y_predictions = model.predict(image) label = decode_predictions(y_predictions) print(label) if __name__ == '__main__': predict()
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:TensorFlow keras vgg16net的使用 - Python技术站