Pix2Pix模型实现--tensorflow2.0

该篇博文使用 cGAN 实现了 image to image translation,由于 cGAN 可以通过添加条件来指导图像的生成,因此用此图像翻译技术,可以很好的实现诸如图片着色等任务,具体模型参考 此篇论文

pix2pix 模型使用输入图像(x)作为条件(这里没有噪声输入 z, 可以把 G 的输入 x 看做是噪声 z),学习从输入图像到输出图像的映射,从而得到指定的输出图像。具体过程如下图:

传统图像转换过程中都是针对具体问题采用特定算法去解决,而这些过程的本质都是根据像素点(输入信息)对像素点做出预测(predict from pixels to pixels)pix2pix的目标就是建立一个通用的架构去解决图像翻译问题,使得我们不必要为每个特定任务都重新设计一个损失函数。但是该模型也有一定的缺点,pix2pix在训练时需要成对的图像(x和y),模型学到的是 x 到 y 之间的一对一映射。也就说,pix2pix就是对ground truth的重建:输入轮廓—>经过Unet编码解码成对应的向量—>解码成真实图。这种一对一映射的应用范围十分有限,当我们输入的数据与训练集中的数据差距较大时,生成的结果很可能就没有意义,这就要求我们的训练数据集要尽量涵盖各种类型。以轮廓图到服装为例,我们在自己的数据集上训练好模型,当输入与训练集中类似的轮廓图时得到以下结果:

当我们输入训练集中不存在的轮廓图时,得到以下:

我们可以看出,服装的形态还是可以保持的,但是生成图像的颜色并不能令人满意。

对于这种训练数据需要成对的问题,我们下一篇博文会讲述如何使用 cycleGAN 来解决这个问题。那么,闲话不多说,先上代码来复现一下论文中提到的 pix2pix 模型。

1、数据集图像处理

简单起见,我们直接使用论文中的 CMP Facade Database,给定楼房的轮廓生成实际的楼房图,数据集中的样本是一张张 x+y 组成的图片,大小为 256 * 512 像素,左边部分是真实的楼房图片,右边部分是与之对应的轮廓图,如下图所示:

首先要做的是将图片的左右两侧分开,实现函数如下:

1
2
3
4
5
6
7
def load_image(image_path):
image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image)
width = tf.shape(image)[1] // 2
real_image = tf.cast(image[:, :width, :], tf.float32)
input_image = tf.cast(image[:, width:, :], tf.float32)
return input_image, real_image

分开之后,对于每张图片我们可以得到一个 input_imagereal_image。之后,为了使训练得到的模型更具有适应性,模型训练时每次喂入的样本需要进行随机扰动(先将 256*256 的图像放大到 286*286,之后再裁剪到 256*256)、镜像等操作。论文中是这样描述的:

1、resize an image to bigger height and width
2、randomly crop to the target size
3、randomly flip the image horizontally

