Pytorch实现WGAN用于动漫头像生成

yizhihongxing

下面是详细讲解“Pytorch实现WGAN用于动漫头像生成”的完整攻略。

概述

本攻略将介绍如何使用Pytorch实现WGAN算法,用于生成动漫头像。WGAN全称为Wasserstein GAN,是针对传统GAN中存在的固有问题,如模式崩溃(mode collapse)等进行改进而提出的一种生成对抗网络算法。

本教程将分为以下两个部分:

  1. 数据准备:包括数据集的下载和预处理
  2. 模型训练:包括模型的搭建和训练

数据准备

数据集下载

本次实验选择的是动漫头像数据集,数据集的下载地址为:https://drive.google.com/drive/folders/17441N9JNzoa8HSo1OCl7rexRQ9pY2R5O。

将数据集下载完成后,解压到任意位置即可。

数据预处理

在数据预处理阶段,我们需要将数据集中的图片转换为统一大小(128 x 128)且格式一致的图像。

from PIL import Image
import os

size = 128, 128
data_dir = 'data/anime/'
filename_list = os.listdir(data_dir)

for filename in filename_list:
    if filename.endswith('.png') or filename.endswith('.jpg'):
        # 读取图片并转换大小
        img = Image.open(os.path.join(data_dir, filename)).convert('RGB')
        img = img.resize(size)
        # 格式转换
        img.save(os.path.join(data_dir, filename.split('/')[-1].split('.')[0]+'.jpg'))
        os.remove(os.path.join(data_dir, filename))

以上代码将会批量处理数据集中的图片,将其转换为统一大小和格式,并将png图片转换为jpg格式。

模型训练

模型搭建

在此教程中,我们选择Pytorch作为深度学习框架,并构建WGAN网络模型。WGAN中包含了两个神经网络:判别器和生成器,两者之间通过对抗学习进行权值的更新。

