Keras提供了许多常用数据集,例如MNIST、CIFAR-10等,以及训练好的模型,如VGG16、ResNet50等。在使用这些数据集和模型时,我们需要知道它们所存放的位置。
数据集存放位置
Keras数据集默认存放在用户目录下的".keras/datasets"文件夹中。当我们第一次调用某个数据集时,Keras会自动下载并解压至该文件夹中。例如我们调用MNIST数据集时,Keras会自动从官网下载mnist.npz文件,然后存放至上述文件夹中。
我们可以使用下面的代码来获取MNIST数据集的存放位置:
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
mnist_dataset_path = mnist.get_file('mnist.npz', origin='https://s3.amazonaws.com/img-datasets/mnist.npz')
其中,mnist.get_file()函数可以检查文件是否已经存在,如果不存在则自动下载文件并返回文件路径。我们可以将其保存在一个变量中,以方便后续使用。
类似的,CIFAR-10数据集以及其它常用数据集也可以通过该方法获取其存放位置,具体请参考官方文档。
模型存放位置
Keras提供的一些常见模型,如VGG16、ResNet50等,可以通过直接调用keras.applications模块中的对应函数来使用。这些模型的预训练权重文件默认存放在用户目录下的".keras/models"文件夹中。当我们第一次调用某个模型函数时,Keras会自动下载其对应的预训练权重,并存放至该文件夹中。例如我们调用VGG16模型时,Keras会自动从官网下载vgg16_weights_tf_dim_ordering_tf_kernels.h5文件,并存放至上述文件夹中。
我们可以使用下面的代码来获取VGG16模型的预训练权重文件的存放位置:
from keras.applications.vgg16 import VGG16
vgg16_model_weights_path = VGG16(weights='imagenet').weights_path
其中,VGG16函数的weights参数设为'imagenet',表示使用预训练权重。VGG16函数返回的对象有一个weights_path属性,即为其权重文件的路径。
类似的,ResNet50、InceptionV3等常用模型也可以通过该方法获取其预训练权重文件的存放位置,具体请参考官方文档。
示例一:获取CIFAR-10数据集存放位置
from keras.datasets import cifar10
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()
cifar10_dataset_path = cifar10.get_file('cifar-10-batches-py.tar.gz',
origin='https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz',
extract=True)
此代码将自动下载并解压CIFAR-10数据集,并将其存放在mnist_dataset_path变量所指的文件夹中。
示例二:获取ResNet50模型的预训练权重文件的存放位置
from keras.applications.resnet50 import ResNet50
resnet50_model_weights_path = ResNet50(weights='imagenet').weights_path
此代码将自动下载ResNet50模型的预训练权重文件,并将其存放在resnet50_model_weights_path指定的路径中。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Keras自动下载的数据集/模型存放位置介绍 - Python技术站