下面是详细讲解“Pytorch实现WGAN用于动漫头像生成”的完整攻略。
概述
本攻略将介绍如何使用Pytorch实现WGAN算法,用于生成动漫头像。WGAN全称为Wasserstein GAN,是针对传统GAN中存在的固有问题,如模式崩溃(mode collapse)等进行改进而提出的一种生成对抗网络算法。
本教程将分为以下两个部分:
- 数据准备:包括数据集的下载和预处理
- 模型训练:包括模型的搭建和训练
数据准备
数据集下载
本次实验选择的是动漫头像数据集,数据集的下载地址为:https://drive.google.com/drive/folders/17441N9JNzoa8HSo1OCl7rexRQ9pY2R5O。
将数据集下载完成后,解压到任意位置即可。
数据预处理
在数据预处理阶段,我们需要将数据集中的图片转换为统一大小(128 x 128)且格式一致的图像。
from PIL import Image
import os
size = 128, 128
data_dir = 'data/anime/'
filename_list = os.listdir(data_dir)
for filename in filename_list:
if filename.endswith('.png') or filename.endswith('.jpg'):
# 读取图片并转换大小
img = Image.open(os.path.join(data_dir, filename)).convert('RGB')
img = img.resize(size)
# 格式转换
img.save(os.path.join(data_dir, filename.split('/')[-1].split('.')[0]+'.jpg'))
os.remove(os.path.join(data_dir, filename))
以上代码将会批量处理数据集中的图片,将其转换为统一大小和格式,并将png图片转换为jpg格式。
模型训练
模型搭建
在此教程中,我们选择Pytorch作为深度学习框架,并构建WGAN网络模型。WGAN中包含了两个神经网络:判别器和生成器,两者之间通过对抗学习进行权值的更新。
import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
spectral_norm(nn.Conv2d(3, 64, 3, 1, 1)),
nn.LeakyReLU(0.2, True),
spectral_norm(nn.Conv2d(64, 128, 4, 2, 1)),
nn.LeakyReLU(0.2, True),
spectral_norm(nn.Conv2d(128, 256, 4, 2, 1)),
nn.LeakyReLU(0.2, True),
spectral_norm(nn.Conv2d(256, 512, 4, 2, 1)),
nn.LeakyReLU(0.2, True),
nn.Conv2d(512, 1, 4, 1, 0)
)
def forward(self, x):
return self.main(x).view(-1)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(128, 1024, 4, 1, 0, bias=False),
nn.BatchNorm2d(1024),
nn.ReLU(True),
nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
self.latent_dim = 128
def forward(self, x):
x = x.view(-1, self.latent_dim, 1, 1)
return self.main(x)
以上代码中,Discriminator和Generator分别表示判别器和生成器的网络结构。其中,Discriminator和Generator都使用了卷积神经网络,以提取图像特征。
模型训练
在模型训练阶段,我们需要对两个网络进行交替训练,并生成合适的随机数据用于模型的生成。
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from tqdm import tqdm
img_size = (128, 128)
batch_size = 64
n_critic = 5
n_epochs = 200
lr = 0.0001
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 数据预处理
transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)
])
dataset = ImageFolder('data/anime', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 初始化模型
d_net = Discriminator().to(device)
g_net = Generator().to(device)
d_opt = torch.optim.RMSprop(d_net.parameters(), lr=lr)
g_opt = torch.optim.RMSprop(g_net.parameters(), lr=lr)
# 训练模型
for epoch in range(n_epochs):
# 进度条
data_iter = iter(dataloader)
with tqdm(total=len(dataloader), ncols=80) as pbar:
for i in range(len(dataloader)):
real_imgs = next(data_iter)[0]
real_imgs = real_imgs.to(device)
z = torch.randn(real_imgs.size(0), g_net.latent_dim).to(device)
# 训练判别器
for j in range(n_critic):
fake_imgs = g_net(z)
d_net.zero_grad()
d_loss = d_net(fake_imgs).mean() - d_net(real_imgs).mean()
eps = torch.rand(real_imgs.size(0), 1, 1, 1).to(device)
x_hat = eps * real_imgs + (1 - eps) * fake_imgs.clone().detach()
d_hat = d_net(x_hat)
d_reg = 10 * ((d_hat.norm(2, dim=1) - 1) ** 2).mean()
d_total_loss = d_loss + d_reg
d_total_loss.backward()
d_opt.step()
# 训练生成器
g_net.zero_grad()
g_loss = -d_net(g_net(z)).mean()
g_loss.backward()
g_opt.step()
# 进度条更新
pbar.set_description("Epoch: %d/%d" % (epoch+1, n_epochs))
pbar.set_postfix({'Loss_D': d_loss.item(), 'Loss_G': g_loss.item()})
pbar.update(1)
以上代码中,我们首先对数据集进行预处理,并将其划分为批次进行训练。在模型训练过程中,我们需要交替训练判别器和生成器,并对每个批次的数据进行误差的计算和更新。需要注意的是,在训练判别器过程中,我们使用了WGAN算法中的梯度花式惩罚(gradient penalty)。
生成模型
在训练完成后,我们可以通过生成器进行样本的生成。
import torchvision.utils as vutils
# 生成10个样本
samples = torch.randn(10, Generator().latent_dim).to(device)
gen_imgs = g_net(samples)
gen_imgs = 0.5 * (gen_imgs + 1.0)
# 显示生成的样本
vutils.save_image(gen_imgs, 'samples.jpg', normalize=True, nrow=5)
以上代码将会生成10个通过训练好的模型生成的样本,并将结果保存到samples.jpg中。
至此,Pytorch实现WGAN用于动漫头像生成的攻略已经完整展示完成。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch实现WGAN用于动漫头像生成 - Python技术站