0%

Style Transfer

1
!lspci | grep -i nvidia
02:00.0 3D controller: NVIDIA Corporation GP100GL [Tesla P100 PCIe 16GB] (rev a1)
1
2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import numpy as np
from scipy.optimize import fmin_l_bfgs_b

from keras import backend as K
from keras.preprocessing.image import load_img, save_img, img_to_array
from keras.applications import vgg19
from keras.optimizers import Adam

from keras.preprocessing import image
from keras.applications.vgg19 import preprocess_input
from keras.models import Model

import tensorflow as tf

from matplotlib import pyplot as plt

from PIL import Image
Using TensorFlow backend.
1
2
base_image_path = "dora.png"
style_reference_image_path = "star_sky.jpg"
1
2
3
4
5
6
7
8
9
10
11
plt.figure(figsize=(20,20))

plt.subplot(1, 3, 1)
plt.imshow(Image.open(base_image_path))

plt.subplot(1, 3, 2)
plt.imshow(Image.open(style_reference_image_path))


plt.subplot(1, 3, 3)
plt.imshow(Image.open("dora_at_iteration_9.png"))
<matplotlib.image.AxesImage at 0x7f4a2d7e0320>
png

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def preprocess_image(image_path):
img = load_img(image_path, target_size=(img_nrows, img_ncols))
img = img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input(img)
return img

def deprocess_image(x):
x = x.reshape((img_nrows, img_ncols, 3))
# Remove zero-center by mean pixel
x[:, :, 0] += 103.939
x[:, :, 1] += 116.779
x[:, :, 2] += 123.68
# 'BGR'->'RGB'
x = x[:, :, ::-1]
x = np.clip(x, 0, 255).astype('uint8')
return x


def extract_features(x, content_layers, style_layers):
contents = []
styles = []
for layer in model.layers:
x = layer(x)
if layer.name in content_layers:
contents.append(x)
if layer.name in style_layers:
styles.append(x)
return contents,styles

def get_contents(image):
"""
extract feature from image
"""
x = image.img_to_array(content_image)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
y = extract_features(x, content_layers, style_layers)
return x,y
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
total_variation_weight = 1
style_weight = 1000
content_weight = 1

width, height = load_img(base_image_path).size
img_nrows = 400
img_ncols = int(width * img_nrows / height)

base_image = K.constant(preprocess_image(base_image_path))
style_reference_image = K.constant(preprocess_image(style_reference_image_path))
combination_image = K.placeholder((1, img_nrows, img_ncols, 3))

input_tensor = K.concatenate([base_image,
style_reference_image,
combination_image], axis=0)
WARNING: Logging before flag parsing goes to stderr.
W0814 07:28:44.852089 140562742982464 deprecation_wrapper.py:119] From /home/rczhang/miniconda3/envs/ali_comp/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

匹配方案

  • 使用每个卷积块的第一个卷积层输出来匹配样式(称 之为样式层),和第四块中的最后一个卷积层来匹配内容(称之为内容层)
  • 这里我们选取比较靠后的内容层,以避免合成 图像保留过多内容图像的细节
1
2
3
4
model = vgg19.VGG19(input_tensor=input_tensor,
weights='imagenet', include_top=False)
print('Model loaded.')
model.summary()
W0814 07:28:44.914284 140562742982464 deprecation_wrapper.py:119] From /home/rczhang/miniconda3/envs/ali_comp/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0814 07:28:44.925884 140562742982464 deprecation_wrapper.py:119] From /home/rczhang/miniconda3/envs/ali_comp/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

W0814 07:28:45.143045 140562742982464 deprecation_wrapper.py:119] From /home/rczhang/miniconda3/envs/ali_comp/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3976: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.

W0814 07:28:48.193835 140562742982464 deprecation_wrapper.py:119] From /home/rczhang/miniconda3/envs/ali_comp/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.

W0814 07:28:48.205155 140562742982464 deprecation_wrapper.py:119] From /home/rczhang/miniconda3/envs/ali_comp/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:181: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.



