Keras模型训练全攻略:从内置循环到自定义回调
在深度学习模型开发中,Keras提供了多种训练工作流程,从简单到复杂,满足不同场景的需求。本文将带你全面了解Keras的训练和评估机制,包括标准工作流程、自定义指标、回调函数以及TensorBoard可视化监控。
标准工作流程:compile()、fit()、evaluate()、predict()
Keras最基本的训练流程非常直观:
# 构建模型
defget_mnist_model():
inputs = keras.Input(shape=(28, 28, 1))
x = keras.layers.Conv2D(32, 3, activation="relu")(inputs)
x = keras.layers.Conv2D(64, 3, activation="relu")(x)
x = keras.layers.MaxPooling2D(2)(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dropout(0.5)(x)
outputs = keras.layers.Dense(10, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
# 标准训练流程
model = get_mnist_model()
model.compile(optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
# 训练模型
model.fit(train_images, train_labels, epochs=5, batch_size=64)
# 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels)
# 进行预测
predictions = model.predict(test_images)
这个简单的工作流程可以通过自定义指标和回调函数来增强灵活性。
编写自定义指标
虽然Keras内置了许多常用指标,但有时我们需要根据特定任务创建自定义指标。自定义指标需要继承keras.metrics.Metric类:
classRootMeanSquaredError(keras.metrics.Metric):
def__init__(self, name="rmse", **kwargs):
super().__init__(name=name, **kwargs)
self.mse_sum = self.add_weight(name="mse_sum", initializer="zeros")
self.total_samples = self.add_weight(name="total_samples",
initializer="zeros", dtype="int32")
defupdate_state(self, y_true, y_pred, sample_weight=None):
# 将标签转换为one-hot编码
y_true = tf.one_hot(y_true, depth=tf.shape(y_pred)[1])
mse = tf.reduce_sum(tf.square(y_true - y_pred))
self.mse_sum.assign_add(mse)
num_samples = tf.shape(y_pred)[0]
self.total_samples.assign_add(num_samples)
defresult(self):
return tf.sqrt(self.mse_sum / tf.cast(self.total_samples, tf.float32))
defreset_state(self):
self.mse_sum.assign(0.)
self.total_samples.assign(0)
使用自定义指标的方法与内置指标完全相同:
model.compile(optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
metrics=["accuracy", RootMeanSquaredError()])
使用回调函数
回调函数是Keras训练过程中的"无人机",它们能够监控训练状态并采取相应行动。常用的内置回调函数包括:
-
ModelCheckpoint:保存模型检查点 -
EarlyStopping:提前终止训练 -
LearningRateScheduler:动态调整学习率 -
ReduceLROnPlateau:在指标停滞时降低学习率 -
CSVLogger:将训练记录保存到CSV文件
EarlyStopping和ModelCheckpoint示例
callbacks_list = [
keras.callbacks.EarlyStopping(
monitor="val_accuracy",
patience=2, # 如果2轮验证准确率没有改善则停止训练
),
keras.callbacks.ModelCheckpoint(
filepath="checkpoint_path.keras",
monitor="val_loss",
save_best_only=True, # 只保存最佳模型
)
]
model.fit(train_images, train_labels,
epochs=10,
callbacks=callbacks_list,
validation_data=(val_images, val_labels))
编写自定义回调函数
当内置回调函数无法满足需求时,我们可以创建自定义回调函数:
from matplotlib import pyplot as plt
classLossHistory(keras.callbacks.Callback):
defon_train_begin(self, logs):
self.per_batch_losses = []
defon_batch_end(self, batch, logs):
self.per_batch_losses.append(logs.get("loss"))
defon_epoch_end(self, epoch, logs):
plt.clf()
plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses,
label="Training loss for each batch")
plt.xlabel(f"Batch (epoch {epoch})")
plt.ylabel("Loss")
plt.legend()
plt.savefig(f"plot_at_epoch_{epoch}.png")
self.per_batch_losses = []
这个回调函数会在每个epoch结束时绘制该epoch内每个batch的损失曲线。
利用TensorBoard进行监控和可视化
TensorBoard是TensorFlow的可视化工具包,可以实时监控训练过程:
# 创建TensorBoard回调
tensorboard = keras.callbacks.TensorBoard(
log_dir="/full_path_to_your_log_dir",
)
model.fit(train_images, train_labels,
epochs=10,
validation_data=(val_images, val_labels),
callbacks=[tensorboard])
启动TensorBoard服务:
tensorboard --logdir /full_path_to_your_log_dir
TensorBoard可以提供:
-
训练指标和验证指标的实时可视化 -
模型架构可视化 -
激活函数和梯度的直方图 -
嵌入的三维可视化
完整示例代码
以下是结合了所有上述技术的完整示例:
import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
import os
import datetime
# 1. 准备数据
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
train_images = train_images.reshape((60000, 28, 28, 1)).astype("float32") / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype("float32") / 255
val_images = train_images[:10000]
val_labels = train_labels[:10000]
train_images = train_images[10000:]
train_labels = train_labels[10000:]
# 2. 构建模型
defget_mnist_model():
inputs = keras.Input(shape=(28, 28, 1))
x = keras.layers.Conv2D(32, 3, activation="relu")(inputs)
x = keras.layers.Conv2D(64, 3, activation="relu")(x)
x = keras.layers.MaxPooling2D(2)(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dropout(0.5)(x)
outputs = keras.layers.Dense(10, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
# 3. 使用自定义指标和回调函数训练
model = get_mnist_model()
model.compile(optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
metrics=["accuracy", RootMeanSquaredError()])
# 设置回调函数
callbacks_list = [
keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=2),
keras.callbacks.ModelCheckpoint(
filepath="mnist_model.keras",
monitor="val_loss",
save_best_only=True,
),
LossHistory()
]
# 训练模型
history = model.fit(train_images, train_labels,
epochs=10,
callbacks=callbacks_list,
validation_data=(val_images, val_labels))
# 4. 评估最佳模型
best_model = keras.models.load_model("mnist_model.keras")
test_loss, test_acc, test_rmse = best_model.evaluate(test_images, test_labels)
print(f"测试集结果 - 损失: {test_loss:.4f}, 准确率: {test_acc:.4f}, RMSE: {test_rmse:.4f}")
总结
Keras提供了从简单到复杂的多种训练工作流程:
-
标准工作流程:使用 compile()、fit()、evaluate()和predict()快速上手 -
自定义指标:通过继承 keras.metrics.Metric类创建针对特定任务的评估指标 -
回调函数:使用内置或自定义回调函数增强训练过程的控制能力 -
TensorBoard:利用可视化工具实时监控训练过程和模型性能
掌握这些技术可以帮助你构建更加灵活和强大的深度学习模型,提高开发效率和模型性能。

