from tensorflow.python.keras.applications.vgg16 import VGG16,preprocess_input,decode_predictions
from tensorflow.python.keras.preprocessing.image import load_img,img_to_array


def predict():
    model = VGG16()
    print(model.summary())
    #预测一张图片的类别
    #加载图片并输入到模型当中
    #(224,224)是VGG的输入要求
    image = load_img("./tiger.png",target_size=(224,224))
    image = img_to_array(image)

    #输入到卷积神经网络当中,需要四维结构
    image = image.reshape((1,image.shape[0],image.shape[1],image.shape[2]))
    print(image.shape)

    #预测之前做图片的数据处理,归一化处理等
    image = preprocess_input(image)
    y_predictions = model.predict(image)

    label = decode_predictions(y_predictions)

    print(label)

if __name__ == '__main__':
    predict()