
01
引言
人工智能(AI)已经在无数方面改变了世界,但我们是否已经达到了它的顶峰?人工智能是否已经到了极致,或者还有改进的余地?虽然人工智能的能力有时被过分夸大,但不可否认的是,它具有学习、进化、变得更高效、使用更安全的潜力。
在本文中,我们将深入探讨参数高效微调技术,这是一种突破性技术,它让人工智能变得更智能、更快速、更易用。到最后,大家将了解 PEFT 是什么、它为什么重要,以及它如何为每个人而不仅仅是大公司带来人工智能革命。
02
回顾:从RNN到Transformer
要了解 PEFT,让我们回到自然语言处理(NLP)的早期。早在 2016 年,循环神经网络(RNN)是翻译语言或预测句子中下一个单词等任务的黄金标准。RNN 按顺序处理单词,这使得在现代 GPU 上的训练速度慢得令人痛苦,而且效率低下。
进入Transformer神经网络时代。这一革命性架构于 2017 年推出,通过同时处理所有输入单词取代了 RNN。Transformer 使用注意力机制,将注意力集中在句子中最重要的部分,使其能够更快、更有效地进行训练。
如果把 RNN 想象成图书管理员逐字逐句地阅读一本书,找出关键信息。而Transformer就像图书管理员一次翻阅整本书,并立即找出重要章节。
Transformer在提高性能的同时,也带来了新的挑战:它们需要大量的数据和计算资源来进行训练。
03
为了解决这个问题,研究人员转向了迁移学习。其原理如下:
Pre-training阶段: 首先使用海量数据集对模型进行广泛、通用任务的训练(如预测句子中的下一个单词)。这样就建立了一个能够理解语言模式的基础模型。
-
Fine-tuning阶段: 使用较小的数据集,在特定任务(如QA任务)上进一步训练预训练模型。
这种由 BERT 和 GPT 等模型推广的方法大大减少了微调所需的数据。但有一个问题--微调这些大型模型中的每个参数需要消耗大量的时间、金钱和存储空间。
我们来看个例子,假设我们需要微调像 BERT 这样的模型。这个过程需要调整模型的全部 3.45 亿个参数。现在,如果要微调用于情感分析的同一模型,则需要再存储 3.45 亿个参数--每个任务一个参数集。再乘以几十个任务,存储需求就会变成天文数字。
04
05
-
预训练模型:从预先训练好的Transformer模型(如 BERT)开始。
适配器层:在每个Transformer层中插入小型神经网络模块(适配器)。
微调:只训练适配器层,同时冻结原始模型参数。
重复使用和替换:对于每项新任务,在不改变基础模型的情况下更换适配器层。
假设大家正在对 BERT 进行微调,我们来对比下:
传统方法将更新所有 3.45 亿个参数,需要大量存储空间。
使用 PEFT 时,只需训练和存储适配器层(仅占模型总参数的 1.8%)。
存储效率:PEFT 无需为每个任务存储数以亿计的参数,而是将其减少到几百万个。
成本效益:对小型机构和研究人员而言,微调大型语言模型变得可行。
性能保持:尽管可训练参数大幅减少,但 PEFT 的性能仍可媲美传统的微调技术。
具体而言,考虑到 BERT-Large 有 3.45 亿个参数。使用 PEFT,每个任务只需训练约 634 万个参数。在不影响准确性的前提下,存储需求减少了 98.2%。
06
使用PEFT和LORA微调LLM
我们将探索使用 LoRA(Low-Rank Adaptation)对 bigscience/mt0-large 模型进行微调,LoRA 是 PEFT 库中的一种参数高效微调技术。
首先导入必要的库:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizerfrom peft import get_peft_model, LoraConfig, TaskType# Define model and tokenizer pathsmodel_name_or_path = "bigscience/mt0-large"tokenizer_name_or_path = "bigscience/mt0-large"
设置 PEFT 方法的配置,指定任务类型和超参数。
peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, # Task type: Sequence-to-sequence language modelinginference_mode=False, # Ensure we're in training moder=8, # Rank parameter for LoRAlora_alpha=32, # LoRA scaling factorlora_dropout=0.1 # Dropout rate for LoRA layers)
# Load the base modelmodel = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)# Apply PEFT using the LoRA configurationmodel = get_peft_model(model, peft_config)# Print the number of trainable parametersmodel.print_trainable_parameters()# Output: trainable params: 2359296 || all params: 1231940608 || trainable%: 0.19151053100118282
训练过程与任何标准训练循环相同。有关完整示例,请参阅以下链接:
https://github.com/shobhitag11/llm-finetune/blob/main/fine-tuning-with-LoRa.ipynb
训练完成后,保存模型。PEFT 框架确保只保存增量 PEFT 权重(适配器的权重)
# Save the trained adapter weightsmodel.save_pretrained("output_dir")# Optionally, push the model to the Hugging Face Hubmodel.push_to_hub("my_awesome_peft_model")
保存的权重只包括两个文件:
-
adapter_config.json -
adapter_model.bin(通常是一个小文件,如经过 LoRA 调整的 T0_3B,19MB)。
07
模型推理
要加载微调模型进行推理,请按照以下步骤操作:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizerfrom peft import PeftModel, PeftConfig# Define the adapter model IDpeft_model_id = "smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM"# Load the PEFT configurationconfig = PeftConfig.from_pretrained(peft_model_id)# Load the base model and wrap it with the PEFT adaptermodel = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)model = PeftModel.from_pretrained(model, peft_model_id)# Load the tokenizertokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)# Move the model to the desired device (e.g., GPU)model = model.to("cuda")model.eval()
现在,大家可以使用微调后的模型进行推理。下面是一个推理示例:
import torch# Prepare input textinputs = tokenizer("Tweet text: @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. Label:",return_tensors="pt")# Generate the outputwith torch.no_grad():outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=10)result = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0]print(result) # Output: 'complaint'
08
参数高效微调不仅是一项技术进步,也是人工智能领域的一股先进力量。通过降低微调的成本和复杂性,PEFT 为各行各业的创新开辟了新的可能性,并为更广泛深入的研究提供了能力。
点击上方小卡片关注我
添加个人微信,进专属粉丝群!