具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def resize_image(input_image, real_image, height, width):
input_image = tf.image.resize(input_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
real_image = tf.image.resize(real_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return input_image, real_image
# input_iamge和real_image先合并,再random_crop裁剪,可以保证两者裁剪到的是同一区域,相对应
def random_crop(input_image, real_image):
stacked_image = tf.stack([input_image, real_image], axis=0)
cropped_image = tf.image.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
return cropped_image[0], cropped_image[1]

def random_jitter(input_image, real_image):
input_image, real_image = resize_image(input_image, real_image, 286, 286)
input_image, real_image = random_crop(input_image, real_image)
if tf.random.uniform(()) > 0.5: # 随机选取图像进行水平镜像
input_image = tf.image.flip_left_right(input_image)
real_image = tf.image.flip_left_right(real_image)
return input_image, real_image

最后,对图像的像素值进行一个归一化操作,将其像素值归一化为 [-1, 1] 之间。

1
2
3
4
def normalize_image(input_image, real_image):
input_image = (input_image / 127.5) - 1
real_image = (real_image / 127.5) - 1
return input_image, real_image

总体的数据加载函数如下,包括训练集的加载(需要随机裁剪、镜像)和测试集的加载(不需要任何多余操作)。

1
2
3
4
5
6
7
8
9
10
11
def load_train_image(image_file):
input_image, real_image = load_image(image_file)
input_image, real_image = random_jitter(input_image, real_image)
input_image, real_image = normalize_image(input_image, real_image)
return input_image, real_image

def load_test_image(image_file):
input_image, real_image = load_image(image_file)
input_image, real_image = resize_image(input_image, real_image, IMG_HEIGHT, IMG_WIDTH)
input_image, real_image = normalize_image(input_image, real_image)
return input_image, real_image
2、构建生成器

生成器是一个 U-net 架构的网络,其结构如下图:

由降采样的编码层和上采样的解码层组成,其中,编码器和解码器的层还会跨层通过 channel 进行拼接,这点和 FCN全卷积网络)有些不同,虽然它们都可以用来做图像语义分割,FCN 是像素点 add

Encoder 模块中(降采样)的每个 block 结构为: Conv-->Batchnorm-->Leaky ReLU,代码实现如下:

1
2
3
4
5
6
7
def downsample(filters, kernel_size, apply_batchnorm=True):
result = tf.keras.Sequential()
result.add(tf.keras.layers.Conv2D(filters, kernel_size, strides=2, padding='same', kernel_initializer=tf.random_normal_initializer(0., 0.02), use_bias=False))
if apply_batchnorm:
result.add(tf.keras.layers.BatchNormalization())
result.add(tf.keras.layers.LeakyReLU())
return result

Decoder 模块中(上采样)的每个 block 结构为:Transposed Conv-->Batchnorm-->Dropout(applied to the first 3 blocks)-->Relu,代码实现如下:

1
2
3
4
5
6
7
8
def upsample(filters, kernel_size, apply_dropout=False):
result = tf.keras.Sequential()
result.add(tf.keras.layers.Conv2DTranspose(filters, kernel_size, strides=2, padding='same', kernel_initializer=tf.random_normal_initializer(0., 0.02), use_bias=False))
result.add(tf.keras.layers.BatchNormalization())
if apply_dropout:
result.add(tf.keras.layers.Dropout(rate=0.5))
result.add(tf.keras.layers.ReLU())
return result

上面关于 encoder 和 decoder 模块的实现方式,采用的是 Keras 的序列模型实现方式。接下来的 generator 采用的就是 keras 的 通用模型实现方式。两者都需要在模型的第一层指定输入的尺寸。

定义好了 EncoderDecoder 模块后,就可以基于此来搭建整个 U-net 架构的 Generator 了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def Generator():
inputs = tf.keras.layers.Input(shape=(256, 256, 3))
down_stack = [
downsample(64, (4, 4), apply_batchnorm=False), # [batch, 128, 128, 64]
downsample(128, (4, 4)), # [batch, 64, 64, 128]
downsample(256, (4, 4)), # [batch, 32, 32, 256]
downsample(512, (4, 4)), # [batch, 16, 16, 512]
downsample(512, (4, 4)), # [batch, 8, 8, 512]
downsample(512, (4, 4)), # [batch, 4, 4, 512]
downsample(512, (4, 4)), # [batch, 2, 2, 512]
downsample(512, (4, 4)), # [batch, 1, 1, 512]
]
up_stack = [
upsample(512, (4, 4), apply_dropout=True), # [batch, 2, 2, 512]
upsample(512, (4, 4), apply_dropout=True), # [batch, 4, 4, 512]
upsample(512, (4, 4), apply_dropout=True), # [batch, 8, 8, 512]
upsample(512, (4, 4)), # [batch, 16, 16, 512]
upsample(256, (4, 4)), # [batch, 32, 32, 256]
upsample(128, (4, 4)), # [batch, 64, 64, 128]
upsample(64, (4, 4)), # [batch, 128, 128, 64]
]
last_layer = tf.keras.layers.Conv2DTranspose(filters=OUTPUT_CHANNELS, kernel_size=(4, 4), strides=2, padding='same', kernel_initializer=tf.random_normal_initializer(0., 0.02), activation='tanh') # [batch, 256, 256, 3]
x = inputs # inputs is a tensor with shape [256, 256, 3]
down_outputs = []
for down_layer in down_stack:
x = down_layer(x) # 调用的是 call 方法
down_outputs.append(x)
down_outputs = reversed(down_outputs[:-1]) # 2-->4-->8-->16-->32-->64-->128, 共 7 层
for up_layer, down_output in zip(up_stack, down_outputs):
x = up_layer(x)
x = tf.concat([x, down_output], axis=3)
x = last_layer(x) # [batch, 256, 256, 3]
return tf.keras.Model(inputs=inputs, outputs=x)

其中第 31 行的 x = tf.concat([x, down_output], axis=3) 就是实现的 skip connection,将两者的 channel 拼接起来。可见,generator 输入尺寸是 256*256*3 ,输出尺寸还是 256*256*3generator 的整个架构如下图:

3、构建判别器

判别器接收两个输入,一个是条件(x) input_image,另一个是生成图片generated_image 或者是真实目标图片real_image。对于input_image + generated_image 的组合,判别器应该尽可能的输出全零,相反,对于 input_image + real_image 的组合,判别器应该尽可能的输出全 1。

这里,判别器使用的是 PatchGAN 的架构,最后输出的不再是一个标量值,代表 real 或者 fake ,而是一个 30*30 的矩阵,矩阵中的每个元素是对自己感受野区域receptive field真伪的一个判定,这样的话判别器可以输出更加细粒度的判定结果。

对于 PatchGAN 的理解,可以参阅作者的回答,里面详细介绍了什么是 PatchGAN 以及输出矩阵中每个元素感受野大小的计算等问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def Discriminator():
input_image = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
target_image = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')
x = tf.concat([input_image, target_image], axis=3) # [batch, 256, 256, 6]

down_sample1 = downsample(filters=64, kernel_size=(4, 4), apply_batchnorm=False)(x) # [batch, 128, 128, 64]
down_sample2 = downsample(128, (4, 4))(down_sample1) # [batch, 64, 64, 128]
down_sample3 = downsample(256, (4, 4))(down_sample2) # [batch, 32, 32, 256]

zero_pad1 = tf.keras.layers.ZeroPadding2D(padding=(1, 1))(down_sample3) # [batch, 34, 34, 256]
conv = tf.keras.layers.Conv2D(filters=512, kernel_size=(4, 4), strides=1, padding='valid', kernel_initializer=tf.random_normal_initializer(0., 0.02), use_bias=False)(zero_pad1) # [batch, 31, 31, 512]
batchnorm = tf.keras.layers.BatchNormalization()(conv)
leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm)

