文章目录
本教程将通过一个具体的实例来讲解DCGANs。我们将训练一个生成对抗性网络(GAN),在向其展示许多真正名人的照片后,该网络能产生新的名人。此处的大部分代码都来自pytorch/examples中的dcgan实现,本文将对实现方式进行详细的讲解,并阐明该模型如何以及为什么起作用。你之前并不了解GAN也没关系,但对于新手的话可能需要花费一些时间来理解幕后的实际情况。同样,如果有一两个GPU的话,将会帮助你节省训练时间。让我们开始吧。
2. 概述
2.1. 什么是GAN(生成对抗网络)
GANs是一个深度学习模型框架,用于获取训练数据的分布,这样我们就可以从同一分布中生成新的数据。GANs是Ian Goodfellow在2014年提出的,并在论文Generative Adversarial Nets中进行了首次描述。
它们由两个不同的模型组成,分别是生成器和判别器。生成器的工作是生成看起来像训练图像的假图。判别器的任务是判别一张图像是真实的训练图像还是来自生成器的伪图像。在训练过程中,生成器通过生成越来越像真实图像的伪图来尝试骗过判别器,而判别器则是努力地想成为更好的侦探,这样才能正确地对真实和伪造的图像进行分类。
博弈的平衡点是当生成器生成的伪造图像看起来像直接来自训练数据,而判别器始终以50%的置信度推测生成器的输出是真的还是假的。
现在,让我们从判别器开始定义一些在整个教程中都会使用的符号。令
x
x
x 为图像数据,
D
(
x
)
D(x)
D(x)是判别器网络输出
x
x
x来自训练数据而不是生成器的概率。由于我们要处理图像,因此
D
(
x
)
D(x)
D(x)的输入是CHW大小为3x64x64的图像。直观地说,当
x
x
x来自训练数据时,
D
(
x
)
D(x)
D(x)的值应该高;当
x
x
x来自生成生成器时,
D
(
x
)
D(x)
D(x)的值应该低。
D
(
x
)
D(x)
D(x)其实也可以看作是传统的二分类器。
对于生成器的表示法,令
z
z
z为从标准正态分布采样的潜在空间向量。
G
(
z
)
G(z)
G(z)表示将潜在空间向量
z
z
z映射到数据空间的生成器函数。
G
G
G的目标是估计训练数据分布(
p
d
a
t
a
p_{data}
pdata),以便它可以从估计的数据分布(
p
g
p_g
pg)中生成假样本。
因此,
D
(
G
(
z
)
)
D(G(z))
D(G(z))是生成器
G
G
G的输出为真实图像的概率值(标量)。正如Goodfellow的论文所描述的,
D
D
D和
G
G
G玩一个minimax的游戏,其中
D
D
D尝试使它能正确分类真图和伪图的概率最大化(
l
o
g
D
(
x
)
logD(x)
logD(x)),而
G
G
G却尝试使
D
D
D预测其输出是伪图的概率最小化(
l
o
g
(
1
−
D
(
G
(
x
)
)
)
log(1-D(G(x)))
log(1−D(G(x))))。论文中,GAN的损失函数是:
min
G
max
D
V
(
D
,
G
)
=
E
x
∼
p
d
a
t
a
(
x
)
[
l
o
g
D
(
x
)
]
+
E
z
∼
p
z
(
z
)
[
l
o
g
(
1
−
D
(
G
(
z
)
)
)
]
underset{G}{text{min}} underset{D}{text{max}}V(D,G) = mathbb{E}_{xsim p_{data}(x)}big[logD(x)big] + mathbb{E}_{zsim p_{z}(z)}big[log(1-D(G(z)))big]
GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
从理论上讲,此minimax游戏的最终解决方案是
p
g
=
p
d
a
t
a
p_g = p_{data}
pg=pdata,并且判别器会随机猜测输入的图像是真还是假。但是GANs的收敛理论仍在积极地研究中,实际上模型也并不总是能够达到这一点。
2.2. 什么是DCGAN(深度卷积生成对抗网络)
DCGAN是上述讲的GAN的一个分支,不同的是DCGAN分别在判别器和生成器中使用卷积和反卷积层。它最初是由Radford等人在论文Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks中提出的。
判别器由卷积层、批标准化层、以及LeakyReLU激活函数组成。输入是一张3x64x64的图片,输出是该图来自真实数据分布的标量概率值。
生成器由反卷积层、批标准化层、以及ReLU激活函数组成。输入是一个来自标准正分布的潜在空间向量
z
z
z,输出是一个3x64x64的RGB彩色图片。反置卷积层将潜在空间向量转换为具有与真实图像相同的维度。论文中,作者还提供了有关如何设置优化器,如何计算损失函数,以及如何初始化模型权重的一些技巧,所有这些将在接下来的部分中进行讲解。
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
输出:
Random Seed: 999
3. 输入
- dataroot:数据集文件夹所在的路径
- workers :数据加载器加载数据的线程数
- batch_size:训练的批次大小。DCGAN论文中用的是128
-
image_size:训练图像的维度。默认是64x64。如果需要其它尺寸,必须更改
D
D
G
G
- nc:输入图像的通道数。对于彩色图像是3
- nz:潜在空间的长度
- ngf:与通过生成器进行的特征映射的深度有关
- ndf:设置通过鉴别器传播的特征映射的深度
- num_epochs:训练的总轮数。训练的轮数越多,可能会导致更好的结果,但也会花费更长的时间
- lr:学习率。DCGAN论文中用的是0.0002
- beta1:Adam优化器的参数beta1。论文中,值为0.5
- ngpus:可用的GPU数量。如果为0,代码将在CPU模式下运行;如果大于0,它将在该数量的GPU下运行
# Root directory for datasetdataroot = "data/celeba"# Number of workers for dataloaderworkers = 2# Batch size during trainingbatch_size = 128# Spatial size of training images. All images will be resized to this# size using a transformer.image_size = 64# Number of channels in the training images. For color images this is 3nc = 3# Size of z latent vector (i.e. size of generator input)nz = 100# Size of feature maps in generatorngf = 64# Size of feature maps in discriminatorndf = 64# Number of training epochsnum_epochs = 5# Learning rate for optimizerslr = 0.0002# Beta1 hyperparam for Adam optimizersbeta1 = 0.5# Number of GPUs available. Use 0 for CPU mode.ngpu = 1
4. 数据
本教程中,我们将使用Celeb-A Faces数据集,该数据集可以在链接的网站或谷歌云盘中下载。数据集下载下来是一个名为img_align_celeba.zip的压缩文件。下载后,创建一个名为celeba的目录,并将zip文件解压到该目录中。然后,将dataroot
设置为刚创建的目录。结果目录结构应该为:
/path/to/celeba
-> img_align_celeba
-> 188242.jpg
-> 173822.jpg
-> 284702.jpg
-> 537394.jpg
...
这是重要的一步,因为我们将使用ImageFolder
数据集类,该类要求数据集的根文件夹中有子目录。现在,我们可以创建数据集、数据加载器,以及设置训练的设备,最后可视化一些训练数据。
# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=workers)
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
5. 实现
设置输入参数并准备好数据集后,我们可以进入实现了。我们将从权重初始化策略开始,然后详细的讨论生成器、判别器、损失函数和训练过程。
5.1. 权重初始化
在DCGAN论文中,作者指出所有模型权重应当从均值为0,标准差为0.02的正态分布中随机初始化。weights_init
函数以初始化的模型为输入,重新初始化所有卷积层、反卷积层和批标准化层,以满足这一标准。该函数在初始化后立即应用于模型。
# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
5.2. 生成器
生成器
G
G
G用于将潜在空间向量
z
z
z映射到数据空间。由于我们的数据是图像,因此将
z
z
z转换到数据空间意味着最终创建与训练图像大小相同的RGB图像(即3x64x64)。
实际上,这是通过一系列的二维反卷积层来完成的,每层都配带有批标准化层和relu激活。生成器的输出最终经过tanh函数处理,以使其返回到[-1, 1]的输入数据范围。
值得注意的是,在反卷积层之后存在批标准化函数,这是DCGAN论文中的关键贡献。这些层有助于训练过程中的梯度流动,DCGAN论文中生成器的一张图片如下。
注意,我们在输入部分中设置的输入(nz
,ngf
和nc
)如何影响代码中的生成器体系结构。 nz
是输入向量
z
z
z的长度,ngf
与通过生成器传播的特征图的大小有关,nc
是输出图像的通道数(对于RGB图像来说是3)。 下面是生成器的代码。
# Generator Code
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
现在,我们可以实例化生成器并应用weights_init
函数。检查打印的模型以查看生成器对象的结构。
# Create the generator
netG = Generator(ngpu).to(device)
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
netG = nn.DataParallel(netG, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
netG.apply(weights_init)
# Print the model
print(netG)
输出:
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU(inplace=True)
(12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(13): Tanh()
)
)
5.3. 判别器
如前所述,判别器
D
D
D是一个二分类网络,该网络将图像作为输入,并输出该图是真(与假相对)的标量概率。
这里,
D
D
D以3x64x64的图像作为输入,通过一系列的Conv2d
,BatchNorm2d
和LeakyReLU
层的处理,然后通过Sigmoid
激活函数输出最终概率。对于这个问题,如果需要的话,这个体系结构可以扩展更多的层,但是使用strided convolution
,BatchNorm
和LeakyReLUs
具有重要意义。DCGAN论文提到,使用strided convolution
而不是通过池化来进行下采样是个好方法,因为它可以让网络学习自己的池化函数。 batch norm
和leaky relu
函数还可以促进健康的梯度流动,这对于
G
G
G和
D
D
D的学习过程都至关重要。
判别器代码
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
现在和生成器一样,我们可以创建判别器,应用weights_init
函数,并打印模型结构。
# Create the Discriminator
netD = Discriminator(ngpu).to(device)
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
netD = nn.DataParallel(netD, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
netD.apply(weights_init)
# Print the model
print(netD)
输出:
Discriminator(
(main): Sequential(
(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(12): Sigmoid()
)
)
5.4. 损失函数和优化器
D
D
D和
G
G
G设置之后,我们可以指定它们如何通过损失函数和优化器学习。我们将使用在PyTorch中定义的二元交叉熵损失(BCELoss)函数:
ℓ
(
x
,
y
)
=
L
=
{
l
1
,
…
,
l
N
}
⊤
,
l
n
=
−
[
y
n
⋅
log
x
n
+
(
1
−
y
n
)
⋅
log
(
1
−
x
n
)
]
ell(x, y) = L = {l_1,dots,l_N}^top, quad l_n = - left[ y_n cdot log x_n + (1 - y_n) cdot log (1 - x_n) right]
ℓ(x,y)=L={l1,…,lN}⊤,ln=−[yn⋅logxn+(1−yn)⋅log(1−xn)]
注意此函数如何提供目标函数中两个对数成分的计算(即
l
o
g
(
D
(
x
)
)
log(D(x))
log(D(x))和
l
o
g
(
1
−
D
(
G
(
z
)
)
)
log(1-D(G(z)))
log(1−D(G(z))))。 我们可以指定BCE方程的哪一部分用于
y
y
y输入。 这是在即将到来的训练循环中完成的,但重要的是要了解如何仅通过更改
y
y
y(即GT标签)就可以选择想要计算的组件。
接下来,我们将实际标签定义为1,将假标签定义为0。这些标签将在计算
D
D
D和
G
G
G的损失时使用,这是在原始GAN论文中使用的惯例。
最后,我们设置了两个单独的优化器,一个针对
D
D
D,一个针对
G
G
G。正如DCGAN论文中所规定的,这两个都是lr
为0.0002且Beta1
为0.5的Adam
优化器。为了跟踪生成器的学习过程,我们将生成一批来自高斯分布的固定潜在空间向量(即fixed_noise
)。在训练循环中,我们将定期地把fixed_noise
输入到
G
G
G中,经过多次迭代,我们将看到图像从噪声中形成。
# Initialize BCELoss function
criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
5.5. 训练
最后,既然定义了GAN框架的所有部分,我们就可以对其进行训练了。要注意,训练GAN网络在某种程度上来说是一种艺术形式,因为不正确的超参数设置会导致模式崩溃,而对失败的原因几乎不可解释。
在这里,我们将严格遵守Goodfellow论文中的算法1,同时遵守ganhacks中展示的一些最佳做法。也即是说,我们将为真图和假图构造不同的mini-batches
,并调整
G
G
G的目标函数,使
l
o
g
D
(
G
(
z
)
)
logD(G(z))
logD(G(z))最大化。训练分为两个主要部分,第一部分是判别器的更新,第二部分是生成器的更新。
5.5.1. 第一部分 - 训练判别器
回想一下,训练判别器的目的是最大限度地提高将给定输入正确分类为真实或伪造的可能性。就像Goodfellow在论文中所说的,我们希望“通过提升其随机梯度来更新鉴别器”。
实际上,我们想最大化
l
o
g
(
D
(
x
)
)
+
l
o
g
(
1
−
D
(
G
(
z
)
)
)
log(D(x))+log(1−D(G(z)))
log(D(x))+log(1−D(G(z)))。由于ganhacks提出了单独的mini-batch建议,因此我们将分两步进行计算。首先,我们将从训练集中构造一批真实样本,向前传播给
D
D
D,计算损失(
l
o
g
(
D
(
x
)
)
log(D(x))
log(D(x))),然后向后传播计算梯度。接着,我们将用当前的生成器构造一批假样本,将该批样本向前传播给
D
D
D,计算损失(
l
o
g
(
1
−
D
(
G
(
z
)
)
)
log(1−D(G(z)))
log(1−D(G(z)))),并向后传播累加梯度。现在,随着从所有真实批次和所有假批次累积的梯度,我们称之为判别器的优化器的一个步骤。
5.5.2. 第二部分 - 训练生成器
如原论文所述,我们希望通过最小化
l
o
g
(
1
−
D
(
G
(
z
)
)
)
log(1−D(G(z)))
log(1−D(G(z)))来训练生成器,以产生更好的伪造品。但又如前所述,Goodfellow表明,这不能提供足够的梯度,特别是在学习过程的早期。而解决方案是改为最大化
l
o
g
(
D
(
G
(
z
)
)
)
log(D(G(z)))
log(D(G(z)))。
在代码中,我们的具体实现方法是:用判别器对第一部分生成器的输出进行分类,使用真图的标签作为GT计算
G
G
G的损失,计算
G
G
G在反向传播中的梯度,最后通过优化器step
更新
G
G
G的参数。使用真图的标签作为GT来计算损失似乎是违反常识的,但这允许我们使用BCELoss的
l
o
g
(
x
)
log(x)
log(x)部分(而不是
l
o
g
(
1
−
x
)
log(1−x)
log(1−x)部分),这正是我们想要的。
最后,我们将做一些统计报告,在每个epoch
结束时,我们将通过生成器推动我们的fixed_noise batch
,以直观地跟踪
G
G
G的训练过程。 上报的训练统计数据为:
-
Loss_D - 判别器损失,计算为所有真实批次和所有假批次的损失之和 (
l
o
g
(
D
(
x
)
)
+
l
o
g
(
D
(
G
(
z
)
)
)
log(D(x))+log(D(G(z)))
- Loss_G - 生成器损失,计算为log(D(G(z)))。
- D(x) - 判别器对于真实批次的平均输出(整个批次)。刚开始训练的时候这个值应该接近1,当
G
G
- D(G(z)) - 判别器对于假批次的平均输出。第一个数字在
D
D
D
D
G
G
注意:此步骤可能需要一段时间。具体取决于你运行了多少个epoch
以及是否从数据集中删除了一些数据。
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]tLoss_D: %.4ftLoss_G: %.4ftD(x): %.4ftD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
输出:
Starting Training Loop...
[0/5][0/1583] Loss_D: 1.9847 Loss_G: 5.5914 D(x): 0.6004 D(G(z)): 0.6680 / 0.0062
[0/5][50/1583] Loss_D: 0.4017 Loss_G: 17.8778 D(x): 0.8368 D(G(z)): 0.0000 / 0.0000
[0/5][100/1583] Loss_D: 2.8508 Loss_G: 22.8236 D(x): 0.9634 D(G(z)): 0.8460 / 0.0000
[0/5][150/1583] Loss_D: 0.2360 Loss_G: 5.4596 D(x): 0.8440 D(G(z)): 0.0308 / 0.0090
[0/5][200/1583] Loss_D: 1.6425 Loss_G: 4.7064 D(x): 0.3414 D(G(z)): 0.0079 / 0.0176
[0/5][250/1583] Loss_D: 0.2731 Loss_G: 4.4791 D(x): 0.9431 D(G(z)): 0.1680 / 0.0225
[0/5][300/1583] Loss_D: 0.6051 Loss_G: 4.6251 D(x): 0.8278 D(G(z)): 0.2424 / 0.0230
[0/5][350/1583] Loss_D: 0.7070 Loss_G: 1.6842 D(x): 0.6204 D(G(z)): 0.0824 / 0.2560
[0/5][400/1583] Loss_D: 0.6758 Loss_G: 4.0679 D(x): 0.9354 D(G(z)): 0.3946 / 0.0288
[0/5][450/1583] Loss_D: 0.5348 Loss_G: 5.7453 D(x): 0.9625 D(G(z)): 0.3514 / 0.0083
[0/5][500/1583] Loss_D: 0.6896 Loss_G: 7.8784 D(x): 0.9364 D(G(z)): 0.4080 / 0.0012
[0/5][550/1583] Loss_D: 0.4377 Loss_G: 8.1336 D(x): 0.9425 D(G(z)): 0.2840 / 0.0007
[0/5][600/1583] Loss_D: 1.8797 Loss_G: 2.5577 D(x): 0.3201 D(G(z)): 0.0123 / 0.1258
[0/5][650/1583] Loss_D: 1.3832 Loss_G: 10.6947 D(x): 0.9770 D(G(z)): 0.7006 / 0.0001
[0/5][700/1583] Loss_D: 0.3195 Loss_G: 3.7833 D(x): 0.8474 D(G(z)): 0.0844 / 0.0789
[0/5][750/1583] Loss_D: 0.2142 Loss_G: 4.1755 D(x): 0.8942 D(G(z)): 0.0813 / 0.0232
[0/5][800/1583] Loss_D: 1.4535 Loss_G: 2.3077 D(x): 0.4024 D(G(z)): 0.0111 / 0.1806
[0/5][850/1583] Loss_D: 0.4109 Loss_G: 6.3312 D(x): 0.9002 D(G(z)): 0.2153 / 0.0048
[0/5][900/1583] Loss_D: 2.7930 Loss_G: 4.5548 D(x): 0.1428 D(G(z)): 0.0022 / 0.0240
[0/5][950/1583] Loss_D: 0.3493 Loss_G: 5.5976 D(x): 0.8767 D(G(z)): 0.1498 / 0.0080
[0/5][1000/1583] Loss_D: 0.6749 Loss_G: 5.0457 D(x): 0.6349 D(G(z)): 0.0215 / 0.0194
[0/5][1050/1583] Loss_D: 0.4009 Loss_G: 4.5791 D(x): 0.7669 D(G(z)): 0.0484 / 0.0260
[0/5][1100/1583] Loss_D: 0.3453 Loss_G: 2.7277 D(x): 0.8885 D(G(z)): 0.1408 / 0.1219
[0/5][1150/1583] Loss_D: 0.2484 Loss_G: 5.0396 D(x): 0.8727 D(G(z)): 0.0595 / 0.0174
[0/5][1200/1583] Loss_D: 0.6760 Loss_G: 3.2315 D(x): 0.7052 D(G(z)): 0.1756 / 0.0688
[0/5][1250/1583] Loss_D: 0.5845 Loss_G: 3.1392 D(x): 0.7576 D(G(z)): 0.2018 / 0.0673
[0/5][1300/1583] Loss_D: 0.2762 Loss_G: 4.9311 D(x): 0.8666 D(G(z)): 0.0933 / 0.0136
[0/5][1350/1583] Loss_D: 0.4753 Loss_G: 4.7346 D(x): 0.8595 D(G(z)): 0.2228 / 0.0170
[0/5][1400/1583] Loss_D: 0.3764 Loss_G: 5.9964 D(x): 0.7758 D(G(z)): 0.0109 / 0.0098
[0/5][1450/1583] Loss_D: 0.4025 Loss_G: 3.8804 D(x): 0.8158 D(G(z)): 0.1413 / 0.0320
[0/5][1500/1583] Loss_D: 0.6678 Loss_G: 2.7302 D(x): 0.6980 D(G(z)): 0.1486 / 0.1040
[0/5][1550/1583] Loss_D: 0.6062 Loss_G: 3.1664 D(x): 0.7235 D(G(z)): 0.1305 / 0.0783
[1/5][0/1583] Loss_D: 0.6615 Loss_G: 8.0512 D(x): 0.9412 D(G(z)): 0.3797 / 0.0007
[1/5][50/1583] Loss_D: 0.8057 Loss_G: 2.1089 D(x): 0.5929 D(G(z)): 0.0869 / 0.1893
[1/5][100/1583] Loss_D: 0.4206 Loss_G: 3.3245 D(x): 0.7409 D(G(z)): 0.0554 / 0.0640
[1/5][150/1583] Loss_D: 0.6361 Loss_G: 4.0774 D(x): 0.7830 D(G(z)): 0.2605 / 0.0256
[1/5][200/1583] Loss_D: 1.7394 Loss_G: 7.5861 D(x): 0.9685 D(G(z)): 0.7499 / 0.0014
[1/5][250/1583] Loss_D: 0.4597 Loss_G: 3.1064 D(x): 0.7053 D(G(z)): 0.0265 / 0.0844
[1/5][300/1583] Loss_D: 0.4190 Loss_G: 2.2869 D(x): 0.7942 D(G(z)): 0.1163 / 0.1660
[1/5][350/1583] Loss_D: 0.4724 Loss_G: 4.3673 D(x): 0.8292 D(G(z)): 0.2106 / 0.0213
[1/5][400/1583] Loss_D: 0.2877 Loss_G: 4.3217 D(x): 0.8823 D(G(z)): 0.1125 / 0.0225
[1/5][450/1583] Loss_D: 0.8508 Loss_G: 0.8635 D(x): 0.5397 D(G(z)): 0.0390 / 0.5324
[1/5][500/1583] Loss_D: 0.4317 Loss_G: 3.1585 D(x): 0.7646 D(G(z)): 0.0931 / 0.0767
[1/5][550/1583] Loss_D: 0.8256 Loss_G: 6.1484 D(x): 0.9395 D(G(z)): 0.4563 / 0.0051
[1/5][600/1583] Loss_D: 0.9765 Loss_G: 1.5017 D(x): 0.4807 D(G(z)): 0.0076 / 0.2843
[1/5][650/1583] Loss_D: 1.8020 Loss_G: 8.8270 D(x): 0.9480 D(G(z)): 0.7248 / 0.0003
[1/5][700/1583] Loss_D: 0.3680 Loss_G: 3.7401 D(x): 0.7991 D(G(z)): 0.0949 / 0.0404
[1/5][750/1583] Loss_D: 0.5763 Loss_G: 2.0559 D(x): 0.6739 D(G(z)): 0.0851 / 0.1882
[1/5][800/1583] Loss_D: 0.7773 Loss_G: 5.0999 D(x): 0.9399 D(G(z)): 0.4335 / 0.0142
[1/5][850/1583] Loss_D: 0.3901 Loss_G: 3.4356 D(x): 0.8537 D(G(z)): 0.1744 / 0.0491
[1/5][900/1583] Loss_D: 0.7268 Loss_G: 6.5356 D(x): 0.9635 D(G(z)): 0.4428 / 0.0027
[1/5][950/1583] Loss_D: 0.4570 Loss_G: 3.8893 D(x): 0.8707 D(G(z)): 0.2376 / 0.0304
[1/5][1000/1583] Loss_D: 1.3551 Loss_G: 7.2447 D(x): 0.9333 D(G(z)): 0.6422 / 0.0030
[1/5][1050/1583] Loss_D: 0.3905 Loss_G: 3.3360 D(x): 0.8183 D(G(z)): 0.1462 / 0.0537
[1/5][1100/1583] Loss_D: 1.3858 Loss_G: 0.9796 D(x): 0.3336 D(G(z)): 0.0259 / 0.4584
[1/5][1150/1583] Loss_D: 0.5776 Loss_G: 2.6197 D(x): 0.6443 D(G(z)): 0.0532 / 0.1051
[1/5][1200/1583] Loss_D: 0.5647 Loss_G: 3.5713 D(x): 0.8026 D(G(z)): 0.2450 / 0.0428
[1/5][1250/1583] Loss_D: 0.4568 Loss_G: 3.6666 D(x): 0.8934 D(G(z)): 0.2581 / 0.0403
[1/5][1300/1583] Loss_D: 0.7197 Loss_G: 1.8175 D(x): 0.6211 D(G(z)): 0.1035 / 0.2184
[1/5][1350/1583] Loss_D: 0.5255 Loss_G: 3.2736 D(x): 0.8141 D(G(z)): 0.2233 / 0.0574
[1/5][1400/1583] Loss_D: 0.8241 Loss_G: 3.0776 D(x): 0.7807 D(G(z)): 0.3659 / 0.0743
[1/5][1450/1583] Loss_D: 0.4302 Loss_G: 3.3777 D(x): 0.9058 D(G(z)): 0.2518 / 0.0519
[1/5][1500/1583] Loss_D: 0.4173 Loss_G: 2.5610 D(x): 0.7916 D(G(z)): 0.1358 / 0.1058
[1/5][1550/1583] Loss_D: 0.7993 Loss_G: 5.1228 D(x): 0.8527 D(G(z)): 0.4162 / 0.0104
[2/5][0/1583] Loss_D: 0.4844 Loss_G: 2.2263 D(x): 0.7645 D(G(z)): 0.1510 / 0.1426
[2/5][50/1583] Loss_D: 0.6756 Loss_G: 2.4608 D(x): 0.5915 D(G(z)): 0.0657 / 0.1248
[2/5][100/1583] Loss_D: 0.4391 Loss_G: 3.0181 D(x): 0.7901 D(G(z)): 0.1486 / 0.0744
[2/5][150/1583] Loss_D: 0.5683 Loss_G: 1.8918 D(x): 0.7083 D(G(z)): 0.1411 / 0.1858
[2/5][200/1583] Loss_D: 0.5932 Loss_G: 3.3342 D(x): 0.9111 D(G(z)): 0.3576 / 0.0522
[2/5][250/1583] Loss_D: 0.7331 Loss_G: 2.3817 D(x): 0.6635 D(G(z)): 0.1665 / 0.1397
[2/5][300/1583] Loss_D: 0.5493 Loss_G: 2.3824 D(x): 0.7491 D(G(z)): 0.1742 / 0.1196
[2/5][350/1583] Loss_D: 0.6197 Loss_G: 1.8560 D(x): 0.6443 D(G(z)): 0.1018 / 0.1972
[2/5][400/1583] Loss_D: 0.6172 Loss_G: 3.0777 D(x): 0.8482 D(G(z)): 0.3251 / 0.0621
[2/5][450/1583] Loss_D: 0.5047 Loss_G: 3.2941 D(x): 0.9174 D(G(z)): 0.3116 / 0.0566
[2/5][500/1583] Loss_D: 0.7335 Loss_G: 1.2796 D(x): 0.5676 D(G(z)): 0.0575 / 0.3470
[2/5][550/1583] Loss_D: 0.7716 Loss_G: 1.9450 D(x): 0.5513 D(G(z)): 0.0580 / 0.1922
[2/5][600/1583] Loss_D: 0.4425 Loss_G: 2.0531 D(x): 0.8015 D(G(z)): 0.1640 / 0.1686
[2/5][650/1583] Loss_D: 1.0964 Loss_G: 4.4602 D(x): 0.9096 D(G(z)): 0.5833 / 0.0163
[2/5][700/1583] Loss_D: 0.4745 Loss_G: 2.8636 D(x): 0.8492 D(G(z)): 0.2403 / 0.0770
[2/5][750/1583] Loss_D: 0.4947 Loss_G: 3.6931 D(x): 0.8803 D(G(z)): 0.2732 / 0.0364
[2/5][800/1583] Loss_D: 0.9355 Loss_G: 4.3906 D(x): 0.9120 D(G(z)): 0.5168 / 0.0195
[2/5][850/1583] Loss_D: 0.9213 Loss_G: 1.6006 D(x): 0.4645 D(G(z)): 0.0339 / 0.2467
[2/5][900/1583] Loss_D: 0.5337 Loss_G: 3.7601 D(x): 0.9101 D(G(z)): 0.3310 / 0.0314
[2/5][950/1583] Loss_D: 1.2562 Loss_G: 4.9530 D(x): 0.9432 D(G(z)): 0.6244 / 0.0144
[2/5][1000/1583] Loss_D: 0.4187 Loss_G: 2.4701 D(x): 0.8454 D(G(z)): 0.1945 / 0.1129
[2/5][1050/1583] Loss_D: 0.5796 Loss_G: 2.3732 D(x): 0.7714 D(G(z)): 0.2253 / 0.1216
[2/5][1100/1583] Loss_D: 0.6325 Loss_G: 2.5824 D(x): 0.8307 D(G(z)): 0.3235 / 0.0939
[2/5][1150/1583] Loss_D: 0.7639 Loss_G: 3.9487 D(x): 0.9031 D(G(z)): 0.4398 / 0.0291
[2/5][1200/1583] Loss_D: 0.7040 Loss_G: 3.3561 D(x): 0.8073 D(G(z)): 0.3403 / 0.0500
[2/5][1250/1583] Loss_D: 1.0567 Loss_G: 4.7122 D(x): 0.9292 D(G(z)): 0.5656 / 0.0155
[2/5][1300/1583] Loss_D: 0.5431 Loss_G: 2.4260 D(x): 0.7628 D(G(z)): 0.2028 / 0.1116
[2/5][1350/1583] Loss_D: 0.7633 Loss_G: 4.1670 D(x): 0.9257 D(G(z)): 0.4404 / 0.0237
[2/5][1400/1583] Loss_D: 2.1958 Loss_G: 0.5288 D(x): 0.1539 D(G(z)): 0.0147 / 0.6404
[2/5][1450/1583] Loss_D: 0.6991 Loss_G: 1.8573 D(x): 0.5818 D(G(z)): 0.0621 / 0.1980
[2/5][1500/1583] Loss_D: 0.8286 Loss_G: 3.6899 D(x): 0.8805 D(G(z)): 0.4440 / 0.0364
[2/5][1550/1583] Loss_D: 0.5100 Loss_G: 2.5931 D(x): 0.7721 D(G(z)): 0.1862 / 0.0989
[3/5][0/1583] Loss_D: 0.7136 Loss_G: 2.6315 D(x): 0.8178 D(G(z)): 0.3462 / 0.1034
[3/5][50/1583] Loss_D: 0.6472 Loss_G: 2.6359 D(x): 0.7572 D(G(z)): 0.2460 / 0.0962
[3/5][100/1583] Loss_D: 0.5211 Loss_G: 1.7793 D(x): 0.7275 D(G(z)): 0.1402 / 0.2050
[3/5][150/1583] Loss_D: 0.9620 Loss_G: 4.0717 D(x): 0.9423 D(G(z)): 0.5500 / 0.0243
[3/5][200/1583] Loss_D: 0.5469 Loss_G: 2.1994 D(x): 0.7581 D(G(z)): 0.1972 / 0.1359
[3/5][250/1583] Loss_D: 0.3941 Loss_G: 2.7071 D(x): 0.7281 D(G(z)): 0.0401 / 0.0902
[3/5][300/1583] Loss_D: 0.6482 Loss_G: 1.4858 D(x): 0.6275 D(G(z)): 0.1085 / 0.2802
[3/5][350/1583] Loss_D: 1.2781 Loss_G: 4.7393 D(x): 0.9594 D(G(z)): 0.6587 / 0.0120
[3/5][400/1583] Loss_D: 0.5942 Loss_G: 2.8406 D(x): 0.7861 D(G(z)): 0.2579 / 0.0784
[3/5][450/1583] Loss_D: 0.5395 Loss_G: 1.9849 D(x): 0.6755 D(G(z)): 0.0854 / 0.1764
[3/5][500/1583] Loss_D: 0.7941 Loss_G: 2.5871 D(x): 0.7891 D(G(z)): 0.3784 / 0.1006
[3/5][550/1583] Loss_D: 0.6556 Loss_G: 3.9228 D(x): 0.9328 D(G(z)): 0.4053 / 0.0254
[3/5][600/1583] Loss_D: 0.6489 Loss_G: 3.2773 D(x): 0.8385 D(G(z)): 0.3419 / 0.0490
[3/5][650/1583] Loss_D: 0.9217 Loss_G: 1.3858 D(x): 0.4992 D(G(z)): 0.0854 / 0.3095
[3/5][700/1583] Loss_D: 0.4947 Loss_G: 2.2791 D(x): 0.7948 D(G(z)): 0.2035 / 0.1332
[3/5][750/1583] Loss_D: 0.9676 Loss_G: 1.6087 D(x): 0.4641 D(G(z)): 0.0363 / 0.2599
[3/5][800/1583] Loss_D: 0.5918 Loss_G: 1.8852 D(x): 0.7019 D(G(z)): 0.1637 / 0.1948
[3/5][850/1583] Loss_D: 0.7856 Loss_G: 3.4243 D(x): 0.8672 D(G(z)): 0.4219 / 0.0512
[3/5][900/1583] Loss_D: 0.5023 Loss_G: 2.7348 D(x): 0.8372 D(G(z)): 0.2416 / 0.0851
[3/5][950/1583] Loss_D: 0.9028 Loss_G: 1.8348 D(x): 0.5362 D(G(z)): 0.1219 / 0.2110
[3/5][1000/1583] Loss_D: 0.8118 Loss_G: 3.9327 D(x): 0.9092 D(G(z)): 0.4586 / 0.0306
[3/5][1050/1583] Loss_D: 0.8709 Loss_G: 3.1103 D(x): 0.8752 D(G(z)): 0.4686 / 0.0639
[3/5][1100/1583] Loss_D: 0.4286 Loss_G: 2.9141 D(x): 0.8379 D(G(z)): 0.1912 / 0.0741
[3/5][1150/1583] Loss_D: 0.6005 Loss_G: 1.8091 D(x): 0.7044 D(G(z)): 0.1727 / 0.2042
[3/5][1200/1583] Loss_D: 0.7432 Loss_G: 3.8108 D(x): 0.9088 D(G(z)): 0.4344 / 0.0297
[3/5][1250/1583] Loss_D: 0.6872 Loss_G: 1.8717 D(x): 0.7355 D(G(z)): 0.2731 / 0.1789
[3/5][1300/1583] Loss_D: 0.5740 Loss_G: 3.4426 D(x): 0.8874 D(G(z)): 0.3380 / 0.0422
[3/5][1350/1583] Loss_D: 0.5689 Loss_G: 2.0738 D(x): 0.6823 D(G(z)): 0.0966 / 0.1621
[3/5][1400/1583] Loss_D: 0.5023 Loss_G: 3.1107 D(x): 0.9225 D(G(z)): 0.3231 / 0.0565
[3/5][1450/1583] Loss_D: 0.7466 Loss_G: 3.1208 D(x): 0.8441 D(G(z)): 0.3891 / 0.0634
[3/5][1500/1583] Loss_D: 0.7135 Loss_G: 2.8145 D(x): 0.8924 D(G(z)): 0.4117 / 0.0765
[3/5][1550/1583] Loss_D: 0.7881 Loss_G: 4.0945 D(x): 0.9332 D(G(z)): 0.4717 / 0.0258
[4/5][0/1583] Loss_D: 0.6309 Loss_G: 2.2672 D(x): 0.7764 D(G(z)): 0.2761 / 0.1311
[4/5][50/1583] Loss_D: 0.8068 Loss_G: 1.4844 D(x): 0.5595 D(G(z)): 0.1015 / 0.2795
[4/5][100/1583] Loss_D: 0.4912 Loss_G: 2.0030 D(x): 0.7526 D(G(z)): 0.1516 / 0.1674
[4/5][150/1583] Loss_D: 3.0392 Loss_G: 0.6172 D(x): 0.0896 D(G(z)): 0.0134 / 0.6503
[4/5][200/1583] Loss_D: 0.6768 Loss_G: 2.5170 D(x): 0.7543 D(G(z)): 0.2852 / 0.0986
[4/5][250/1583] Loss_D: 1.2451 Loss_G: 0.9252 D(x): 0.3817 D(G(z)): 0.0554 / 0.4569
[4/5][300/1583] Loss_D: 0.5916 Loss_G: 1.7704 D(x): 0.6588 D(G(z)): 0.1113 / 0.2144
[4/5][350/1583] Loss_D: 1.3058 Loss_G: 0.6935 D(x): 0.3416 D(G(z)): 0.0394 / 0.5486
[4/5][400/1583] Loss_D: 0.6206 Loss_G: 3.0787 D(x): 0.8405 D(G(z)): 0.3261 / 0.0609
[4/5][450/1583] Loss_D: 0.5866 Loss_G: 1.4752 D(x): 0.6981 D(G(z)): 0.1565 / 0.2718
[4/5][500/1583] Loss_D: 0.5616 Loss_G: 3.0459 D(x): 0.8869 D(G(z)): 0.3223 / 0.0650
[4/5][550/1583] Loss_D: 0.6073 Loss_G: 3.2580 D(x): 0.7503 D(G(z)): 0.2344 / 0.0500
[4/5][600/1583] Loss_D: 0.6905 Loss_G: 3.0939 D(x): 0.8591 D(G(z)): 0.3762 / 0.0589
[4/5][650/1583] Loss_D: 0.5836 Loss_G: 1.7048 D(x): 0.6781 D(G(z)): 0.1227 / 0.2282
[4/5][700/1583] Loss_D: 0.8543 Loss_G: 3.7586 D(x): 0.8876 D(G(z)): 0.4712 / 0.0337
[4/5][750/1583] Loss_D: 0.8484 Loss_G: 2.3787 D(x): 0.6606 D(G(z)): 0.2724 / 0.1192
[4/5][800/1583] Loss_D: 0.5562 Loss_G: 2.1677 D(x): 0.7446 D(G(z)): 0.1887 / 0.1533
[4/5][850/1583] Loss_D: 0.7600 Loss_G: 1.4960 D(x): 0.5447 D(G(z)): 0.0559 / 0.2722
[4/5][900/1583] Loss_D: 0.5677 Loss_G: 3.0179 D(x): 0.8308 D(G(z)): 0.2804 / 0.0664
[4/5][950/1583] Loss_D: 0.5381 Loss_G: 2.9582 D(x): 0.7989 D(G(z)): 0.2345 / 0.0711
[4/5][1000/1583] Loss_D: 0.8333 Loss_G: 2.8499 D(x): 0.7720 D(G(z)): 0.3700 / 0.0786
[4/5][1050/1583] Loss_D: 0.5125 Loss_G: 1.8930 D(x): 0.7287 D(G(z)): 0.1387 / 0.1848
[4/5][1100/1583] Loss_D: 0.4527 Loss_G: 3.0039 D(x): 0.8639 D(G(z)): 0.2413 / 0.0614
[4/5][1150/1583] Loss_D: 0.7072 Loss_G: 0.8361 D(x): 0.5589 D(G(z)): 0.0563 / 0.4846
[4/5][1200/1583] Loss_D: 0.8619 Loss_G: 4.9323 D(x): 0.9385 D(G(z)): 0.4880 / 0.0112
[4/5][1250/1583] Loss_D: 0.6864 Loss_G: 2.4925 D(x): 0.7232 D(G(z)): 0.2431 / 0.1152
[4/5][1300/1583] Loss_D: 0.5835 Loss_G: 3.1599 D(x): 0.8430 D(G(z)): 0.3018 / 0.0644
[4/5][1350/1583] Loss_D: 0.9119 Loss_G: 4.7225 D(x): 0.9409 D(G(z)): 0.5082 / 0.0154
[4/5][1400/1583] Loss_D: 0.3856 Loss_G: 3.1007 D(x): 0.8980 D(G(z)): 0.2238 / 0.0584
[4/5][1450/1583] Loss_D: 1.3314 Loss_G: 5.1061 D(x): 0.9395 D(G(z)): 0.6621 / 0.0094
[4/5][1500/1583] Loss_D: 0.5882 Loss_G: 1.7242 D(x): 0.6443 D(G(z)): 0.0785 / 0.2306
[4/5][1550/1583] Loss_D: 0.5792 Loss_G: 2.0347 D(x): 0.7582 D(G(z)): 0.2143 / 0.1594
6. 结果
最后,让我们看看我们是如何做到的。在这里,我们将看到三个不同的结果。首先,我们将看到
D
D
D和
G
G
G的损失在训练过程中是如何变化的。然后,我们将可视化
G
G
G在每个epoch
的fixed_noise batch
上的输出。最后,我们将对比一批真实数据和一批来自
G
G
G的假数据。
6.1. 损失随迭代次数的变化趋势图
以下是
D
D
D&
G
G
G的损失与迭代次数的关系图。
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
6.2. 可视化G的训练过程
还记得我们是如何在每个训练的epoch
后保存生成器的输出吗?现在,我们可以用动画来可视化
G
G
G的训练过程。
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())
6.3. 真图 vs 假图
最后,让我们并排对比查看一些真实图像和虚假图像。
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))
# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))
# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()
7. 展望
本教程到这里已经结束了,但是如果你想深入地研究和了解GAN,你可以:
脚本总运行: ( 28 minutes 38.953 seconds)
8. 原文
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:【pytorch】DCGAN实战教程(官方教程) - Python技术站