Pytorch实现WGAN用于动漫头像生成

下面是详细讲解“Pytorch实现WGAN用于动漫头像生成”的完整攻略。

概述

本攻略将介绍如何使用Pytorch实现WGAN算法,用于生成动漫头像。WGAN全称为Wasserstein GAN,是针对传统GAN中存在的固有问题,如模式崩溃(mode collapse)等进行改进而提出的一种生成对抗网络算法。

本教程将分为以下两个部分:

  1. 数据准备:包括数据集的下载和预处理
  2. 模型训练:包括模型的搭建和训练

数据准备

数据集下载

本次实验选择的是动漫头像数据集,数据集的下载地址为:https://drive.google.com/drive/folders/17441N9JNzoa8HSo1OCl7rexRQ9pY2R5O。

将数据集下载完成后,解压到任意位置即可。

数据预处理

在数据预处理阶段,我们需要将数据集中的图片转换为统一大小(128 x 128)且格式一致的图像。

from PIL import Image
import os

size = 128, 128
data_dir = 'data/anime/'
filename_list = os.listdir(data_dir)

for filename in filename_list:
    if filename.endswith('.png') or filename.endswith('.jpg'):
        # 读取图片并转换大小
        img = Image.open(os.path.join(data_dir, filename)).convert('RGB')
        img = img.resize(size)
        # 格式转换
        img.save(os.path.join(data_dir, filename.split('/')[-1].split('.')[0]+'.jpg'))
        os.remove(os.path.join(data_dir, filename))

以上代码将会批量处理数据集中的图片,将其转换为统一大小和格式,并将png图片转换为jpg格式。

模型训练

模型搭建

在此教程中,我们选择Pytorch作为深度学习框架,并构建WGAN网络模型。WGAN中包含了两个神经网络:判别器和生成器,两者之间通过对抗学习进行权值的更新。

