GAN(Generative Adversarial Networks)是一种深度学习模型,用于生成与训练数据相似的新数据。在PyTorch中,我们可以使用GAN来生成图像、音频等数据。以下是使用PyTorch实现GAN的完整攻略,包括两个示例说明。
1. 实现简单的GAN
以下是使用PyTorch实现简单的GAN的步骤:
- 导入必要的库
python
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
- 定义生成器和判别器
```python
# 定义生成器
class Generator(nn.Module):
def init(self):
super(Generator, self).init()
self.fc1 = nn.Linear(100, 128)
self.fc2 = nn.Linear(128, 784)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.sigmoid(self.fc2(x))
return x
# 定义判别器
class Discriminator(nn.Module):
def init(self):
super(Discriminator, self).init()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.sigmoid(self.fc2(x))
return x
```
- 定义训练函数
```python
def train(num_epochs, batch_size, learning_rate):
# 加载数据
train_loader = torch.utils.data.DataLoader(
dataset=torchvision.datasets.MNIST(root='./data', train=True, transform=torchvision.transforms.ToTensor(), download=True),
batch_size=batch_size, shuffle=True)
# 初始化生成器和判别器
G = Generator()
D = Discriminator()
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_G = torch.optim.Adam(G.parameters(), lr=learning_rate)
optimizer_D = torch.optim.Adam(D.parameters(), lr=learning_rate)
# 训练模型
for epoch in range(num_epochs):
for i, (images, _) in enumerate(train_loader):
# 训练判别器
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# 计算判别器对真实数据的损失
outputs = D(images.view(batch_size, -1))
d_loss_real = criterion(outputs, real_labels)
# 计算判别器对生成数据的损失
z = torch.randn(batch_size, 100)
fake_images = G(z)
outputs = D(fake_images)
d_loss_fake = criterion(outputs, fake_labels)
# 计算判别器总损失
d_loss = d_loss_real + d_loss_fake
# 反向传播和优化
optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()
# 训练生成器
z = torch.randn(batch_size, 100)
fake_images = G(z)
outputs = D(fake_images)
g_loss = criterion(outputs, real_labels)
# 反向传播和优化
optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()
# 打印损失
print('Epoch [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'
.format(epoch+1, num_epochs, d_loss.item(), g_loss.item()))
# 保存模型
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')
```
- 训练模型并生成图像
```python
# 训练模型
train(num_epochs=200, batch_size=100, learning_rate=0.0002)
# 加载模型
G = Generator()
G.load_state_dict(torch.load('G.ckpt'))
# 生成图像
z = torch.randn(10, 100)
fake_images = G(z)
for i in range(10):
plt.imshow(fake_images[i].detach().numpy().reshape(28, 28), cmap='gray')
plt.show()
```
运行上述代码,即可训练GAN并生成图像。
2. 实现DCGAN
DCGAN(Deep Convolutional Generative Adversarial Networks)是一种使用卷积神经网络的GAN模型,用于生成更高质量的图像。以下是使用PyTorch实现DCGAN的步骤:
- 导入必要的库
python
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
- 定义生成器和判别器
```python
# 定义生成器
class Generator(nn.Module):
def init(self):
super(Generator, self).init()
self.fc1 = nn.Linear(100, 77256)
self.conv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)
self.bn1 = nn.BatchNorm2d(128)
self.conv2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.ConvTranspose2d(64, 1, kernel_size=3, stride=2, padding=1, output_padding=1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = x.view(-1, 256, 7, 7)
x = torch.relu(self.bn1(self.conv1(x)))
x = torch.relu(self.bn2(self.conv2(x)))
x = torch.sigmoid(self.conv3(x))
return x
# 定义判别器
class Discriminator(nn.Module):
def init(self):
super(Discriminator, self).init()
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
self.bn2 = nn.BatchNorm2d(128)
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
self.bn3 = nn.BatchNorm2d(256)
self.fc1 = nn.Linear(25644, 1)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.bn2(self.conv2(x)))
x = torch.relu(self.bn3(self.conv3(x)))
x = x.view(-1, 256*4*4)
x = torch.sigmoid(self.fc1(x))
return x
```
- 定义训练函数
```python
def train(num_epochs, batch_size, learning_rate):
# 加载数据
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(64),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 初始化生成器和判别器
G = Generator()
D = Discriminator()
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_G = torch.optim.Adam(G.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=learning_rate, betas=(0.5, 0.999))
# 训练模型
for epoch in range(num_epochs):
for i, (images, _) in enumerate(train_loader):
# 训练判别器
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# 计算判别器对真实数据的损失
outputs = D(images)
d_loss_real = criterion(outputs, real_labels)
# 计算判别器对生成数据的损失
z = torch.randn(batch_size, 100)
fake_images = G(z)
outputs = D(fake_images.detach())
d_loss_fake = criterion(outputs, fake_labels)
# 计算判别器总损失
d_loss = d_loss_real + d_loss_fake
# 反向传播和优化
optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()
# 训练生成器
z = torch.randn(batch_size, 100)
fake_images = G(z)
outputs = D(fake_images)
g_loss = criterion(outputs, real_labels)
# 反向传播和优化
optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()
# 打印损失
print('Epoch [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'
.format(epoch+1, num_epochs, d_loss.item(), g_loss.item()))
# 保存模型
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')
```
- 训练模型并生成图像
```python
# 训练模型
train(num_epochs=200, batch_size=64, learning_rate=0.0002)
# 加载模型
G = Generator()
G.load_state_dict(torch.load('G.ckpt'))
# 生成图像
z = torch.randn(10, 100)
fake_images = G(z)
for i in range(10):
plt.imshow(fake_images[i].detach().numpy().reshape(64, 64), cmap='gray')
plt.show()
```
运行上述代码,即可训练DCGAN并生成图像。
以上就是使用PyTorch实现GAN和DCGAN的完整攻略,包括两个示例说明。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch GAN生成对抗网络实例 - Python技术站