cycleGAN 模型属于 GAN 模型的一个变形,在很多情况下,我们无法获得或者很难获得成对的训练数据,cycleGAN 要解决的问题是: seek an algorithm that can learn to translate between domains without paired input-output examples,如下图:

其中,cycleGAN 的网络结构如下图所示:

它包括两个生成器 G、F和两个判别器 Dx、Dy。生成器 G 对输入的图片x 进行变换生成图片 $\hat y$ ,生成器 F 对输入的照片 y 进行变换生成照片 $\hat x$ 。判别器 Dx 对生成器 G生成的图片 $\hat y$ 和真实的图片 y 进行判别,分辨真假;判别器 Dy 对生成器 F 生成的图片 $\hat x$ 和真实的图片 x 进行判别,分辨真伪。
本文所复现的论文地址:https://arxiv.org/abs/1703.10593
1、数据集处理
所用数据集地址:https://people.eecs.berkeley.edu/%7Etaesung_park/CycleGAN/datasets/,(horse<-->zebra)-->
其中 trainA 中是 horse 照片,trainB 中是 zebra 照片,图片大小为 256*256*3,其处理过程和前几篇博文中的一样,包括放大、裁剪、随机镜像、归一化等操作,具体代码如下:
1 | def load_image(image_path): |
2、定义生成器
这里生成器和 pix2pix 一样,同样采用的是 u-net 结构,论文中说是用 instance normalization 代替 batch normalization,因为我们这里设置的 batch size = 1 ,所以,就一点都没有改动,直接搬过来,具体代码如下:
1 | def Generator(): |
3、定义判别器
判别器的结构也是和 pix2pix 中的结构一样(patchGAN),只有一点不同,就是这里的判别器的输入是一张图片,不再是之前的一张图片+条件。 代码如下:
1 | def Discriminator(): |
4、定义损失函数
对于生成器来说,要满足以下几个要求:
生成的图片不能够被判别器认出来,即生成图片经过判别器输出的
30*30*1的矩阵和全1矩阵的差距。对生成器
G来说,输入horse要输出zebra,但是,输入zebra还要输出zebra,即same_loss;对F来说也一样。循环一致性损失,即
x–F(G(x))– $\hat x$ ,两者的差距要尽可能的小,即cycle_loss。同样的,y–G(F(y))– $\hat y$ 。
具体代码定义如下:
1 | def generator_loss(disc_generated_output): |
对判别器来说,满足能判别真假图就可以了,真实图片的判别输出与全 1 比较,生成图片的判别输出与全 0 比较。代码定义如下:
1 | def discriminator_loss(disc_real_output, disc_generated_output): |
5、定义优化器及训练过程
训练过程就是: 生成图片—计算损失—计算梯度—更新参数 的过程。
1 | Gx = Generator() |
6、开始训练
训练过程无需成对的训练数据,每次随机从两个数据集 trainA 和 trainB 中挑选两张图片,执行 train_step 即可。
1 | def train(): |
7、效果展示
在训练过程中我们 save了7个阶段性模型参数,我们使用test 数据集依次查看一下不同阶段的模型效果。
1 | def show_image(step): |