import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            spectral_norm(nn.Conv2d(3, 64, 3, 1, 1)),
            nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(64, 128, 4, 2, 1)),
            nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(128, 256, 4, 2, 1)),
            nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(256, 512, 4, 2, 1)),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(512, 1, 4, 1, 0)
        )

    def forward(self, x):
        return self.main(x).view(-1)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(128, 1024, 4, 1, 0, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
        self.latent_dim = 128

    def forward(self, x):
        x = x.view(-1, self.latent_dim, 1, 1)
        return self.main(x)

以上代码中,Discriminator和Generator分别表示判别器和生成器的网络结构。其中,Discriminator和Generator都使用了卷积神经网络,以提取图像特征。

模型训练

在模型训练阶段,我们需要对两个网络进行交替训练,并生成合适的随机数据用于模型的生成。

import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from tqdm import tqdm

img_size = (128, 128)
batch_size = 64
n_critic = 5
n_epochs = 200
lr = 0.0001
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 数据预处理
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)
])
dataset = ImageFolder('data/anime', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 初始化模型
d_net = Discriminator().to(device)
g_net = Generator().to(device)
d_opt = torch.optim.RMSprop(d_net.parameters(), lr=lr)
g_opt = torch.optim.RMSprop(g_net.parameters(), lr=lr)

# 训练模型
for epoch in range(n_epochs):
    # 进度条
    data_iter = iter(dataloader)
    with tqdm(total=len(dataloader), ncols=80) as pbar:
        for i in range(len(dataloader)):
            real_imgs = next(data_iter)[0]
            real_imgs = real_imgs.to(device)
            z = torch.randn(real_imgs.size(0), g_net.latent_dim).to(device)

            # 训练判别器
            for j in range(n_critic):
                fake_imgs = g_net(z)

                d_net.zero_grad()
                d_loss = d_net(fake_imgs).mean() - d_net(real_imgs).mean()
                eps = torch.rand(real_imgs.size(0), 1, 1, 1).to(device)
                x_hat = eps * real_imgs + (1 - eps) * fake_imgs.clone().detach()
                d_hat = d_net(x_hat)
                d_reg = 10 * ((d_hat.norm(2, dim=1) - 1) ** 2).mean()
                d_total_loss = d_loss + d_reg
                d_total_loss.backward()
                d_opt.step()

            # 训练生成器
            g_net.zero_grad()
            g_loss = -d_net(g_net(z)).mean()
            g_loss.backward()
            g_opt.step()

            # 进度条更新
            pbar.set_description("Epoch: %d/%d" % (epoch+1, n_epochs))
            pbar.set_postfix({'Loss_D': d_loss.item(), 'Loss_G': g_loss.item()})
            pbar.update(1)

以上代码中,我们首先对数据集进行预处理,并将其划分为批次进行训练。在模型训练过程中,我们需要交替训练判别器和生成器,并对每个批次的数据进行误差的计算和更新。需要注意的是,在训练判别器过程中,我们使用了WGAN算法中的梯度花式惩罚(gradient penalty)。

生成模型

在训练完成后,我们可以通过生成器进行样本的生成。

import torchvision.utils as vutils

# 生成10个样本
samples = torch.randn(10, Generator().latent_dim).to(device)
gen_imgs = g_net(samples)
gen_imgs = 0.5 * (gen_imgs + 1.0)

# 显示生成的样本
vutils.save_image(gen_imgs, 'samples.jpg', normalize=True, nrow=5)

以上代码将会生成10个通过训练好的模型生成的样本,并将结果保存到samples.jpg中。

至此,Pytorch实现WGAN用于动漫头像生成的攻略已经完整展示完成。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch实现WGAN用于动漫头像生成 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • 空间域图像增强:卷积和空间域滤波

    1、什么是卷积?   卷积:数学中关于两个函数的一种无穷积分运算,是通过两个函数f 和g 生成第三个函数的一种数学算子,表征函数f 经过翻转和平移与g 的重叠部分的累积。 2、什么是空间卷积? ž  线性空间滤波又称为空间卷积   在执行线性空间滤波时,我们必须理解两个相关的含义,相关和卷积。相关是掩膜w在下图1图像f中移动的过程。卷积是相同的过程,只是在图…

    2023年4月6日
    00
  • 积性函数求和:构造狄利克雷卷积将值域限定于powerful number

    前情提要:$O(n^{0.75}/\log n)$ 时间的积性函数求和。当 $n \ge 10^{12}$ 的时候需要十几秒出解。 如果积性函数的性质更好,那么我们可以更快地求和。 假设积性函数 $f$ 和易于求和的积性函数 $g$ 满足 $f(p)=g(p)$,且 $f=g*h$, $g*h$ 表示 $g, h$ 的狄利克雷卷积,也就是 $f(n)=\su…

    卷积神经网络 2023年4月7日
    00
  • Deep Learning系统实训之三:卷积神经网络

      边界填充(padding):卷积过程中,越靠近图片中间位置的像素点越容易被卷积计算多次,越靠近边缘的像素点被卷积计算的次数越少,填充就是为了使原来边缘像素点的位置变得相对靠近中部,而我们又不想让填充的数据影响到我们的计算结果,故填充值选择均用0来填充。 池化层不需要参数、只是对特征图进行压缩操作,以减少计算量:池化几乎不用平均池化,多用最大池化操作,对于…

    2023年4月8日
    00
  • 深度学习面试题10:二维卷积(Full卷积、Same卷积、Valid卷积、带深度的二维卷积)

      二维Full卷积   二维Same卷积   二维Valid卷积   三种卷积类型的关系   具备深度的二维卷积   具备深度的张量与多个卷积核的卷积   参考资料 二维卷积的原理和一维卷积类似,也有full卷积、same卷积和valid卷积。 举例:3*3的二维张量x和2*2的二维张量K进行卷积 二维Full卷积 Full卷积的计算过程是:K沿着x从左到…

    2023年4月7日
    00
  • 如何用Python 实现景区安防系统

    如何用Python实现景区安防系统 介绍 随着旅游业的发展,景区越来越受到人们的欢迎。同时,景区的安全问题也备受关注。为了保障游客的人身财产安全,景区管理部门需要建立一套完善的安防系统。本文将介绍如何用Python实现景区安防系统。 前置条件 在开始实现景区安防系统之前,我们需要准备以下硬件设备: 摄像头(可采用网络摄像头或USB摄像头) 树莓派(作为中心控…

    卷积神经网络 2023年5月15日
    00
  • 卷积神经网络概念与原理

    一、卷积神经网络的基本概念          受Hubel和Wiesel对猫视觉皮层电生理研究启发,有人提出卷积神经网络(CNN),Yann Lecun 最早将CNN用于手写数字识别并一直保持了其在该问题的霸主地位。近年来卷积神经网络在多个方向持续发力,在语音识别、人脸识别、通用物体识别、运动分析、自然语言处理甚至脑电波分析方面均有突破。        卷积…

    2023年4月8日
    00
  • 卷积定理的证明

      今天终于搞明白了卷积定理的证明,以前一直拿来就用的“时域卷积等于频域点积”终于得以揭秘:   直接证明一下连续情况好了,很容易推广到离散域(我不会):   傅里叶变换的定义是:     FT(f) = integrate [-inf,+inf] f(t)*e^(-i*w*t) dt 卷积的定义是(先用@冒充一下卷积的算符qwq,学完latex一定改): …

    卷积神经网络 2023年4月6日
    00
  • 机器学习三 卷积神经网络作业

    本来这门课程http://speech.ee.ntu.edu.tw/~tlkagk/courses_ML16.html 作业是用卷积神经网络做半监督学习,这个还没完全解决,于是先从基础的开始,用keras 实现cifar10。 以下是代码 1 # -*- coding: utf-8 -*- 2 __author__ = ‘Administrator’ 3 4…

    2023年4月6日
    00
合作推广
合作推广
分享本页
返回顶部