zero_pad2 = tf.keras.layers.ZeroPadding2D(padding=(1, 1))(leaky_relu) # [batch, 33, 33, 512]
output = tf.keras.layers.Conv2D(filters=1, kernel_size=(4, 4), strides=1, kernel_initializer=tf.random_normal_initializer(0., 0.02))(zero_pad2)
# [batch, 30, 30, 1]
return tf.keras.Model(inputs=[input_image, target_image], outputs=output)

整个判别器的架构图如下:

4、定义生成器和判别器的损失函数

pix2pix 说到底还是一个 cGAN 模型,所以其整体的损失函数也和 cGAN 差不多,定义如下:

cGAN 损失函数的基础上,加上了生成器 GL1 损失,两部分损失定义如下:

其具体的代码实现很简单,如下所示:

1
2
3
4
5
6
7
8
9
10
11
def generator_loss(disc_generated_output, generated_image, real_image):
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
l1_loss = tf.reduce_mean(tf.abs(generated_image - real_image))
return gan_loss + 100 * l1_loss

def discriminator_loss(disc_real_output, disc_generated_output):
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
return real_loss + generated_loss
5、定义优化器及训练过程

训练过程中,每次喂入一对图像input_image 和对应的 real_image。之后,生成器根据 input_image 生成 generated_image,接着,判别器分别判断input_image + generated_imageinput_image + real_image,输出 disc_generate_outputdisc_real_output,计算出 loss 后,计算梯度并更新模型参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
generator = Generator()
discriminator = Discriminator()
generator_optimizier = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
discriminator_optimizier = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)

@tf.function
def train_step(input_image, real_image):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_image = generator(inputs=input_image, training=True)

disc_generate_output = discriminator(inputs=[input_image, generated_image], training=True)
disc_real_output = discriminator(inputs=[input_image, real_image], training=True)

