想试试用 GAN 生成动漫头像,于是先用简单 GAN 试一试。不过有一说一,GAN 真的难训练。
封面《Butterfly Seeker》

前言

DCGAN (Deep Convolutional Generative Adversarial Networks) 是 GAN 的一种变体,顾名思义就是在原本的 GAN 中引入了卷积神经网络。本次使用 DCGAN 来生成手写数字。

Step 1 导入包

1
2
3
4
5
6
7
import tensorflow as tf
import tensorflow.keras.layers as layers
import numpy as np
from tqdm import tqdm
import os
print(tf.__version__)
print(tf.config.list_physical_devices())
2.4.0
[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

step 2 检查数据集

本次使用 mnist 作为手写数字识别

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
%matplotlib inline
%config InlineBackend.figure_format = 'svg'


def load_dataset(batch_size=128):
mnist = tf.keras.datasets.mnist
# load dataset
(x_train, _), _ = mnist.load_data()
# [0,255] -> [-1,1]
x_train = x_train/127.5 - 1
# (batch_size,w,h) -> (batch,w,h,channel)
x_train = np.expand_dims(x_train, axis=3)
dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(60000).batch(batch_size)
return dataset
#查看训练样本
from matplotlib import pyplot as plt

dataset = load_dataset()
plt.figure(figsize=(8,8)) #8*8 inch
for i,(img) in enumerate(dataset.unbatch().take(9)):
ax = plt.subplot(3, 3, i+1)
ax.imshow(img[:,:,0])
#ax.set_title("label = %d" % label)
ax.set_xticks([])
ax.set_yticks([])
plt.show()

svg

step 3 构建模型

权重初始化

权重初始化,根据 DCGAN 的论文,所有权重初始化为均值为 0,方差为 0.02。同时参考其他人的实现方法将 BN 层的初始化改了一下

1
2
3
4
5
6
7
8
9
10
11
12
13
def init_weight(model:tf.keras.models.Model):
for layer in model.layers:
rule_weights = tf.keras.initializers.RandomNormal(mean=0,stddev=0.02)
if isinstance(layer,layers.Conv2D):
layer.weights[0] = rule_weights # 0:weight 1:bias
if isinstance(layer,layers.Conv2DTranspose):
layer.weights[0] = rule_weights
if isinstance(layer,layers.Dense):
layer.weights[0] = rule_weights
if isinstance(layer,layers.BatchNormalization):
layer.weights[0] = tf.keras.initializers.RandomNormal(mean=1.0,stddev=0.02)
layer.weights[1] = tf.keras.initializers.Zeros()

生成器模型

生成器的模型我看到有多种实现方法,论文中的 project and reshape 有全连接和卷积两种方法,此处使用全连接来实现。反卷积有 Conv2DTransposeUpsample2D+Conv2D 两种方法,此处按照论文使用前者。按照论文应该使用 ReLU 函数,此处替换为 Leaky ReLU,alpha 按照论文设置为 0.2。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

def build_generator() -> tf.keras.Model:
model = tf.keras.models.Sequential()
# z lantent_dim = 100
model.add(layers.InputLayer(input_shape=100))
# project and reshape
model.add(layers.Dense(7*7*256))
model.add(layers.Reshape((7, 7, 256)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU(alpha=0.2))

# 7*7*256 -> 7*7*128
model.add(layers.Conv2DTranspose(128,kernel_size=5,strides=1,padding='same',use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU(alpha=0.2))

# 7*7*128 -> 14*14*64
model.add(layers.Conv2DTranspose(64,kernel_size=5,strides=2,padding='same',use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU(alpha=0.2))

# 14*14*64 -> 28*28*32
model.add(layers.Conv2DTranspose(32,kernel_size=5,strides=2,padding='same',use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU(alpha=0.2))

model.add(layers.Conv2D(1,kernel_size=5,strides=1,padding='same',use_bias=False))
model.add(layers.Activation("tanh"))

model.summary()
return model

generator = build_generator()
if os.path.exists('generator.index'):
generator.load_weights('generator')
else:
init_weight(generator)

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 12544)             1266944   
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 256)         0         
_________________________________________________________________
batch_normalization (BatchNo (None, 7, 7, 256)         1024      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 7, 7, 128)         819200    
_________________________________________________________________
batch_normalization_1 (Batch (None, 7, 7, 128)         512       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 14, 14, 64)        204800    
_________________________________________________________________
batch_normalization_2 (Batch (None, 14, 14, 64)        256       
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 28, 28, 32)        51200     
_________________________________________________________________
batch_normalization_3 (Batch (None, 28, 28, 32)        128       
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 28, 28, 32)        0         
_________________________________________________________________
conv2d (Conv2D)              (None, 28, 28, 1)         800       
_________________________________________________________________
activation (Activation)      (None, 28, 28, 1)         0         
=================================================================
Total params: 2,344,864
Trainable params: 2,343,904
Non-trainable params: 960
_________________________________________________________________

构造判决器

此处最后使用 sigmoid 激活函数,同时参考其他人代码加入了 Dropout,不然后面训练的时候都没有对抗,两个模型直接摆烂(哭)。还需要注意的是 D 不能太强,不然 G 训练不起来

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def build_discriminator() -> tf.keras.Model:
model = tf.keras.models.Sequential()
model.add(layers.InputLayer(input_shape=(28,28,1)))
# 28*28*1 -> 14*14*64
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',use_bias=False))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Dropout(0.3))

# 14*14*64 -> 7*7*128
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same',use_bias=False))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Dropout(0.3))

model.add(layers.Flatten())
model.add(layers.Dense(1,activation='sigmoid'))

model.summary()
return model

discriminator = build_discriminator()
if os.path.exists('discriminator.index'):
discriminator.load_weights('discriminator')
else:
init_weight(discriminator)
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 14, 14, 64)        1600      
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 14, 14, 64)        0         
_________________________________________________________________
dropout (Dropout)            (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 7, 7, 128)         204800    
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 7, 7, 128)         0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 7, 7, 128)         0         
_________________________________________________________________
flatten (Flatten)            (None, 6272)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 6273      
=================================================================
Total params: 212,673
Trainable params: 212,673
Non-trainable params: 0
_________________________________________________________________