Model loaded.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, None, None, 3)     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, None, None, 64)    1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, None, None, 64)    36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, None, None, 64)    0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, None, None, 128)   73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, None, None, 128)   147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, None, None, 128)   0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, None, None, 256)   295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, None, None, 256)   590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, None, None, 256)   590080    
_________________________________________________________________
block3_conv4 (Conv2D)        (None, None, None, 256)   590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, None, None, 256)   0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, None, None, 512)   1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block4_conv4 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, None, None, 512)   0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_conv4 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, None, None, 512)   0         
=================================================================
Total params: 20,024,384
Trainable params: 20,024,384
Non-trainable params: 0
_________________________________________________________________
1
2
3
content_layers = "block5_conv4"
style_layers = ["block1_conv1","block2_conv1","block3_conv1","block4_conv1","block5_conv1"]
outputs_dict = dict([(layer.name, layer.output) for layer in model.layers])

Loss

  • Content loss
  • Style loss
    • 对于样式,我们可以简单将它看成是像素点在每个通道的统计分布。例如要匹配两张图像的样式, 我们可以匹配这两张图像在 RGB 这三个通道上的直方图。更一般的,假设卷积层的输出格式是c × h × w,既(通道,高,宽)。那么我们可以把它变形成 c × hw 的二维数组,并将它看成是一 个维度为 c 的随机变量采样到的 hw 个点。所谓的样式匹配就是使得两个 c 维随机变量统计分布 一致。匹配统计分布常用的做法是冲量匹配,就是说使得他们有一样的均值,协方差,和其他高维的冲量。为了计算简单起⻅,我们只匹配二阶信息,即协方差。

Noise reduction

当我们使用靠近输出层的神经层输出来匹配时,经常可以观察到学到的合成图像里面有大量高 频噪音,即有特别亮或者暗的颗粒像素。一种常用的降噪方法是总变差降噪(total variation denoising)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def gram_matrix(x):
assert K.ndim(x) == 3
if K.image_data_format() == 'channels_first':
features = K.batch_flatten(x)
else:
features = K.batch_flatten(K.permute_dimensions(x, (2, 0, 1)))
gram = K.dot(features, K.transpose(features))
return gram

# the "style loss" is designed to maintain
# the style of the reference image in the generated image.
# It is based on the gram matrices (which capture style) of
# feature maps from the style reference image
# and from the generated image


def style_loss(style, combination):
assert K.ndim(style) == 3
assert K.ndim(combination) == 3
S = gram_matrix(style)
C = gram_matrix(combination)
channels = 3
size = img_nrows * img_ncols
return K.sum(K.square(S - C)) / (4.0 * (channels ** 2) * (size ** 2))

# an auxiliary loss function
# designed to maintain the "content" of the
# base image in the generated image


def content_loss(base, combination):
return K.sum(K.square(combination - base))

# the 3rd loss function, total variation loss,
# designed to keep the generated image locally coherent


def total_variation_loss(x):
assert K.ndim(x) == 4
if K.image_data_format() == 'channels_first':
a = K.square(
x[:, :, :img_nrows - 1, :img_ncols - 1] - x[:, :, 1:, :img_ncols - 1])
b = K.square(
x[:, :, :img_nrows - 1, :img_ncols - 1] - x[:, :, :img_nrows - 1, 1:])
else:
a = K.square(
x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :])
b = K.square(
x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :])
return K.sum(K.pow(a + b, 1.25))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
loss = K.variable(0.0)
content_layers_features = outputs_dict[content_layers]
base_image_features = content_layers_features[0, :, :, :]
combination_features = content_layers_features[2, :, :, :]
loss += content_weight * content_loss(base_image_features,
combination_features)

