DeepResearch爆火,各家都推荐出自己DeepResearch方案,有基于Workflow的方案,有基于Agent的方案, 但是基础的DeepResearch的一般都包含这4个关键阶段:
-
查询生成(Generate Query) -
网络搜索(Web search) -
信息校验反思(Reflection) -
结果生成(Answer Generation)
其中查询生成阶段作为第一步非常重要,既需要合理地生成多样性的查询,又需要扩展的查询不要产生语义偏离,生成一些干扰信息。而这些要求一般都写在prompt里面,但是模型指令遵循比较弱的时候,经常出现query生成的结果非常不好,比如:
❝“Qwen3是如何实现混合推理的” 被改写成 “混合推理技术在知识图谱中的应用实例”
今天笔者就利用Qwen0.6B,采用强化学习的方案去优化一下查询生成(Generate Query)这个阶段的效果,主要目标是引导模型生成 在语义上与原始查询相似但在字面表达上不同的查询。
GRPO算法与奖励设计:
GRPO原理
如下图所示, 在传统的近端策略优化算法(PPO)中,通常需要同时训练策略模型和价值模型,后者用于估计每个状态的期望回报,并以此作为优势函数的基线。对于大型语言模型来说,训练与策略模型规模相当的价值网络不仅增加了计算量,还会带来显著的内存开销。为了解决这一问题,GRPO 提出了利用“组内”生成数据的思路:
-
多样本生成: 对于每个输入(例如一个问题),模型根据旧策略生成多个候选输出。 -
奖励评估: 对每个候选输出采用特定的奖励函数进行评估,奖励可以包括答案正确性、格式符合要求、推理过程合理等指标(例如 DeepSeek 系列中常用的准确性奖励和格式奖励) -
组内优势计算: 将这组输出的奖励视为一个样本集,直接计算其均值和标准差,并将每个输出的奖励进行标准化(即减去均值、除以标准差),从而获得组内相对优势。这种方式能够反映出同一问题下各个候选答案的“相对好坏”,而不需要单独训练一个价值模型。
其优势省去价值网络,占用资源少,同时训练稳定性较PPO要高。GRPO有个好处可以采用Rule Based的奖励函数,只要设计好的奖励函数,这样模型就能够被奖励引导,能力变强。笔着就Query生成这个任务设计了下方3个奖励函数。
Query生成奖励函数设计
为了生成语义上与原始查询相似但在字面表达上不同query, 设计了三个奖励函数,分别从语义相似性、文本多样性和输出格式有效性多个维度对模型输出进行评估。
1️.语义相似&表达多样性奖励函数
衡量生成的 rewritten_query 是否在语义上贴近原始查询(query),但又在字面表达上具有差异性。
reward = cosine_similarity - jaccard_similarity
-
cosine_similarity: 使用 m3e-small 编码器将原始响应和模型生成的 rewritten_query 编码为向量,通过余弦相似度衡量语义接近度。 -
jaccard_similarity: 用于衡量重写查询和原始查询在词汇上的重合程度(即字面相似度)。 -
目标是 提高语义相似度 但 降低字面相似度,因此用差值作为奖励。
2. 格式合规性检查(硬约束)
确保模型输出严格符合 JSON 格式,便于结构化解析和后续使用。只要能解析出一个 JSON 对象,就给 0.2 的奖励,否则为 0.0。
3. 格式合规性检查(软约束)
当模型还没有学会完全生成正确 JSON 时,给予部分奖励,鼓励其逐步向正确格式靠拢。 如果文本中包含 { +0.2, 如果包含 } 且在 { 之后 +0.2, 最多奖励 0.4
实战部分:
初始化模型
这里采用unsloth 框架进行训练,模型选择Qwen3-0.6B,采用Lora 进行少量参数微调。
from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
# 设置最大序列长度,可根据需要增加以处理更长的推理链
max_seq_length = 1024
# LoRA 的秩,更大的秩意味着更“智能”的模型,但训练速度会变慢
lora_rank = 64
# 从预训练模型加载 FastLanguageModel 和对应的分词器
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "/mnt/d/wsl/work/jupyter/model_hub/Qwen3-0.6B", # 指定模型名称
max_seq_length = max_seq_length, # 设置最大序列长度
fast_inference = True, # use vLLM for fast inference!
max_lora_rank = lora_rank,
gpu_memory_utilization = 0.5, # Reduce if out of memory
)
# 获取 PEFT (Parameter-Efficient Fine-Tuning) 模型,并应用 LoRA 配置
model = FastLanguageModel.get_peft_model(
model,
r = lora_rank, # LoRA 的秩,建议选择 8, 16, 32, 64, 128 等值
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
], # 指定要应用 LoRA 的目标模块,如果内存不足可移除 QKVO
lora_alpha = lora_rank, # LoRA 的 alpha 参数,通常设置为与 r 相同
use_gradient_checkpointing = "unsloth", # 启用梯度检查点以进行长上下文微调
random_state = 3407, # 设置随机种子以保证结果可复现
)
加载数据
定义好SYSTEM_PROMPT ,TASK_PROMPT ,将数据格式处理成如下格式,
from datasets import load_dataset
dataset = load_dataset("parquet", data_files="/mnt/d/wsl/work/jupyter/data_hub/Chinese-QA-Agriculture/Chinese-QA-AFAF-train-v2.parquet")
split_dataset = dataset["train"].train_test_split(test_size=0.95, seed=42)
train_dataset = split_dataset['train']
SYSTEM_PROMPT = """请扮演一个搜索大师。"""
TASK_PROMPT = """
你的任务是接收一个原始查询,并对其进行改写,以提升搜索引擎的理解和匹配效果。改写后的查询应与原始查询表达相近的意图,但在用词、结构或描述上应尽量有所不同,以提高搜索多样性和命中率。请以如下 JSON 格式输出,仅包含字段 "rewritten_query"。
例如:
原始查询: "如何训练一个大型语言模型"
改写后的查询: "大型语言模型训练教程"
现在原始查询是:{}
请输出如下格式的 JSON 结果:/no_think
"""
train_dataset_format = train_dataset.map(lambda x: {
"prompt" : [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content":TASK_PROMPT.format(x["prompt"])},
],
"query": x["prompt"],
})
数据格式:
{'id': 'u1aUk83rJUWb',
'prompt': [
{'content': '请扮演一个搜索大师。', 'role': 'system'},
{'content': '\n你的任务是接收一个原始查询,并对其进行改写,以提升搜索引擎的理解和匹配效果。改写后的查询应与原始查询表达相近的意图,但在用词、结构或描述上应尽量有所不同,以提高搜索多样性和命中率。请以如下 JSON 格式输出,仅包含字段 "rewritten_query"。\n\n例如:\n原始查询: "如何训练一个大型语言模型"\n改写后的查询: "大型语言模型训练教程"\n\n现在原始查询是:多肉植物每天至少需要多少小时的光照?\n\n请输出如下格式的 JSON 结果:/no_think\n','role': 'user'}],
'response': '大多数多肉植物每天至少需要6到8小时的光照。',
'query': '多肉植物每天至少需要多少小时的光照?'
}
设计奖励函数
实现上方设计的奖励函数:
from sentence_transformers import SentenceTransformer
import json
import re
from sentence_transformers import util
# 辅助函数:从文本中提取 JSON 答案
def extract_json_answer(text: str) -> dict | None:
try:
# 查找 JSON 对象的起始和结束位置
json_start = text.find('{')
json_end = text.rfind('}')
if json_start != -1 and json_end != -1:
json_str = text[json_start : json_end + 1]
return json.loads(json_str)
except json.JSONDecodeError:
pass
return None
def _calculate_jaccard_similarity(s1: str, s2: str) -> float:
"""
Calculates the Jaccard similarity between two strings based on their word sets.
"""
# Simple tokenization by splitting on whitespace and converting to lowercase
words1 = set(s1.lower().split())
words2 = set(s2.lower().split())
if not words1 and not words2:
return 1.0 # Both empty, considered perfectly similar literally
if not words1 or not words2:
return 0.0 # One empty, one not, considered dissimilar literally
intersection = len(words1.intersection(words2))
union = len(words1.union(words2))
return intersection / union
class QueryRewriteReward:
def __init__(self):
# 用m3e embeding模型做相似度计算
self.embedding_model = SentenceTransformer("/mnt/d/wsl/work/jupyter/model_hub/m3e-small")
def calculate_similarity_reward(self, query, response, completions, **kwargs) -> list[float]:
rewritten_querys = [completion[0]['content'] for completion in completions]
rewards = []
answer_embedding = self.embedding_model.encode(response[0], convert_to_tensor=True)
for rewritten_query in rewritten_querys:
# Compute embeddings for rewritten query
rewritten_query_embedding = self.embedding_model.encode(rewritten_query, convert_to_tensor=True)
# Compute cosine similarity
cosine_similarity = util.cos_sim(answer_embedding, rewritten_query_embedding).item()
jaccard_similarity = _calculate_jaccard_similarity(rewritten_query, query[0])
rewards.append(cosine_similarity-jaccard_similarity)
return rewards
def json_format_reward_func(self, completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
rewards = []
for response in responses:
json_data = extract_json_answer(response)
if json_data:
reward = 0.2 # 有效的 JSON 格式
else:
reward = 0.0 # 无效的 JSON 格式
rewards.append(reward)
return rewards
def soft_json_format_reward_func(self, completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
rewards = []
for response in responses:
reward = 0.0
# 使用正则表达式检查是否包含 JSON 格式的内容
json_start = response.find('{')
json_end = response.rfind('}')
if json_start != -1:
reward += 0.2 # 包含 JSON 开始符
if json_end != -1 and json_end > json_start:
reward += 0.2
rewards.append(reward)
return rewards
query_rewrite_reward_func = QueryRewriteReward().calculate_similarity_reward
json_format_reward_func = QueryRewriteReward().json_format_reward_func
soft_json_format_reward_func = QueryRewriteReward().soft_json_format_reward_func
模型训练
模型训练:奖励稳定上升,证明模型正在学习。
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
use_vllm = True,
learning_rate = 5e-6,
adam_beta1 = 0.9,
adam_beta2 = 0.99,
weight_decay = 0.1,
warmup_ratio = 0.1,
lr_scheduler_type = "cosine",
optim = "adamw_8bit",
logging_steps = 10,
per_device_train_batch_size = 1,
gradient_accumulation_steps = 1, # Increase to 4 for smoother training
num_generations = 4, # Decrease if out of memory
max_prompt_length = 512,
max_completion_length = 200,
num_train_epochs = 1, # Set to 1 for a full training run
save_steps = 500,
max_steps = 1000,
max_grad_norm = 0.1,
report_to = "none",
output_dir = "outputs_v1"
# log_completions = True
)
new_train_dataset_format = train_dataset_format.train_test_split(test_size = 0.01)
trainer = GRPOTrainer(
model = model,
processing_class = tokenizer,
reward_funcs = [
query_rewrite_reward_func,
json_format_reward_func,
soft_json_format_reward_func
],
args = training_args,
train_dataset = new_train_dataset_format["train"],
eval_dataset = new_train_dataset_format["test"],
)
trainer.train()
奖励过程可视化结果:
模型推理
看看训练结束后模型的效果:确实重写的query具备多样性,且语义信息没有偏移。
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content":TASK_PROMPT.format("Qwen3是如何实现混合推理的")},
]
text = tokenizer.apply_chat_template(
messages,
add_generation_prompt = True, # Must add for generation
tokenize = False,
)
from vllm import SamplingParams
sampling_params = SamplingParams(
temperature = 1.0,
top_k = 50,
max_tokens = 256,
)
for i in range(3):
output = model.fast_generate(
text,
sampling_params = sampling_params,
#lora_request = model.load_lora("/mnt/d/wsl/work/jupyter/outputs_v1/checkpoint-1000"),
)[0].outputs[0].text
print(output)
效果:
####################################
<think>
</think>
json
{
"rewritten_query": "Qwen3如何实现混合推理功能"
}
####################################
<think>
</think>
json
{
"rewritten_query": "Qwen3混合推理技术实现方式"
}
####################################
<think>
</think>
{"rewritten_query": "混合推理技术在知识图谱中的应用实例"}
####################################
关注我们!与InfraLink共赴智能未来
🔗 聚焦数据科学 | 深耕算法创新 | 赋能AI工程化
📌 技术干货持续更新,全球生态合作共建
✨ 点击关注@InfraLink,解锁更多前沿技术资讯与实践洞察

