tf.keras.Model子类的注意项
自定义网络层
自定义网络层需要继承 tf.keras.layers.Layer,自定义的层的 trainable 参数默认为 True,即默认可以训练, 如以下定义了两个层:
1 | class MaskMean(tf.keras.layers.Layer): |
1 | class Predict(tf.keras.layers.Layer): |
上面定义的两个层,第一个层是作为 Batch MaskMean 层,Predict层是回归模型的输出层
这里注意:
1 | mask_mean_layer = MaskMean() |
这里实例化了两个层的对象,但是此时 mask_mean_layer.trainable_variables 和 output_layer.trainable_variables 为空,这是 tf.keras 的一个特点,因为此时不知道 input_shape, 所以此时还没有分配层的变量,只有当输入一次数据后才会产生变量,当输入一次数据后, trainable_variables 参数就已经产生了。
自定义模型
另外的注意点:
1 | class Regulation(tf.keras.Model): |
这里定义了一个 tf.keras.Model 的一个子类,包含两个自定义层。这时查看 model.layers 可以发现只有一个层, 而且即使输入一次数据,它的 trainable_variables 仍为空,后来发现 transformers 库里的每个 Model 都只包含一个层,所以经过尝试,重新定义如下:
1 | class Regulation(tf.keras.layers.Layer): |
这时可以发现 model.layers 包含一个层, 而且model.layers[0]._layers 包含两个自定义层,此时输入一次数据并输出后,发现 model.trainable_variable 已经分配变量了。
结论: 通过这次试验发现,tf.keras.Model 的子类只能包含一个层,而且只有自定义的层可以包含多个网络层,而且只有在输入一次数据后才会分配变量。
- 本文作者: 程序猪-渔枫
- 本文链接: https://over-shine.github.io/2020/08/21/tf-keras-Model子类注意项/
- 版权声明: 本博客所有文章除特别声明外,均采用 MIT 许可协议。转载请注明出处!