for layer_name in style_layers:
layer_features = outputs_dict[layer_name]
style_reference_features = layer_features[1, :, :, :]
combination_features = layer_features[2, :, :, :]
sl = style_loss(style_reference_features, combination_features)
loss += (style_weight / len(style_layers)) * sl
loss += total_variation_weight * total_variation_loss(combination_image)
grads = K.gradients(loss, combination_image)
W0814 07:28:58.561622 140562742982464 variables.py:2429] Variable += will be deprecated. Use variable.assign_add if you want assignment to the variable value or 'x = x + y' if you want a new python Tensor object.
W0814 07:29:00.537572 140562742982464 deprecation.py:323] From /home/rczhang/miniconda3/envs/ali_comp/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:1205: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
outputs = [loss]
if isinstance(grads, (list, tuple)):
outputs += grads
else:
outputs.append(grads)

f_outputs = K.function([combination_image], outputs)

def eval_loss_and_grads(x):
if K.image_data_format() == 'channels_first':
x = x.reshape((1, 3, img_nrows, img_ncols))
else:
x = x.reshape((1, img_nrows, img_ncols, 3))
outs = f_outputs([x])
loss_value = outs[0]
if len(outs[1:]) == 1:
grad_values = outs[1].flatten().astype('float64')
else:
grad_values = np.array(outs[1:]).flatten().astype('float64')
return loss_value, grad_values
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class Evaluator(object):

def __init__(self):
self.loss_value = None
self.grads_values = None

def loss(self, x):
assert self.loss_value is None
loss_value, grad_values = eval_loss_and_grads(x)
self.loss_value = loss_value
self.grad_values = grad_values
return self.loss_value

def grads(self, x):
assert self.loss_value is not None
grad_values = np.copy(self.grad_values)
self.loss_value = None
self.grad_values = None
return grad_values

evaluator = Evaluator()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
iterations = 100
result_prefix = "chenzai"
x = preprocess_image(base_image_path)

for i in range(iterations):
print('Start of iteration', i)
x, min_val, info = fmin_l_bfgs_b(evaluator.loss, x.flatten(),
fprime=evaluator.grads, maxfun=20)
print('Current loss value:', min_val)
# save current generated image
img = deprocess_image(x.copy())
fname = result_prefix + '_at_iteration_%d.png' % i
save_img(fname, img)
print('Image saved as', fname)
Start of iteration 0
Current loss value: 1674930200000.0
Image saved as chenzai_at_iteration_0.png
Start of iteration 1
Current loss value: 717802100000.0
Image saved as chenzai_at_iteration_1.png
Start of iteration 2
Current loss value: 454865100000.0
Image saved as chenzai_at_iteration_2.png
Start of iteration 3
Current loss value: 328861320000.0
Image saved as chenzai_at_iteration_3.png
Start of iteration 4
Current loss value: 268055990000.0
Image saved as chenzai_at_iteration_4.png
Start of iteration 5
Current loss value: 222694330000.0
Image saved as chenzai_at_iteration_5.png
Start of iteration 6
Current loss value: 191597500000.0
Image saved as chenzai_at_iteration_6.png
Start of iteration 7
Current loss value: 170507350000.0
Image saved as chenzai_at_iteration_7.png
Start of iteration 8
Current loss value: 151088100000.0
Image saved as chenzai_at_iteration_8.png
Start of iteration 9
Current loss value: 138113840000.0
Image saved as chenzai_at_iteration_9.png
Start of iteration 10
Current loss value: 129697645000.0
Image saved as chenzai_at_iteration_10.png
Start of iteration 11
Current loss value: 122296410000.0
Image saved as chenzai_at_iteration_11.png
Start of iteration 12
Current loss value: 113507475000.0
Image saved as chenzai_at_iteration_12.png
Start of iteration 13
Current loss value: 107962425000.0
Image saved as chenzai_at_iteration_13.png
Start of iteration 14
Current loss value: 97767780000.0
Image saved as chenzai_at_iteration_14.png
Start of iteration 15
Current loss value: 91393800000.0
Image saved as chenzai_at_iteration_15.png
Start of iteration 16
Current loss value: 87411384000.0
Image saved as chenzai_at_iteration_16.png
Start of iteration 17
Current loss value: 83964305000.0
Image saved as chenzai_at_iteration_17.png
Start of iteration 18
Current loss value: 80351320000.0
Image saved as chenzai_at_iteration_18.png
Start of iteration 19
1
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7f56987e2400>
png
1