gen_loss = generator_loss(disc_generate_output, generated_image, real_image)
disc_loss = discriminator_loss(disc_real_output, disc_generate_output)

gen_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
disc_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

generator_optimizier.apply_gradients(zip(gen_gradients, generator.trainable_variables))
discriminator_optimizier.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))
6、开始训练

定义checkpoint,训练过程中每 20 轮保存一下模型,同时展示每一轮训练后,生成器生成图片的效果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def train(train_dataset_dir):
files = os.listdir(train_dataset_dir)
for file in files:
image_path = os.path.join(train_dataset_dir, file)
input_image, real_image = load_train_image(image_path)
train_step(input_image[tf.newaxis,...], real_image[tf.newaxis,...])
checkpint = tf.train.Checkpoint(generator=generator, discriminator=discriminator, generator_optimizer = generator_optimizier, discriminator_optimizier=discriminator_optimizier)

if __name__ == "__main__":
test_images_dir = "dataset/facades/test/"
test_image_names = os.listdir(test_images_dir) # return a filename list
for epoch in range(EPOCHs):
start = time.time()
train("dataset/facades/train")
print("Epoch", epoch, ": ", time.time()-start)
if (epoch+1) % 20 == 0:
checkpint.save(file_prefix='training_checkpoints/pix2pix')
# 随机找一张图片查看模型的训练效果
test_image_path = os.path.join(test_images_dir, random.choice(test_image_names))
test_input_image, test_real_image = load_test_image(test_image_path)
show_image(generator, test_input_image, test_real_image, epoch)
7、效果展示

每次画出 输入图、真实图、生成图,三者形成效果对比,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def show_image(generator_model, input_image, real_image, epoch):
generated_image = generator_model(input_image[tf.newaxis,...], training=False)[0]
plt.figure(figsize=(15, 6))
plt.subplot(1, 3, 1)
plt.imshow(input_image*0.5+0.5) # pixel value is between [0, 1]
plt.title("Input Image")
plt.axis("off")
plt.subplot(1, 3, 2)
plt.imshow(real_image * 0.5 + 0.5) # pixel value is between [0, 1]
plt.title("Real Image")
plt.axis("off")
plt.subplot(1, 3, 3)
plt.imshow(generated_image * 0.5 + 0.5) # pixel value is between [0, 1]
plt.title("Generated Image")
plt.axis("off")
plt.savefig("pix2pix_image_save/" + str(epoch) + ".png")
plt.show()

由上图可见,训练效果还不错。

完整代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
# _*_ coding: utf-8 _*_
"""
@author: Jibao Wang
@time: 2019/12/27 19:32
"""

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os, time, random


IMG_WIDTH = 256
IMG_HEIGHT = 256
OUTPUT_CHANNELS = 3
EPOCHs = 150

def load_image(image_path):
image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image)
width = tf.shape(image)[1] // 2
real_image = tf.cast(image[:, :width, :], tf.float32)
input_image = tf.cast(image[:, width:, :], tf.float32)
return input_image, real_image

