Example code of DCGAN's Generator and Discriminator that deal with image size (3, 512, 512)
:
import torch
from torch import nn
# Batch size during training
batch_size = 128
# Number of channels in the training images. For color images this is 3
nc = 3
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Base size of feature maps in generator
ngf = 16
# Base size of feature maps in discriminator
ndf = 16
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(nz, ngf * 64, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 64),
nn.ReLU(True),
# state size. (ngf*64) x 4 x 4
nn.ConvTranspose2d(ngf * 64, ngf * 32, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 32),
nn.ReLU(True),
# state size. (ngf*32) x 8 x 8
nn.ConvTranspose2d(ngf * 32, ngf * 16, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 16),
nn.ReLU(True),
# state size. (ngf*16) x 16 x 16
nn.ConvTranspose2d(ngf * 16, ngf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 32 x 32
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 64 x 64
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 128 x 128
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 256 x 256
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 512 x 512
)
def forward(self, x):
return self.main(x)
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 512 x 512
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 256 x 256
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 128 x 128
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 64 x 64
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 32 x 32
nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 16),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*16) x 16 x 16
nn.Conv2d(ndf * 16, ndf * 32, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 32),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*32) x 8 x 8
nn.Conv2d(ndf * 32, ndf * 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 64),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*64) x 4 x 4
nn.Conv2d(ndf * 64, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x)
netG = Generator(1)
netD = Discriminator(1)
inputG = torch.randn(batch_size, nz, 1, 1)
inputD = torch.randn(batch_size, nc, 512, 512)
print(netG)
print(netD)
print("
Outputs size:")
print(netG(inputG).size())
print(netD(inputD).size())
Outputs:
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU(inplace=True)
(12): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(14): ReLU(inplace=True)
(15): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(16): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(17): ReLU(inplace=True)
(18): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(19): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(20): ReLU(inplace=True)
(21): ConvTranspose2d(16, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(22): Tanh()
)
)
Discriminator(
(main): Sequential(
(0): Conv2d(3, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(13): LeakyReLU(negative_slope=0.2, inplace=True)
(14): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(15): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(16): LeakyReLU(negative_slope=0.2, inplace=True)
(17): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(18): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(19): LeakyReLU(negative_slope=0.2, inplace=True)
(20): Conv2d(1024, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(21): Sigmoid()
)
)
Outputs size:
torch.Size([128, 3, 512, 512])
torch.Size([128, 1, 1, 1])
Beware the number of output channel in each Conv2d and ConvTranspose2d above is just example, it is the hyperparameters you can turn to gain better perform