最近学习DeepLearning, 在网上找到了一个自编码器的代码,运行以下,还比较好用,分享如下。由于代码出处无处可考,故不予特殊说明。
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Mon Jan 1 12:45:57 2018 @author: pc """ import torch import torch.nn as nn from torch.autograd import Variable import torch.utils.data as Data import torchvision import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from matplotlib import cm import numpy as np # torch.manual_seed(1) # reproducible # Hyper Parameters EPOCH = 10 BATCH_SIZE = 100 LR = 0.005 # learning rate DOWNLOAD_MNIST = True N_TEST_IMG = 5 # Mnist digits dataset train_data = torchvision.datasets.MNIST( root='./mnist/', train=True, # this is training data transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0] download=DOWNLOAD_MNIST, # download it if you don't have it ) # plot one example print(train_data.train_data.size()) # (60000, 28, 28) print(train_data.train_labels.size()) # (60000) plt.imshow(train_data.train_data[2].numpy(), cmap='gray') plt.title('%i' % train_data.train_labels[2]) plt.show() # Data Loader for easy mini-batch return in training, the image batch shape will be (100, 1, 28, 28) train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) class AutoEncoder(nn.Module): def __init__(self): super(AutoEncoder, self).__init__() self.encoder = nn.Sequential( nn.Linear(28*28, 128), nn.Tanh(), nn.Linear(128, 64), nn.Tanh(), nn.Linear(64, 12), nn.Tanh(), nn.Linear(12, 3), # compress to 3 features which can be visualized in plt ) self.decoder = nn.Sequential( nn.Linear(3, 12), nn.Tanh(), nn.Linear(12, 64), nn.Tanh(), nn.Linear(64, 128), nn.Tanh(), nn.Linear(128, 28*28), nn.Sigmoid(), # compress to a range (0, 1) ) def forward(self, x): encoded = self.encoder(x) decoded = self.decoder(encoded) return encoded, decoded autoencoder = AutoEncoder() optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR) loss_func = nn.MSELoss() # initialize figure f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2)) plt.ion() # continuously plot # original data (first row) for viewing view_data = Variable(train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.) for i in range(N_TEST_IMG): a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray'); a[0][i].set_xticks(()); a[0][i].set_yticks(()) for epoch in range(EPOCH): for step, (x, y) in enumerate(train_loader): b_x = Variable(x.view(-1, 28*28)) # batch x, shape (batch, 28*28) b_y = Variable(x.view(-1, 28*28)) # batch y, shape (batch, 28*28) b_label = Variable(y) # batch label encoded, decoded = autoencoder(b_x) loss = loss_func(decoded, b_y) # mean square error optimizer.zero_grad() # clear gradients for this training step loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients if step % 100 == 0: print('Epoch: ', epoch, '| train loss: %.4f' % loss.data[0]) # plotting decoded image (second row) _, decoded_data = autoencoder(view_data) for i in range(N_TEST_IMG): a[1][i].clear() a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap='gray') a[1][i].set_xticks(()); a[1][i].set_yticks(()) plt.draw(); plt.pause(0.05) plt.ioff() plt.show() # visualize in 3D plot
以上代码为 pytorch
运行效果图:
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:深度学习之自编码器 示例 - Python技术站