cycleGAN
模型属于 GAN
模型的一个变形,在很多情况下,我们无法获得或者很难获得成对的训练数据,cycleGAN
要解决的问题是: seek an algorithm that can learn to translate between domains without paired input-output examples
,如下图:
其中,cycleGAN
的网络结构如下图所示:
它包括两个生成器 G、F
和两个判别器 Dx、Dy
。生成器 G
对输入的图片x
进行变换生成图片 $\hat y$ ,生成器 F
对输入的照片 y
进行变换生成照片 $\hat x$ 。判别器 Dx
对生成器 G
生成的图片 $\hat y$ 和真实的图片 y
进行判别,分辨真假;判别器 Dy
对生成器 F
生成的图片 $\hat x$ 和真实的图片 x
进行判别,分辨真伪。
本文所复现的论文地址:https://arxiv.org/abs/1703.10593
1、数据集处理
所用数据集地址:https://people.eecs.berkeley.edu/%7Etaesung_park/CycleGAN/datasets/,(horse<-->zebra)-->
其中 trainA
中是 horse
照片,trainB
中是 zebra
照片,图片大小为 256*256*3
,其处理过程和前几篇博文中的一样,包括放大、裁剪、随机镜像、归一化等操作,具体代码如下:
1 | def load_image(image_path): |
2、定义生成器
这里生成器和 pix2pix
一样,同样采用的是 u-net
结构,论文中说是用 instance normalization
代替 batch normalization
,因为我们这里设置的 batch size = 1
,所以,就一点都没有改动,直接搬过来,具体代码如下:
1 | def Generator(): |
3、定义判别器
判别器的结构也是和 pix2pix
中的结构一样(patchGAN
),只有一点不同,就是这里的判别器的输入是一张图片,不再是之前的一张图片+条件
。 代码如下:
1 | def Discriminator(): |
4、定义损失函数
对于生成器来说,要满足以下几个要求:
生成的图片不能够被判别器认出来,即生成图片经过判别器输出的
30*30*1
的矩阵和全1
矩阵的差距。对生成器
G
来说,输入horse
要输出zebra
,但是,输入zebra
还要输出zebra
,即same_loss
;对F
来说也一样。循环一致性损失,即
x
–F(G(x))
– $\hat x$ ,两者的差距要尽可能的小,即cycle_loss
。同样的,y
–G(F(y))
– $\hat y$ 。
具体代码定义如下:
1 | def generator_loss(disc_generated_output): |
对判别器来说,满足能判别真假图就可以了,真实图片的判别输出与全 1
比较,生成图片的判别输出与全 0
比较。代码定义如下:
1 | def discriminator_loss(disc_real_output, disc_generated_output): |
5、定义优化器及训练过程
训练过程就是: 生成图片—计算损失—计算梯度—更新参数 的过程。
1 | Gx = Generator() |
6、开始训练
训练过程无需成对的训练数据,每次随机从两个数据集 trainA
和 trainB
中挑选两张图片,执行 train_step
即可。
1 | def train(): |
7、效果展示
在训练过程中我们 save
了7个阶段性模型参数,我们使用test
数据集依次查看一下不同阶段的模型效果。
1 | def show_image(step): |