AlexNet 实现猫狗大战
前言
早就听说过 tensorflow 的大名,但一直没有时间仔细研究,而且一直缺少数据集,也就没怎么了解。同时因为那 tensorflow1 的复杂静态图一直劝退。如今有了时间,而且 tensorflow2 更加易懂,因此仔细研究了一下,也就是第一次炼丹。首先是在 kaggle 下载数据集就难到我了,没想到翻了墙还被困在手机号验证。最后只能借助于微软提供的数据集下载。最开始用自己的 CPU 跑又慢又卡,自从我换了硬盘加了内存以来,好久都没有这么卡过了。
随后换上了 Google 的 colab , 用了免费 GPU 才知道 GPU 的好。
以下正文
1、数据准备
在 Colab 上需要下载数据和准备环境,已经将数据上传到了 Github
1 2 3 4 !pip install tensorflow !wget https://github.com/qxdn/DogsVsCats/releases/download/v1.0 .0 /data.zip !unzip data.zip !rm -rf data.zip
AlexNet 的输入是 [227,227,3]
,因此要对输入进行预处理,使得符合规范,同时还需要进行贴标签 (dog=1,cat=0)
1 2 3 4 5 6 7 8 9 10 11 12 import tensorflow as tffrom tensorflow.keras import layers,modelsBATCH_SIZE = 125 def load_image (img_path,size=(227 ,227 ) ): label = tf.constant(0 ,tf.int8) if tf.strings.regex_full_match(img_path,".*cat.*" ) \ else tf.constant(1 ,tf.int8) img = tf.io.read_file(img_path) img = tf.image.decode_jpeg(img) img = tf.image.resize(img,size)/255.0 return (img,label)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 ds_train = tf.data.Dataset.list_files("./data/train/*.jpg" ) \ .map (load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \ .shuffle(buffer_size = 1000 ).batch(BATCH_SIZE) \ .prefetch(tf.data.experimental.AUTOTUNE) ds_test = tf.data.Dataset.list_files("./data/test/*.jpg" ) \ .map (load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \ .batch(BATCH_SIZE) \ .prefetch(tf.data.experimental.AUTOTUNE)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 %matplotlib inline %config InlineBackend.figure_format = 'svg' from matplotlib import pyplot as plt plt.figure(figsize=(8 ,8 )) for i,(img,label) in enumerate (ds_train.unbatch().take(9 )): ax=plt.subplot(3 ,3 ,i+1 ) ax.imshow(img.numpy()) ax.set_title("label = %d" %label) ax.set_xticks([]) ax.set_yticks([]) plt.show()
1 2 for x,y in ds_train.take(1 ): print (x.shape,y.shape)
(125, 227, 227, 3) (125,)
2、定义模型
AlexNet 相较于以前的网络主要引入了 Relu 激活函数、DropOut 层和 lrn 局部响应
1 2 3 4 5 6 7 8 9 10 11 12 class LRN (layers.Layer): def __init__ (self ): super (LRN, self).__init__() self.depth_radius=2 self.bias=1 self.alpha=1e-4 self.beta=0.75 def call (self,x ): return tf.nn.lrn(x,depth_radius=self.depth_radius, bias=self.bias,alpha=self.alpha, beta=self.beta)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 tf.keras.backend.clear_session() model = models.Sequential() model.add(layers.Conv2D(96 ,kernel_size=(11 ,11 ),activation='relu' ,strides=(4 ,4 ),input_shape=(227 ,227 ,3 ))) model.add(layers.MaxPool2D(pool_size=(3 , 3 ),strides=(2 ,2 ))) model.add(LRN()) model.add(layers.Conv2D(256 ,kernel_size=(5 ,5 ),activation='relu' ,strides=(1 ,1 ),padding='same' )) model.add(layers.MaxPool2D(pool_size=(3 , 3 ),strides=(2 ,2 ))) model.add(LRN()) model.add(layers.Conv2D(384 ,kernel_size=(3 ,3 ),activation='relu' ,strides=(1 ,1 ),padding='same' )) model.add(layers.Conv2D(384 ,kernel_size=(3 ,3 ),activation='relu' ,strides=(1 ,1 ),padding='same' )) model.add(layers.Conv2D(256 ,kernel_size=(3 ,3 ),activation='relu' ,strides=(1 ,1 ),padding='same' )) model.add(layers.MaxPool2D(pool_size=(3 , 3 ),strides=(2 ,2 ))) model.add(layers.Flatten()) model.add(layers.Dense(4096 ,activation='relu' )) model.add(layers.Dropout(0.25 )) model.add(layers.Dense(4096 ,activation='relu' )) model.add(layers.Dropout(0.25 )) model.add(layers.Dense(2 ,activation='softmax' )) model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 55, 55, 96) 34944
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 27, 27, 96) 0
_________________________________________________________________
lrn (LRN) (None, 27, 27, 96) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 27, 27, 256) 614656
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 13, 13, 256) 0
_________________________________________________________________
lrn_1 (LRN) (None, 13, 13, 256) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 13, 13, 384) 885120
_________________________________________________________________
conv2d_3 (Conv2D) (None, 13, 13, 384) 1327488
_________________________________________________________________
conv2d_4 (Conv2D) (None, 13, 13, 256) 884992
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 6, 6, 256) 0
_________________________________________________________________
flatten (Flatten) (None, 9216) 0
_________________________________________________________________
dense (Dense) (None, 4096) 37752832
_________________________________________________________________
dropout (Dropout) (None, 4096) 0
_________________________________________________________________
dense_1 (Dense) (None, 4096) 16781312
_________________________________________________________________
dropout_1 (Dropout) (None, 4096) 0
_________________________________________________________________
dense_2 (Dense) (None, 2) 8194
=================================================================
Total params: 58,289,538
Trainable params: 58,289,538
Non-trainable params: 0
_________________________________________________________________
3、训练模型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 import osmodel.compile (optimizer=tf.keras.optimizers.Adam(learning_rate=0.00001 ), loss='sparse_categorical_crossentropy' , metrics=["accuracy" ]) checkpoint_save_path='./checkpoint/AlexNet/AlexNet-{epoch:02d}.ckpt' if os.path.exists(checkpoint_save_path+'.index' ): print ('----------------load the model--------------' ) model.load_weights(checkpoint_save_path) cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True , save_best_only=True ) history = model.fit(ds_train,epochs= 20 ,validation_data=ds_test, callbacks = [cp_callback],workers = 4 )
Epoch 1/20
184/184 [==============================] - 40s 219ms/step - loss: 0.6630 - accuracy: 0.5944 - val_loss: 0.6141 - val_accuracy: 0.6615
Epoch 2/20
184/184 [==============================] - 40s 220ms/step - loss: 0.5970 - accuracy: 0.6754 - val_loss: 0.5586 - val_accuracy: 0.7040
Epoch 3/20
184/184 [==============================] - 40s 219ms/step - loss: 0.5426 - accuracy: 0.7244 - val_loss: 0.5088 - val_accuracy: 0.7505
Epoch 4/20
184/184 [==============================] - 40s 219ms/step - loss: 0.4946 - accuracy: 0.7598 - val_loss: 0.4840 - val_accuracy: 0.7735
Epoch 5/20
184/184 [==============================] - 40s 220ms/step - loss: 0.4545 - accuracy: 0.7852 - val_loss: 0.4469 - val_accuracy: 0.8055
Epoch 6/20
184/184 [==============================] - 40s 220ms/step - loss: 0.4207 - accuracy: 0.8064 - val_loss: 0.4180 - val_accuracy: 0.8100
Epoch 7/20
184/184 [==============================] - 40s 219ms/step - loss: 0.3991 - accuracy: 0.8193 - val_loss: 0.4016 - val_accuracy: 0.8250
Epoch 8/20
184/184 [==============================] - 41s 220ms/step - loss: 0.3788 - accuracy: 0.8309 - val_loss: 0.3908 - val_accuracy: 0.8315
Epoch 9/20
184/184 [==============================] - 40s 219ms/step - loss: 0.3587 - accuracy: 0.8427 - val_loss: 0.3733 - val_accuracy: 0.8375
Epoch 10/20
184/184 [==============================] - 38s 207ms/step - loss: 0.3387 - accuracy: 0.8539 - val_loss: 0.3860 - val_accuracy: 0.8345
Epoch 11/20
184/184 [==============================] - 40s 220ms/step - loss: 0.3273 - accuracy: 0.8576 - val_loss: 0.3498 - val_accuracy: 0.8475
Epoch 12/20
184/184 [==============================] - 38s 207ms/step - loss: 0.3064 - accuracy: 0.8676 - val_loss: 0.3671 - val_accuracy: 0.8380
Epoch 13/20
184/184 [==============================] - 38s 208ms/step - loss: 0.2931 - accuracy: 0.8762 - val_loss: 0.3501 - val_accuracy: 0.8495
Epoch 14/20
184/184 [==============================] - 40s 220ms/step - loss: 0.2737 - accuracy: 0.8864 - val_loss: 0.3361 - val_accuracy: 0.8540
Epoch 15/20
184/184 [==============================] - 40s 219ms/step - loss: 0.2694 - accuracy: 0.8890 - val_loss: 0.3311 - val_accuracy: 0.8585
Epoch 16/20
184/184 [==============================] - 38s 208ms/step - loss: 0.2557 - accuracy: 0.8938 - val_loss: 0.3388 - val_accuracy: 0.8565
Epoch 17/20
184/184 [==============================] - 40s 219ms/step - loss: 0.2362 - accuracy: 0.9030 - val_loss: 0.3160 - val_accuracy: 0.8680
Epoch 18/20
184/184 [==============================] - 38s 207ms/step - loss: 0.2184 - accuracy: 0.9106 - val_loss: 0.3692 - val_accuracy: 0.8515
Epoch 19/20
184/184 [==============================] - 38s 207ms/step - loss: 0.2085 - accuracy: 0.9160 - val_loss: 0.3198 - val_accuracy: 0.8710
Epoch 20/20
184/184 [==============================] - 41s 222ms/step - loss: 0.1909 - accuracy: 0.9239 - val_loss: 0.3019 - val_accuracy: 0.8705
4、评估模型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 %matplotlib inline %config InlineBackend.figure_format = 'svg' import matplotlib.pyplot as pltdef plot_metric (history, metric ): train_metrics = history.history[metric] val_metrics = history.history['val_' +metric] epochs = range (1 , len (train_metrics) + 1 ) plt.plot(epochs, train_metrics, 'bo--' ) plt.plot(epochs, val_metrics, 'ro-' ) plt.title('Training and validation ' + metric) plt.xlabel("Epochs" ) plt.ylabel(metric) plt.legend(["train_" +metric, 'val_' +metric]) plt.show()
1 plot_metric(history,"loss" )
1 plot_metric(history,"accuracy" )
1 2 3 val_loss,val_accuracy = model.evaluate(ds_test,workers=4 ) print (val_loss,val_accuracy)
16/16 [==============================] - 2s 144ms/step - loss: 0.3019 - accuracy: 0.8705
0.3018629252910614 0.8705000281333923
5、使用模型
1 2 3 def perpareImage (img,sizes=(-1 ,227 ,227 ,3 ) ): img = tf.reshape(img,shape=sizes) return img
1 2 3 4 5 6 7 for i,(img,label) in enumerate (ds_test.unbatch().take(1 )): img=tf.reshape(img,shape=(-1 ,227 ,227 ,3 )) p=model.predict(img) print (p) p=tf.argmax(p,1 ) print (p) print (p.shape)
[[0.91607434 0.08392565]]
tf.Tensor([0], shape=(1,), dtype=int64)
(1,)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 %matplotlib inline %config InlineBackend.figure_format = 'svg' from matplotlib import pyplot as plt plt.figure(figsize=(8 ,8 )) for i,(img,label) in enumerate (ds_test.unbatch().take(9 )): ax=plt.subplot(3 ,3 ,i+1 ) ax.imshow(img.numpy()) predict = model.predict(perpareImage(img)) ax.set_title("predict=%d,label=%d" %(tf.argmax(predict,1 ),label)) ax.set_xticks([]) ax.set_yticks([]) plt.show()
6、保存模型
tensorflow 原生保存
1 2 3 4 5 6 7 model.save('./model/tf_AlexNet_Model' , save_format="tf" ) print ('export saved model.' )model_loaded = tf.keras.models.load_model('./model/tf_AlexNet_Model' ) model_loaded.evaluate(ds_test)
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: ./model/tf_AlexNet_Model/assets
export saved model.
16/16 [==============================] - 2s 147ms/step - loss: 0.3019 - accuracy: 0.5155
[0.301862895488739, 0.515500009059906]
repo
DogsVsCats
后记
踩的坑有点点多,比如搭建网络时候忘记摊平,预测的时候因为形状不对卡了半天,CPU 训练的时候 loss 和准确度不变,差点以为失败了,不过最后看来还算成功吧。赞美 Google Colab。