想试试用GAN生成动漫头像,于是先用简单GAN试一试。不过有一说一,GAN真的难训练。
封面《Butterfly Seeker》
前言
DCGAN(Deep Convolutional Generative Adversarial Networks)是GAN的一种变体,顾名思义就是在原本的GAN中引入了卷积神经网络。本次使用DCGAN来生成手写数字。
Step 1 导入包
1 | import tensorflow as tf |
2.4.0
[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
step 2 检查数据集
本次使用mnist作为手写数字识别
1 | %matplotlib inline |
step 3 构建模型
权重初始化
权重初始化,根据DCGAN的论文,所有权重初始化为均值为0,方差为0.02。同时参考其他人的实现方法将BN层的初始化改了一下
1 | def init_weight(model:tf.keras.models.Model): |
生成器模型
生成器的模型我看到有多种实现方法,论文中的project and reshape有全连接和卷积两种方法,此处使用全连接来实现。反卷积有Conv2DTranspose
和Upsample2D+Conv2D
两种方法,此处按照论文使用前者。按照论文应该使用ReLU函数,此处替换为Leaky ReLU,alpha按照论文设置为0.2。
1 |
|
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 12544) 1266944
_________________________________________________________________
reshape (Reshape) (None, 7, 7, 256) 0
_________________________________________________________________
batch_normalization (BatchNo (None, 7, 7, 256) 1024
_________________________________________________________________
leaky_re_lu (LeakyReLU) (None, 7, 7, 256) 0
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 7, 7, 128) 819200
_________________________________________________________________
batch_normalization_1 (Batch (None, 7, 7, 128) 512
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 7, 7, 128) 0
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 14, 14, 64) 204800
_________________________________________________________________
batch_normalization_2 (Batch (None, 14, 14, 64) 256
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 14, 14, 64) 0
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 28, 28, 32) 51200
_________________________________________________________________
batch_normalization_3 (Batch (None, 28, 28, 32) 128
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 28, 28, 32) 0
_________________________________________________________________
conv2d (Conv2D) (None, 28, 28, 1) 800
_________________________________________________________________
activation (Activation) (None, 28, 28, 1) 0
=================================================================
Total params: 2,344,864
Trainable params: 2,343,904
Non-trainable params: 960
_________________________________________________________________
构造判决器
此处最后使用sigmoid激活函数,同时参考其他人代码加入了Dropout,不然后面训练的时候都没有对抗,两个模型直接摆烂(哭)。还需要注意的是D不能太强,不然G训练不起来
1 | def build_discriminator() -> tf.keras.Model: |
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 14, 14, 64) 1600
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 14, 14, 64) 0
_________________________________________________________________
dropout (Dropout) (None, 14, 14, 64) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 7, 7, 128) 204800
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 7, 7, 128) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 7, 7, 128) 0
_________________________________________________________________
flatten (Flatten) (None, 6272) 0
_________________________________________________________________
dense_1 (Dense) (None, 1) 6273
=================================================================
Total params: 212,673
Trainable params: 212,673
Non-trainable params: 0
_________________________________________________________________
生成联合模型
根据G和D生成联合模型,采用BCE LOSS。优化器根据论文Adam优化器,学习率按照论文选择0.0002,beta_1选择0.5。需要注意的是这里设置了D的trainable为False。而D编译后就算设置了trainable为False也可以训练,而联合模型编译时D设置了不能训练,因此联合模型只会训练G。
1 | def build_DCGAN(discriminator: tf.keras.Model, generator: tf.keras.Model) -> tf.keras.Model: |
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_3 (InputLayer) [(None, 100)] 0
_________________________________________________________________
sequential (Sequential) (None, 28, 28, 1) 2344864
_________________________________________________________________
sequential_1 (Sequential) (None, 1) 212673
=================================================================
Total params: 2,557,537
Trainable params: 2,343,904
Non-trainable params: 213,633
_________________________________________________________________
保存和显示图片
1 | if not os.path.exists('./images'): |
step 3 训练模型
根据论文设置batch_size为128,输入范围转换为[-1,1]
1 | d_loss_list = [] |
469it [01:57, 4.01it/s, step=468, d_loss=0.4450, acc=0.9115, g_loss=0.1112]
469it [01:53, 4.12it/s, step=468, d_loss=0.5188, acc=0.8750, g_loss=0.8756]
469it [01:53, 4.13it/s, step=468, d_loss=0.6582, acc=0.6354, g_loss=1.1004]
469it [01:53, 4.12it/s, step=468, d_loss=0.6392, acc=0.6667, g_loss=1.0247]
469it [01:53, 4.15it/s, step=468, d_loss=0.6302, acc=0.6510, g_loss=0.9042]
469it [01:53, 4.14it/s, step=468, d_loss=0.4286, acc=0.8490, g_loss=1.3098]
469it [01:53, 4.13it/s, step=468, d_loss=0.5711, acc=0.7552, g_loss=1.0918]
469it [01:54, 4.11it/s, step=468, d_loss=0.1656, acc=0.9531, g_loss=0.0233]
469it [01:53, 4.13it/s, step=468, d_loss=0.4908, acc=0.8177, g_loss=1.5465]
469it [01:53, 4.13it/s, step=468, d_loss=0.4977, acc=0.7656, g_loss=1.1850]
469it [01:54, 4.11it/s, step=468, d_loss=0.4030, acc=0.8646, g_loss=1.3088]
469it [01:53, 4.12it/s, step=468, d_loss=0.4622, acc=0.8177, g_loss=0.6755]
469it [01:53, 4.13it/s, step=468, d_loss=0.2610, acc=0.9323, g_loss=0.5214]
469it [01:53, 4.14it/s, step=468, d_loss=0.3714, acc=0.8646, g_loss=1.7093]
469it [01:52, 4.17it/s, step=468, d_loss=0.6514, acc=0.6458, g_loss=1.2826]
469it [01:52, 4.17it/s, step=468, d_loss=0.5218, acc=0.7448, g_loss=1.6180]
469it [01:52, 4.16it/s, step=468, d_loss=0.3786, acc=0.8594, g_loss=1.4780]
469it [01:52, 4.16it/s, step=468, d_loss=0.2106, acc=0.9427, g_loss=2.5880]
469it [01:52, 4.15it/s, step=468, d_loss=0.6040, acc=0.7240, g_loss=1.7180]
469it [01:52, 4.16it/s, step=468, d_loss=0.3441, acc=0.8490, g_loss=1.5291]
469it [01:52, 4.16it/s, step=468, d_loss=0.3020, acc=0.8958, g_loss=2.0768]
469it [01:52, 4.15it/s, step=468, d_loss=0.4162, acc=0.8333, g_loss=1.6850]
469it [01:52, 4.17it/s, step=468, d_loss=0.3279, acc=0.8698, g_loss=1.3302]
469it [01:53, 4.14it/s, step=468, d_loss=0.2163, acc=0.9271, g_loss=2.9825]
469it [01:53, 4.13it/s, step=468, d_loss=0.2229, acc=0.9167, g_loss=2.2349]
469it [01:52, 4.15it/s, step=468, d_loss=0.2062, acc=0.9375, g_loss=1.6529]
469it [01:53, 4.14it/s, step=468, d_loss=0.1285, acc=0.9635, g_loss=2.8505]
469it [01:53, 4.14it/s, step=468, d_loss=0.3163, acc=0.8750, g_loss=2.8405]
469it [01:53, 4.15it/s, step=468, d_loss=0.3621, acc=0.8698, g_loss=2.1161]
469it [01:53, 4.12it/s, step=468, d_loss=0.2909, acc=0.9062, g_loss=2.6993]
469it [01:53, 4.15it/s, step=468, d_loss=0.2011, acc=0.9271, g_loss=2.5064]
469it [01:53, 4.12it/s, step=468, d_loss=0.2564, acc=0.8854, g_loss=2.4029]
469it [01:53, 4.15it/s, step=468, d_loss=0.1781, acc=0.9479, g_loss=1.5723]
469it [01:53, 4.13it/s, step=468, d_loss=0.1547, acc=0.9583, g_loss=1.3841]
469it [01:53, 4.14it/s, step=468, d_loss=0.1384, acc=0.9583, g_loss=1.9002]
469it [01:53, 4.15it/s, step=468, d_loss=0.4499, acc=0.7865, g_loss=2.3525]
469it [01:54, 4.08it/s, step=468, d_loss=0.0773, acc=0.9844, g_loss=0.5579]
469it [01:54, 4.11it/s, step=468, d_loss=0.2629, acc=0.8958, g_loss=0.5309]
469it [02:00, 3.88it/s, step=468, d_loss=0.2257, acc=0.9167, g_loss=2.6041]
469it [01:59, 3.91it/s, step=468, d_loss=0.2283, acc=0.9271, g_loss=1.5944]
469it [02:08, 3.65it/s, step=468, d_loss=0.2495, acc=0.9115, g_loss=2.7131]
469it [02:07, 3.67it/s, step=468, d_loss=0.1437, acc=0.9479, g_loss=0.9581]
469it [02:08, 3.64it/s, step=468, d_loss=0.5403, acc=0.7865, g_loss=2.7837]
469it [02:08, 3.66it/s, step=468, d_loss=0.1985, acc=0.9323, g_loss=2.7737]
469it [02:09, 3.62it/s, step=468, d_loss=0.1488, acc=0.9427, g_loss=0.7168]
469it [02:07, 3.68it/s, step=468, d_loss=0.1167, acc=0.9583, g_loss=0.8895]
469it [02:09, 3.63it/s, step=468, d_loss=0.1243, acc=0.9323, g_loss=2.8780]
469it [02:10, 3.59it/s, step=468, d_loss=0.2008, acc=0.9115, g_loss=2.6001]
469it [02:10, 3.59it/s, step=468, d_loss=0.1474, acc=0.9479, g_loss=1.5862]
469it [02:08, 3.64it/s, step=468, d_loss=0.2088, acc=0.9062, g_loss=3.0232]
step 4 查看结果
查看d_loss、g_loss。可以看到两个loss都在震荡,感觉是还没训练完
1 | %matplotlib inline |
生成动图
将训练图中各个阶段的图片组合起来生成gif图片。虽然效果不是很好,但仍然可以看到数字的形状,没有晶体图
1 | anim_file = 'dcgan.gif' |
保存模型权重
1 | generator.save_weights('generator',True) |
参考文献
Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks
DCGANs (Deep Convolutional Generative Adversarial Networks)