赞
踩
tf.keras是TensorFlow 2.0的高阶API接口,为TensorFlow的代码提供了新的风格和设计模式,大大提升了TF代码的简洁性和复用性,官方也推荐使用tf.keras来进行模型设计和开发。
tf.keras中常用模块如下表所示:
| 模块 | 概述 |
|---|---|
| activations | 激活函数 |
| applications | 预训练网络模块 |
| Callbacks | 在模型训练期间被调用 |
| datasets | tf.keras数据集模块,包括boston_housing,cifar10,fashion_mnist,imdb ,mnist |
| layers | Keras层API |
| losses | 各种损失函数 |
| metircs | 各种评价指标 |
| models | 模型创建模块,以及与模型相关的API |
| optimizers | 优化方法 |
| preprocessing | Keras数据的预处理模块 |
| regularizers | 正则化,L1,L2等 |
| utils | 辅助功能实现 |
深度学习实现的主要流程:1.数据获取,2,数据处理,3.模型创建与训练,4 模型测试与评估,5.模型预测
1.导入tf.keras
使用 tf.keras,首先需要在代码开始时导入tf.keras
- import tensorflow as tf
- from tensorflow import keras
2.数据输入
对于小的数据集,可以直接使用numpy格式的数据进行训练、评估模型,对于大型数据集或者要进行跨设备训练时使用tf.data.datasets来进行数据输入。
3.模型构建
4.训练与评估
- # 配置优化方法,损失函数和评价指标
- model.compile(optimizer=tf.train.AdamOptimizer(0.001),
- loss='categorical_crossentropy',
- metrics=['accuracy'])
- # 指明训练数据集,训练epoch,批次大小和验证集数据
- model.fit/fit_generator(dataset, epochs=10,
- batch_size=3,
- validation_data=val_dataset,
- )
- # 指明评估数据集和批次大小
- model.evaluate(x, y, batch_size=32)
- # 对新的样本进行预测
- model.predict(x, batch_size=32)
5.回调函数(callbacks)
回调函数用在模型训练过程中,来控制模型训练行为,可以自定义回调函数,也可使用tf.keras.callbacks 内置的 callback :
ModelCheckpoint:定期保存 checkpoints。 LearningRateScheduler:动态改变学习速率。 EarlyStopping:当验证集上的性能不再提高时,终止训练。 TensorBoard:使用 TensorBoard 监测模型的状态。
6.模型的保存和恢复
- # 只保存模型的权重
- model.save_weights('./my_model')
- # 加载模型的权重
- model.load_weights('my_model')
- # 保存模型架构与权重在h5文件中
- model.save('my_model.h5')
- # 加载模型:包括架构和对应的权重
- model = keras.models.load_model('my_model.h5')
总结
了解Tensorflow2.0框架的用途及流程
1.使用tf.data加载数据
2、模型的建立与调试
3、模型的训练
4、预训练模型调用
5、模型的部署
知道tf2.0的张量及其操作
张量是多维数组。
1、创建方法:tf.constant()
2、转换为numpy: np.array()或tensor.asnumpy()
3、常用函数:加法,乘法,及各种聚合运算
4、变量:tf.Variable()
知道tf.keras中的相关模块及常用方法
常用模块:models,losses,application等
常用方法:
- 1、导入tf.keras
-
- 2、数据输入
-
- 3、模型构建
-
- 4、训练与评估
-
- 5、回调函数
-
- 6、模型的保存与恢复
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。