0%

Image-to-Image Translation with Conditional Adversarial Networks

Image-to-Image Translation with Conditional Adversarial Networks

  • Category: Article
  • Created: February 12, 2022 2:39 PM
  • Status: Open
  • URL: https://arxiv.org/pdf/1611.07004.pdf
  • Updated: February 15, 2022 5:15 PM

Highlights

  1. We investigate conditional adversarial networks as a general-purpose solution to image-to-image translation problems.
  2. As a community, we no longer hand-engineer our mapping functions, and this work suggests we can achieve reasonable results without hand-engineering our loss functions either.

Intuition

  1. If we take a naive approach and ask the CNN to minimize the Euclidean distance between predicted and ground truth pix- els, it will tend to produce blurry results.
  2. It would be highly desirable if we could instead specify only a high-level goal, like “make the output indistinguishable from reality”, and then automatically learn a loss function appropriate for satisfying this goal.

Methods

Our generator we use a U-Net-based architecture, and for our discriminator we use a convolutional PatchGAN classifier, which only penalizes structure at the scale of image patches.

Loss function

The objective of a conditional GAN can be expressed as

\(\begin{aligned}\mathcal{L}_{c G A N}(G, D)=& \mathbb{E}_{x, y}[\log D(x, y)]+\\& \mathbb{E}_{x, z}[\log (1-D(x, G(x, z))]\end{aligned}\)

Previous approaches have found it beneficial to mix the GAN objective with a more traditional loss, such as L2 distance. The discriminator’s job remains unchanged, but the generator is tasked to not only fool the discriminator but also to be near the ground truth output in an L2 sense. We also explore this option, using L1 distance rather than L2 as L1 encourages less blurring:

\[ \mathcal{L}_{L 1}(G)=\mathbb{E}_{x, y, z}\left[\|y-G(x, z)\|_{1}\right] \]

\[ G^{*}=\arg \min _{G} \max _{D} \mathcal{L}_{c G A N}(G, D)+\lambda \mathcal{L}_{L 1}(G) \]

Generator

To give the generator a means to circumvent the bottle- neck for information like this, we add skip connections, following the general shape of a U-Net.

Screen Shot 2022-02-12 at 15.29.06.png

Discriminator (PatchGAN)

This motivates restricting the GAN discriminator to only model high-frequency structure, relying on an L1 term to force low-frequency correctness.

In order to model high-frequencies, it is sufficient to restrict our attention to the structure in local image patches. Therefore, we design a discriminator architecture – which we term a PatchGAN – that only penalizes structure at the scale of patches.

This discriminator tries to classify if each N × N patch in an image is real or fake. We run this discriminator convolution- ally across the image, averaging all responses to provide the ultimate output of D.

Screen Shot 2022-02-12 at 15.38.31.png

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Discriminator(nn.Module):
'''
Discriminator Class
Structured like the contracting path of the U-Net, the discriminator will
output a matrix of values classifying corresponding portions of the image as real or fake.
Parameters:
input_channels: the number of image input channels
hidden_channels: the initial number of discriminator convolutional filters
'''
def __init__(self, input_channels, hidden_channels=8):
super(Discriminator, self).__init__()
self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
self.contract1 = ContractingBlock(hidden_channels, use_bn=False)
self.contract2 = ContractingBlock(hidden_channels * 2)
self.contract3 = ContractingBlock(hidden_channels * 4)
self.contract4 = ContractingBlock(hidden_channels * 8)
#### START CODE HERE ####
self.final = nn.Conv2d(hidden_channels * 16, 1, kernel_size=1)
#### END CODE HERE ####

def forward(self, x, y):
x = torch.cat([x, y], axis=1)
x0 = self.upfeature(x)
x1 = self.contract1(x0)
x2 = self.contract2(x1)
x3 = self.contract3(x2)
x4 = self.contract4(x3)
xn = self.final(x4)
return xn

# UNIT TEST
test_discriminator = Discriminator(10, 1)
assert tuple(test_discriminator(
torch.randn(1, 5, 256, 256),
torch.randn(1, 5, 256, 256)
).shape) == (1, 1, 16, 16
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon):
'''
Return the loss of the generator given inputs.
Parameters:
gen: the generator; takes the condition and returns potential images
disc: the discriminator; takes images and the condition and
returns real/fake prediction matrices
real: the real images (e.g. maps) to be used to evaluate the reconstruction
condition: the source images (e.g. satellite imagery) which are used to produce the real images
adv_criterion: the adversarial loss function; takes the discriminator
predictions and the true labels and returns a adversarial
loss (which you aim to minimize)
recon_criterion: the reconstruction loss function; takes the generator
outputs and the real images and returns a reconstructuion
loss (which you aim to minimize)
lambda_recon: the degree to which the reconstruction loss should be weighted in the sum
'''
#### START CODE HERE ####
gen_img = gen(condition)
out = disc(gen_img, condition)
adv_loss = adv_criterion(out, torch.ones_like(out))
recon_loss = recon_criterion(gen_img, real)
gen_loss = adv_loss + lambda_recon * recon_loss
#### END CODE HERE ####
return gen_loss
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
adv_criterion = nn.BCEWithLogitsLoss() 
recon_criterion = nn.L1Loss()
lambda_recon = 200

### Update discriminator ###
disc_fake_hat = disc(fake.detach(), condition) # Detach generator
disc_fake_loss = adv_criterion(disc_fake_hat, torch.zeros_like(disc_fake_hat))
disc_real_hat = disc(real, condition)
disc_real_loss = adv_criterion(disc_real_hat, torch.ones_like(disc_real_hat))
disc_loss = (disc_fake_loss + disc_real_loss) / 2
disc_loss.backward(retain_graph=True) # Update gradients
disc_opt.step() # Update optimizer

### Update generator ###
gen_opt.zero_grad()
gen_loss = get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon)

Conclusion

Screen Shot 2022-02-12 at 15.43.41.png

The results in this paper suggest that conditional adversarial networks are a promising approach for many image- to-image translation tasks, especially those involving highly structured graphical outputs. These networks learn a loss adapted to the task and data at hand, which makes them applicable in a wide variety of settings.