大数跨境
0
0

Keras模型训练全攻略:从内置循环到自定义回调

Keras模型训练全攻略:从内置循环到自定义回调 知识代码AI
2025-11-21
0
导读:Keras模型训练全攻略:从内置循环到自定义回调在深度学习模型开发中,Keras提供了多种训练工作流程,从简单

Keras模型训练全攻略:从内置循环到自定义回调

在深度学习模型开发中,Keras提供了多种训练工作流程,从简单到复杂,满足不同场景的需求。本文将带你全面了解Keras的训练和评估机制,包括标准工作流程、自定义指标、回调函数以及TensorBoard可视化监控。

标准工作流程:compile()、fit()、evaluate()、predict()

Keras最基本的训练流程非常直观:

# 构建模型
defget_mnist_model():
    inputs = keras.Input(shape=(28281))
    x = keras.layers.Conv2D(323, activation="relu")(inputs)
    x = keras.layers.Conv2D(643, activation="relu")(x)
    x = keras.layers.MaxPooling2D(2)(x)
    x = keras.layers.Flatten()(x)
    x = keras.layers.Dropout(0.5)(x)
    outputs = keras.layers.Dense(10, activation="softmax")(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
return model

# 标准训练流程
model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])

# 训练模型
model.fit(train_images, train_labels, epochs=5, batch_size=64)

# 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels)

# 进行预测
predictions = model.predict(test_images)

这个简单的工作流程可以通过自定义指标和回调函数来增强灵活性。

编写自定义指标

虽然Keras内置了许多常用指标,但有时我们需要根据特定任务创建自定义指标。自定义指标需要继承keras.metrics.Metric类:

classRootMeanSquaredError(keras.metrics.Metric):
def__init__(self, name="rmse", **kwargs):
        super().__init__(name=name, **kwargs)
        self.mse_sum = self.add_weight(name="mse_sum", initializer="zeros")
        self.total_samples = self.add_weight(name="total_samples"
                                           initializer="zeros", dtype="int32")

defupdate_state(self, y_true, y_pred, sample_weight=None):
# 将标签转换为one-hot编码
        y_true = tf.one_hot(y_true, depth=tf.shape(y_pred)[1])
        mse = tf.reduce_sum(tf.square(y_true - y_pred))
        self.mse_sum.assign_add(mse)
        num_samples = tf.shape(y_pred)[0]
        self.total_samples.assign_add(num_samples)

defresult(self):
return tf.sqrt(self.mse_sum / tf.cast(self.total_samples, tf.float32))

defreset_state(self):
        self.mse_sum.assign(0.)
        self.total_samples.assign(0)

使用自定义指标的方法与内置指标完全相同:

model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy", RootMeanSquaredError()])

使用回调函数

回调函数是Keras训练过程中的"无人机",它们能够监控训练状态并采取相应行动。常用的内置回调函数包括:

  • ModelCheckpoint:保存模型检查点
  • EarlyStopping:提前终止训练
  • LearningRateScheduler:动态调整学习率
  • ReduceLROnPlateau:在指标停滞时降低学习率
  • CSVLogger:将训练记录保存到CSV文件

EarlyStopping和ModelCheckpoint示例

callbacks_list = [
    keras.callbacks.EarlyStopping(
        monitor="val_accuracy",
        patience=2,  # 如果2轮验证准确率没有改善则停止训练
    ),
    keras.callbacks.ModelCheckpoint(
        filepath="checkpoint_path.keras",
        monitor="val_loss",
        save_best_only=True,  # 只保存最佳模型
    )
]

model.fit(train_images, train_labels,
          epochs=10,
          callbacks=callbacks_list,
          validation_data=(val_images, val_labels))

编写自定义回调函数

当内置回调函数无法满足需求时,我们可以创建自定义回调函数:

from matplotlib import pyplot as plt

classLossHistory(keras.callbacks.Callback):
defon_train_begin(self, logs):
        self.per_batch_losses = []

defon_batch_end(self, batch, logs):
        self.per_batch_losses.append(logs.get("loss"))

