大数跨境
0
0

使用预训练模型:在小数据集上实现高效图像分类

使用预训练模型:在小数据集上实现高效图像分类 知识代码AI
2025-11-26
2
导读:使用预训练模型:在小数据集上实现高效图像分类在计算机视觉任务中,当面对小型图像数据集时,使用预训练模型是一种非

使用预训练模型:在小数据集上实现高效图像分类

在计算机视觉任务中,当面对小型图像数据集时,使用预训练模型是一种非常有效的方法。今天我们将深入探讨如何利用预训练模型来解决猫狗分类问题。

什么是预训练模型?

预训练模型是指之前在大型数据集(如ImageNet)上训练好的模型。ImageNet包含140万张标记图像和1000个类别,其中许多是动物类别。这些模型学到的特征层次结构可以作为视觉世界的通用模型,适用于各种计算机视觉问题。

两种使用方法

1. 特征提取(Feature Extraction)

特征提取是指使用预训练模型从新样本中提取特征,然后将这些特征输入新的分类器进行训练。

核心思想:只使用预训练模型的卷积基(卷积层和池化层),在其上添加新的分类器。

# 加载VGG16卷积基
conv_base = keras.applications.vgg16.VGG16(
    weights="imagenet",
    include_top=False,  # 不包含原始分类器
    input_shape=(1801803))

2. 微调模型(Fine-tuning)

微调是指解冻预训练模型的顶部几层,并与新添加的分类器共同训练。

实践步骤

第一步:数据准备

from tensorflow.keras.utils import image_dataset_from_directory

# 加载数据集
train_dataset = image_dataset_from_directory(
    train_dir,
    image_size=(180180),
    batch_size=32)

validation_dataset = image_dataset_from_directory(
    validation_dir,
    image_size=(180180),
    batch_size=32)

第二步:构建模型

# 使用预训练的VGG16作为基础模型
base_model = keras.applications.VGG16(
    weights="imagenet",
    include_top=False,
    input_shape=(1801803))

# 冻结基础模型的权重
base_model.trainable = False

# 添加自定义分类层
inputs = keras.Input(shape=(1801803))
x = keras.applications.vgg16.preprocess_input(inputs)
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)
x = keras.layers.Dense(256, activation="relu")(x)
x = keras.layers.Dropout(0.5)(x)
outputs = keras.layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs, outputs)

第三步:训练模型

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss="binary_crossentropy",
    metrics=["accuracy"])

# 回调函数
callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath="cats_dogs_classifier.keras",
        save_best_only=True,
        monitor="val_loss")
]

# 训练模型
history = model.fit(
    train_dataset,
    epochs=20,
    validation_data=validation_dataset,
    callbacks=callbacks)

第四步:微调模型

# 解冻基础模型的顶层
base_model.trainable = True

# 冻结底层数层,只训练顶层
for layer in base_model.layers[:-4]:
    layer.trainable = False

# 使用较低的学习率重新编译
model.compile(
    optimizer=keras.optimizers.Adam(1e-5),
    loss="binary_crossentropy",
    metrics=["accuracy"])

# 继续训练
history_fine = model.fit(
    train_dataset,
    epochs=30,
    validation_data=validation_dataset)

为什么这种方法有效?

  1. 通用特征提取:预训练模型的底层学习的是通用特征(边缘、纹理等),这些特征对大多数视觉任务都有用

  2. 位置信息保留:卷积特征图保留了物体的位置信息,而全连接层会丢失这些信息

  3. 高效利用数据:即使只有几千张图像,也能获得很好的效果

实验结果

通过这种方法,我们可以在小型猫狗数据集上达到约98.5%的测试精度,这在该任务的原始Kaggle竞赛中属于最佳结果之一。

实用技巧

  1. 选择解冻层数:通常只微调最后2-3层,因为:

    • 底层特征通用性更强
    • 顶层特征更专业化
    • 减少过拟合风险
  2. 学习率设置:微调时使用较小的学习率(如1e-5)

  3. 数据增强:使用数据增强来防止过拟合

完整代码示例

import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt

# 数据准备
train_dataset = image_dataset_from_directory(
"../data/train",
    image_size=(180180),
    batch_size=32)

# 构建模型
base_model = keras.applications.VGG16(
    weights="imagenet",
    include_top=False,
    input_shape=(1801803))
base_model.trainable = False

inputs = keras.Input(shape=(1801803))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dense(256, activation="relu")(x)
x = keras.layers.Dropout(0.5)(x)
outputs = keras.layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs, outputs)

# 训练和微调
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
model.fit(train_dataset, epochs=20)

# 微调
base_model.trainable = True
for layer in base_model.layers[:-4]:
    layer.trainable = False

model.compile(optimizer=keras.optimizers.Adam(1e-5), 
              loss="binary_crossentropy", metrics=["accuracy"])
model.fit(train_dataset, epochs=10)

总结

使用预训练模型是处理小型图像数据集的强大技术。通过特征提取和微调,我们可以在有限的数据上获得优秀的性能,这在现实世界的应用中非常有价值。

无论你是初学者还是有经验的开发者,掌握这项技术都将为你的计算机视觉项目带来显著提升!


注意:实际应用中请确保有合适的数据集,并根据具体任务调整模型参数。


【声明】内容源于网络
0
0
知识代码AI
技术基底 机器视觉全栈 × 光学成像 × 图像处理算法 编程栈 C++/C#工业开发 | Python智能建模 工具链 Halcon/VisionPro工业部署 | PyTorch/TensorFlow模型炼金术 | 模型压缩&嵌入式移植
内容 366
粉丝 0
知识代码AI 技术基底 机器视觉全栈 × 光学成像 × 图像处理算法 编程栈 C++/C#工业开发 | Python智能建模 工具链 Halcon/VisionPro工业部署 | PyTorch/TensorFlow模型炼金术 | 模型压缩&嵌入式移植
总阅读83
粉丝0
内容366