生成联合模型

根据 G 和 D 生成联合模型,采用 BCE LOSS。优化器根据论文 Adam 优化器,学习率按照论文选择 0.0002,beta_1 选择 0.5。需要注意的是这里设置了 D 的 trainable 为 False。而 D 编译后就算设置了 trainable 为 False 也可以训练,而联合模型编译时 D 设置了不能训练,因此联合模型只会训练 G。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def build_DCGAN(discriminator: tf.keras.Model, generator: tf.keras.Model) -> tf.keras.Model:
discriminator.compile(loss='binary_crossentropy',
optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002,beta_1=0.5),
metrics=['accuracy'])
discriminator.trainable = False
noise = layers.Input(shape=100)
img = generator(noise)
valid = discriminator(img)
model = tf.keras.Model(noise,valid)
model.summary()
model.compile(loss='binary_crossentropy',
optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002,beta_1=0.5))
return model

dcgan = build_DCGAN(discriminator,generator)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_3 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
sequential (Sequential)      (None, 28, 28, 1)         2344864   
_________________________________________________________________
sequential_1 (Sequential)    (None, 1)                 212673    
=================================================================
Total params: 2,557,537
Trainable params: 2,343,904
Non-trainable params: 213,633
_________________________________________________________________

保存和显示图片

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
if not os.path.exists('./images'):
os.mkdir('./images')
def plot_images(save2file=False, samples=16, step=0):
filename = "./images/mnist_%d.png" % step
noise = np.random.normal(0, 1, (samples, 100))

images = generator.predict(noise)

plt.figure(figsize=(10, 10))

for i in range(images.shape[0]):
plt.subplot(4, 4, i+1)
image = images[i, :, :, :]
plt.imshow(image[:,:,0])
plt.axis('off')
plt.tight_layout()

if save2file:
plt.savefig(filename)
plt.close('all')
else:
plt.show()

step 3 训练模型

根据论文设置 batch_size 为 128,输入范围转换为 [-1,1]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
d_loss_list = []
g_loss_list = []

def train(batch_size=128, epochs=50, save_interval=10):

dataset = load_dataset(batch_size)

for epoch in range(epochs):