import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            spectral_norm(nn.Conv2d(3, 64, 3, 1, 1)),
            nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(64, 128, 4, 2, 1)),
            nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(128, 256, 4, 2, 1)),
            nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(256, 512, 4, 2, 1)),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(512, 1, 4, 1, 0)
        )

    def forward(self, x):
        return self.main(x).view(-1)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(128, 1024, 4, 1, 0, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
        self.latent_dim = 128

    def forward(self, x):
        x = x.view(-1, self.latent_dim, 1, 1)
        return self.main(x)

以上代码中,Discriminator和Generator分别表示判别器和生成器的网络结构。其中,Discriminator和Generator都使用了卷积神经网络,以提取图像特征。

模型训练

在模型训练阶段,我们需要对两个网络进行交替训练,并生成合适的随机数据用于模型的生成。

import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from tqdm import tqdm

img_size = (128, 128)
batch_size = 64
n_critic = 5
n_epochs = 200
lr = 0.0001
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 数据预处理
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)
])
dataset = ImageFolder('data/anime', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 初始化模型
d_net = Discriminator().to(device)
g_net = Generator().to(device)
d_opt = torch.optim.RMSprop(d_net.parameters(), lr=lr)
g_opt = torch.optim.RMSprop(g_net.parameters(), lr=lr)

# 训练模型
for epoch in range(n_epochs):
    # 进度条
    data_iter = iter(dataloader)
    with tqdm(total=len(dataloader), ncols=80) as pbar:
        for i in range(len(dataloader)):
            real_imgs = next(data_iter)[0]
            real_imgs = real_imgs.to(device)
            z = torch.randn(real_imgs.size(0), g_net.latent_dim).to(device)

            # 训练判别器
            for j in range(n_critic):
                fake_imgs = g_net(z)

                d_net.zero_grad()
                d_loss = d_net(fake_imgs).mean() - d_net(real_imgs).mean()
                eps = torch.rand(real_imgs.size(0), 1, 1, 1).to(device)
                x_hat = eps * real_imgs + (1 - eps) * fake_imgs.clone().detach()
                d_hat = d_net(x_hat)
                d_reg = 10 * ((d_hat.norm(2, dim=1) - 1) ** 2).mean()
                d_total_loss = d_loss + d_reg
                d_total_loss.backward()
                d_opt.step()

            # 训练生成器
            g_net.zero_grad()
            g_loss = -d_net(g_net(z)).mean()
            g_loss.backward()
            g_opt.step()

            # 进度条更新
            pbar.set_description("Epoch: %d/%d" % (epoch+1, n_epochs))
            pbar.set_postfix({'Loss_D': d_loss.item(), 'Loss_G': g_loss.item()})
            pbar.update(1)

以上代码中,我们首先对数据集进行预处理,并将其划分为批次进行训练。在模型训练过程中,我们需要交替训练判别器和生成器,并对每个批次的数据进行误差的计算和更新。需要注意的是,在训练判别器过程中,我们使用了WGAN算法中的梯度花式惩罚(gradient penalty)。

生成模型

在训练完成后,我们可以通过生成器进行样本的生成。

import torchvision.utils as vutils

# 生成10个样本
samples = torch.randn(10, Generator().latent_dim).to(device)
gen_imgs = g_net(samples)
gen_imgs = 0.5 * (gen_imgs + 1.0)

# 显示生成的样本
vutils.save_image(gen_imgs, 'samples.jpg', normalize=True, nrow=5)

以上代码将会生成10个通过训练好的模型生成的样本,并将结果保存到samples.jpg中。

至此,Pytorch实现WGAN用于动漫头像生成的攻略已经完整展示完成。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch实现WGAN用于动漫头像生成 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • 吴裕雄–天生自然 Tensorflow卷积神经网络:花朵图片识别

    import os import numpy as np import matplotlib.pyplot as plt from PIL import Image, ImageChops from skimage import color,data,transform,io #获取所有数据文件夹名称 fileList = os.listdir(“F:\\d…

    2023年4月8日
    00
  • 【笔记】 卷积

    HDU 4609 3-idiots题目链接题解 这个题考察了如何转化成多项式乘法,然后去重和计数很有意思 HDU 1402 A*B problem plus题目链接 将整数转化成向量,最后得到的卷积后的向量处理一下每一位的进位就是结果 BZOJ 2194 快速傅立叶之二题目链接 FFT 能解决形如 c[k] =sigma(a[p]*b[k-p]) (0&lt…

    2023年4月5日
    00
  • CNN中feature map、卷积核、卷积核的个数、filter、channel的概念解释

    参考链接: https://blog.csdn.net/xys430381_1/article/details/82529397作者写的很好,解决了很多基础问题。 feather map理解 这个是输入经过卷积操作后输出的结果,一般都是二维的多张图片,在论文图上都是以是多张二维图片排列在一起的(像个豆腐皮一样),它们其中的每一个都被称为\(feature \…

    2023年4月8日
    00
  • python神经网络Batch Normalization底层原理详解

    下面是关于Python神经网络Batch Normalization底层原理详解的完整攻略。 Batch Normalization的原理 Batch Normalization是一种用于神经网络的技术,旨在加速训练过程并提高模型的准确性。Batch Normalization通过对每个批次的输入进行归一化来实现这一点,从而使网络更加稳定和可靠。 Batch…

    卷积神经网络 2023年5月16日
    00
  • 深度学习笔记 (一) 卷积神经网络基础 (Foundation of Convolutional Neural Networks)

    一、卷积 卷积神经网络(Convolutional Neural Networks)是一种在空间上共享参数的神经网络。使用数层卷积,而不是数层的矩阵相乘。在图像的处理过程中,每一张图片都可以看成一张“薄饼”,其中包括了图片的高度、宽度和深度(即颜色,用RGB表示)。 在不改变权重的情况下,把这个上方具有k个输出的小神经网络对应的小块滑遍整个图像,可以得到一个…

    2023年4月8日
    00
  • PyTorch 模型 onnx 文件导出及调用详情

    介绍: PyTorch是一个基于Python的科学计算库,它有诸多优异的特性,其中一个重要的特性是它的高效特定GPU加速的张量计算(tensor computation)操作。PyTorch 1.0版本(2018年12月)已经发布,包括了对多平台、多端到端场景的支持,同时完善了跨平台支持。 我们可以使用PyTorch训练模型,然后将训练好的模型导出为ONNX…

    卷积神经网络 2023年5月15日
    00
  • PyG搭建GCN模型实现节点分类GCNConv参数详解

    下面是关于使用PyG搭建GCN模型实现节点分类以及GCNConv参数详解的攻略,包含两个示例说明。 示例1:使用PyG搭建GCN模型实现节点分类 以下是一个使用PyG搭建GCN模型实现节点分类的示例: import torch import torch.nn.functional as F from torch_geometric.datasets impo…

    卷积神经网络 2023年5月16日
    00
  • mxnet卷积神经网络训练MNIST数据集测试

        import numpy as np import mxnet as mx import logging logging.getLogger().setLevel(logging.DEBUG) batch_size = 100 mnist = mx.test_utils.get_mnist() train_iter = mx.io.NDArrayIt…

    卷积神经网络 2023年4月7日
    00
合作推广
合作推广
分享本页
返回顶部