def resize_image(input_image, real_image, height, width):
input_image = tf.image.resize(input_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
real_image = tf.image.resize(real_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return input_image, real_image

def random_crop(input_image, real_image):
stacked_image = tf.stack([input_image, real_image], axis=0)
cropped_image = tf.image.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
return cropped_image[0], cropped_image[1]

def random_jitter(input_image, real_image):
input_image, real_image = resize_image(input_image, real_image, 286, 286)
input_image, real_image = random_crop(input_image, real_image)
if tf.random.uniform(()) > 0.5: # 随机选取图像进行水平镜像
input_image = tf.image.flip_left_right(input_image)
real_image = tf.image.flip_left_right(real_image)
return input_image, real_image

# 将图像的像素值归一化为 [-1, 1] 之间
def normalize_image(input_image, real_image):
input_image = (input_image / 127.5) - 1
real_image = (real_image / 127.5) - 1
return input_image, real_image

def load_train_image(image_file):
input_image, real_image = load_image(image_file)
input_image, real_image = random_jitter(input_image, real_image)
input_image, real_image = normalize_image(input_image, real_image)
return input_image, real_image

def load_test_image(image_file):
input_image, real_image = load_image(image_file)
input_image, real_image = resize_image(input_image, real_image, IMG_HEIGHT, IMG_WIDTH)
input_image, real_image = normalize_image(input_image, real_image)
return input_image, real_image

# 构建生成器
def downsample(filters, kernel_size, apply_batchnorm=True):
result = tf.keras.Sequential()
result.add(tf.keras.layers.Conv2D(filters, kernel_size, strides=2, padding='same', kernel_initializer=tf.random_normal_initializer(0., 0.02),
use_bias=False))
if apply_batchnorm:
result.add(tf.keras.layers.BatchNormalization())
result.add(tf.keras.layers.LeakyReLU())
return result

def upsample(filters, kernel_size, apply_dropout=False):
result = tf.keras.Sequential()
result.add(tf.keras.layers.Conv2DTranspose(filters, kernel_size, strides=2, padding='same', kernel_initializer=tf.random_normal_initializer(0., 0.02),
use_bias=False))
result.add(tf.keras.layers.BatchNormalization())
if apply_dropout:
result.add(tf.keras.layers.Dropout(rate=0.5))
result.add(tf.keras.layers.ReLU())
return result

def Generator():
inputs = tf.keras.layers.Input(shape=(256, 256, 3))
down_stack = [
downsample(64, (4, 4), apply_batchnorm=False), # [batch, 128, 128, 64]
downsample(128, (4, 4)), # [batch, 64, 64, 128]
downsample(256, (4, 4)), # [batch, 32, 32, 256]
downsample(512, (4, 4)), # [batch, 16, 16, 512]
downsample(512, (4, 4)), # [batch, 8, 8, 512]
downsample(512, (4, 4)), # [batch, 4, 4, 512]
downsample(512, (4, 4)), # [batch, 2, 2, 512]
downsample(512, (4, 4)), # [batch, 1, 1, 512]
]
up_stack = [
upsample(512, (4, 4), apply_dropout=True), # [batch, 2, 2, 512]
upsample(512, (4, 4), apply_dropout=True), # [batch, 4, 4, 512]
upsample(512, (4, 4), apply_dropout=True), # [batch, 8, 8, 512]
upsample(512, (4, 4)), # [batch, 16, 16, 512]
upsample(256, (4, 4)), # [batch, 32, 32, 256]
upsample(128, (4, 4)), # [batch, 64, 64, 128]
upsample(64, (4, 4)), # [batch, 128, 128, 64]
]
last_layer = tf.keras.layers.Conv2DTranspose(filters=OUTPUT_CHANNELS, kernel_size=(4, 4), strides=2, padding='same',
kernel_initializer=tf.random_normal_initializer(0., 0.02), activation='tanh') # [batch, 256, 256, 3]
x = inputs # inputs is a tensor with shape [256, 256, 3]
down_outputs = []
for down_layer in down_stack:
x = down_layer(x) # 调用的是 call 方法
down_outputs.append(x)
down_outputs = reversed(down_outputs[:-1]) # 2-->4-->8-->16-->32-->64-->128, 共 7 层
for up_layer, down_output in zip(up_stack, down_outputs):
x = up_layer(x)
x = tf.concat([x, down_output], axis=3)
x = last_layer(x) # [batch, 256, 256, 3]
return tf.keras.Model(inputs=inputs, outputs=x)

def Discriminator():
input_image = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
target_image = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')
x = tf.concat([input_image, target_image], axis=3) # [batch, 256, 256, 6]

down_sample1 = downsample(filters=64, kernel_size=(4, 4), apply_batchnorm=False)(x) # [batch, 128, 128, 64]
down_sample2 = downsample(128, (4, 4))(down_sample1) # [batch, 64, 64, 128]
down_sample3 = downsample(256, (4, 4))(down_sample2) # [batch, 32, 32, 256]

zero_pad1 = tf.keras.layers.ZeroPadding2D(padding=(1, 1))(down_sample3) # [batch, 34, 34, 256]
conv = tf.keras.layers.Conv2D(filters=512, kernel_size=(4, 4), strides=1, padding='valid', kernel_initializer=tf.random_normal_initializer(0., 0.02),
use_bias=False)(zero_pad1) # [batch, 31, 31, 512]
batchnorm = tf.keras.layers.BatchNormalization()(conv)
leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm)

zero_pad2 = tf.keras.layers.ZeroPadding2D(padding=(1, 1))(leaky_relu) # [batch, 33, 33, 512]
output = tf.keras.layers.Conv2D(filters=1, kernel_size=(4, 4), strides=1, kernel_initializer=tf.random_normal_initializer(0., 0.02))(zero_pad2)
# [batch, 30, 30, 1]
return tf.keras.Model(inputs=[input_image, target_image], outputs=output)

def generator_loss(disc_generated_output, generated_image, real_image):
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
l1_loss = tf.reduce_mean(tf.abs(generated_image - real_image))
return gan_loss + 100 * l1_loss

def discriminator_loss(disc_real_output, disc_generated_output):
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
return real_loss + generated_loss

generator = Generator()
discriminator = Discriminator()
generator_optimizier = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
discriminator_optimizier = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
checkpint = tf.train.Checkpoint(generator=generator, discriminator=discriminator,
generator_optimizer = generator_optimizier, discriminator_optimizier=discriminator_optimizier)

@tf.function
def train_step(input_image, real_image):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_image = generator(inputs=input_image, training=True)

disc_generate_output = discriminator(inputs=[input_image, generated_image], training=True)
disc_real_output = discriminator(inputs=[input_image, real_image], training=True)

gen_loss = generator_loss(disc_generate_output, generated_image, real_image)
disc_loss = discriminator_loss(disc_real_output, disc_generate_output)

gen_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
disc_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

generator_optimizier.apply_gradients(zip(gen_gradients, generator.trainable_variables))
discriminator_optimizier.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))