pbar = tqdm(enumerate(dataset))
for step, real_img in pbar:
bs = len(real_img)

# label
real = np.ones((bs, 1))
fake = np.zeros((bs, 1))

# generate noise
noise = np.random.normal(0, 1, (bs, 100))
fake_img = generator.predict(noise)

# train d
d_loss_real = discriminator.train_on_batch(real_img, real)
d_loss_fake = discriminator.train_on_batch(fake_img, fake)
d_loss, accuracy = 0.5*np.add(d_loss_real, d_loss_fake)

# train g
noise = np.random.normal(0, 1, (bs, 100))
fake_img = generator.predict(noise)
g_loss = dcgan.train_on_batch(noise, real)

pbar.set_postfix({'step': step, 'd_loss': '{:.4f}'.format(
d_loss), 'acc': '{:.4f}'.format(accuracy), 'g_loss': '{:.4f}'.format(g_loss)})

d_loss_list.append(d_loss)
g_loss_list.append(g_loss)

if epoch % save_interval == 0 or epoch == (epochs-1):
plot_images(save2file=True, step=epoch)
generator.save_weights('generator', True)
discriminator.save_weights('discriminator', True)

train(epochs=50,save_interval=1)
469it [01:57,  4.01it/s, step=468, d_loss=0.4450, acc=0.9115, g_loss=0.1112]
469it [01:53,  4.12it/s, step=468, d_loss=0.5188, acc=0.8750, g_loss=0.8756]
469it [01:53,  4.13it/s, step=468, d_loss=0.6582, acc=0.6354, g_loss=1.1004]
469it [01:53,  4.12it/s, step=468, d_loss=0.6392, acc=0.6667, g_loss=1.0247]
469it [01:53,  4.15it/s, step=468, d_loss=0.6302, acc=0.6510, g_loss=0.9042]
469it [01:53,  4.14it/s, step=468, d_loss=0.4286, acc=0.8490, g_loss=1.3098]
469it [01:53,  4.13it/s, step=468, d_loss=0.5711, acc=0.7552, g_loss=1.0918]
469it [01:54,  4.11it/s, step=468, d_loss=0.1656, acc=0.9531, g_loss=0.0233]
469it [01:53,  4.13it/s, step=468, d_loss=0.4908, acc=0.8177, g_loss=1.5465]
469it [01:53,  4.13it/s, step=468, d_loss=0.4977, acc=0.7656, g_loss=1.1850]
469it [01:54,  4.11it/s, step=468, d_loss=0.4030, acc=0.8646, g_loss=1.3088]
469it [01:53,  4.12it/s, step=468, d_loss=0.4622, acc=0.8177, g_loss=0.6755]
469it [01:53,  4.13it/s, step=468, d_loss=0.2610, acc=0.9323, g_loss=0.5214]
469it [01:53,  4.14it/s, step=468, d_loss=0.3714, acc=0.8646, g_loss=1.7093]
469it [01:52,  4.17it/s, step=468, d_loss=0.6514, acc=0.6458, g_loss=1.2826]
469it [01:52,  4.17it/s, step=468, d_loss=0.5218, acc=0.7448, g_loss=1.6180]
469it [01:52,  4.16it/s, step=468, d_loss=0.3786, acc=0.8594, g_loss=1.4780]
469it [01:52,  4.16it/s, step=468, d_loss=0.2106, acc=0.9427, g_loss=2.5880]
469it [01:52,  4.15it/s, step=468, d_loss=0.6040, acc=0.7240, g_loss=1.7180]
469it [01:52,  4.16it/s, step=468, d_loss=0.3441, acc=0.8490, g_loss=1.5291]
469it [01:52,  4.16it/s, step=468, d_loss=0.3020, acc=0.8958, g_loss=2.0768]
469it [01:52,  4.15it/s, step=468, d_loss=0.4162, acc=0.8333, g_loss=1.6850]
469it [01:52,  4.17it/s, step=468, d_loss=0.3279, acc=0.8698, g_loss=1.3302]
469it [01:53,  4.14it/s, step=468, d_loss=0.2163, acc=0.9271, g_loss=2.9825]
469it [01:53,  4.13it/s, step=468, d_loss=0.2229, acc=0.9167, g_loss=2.2349]
469it [01:52,  4.15it/s, step=468, d_loss=0.2062, acc=0.9375, g_loss=1.6529]
469it [01:53,  4.14it/s, step=468, d_loss=0.1285, acc=0.9635, g_loss=2.8505]
469it [01:53,  4.14it/s, step=468, d_loss=0.3163, acc=0.8750, g_loss=2.8405]
469it [01:53,  4.15it/s, step=468, d_loss=0.3621, acc=0.8698, g_loss=2.1161]
469it [01:53,  4.12it/s, step=468, d_loss=0.2909, acc=0.9062, g_loss=2.6993]
469it [01:53,  4.15it/s, step=468, d_loss=0.2011, acc=0.9271, g_loss=2.5064]
469it [01:53,  4.12it/s, step=468, d_loss=0.2564, acc=0.8854, g_loss=2.4029]
469it [01:53,  4.15it/s, step=468, d_loss=0.1781, acc=0.9479, g_loss=1.5723]
469it [01:53,  4.13it/s, step=468, d_loss=0.1547, acc=0.9583, g_loss=1.3841]
469it [01:53,  4.14it/s, step=468, d_loss=0.1384, acc=0.9583, g_loss=1.9002]
469it [01:53,  4.15it/s, step=468, d_loss=0.4499, acc=0.7865, g_loss=2.3525]
469it [01:54,  4.08it/s, step=468, d_loss=0.0773, acc=0.9844, g_loss=0.5579]
469it [01:54,  4.11it/s, step=468, d_loss=0.2629, acc=0.8958, g_loss=0.5309]
469it [02:00,  3.88it/s, step=468, d_loss=0.2257, acc=0.9167, g_loss=2.6041]
469it [01:59,  3.91it/s, step=468, d_loss=0.2283, acc=0.9271, g_loss=1.5944]
469it [02:08,  3.65it/s, step=468, d_loss=0.2495, acc=0.9115, g_loss=2.7131]
469it [02:07,  3.67it/s, step=468, d_loss=0.1437, acc=0.9479, g_loss=0.9581]
469it [02:08,  3.64it/s, step=468, d_loss=0.5403, acc=0.7865, g_loss=2.7837]
469it [02:08,  3.66it/s, step=468, d_loss=0.1985, acc=0.9323, g_loss=2.7737]
469it [02:09,  3.62it/s, step=468, d_loss=0.1488, acc=0.9427, g_loss=0.7168]
469it [02:07,  3.68it/s, step=468, d_loss=0.1167, acc=0.9583, g_loss=0.8895]
469it [02:09,  3.63it/s, step=468, d_loss=0.1243, acc=0.9323, g_loss=2.8780]
469it [02:10,  3.59it/s, step=468, d_loss=0.2008, acc=0.9115, g_loss=2.6001]
469it [02:10,  3.59it/s, step=468, d_loss=0.1474, acc=0.9479, g_loss=1.5862]
469it [02:08,  3.64it/s, step=468, d_loss=0.2088, acc=0.9062, g_loss=3.0232]

step 4 查看结果

查看 d_loss、g_loss。可以看到两个 loss 都在震荡,感觉是还没训练完

1
2
3
4
5
6
7
8
9
10
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(g_loss_list,label="G")
plt.plot(d_loss_list,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

svg

生成动图

将训练图中各个阶段的图片组合起来生成 gif 图片。虽然效果不是很好,但仍然可以看到数字的形状,没有晶体图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
anim_file = 'dcgan.gif'

import imageio
import glob

filenames = glob.glob('./images/mnist_*.png')
filenames = sorted(filenames)
frames = []
for i,filename in enumerate(filenames):
image = imageio.imread(filename)
frames.append(image)

kargs = { 'duration': 0.5 }
imageio.mimsave(anim_file, frames, 'GIF', **kargs)


from IPython import display
display.Image(filename=anim_file)

保存模型权重

1
2
generator.save_weights('generator',True)
discriminator.save_weights('discriminator',True)

参考文献

Generative Adversarial Nets

Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks

Keras-GAN

DCGANs (Deep Convolutional Generative Adversarial Networks)

PyTorch-GAN

How to Train a GAN? Tips and tricks to make GANs work

DCGAN