defon_epoch_end(self, epoch, logs):
        plt.clf()
        plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses,
                 label="Training loss for each batch")
        plt.xlabel(f"Batch (epoch {epoch})")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(f"plot_at_epoch_{epoch}.png")
        self.per_batch_losses = []

这个回调函数会在每个epoch结束时绘制该epoch内每个batch的损失曲线。

利用TensorBoard进行监控和可视化

TensorBoard是TensorFlow的可视化工具包,可以实时监控训练过程:

# 创建TensorBoard回调
tensorboard = keras.callbacks.TensorBoard(
    log_dir="/full_path_to_your_log_dir",
)

model.fit(train_images, train_labels,
          epochs=10,
          validation_data=(val_images, val_labels),
          callbacks=[tensorboard])

启动TensorBoard服务:

tensorboard --logdir /full_path_to_your_log_dir

TensorBoard可以提供:

  • 训练指标和验证指标的实时可视化
  • 模型架构可视化
  • 激活函数和梯度的直方图
  • 嵌入的三维可视化

完整示例代码

以下是结合了所有上述技术的完整示例:

import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
import os
import datetime

# 1. 准备数据
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
train_images = train_images.reshape((6000028281)).astype("float32") / 255
test_images = test_images.reshape((1000028281)).astype("float32") / 255

val_images = train_images[:10000]
val_labels = train_labels[:10000]
train_images = train_images[10000:]
train_labels = train_labels[10000:]

# 2. 构建模型
defget_mnist_model():
    inputs = keras.Input(shape=(28281))
    x = keras.layers.Conv2D(323, activation="relu")(inputs)
    x = keras.layers.Conv2D(643, activation="relu")(x)
    x = keras.layers.MaxPooling2D(2)(x)
    x = keras.layers.Flatten()(x)
    x = keras.layers.Dropout(0.5)(x)
    outputs = keras.layers.Dense(10, activation="softmax")(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
return model

# 3. 使用自定义指标和回调函数训练
model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy", RootMeanSquaredError()])

# 设置回调函数
callbacks_list = [
    keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=2),
    keras.callbacks.ModelCheckpoint(
        filepath="mnist_model.keras",
        monitor="val_loss",
        save_best_only=True,
    ),
    LossHistory()
]

# 训练模型
history = model.fit(train_images, train_labels,
                    epochs=10,
                    callbacks=callbacks_list,
                    validation_data=(val_images, val_labels))

# 4. 评估最佳模型
best_model = keras.models.load_model("mnist_model.keras")
test_loss, test_acc, test_rmse = best_model.evaluate(test_images, test_labels)
print(f"测试集结果 - 损失: {test_loss:.4f}, 准确率: {test_acc:.4f}, RMSE: {test_rmse:.4f}")

总结

Keras提供了从简单到复杂的多种训练工作流程:

  1. 标准工作流程:使用compile()fit()evaluate()predict()快速上手
  2. 自定义指标:通过继承keras.metrics.Metric类创建针对特定任务的评估指标
  3. 回调函数:使用内置或自定义回调函数增强训练过程的控制能力
  4. TensorBoard:利用可视化工具实时监控训练过程和模型性能

掌握这些技术可以帮助你构建更加灵活和强大的深度学习模型,提高开发效率和模型性能。


【声明】内容源于网络
0
0
知识代码AI
技术基底 机器视觉全栈 × 光学成像 × 图像处理算法 编程栈 C++/C#工业开发 | Python智能建模 工具链 Halcon/VisionPro工业部署 | PyTorch/TensorFlow模型炼金术 | 模型压缩&嵌入式移植
内容 366
粉丝 0
知识代码AI 技术基底 机器视觉全栈 × 光学成像 × 图像处理算法 编程栈 C++/C#工业开发 | Python智能建模 工具链 Halcon/VisionPro工业部署 | PyTorch/TensorFlow模型炼金术 | 模型压缩&嵌入式移植
总阅读94
粉丝0
内容366