def train(train_dataset_dir):
files = os.listdir(train_dataset_dir)
for file in files:
image_path = os.path.join(train_dataset_dir, file)
input_image, real_image = load_train_image(image_path)
train_step(input_image[tf.newaxis,...], real_image[tf.newaxis,...])

def show_image(generator_model, input_image, real_image, epoch):
generated_image = generator_model(input_image[tf.newaxis,...], training=False)[0]
plt.figure(figsize=(15, 6))
plt.subplot(1, 3, 1)
plt.imshow(input_image*0.5+0.5) # pixel value is between [0, 1]
plt.title("Input Image")
plt.axis("off")
plt.subplot(1, 3, 2)
plt.imshow(real_image * 0.5 + 0.5) # pixel value is between [0, 1]
plt.title("Real Image")
plt.axis("off")
plt.subplot(1, 3, 3)
plt.imshow(generated_image * 0.5 + 0.5) # pixel value is between [0, 1]
plt.title("Generated Image")
plt.axis("off")
plt.savefig("pix2pix_image_save/" + str(epoch) + ".png")
plt.show()

# 通过 imageio 生成训练过程动画
def gif_animation_generate():
gif_name = "pix2pix.gif"
filenames = []
for i in range(1, 151):
filenames.append("pix2pix_image_save/" + str(i) + ".png")
frames = []
for filename in filenames:
im = imageio.imread(filename)
frames.append(im)
imageio.mimsave(gif_name, frames, "GIF", duration=0.1)


if __name__ == "__main__":
test_images_dir = "dataset/facades/test/"
test_image_names = os.listdir(test_images_dir) # return a filename list
for epoch in range(EPOCHs):
start = time.time()
train("dataset/facades/train") # 测试集和验证集同时拿来训练
train("dataset/facades/val")
print("Epoch", epoch, ": ", time.time()-start)

if (epoch+1) % 20 == 0:
checkpint.save(file_prefix='training_checkpoints/pix2pix')

# 随机找一张图片查看模型的训练效果
test_image_path = os.path.join(test_images_dir, random.choice(test_image_names))
test_input_image, test_real_image = load_test_image(test_image_path)
show_image(generator, test_input_image, test_real_image, epoch)
# 固定地选取一张图片,查看模型的训练效果
processing_show_input_image, processing_show_real_image = load_test_image("dataset/facades/test/1.jpg")
show_image(generator, processing_show_input_image, processing_show_real_image, epoch+150)
gif_animation_generate()
-------------本文结束感谢您的阅读-------------
您的鼓励就是我创作的动力,求打赏买面包~~
0%