大数跨境

Keras Hub,您的一站式预训练模型库

Keras Hub,您的一站式预训练模型库 谷歌开发者
2024-12-18
13
导读:欢迎您即刻体验 Keras,释放统一、便利且高效的深度学习模型的力量。AI 的未来是多模态的,也期待您借助 KerasHub 加速解锁多模态 AI 的潜力!

KerasHub发布:统一NLP与CV的预训练模型库

作者 / 软件工程师 Divyashree Sreepathihalli 和 Google AI 开发技术推广工程师 Luciano Martins[k]

深度学习快速发展,预训练模型在各类任务中日益关键。Keras凭借其用户友好的API和易用性,始终走在前沿,并拥有KerasNLP(文本)和KerasCV(视觉)等专用库[k]

然而,随着多模态模型兴起——如支持图像输入的LLM或融合文本编码器的视觉系统——维持独立的NLP与CV库已不切实际,易导致开发冗余与生态碎片化[k]

统一的开发者体验

为此,Keras生态系统迎来重大升级:正式推出KerasHub,一个统一的预训练模型中心库,旨在简化对前沿NLP与CV架构的访问。KerasHub集成BERT、EfficientNet等模型,用户可在稳定统一的Keras框架内无缝探索与部署[k]

该平台支持模型发布共享、LoRA微调、量化优化及多主机大规模训练,全面覆盖文本、图像等多模态需求,显著降低开发门槛,加速多模态AI应用落地[k]

KerasHub入门步骤

安装KerasHub

安装KerasHub最新版(支持Keras 3):

$ pip install --upgrade keras-hub

配置Keras 3运行环境(以JAX为例):

import os
# 定义Keras 3后端:"jax"、"tensorflow"或"torch"
os.environ["KERAS_BACKEND"] = "jax"

# 导入Keras 3和KerasHub模块
import keras
import keras_hub

通过KerasHub使用计算机视觉和自然语言模型

Gemma

Gemma是Google推出的先进开源模型系列,基于与Gemini相同的技术,在问答、摘要、推理等文本任务中表现优异,支持定制化微调[k]

使用KerasHub加载并生成内容(以Gemma 2B为例):

# 加载Gemma 2B预训练模型
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma_2b_en")

# 生成文本
gemma_lm.generate("Keras is a", max_length=32)

PaliGemma

PaliGemma是一款紧凑型多模态开源模型,融合SigLIP视觉模型与Gemma语言模型,可理解图像并回答相关问题,适用于图像描述生成、目标识别及文本理解等场景[k]

使用示例(加载模型并提问图像内容):

import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
import keras_hub
from keras.utils import get_file, load_img, img_to_array

# 加载PaliGemma 3B模型(224x224微调版本)
pali_gemma_lm = keras_hub.models.PaliGemmaCausalLM.from_preset("pali_gemma_3b_mix_224")

# 加载测试图像
url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
img_path = get_file(origin=url)
img = img_to_array(load_img(img_path))

# 构建提示并生成回答
prompt = 'answer where is the cow standing?'
output = pali_gemma_lm.generate(
    inputs={
        "images": img,
        "prompts": prompt,
    }
)

Stability.ai Stable Diffusion 3

KerasHub同样支持主流视觉生成模型,如Stability.ai的Stable Diffusion 3,可用于高质量文生图任务[k]

使用示例:

from PIL import Image
from keras.utils import array_to_img
from keras_hub.models import StableDiffusion3TextToImage

text_to_image = StableDiffusion3TextToImage.from_preset(
    "stable_diffusion_3_medium",
    height=1024,
    width=1024,
    dtype="float16",
)

# 生成图像
image = text_to_image.generate("photograph of an astronaut riding a horse, detailed, 8k")

# 显示结果
img = array_to_img(image)
img

对KerasNLP开发者的影响

从KerasNLP迁移至KerasHub极为简便,仅需将导入语句中的keras_nlp替换为keras_hub[k]

示例:原使用KerasNLP加载BERT分类器

import keras_nlp

# 加载BERT模型
classifier = keras_nlp.models.BertClassifier.from_preset(
    "bert_base_en_uncased", 
    num_classes=2,
)

现仅需更改导入模块:

import keras_hub

classifier = keras_hub.models.BertClassifier.from_preset(
    "bert_base_en_uncased", 
    num_classes=2,
)

KerasHub 更新:为 KerasCV 开发者带来统一与便利

KerasHub 正在整合 KerasCV 模型,提供更灵活、高效的深度学习体验

对于 KerasCV 开发者而言,有哪些变化?

如果您是 KerasCV 用户,迁移到 KerasHub 将带来以下优势:

  • 简化模型加载:KerasHub 提供统一 API,便于同时使用 KerasCV 与 KerasNLP 的开发者整合模型。
  • 框架灵活性:支持 JAX、PyTorch 等不同框架的集成,提升跨框架开发便利性。
  • 集中式存储库:所有模型集中管理,便于查找、使用,并支持未来扩展新架构。

如何适配 KerasHub?

模型迁移说明

目前 KerasCV 模型正逐步迁移至 KerasHub。大多数视觉模型已可在 KerasHub 中使用,但 Centerpillar 模型不会被迁移。调用方式如下:

import keras_hub
# Load a model using presetModel = keras_hub.models.<model_name>.from_preset('preset_name')
# or load a custom model by specifying the backbone and preprocessorModel = keras_hub.models.<model_name>(backbone=backbone, preprocessor=preprocessor)

内置预处理功能

KerasHub 的任务模型现已集成预处理器,自动完成图像调整大小、重新缩放等操作,简化输入预处理流程。预处理器作为模型内在组件,也可被自定义替换。

迁移前需手动预处理输入:

# Preprocess inputs for exampledef preprocess_inputs(image, label):    # Resize rescale or do more preprocessing on inputs    return preprocessed_inputsbackbone = keras_cv.models.ResNet50V2Backbone.from_preset(    "resnet50_v2_imagenet",)model = keras_cv.models.ImageClassifier(    backbone=backbone,    num_classes=4,)output = model(preprocessed_input)

迁移后可直接使用预设模型,无需额外预处理:

classifier = keras_hub.models.ImageClassifier.from_preset('resnet_18_imagenet')classifier.predict(inputs)

损失函数调整

原 KerasCV 中的损失函数(如 FocalLoss)现已移至 Keras 核心库。开发者需将导入路径从 keras_cv.losses 更新为 keras.losses

旧写法:

import kerasimport keras_cv
keras_cv.losses.FocalLoss( alpha=0.25, gamma=2, from_logits=False, label_smoothing=0, **kwargs)

新写法:

import keras
keras.losses.FocalLoss( alpha=0.25, gamma=2, from_logits=False, label_smoothing=0, **kwargs)

开始使用 KerasHub

立即体验 KerasHub 的统一模型生态:

  • 查看官方文档:https://keras.io/keras_hub/
  • 获取入门指南:https://keras.io/guides/keras_hub/
  • 试用预训练模型:https://keras.io/api/keras_hub/models/
  • 探索源码并参与贡献:https://github.com/keras-team/keras-hub/
  • 在 Kaggle 上实践 Keras 示例:https://www.kaggle.com/organizations/keras

KerasHub 为开发者提供更高效、统一的深度学习工具链,助力多模态 AI 时代的技术创新[k]

【声明】内容源于网络
0
0
谷歌开发者
谷歌开发
内容 3529
粉丝 0
谷歌开发者 谷歌开发
总阅读15.9k
粉丝0
内容3.5k