大数跨境
0
0

DeepResearch查询优化——GRPO强化学习增强Query生成

DeepResearch查询优化——GRPO强化学习增强Query生成 InfraLink
2025-11-04
0
导读:DeepResearch爆火,各家都推荐出自己DeepResearch方案,有基于Workflow的方案,有

DeepResearch爆火,各家都推荐出自己DeepResearch方案,有基于Workflow的方案,有基于Agent的方案, 但是基础的DeepResearch的一般都包含这4个关键阶段:

  1. 查询生成(Generate Query)
  2. 网络搜索(Web search)
  3. 信息校验反思(Reflection)
  4. 结果生成(Answer Generation)

其中查询生成阶段作为第一步非常重要,既需要合理地生成多样性的查询,又需要扩展的查询不要产生语义偏离,生成一些干扰信息。而这些要求一般都写在prompt里面,但是模型指令遵循比较弱的时候,经常出现query生成的结果非常不好,比如:

“Qwen3是如何实现混合推理的” 被改写成 “混合推理技术在知识图谱中的应用实例”

今天笔者就利用Qwen0.6B,采用强化学习的方案去优化一下查询生成(Generate Query)这个阶段的效果,主要目标是引导模型生成 在语义上与原始查询相似但在字面表达上不同的查询。

GRPO算法与奖励设计:

GRPO原理

如下图所示, 在传统的近端策略优化算法(PPO)中,通常需要同时训练策略模型和价值模型,后者用于估计每个状态的期望回报,并以此作为优势函数的基线。对于大型语言模型来说,训练与策略模型规模相当的价值网络不仅增加了计算量,还会带来显著的内存开销。为了解决这一问题,GRPO 提出了利用“组内”生成数据的思路:

  • 多样本生成: 对于每个输入(例如一个问题),模型根据旧策略生成多个候选输出。
  • 奖励评估: 对每个候选输出采用特定的奖励函数进行评估,奖励可以包括答案正确性、格式符合要求、推理过程合理等指标(例如 DeepSeek 系列中常用的准确性奖励和格式奖励)
  • 组内优势计算: 将这组输出的奖励视为一个样本集,直接计算其均值和标准差,并将每个输出的奖励进行标准化(即减去均值、除以标准差),从而获得组内相对优势。这种方式能够反映出同一问题下各个候选答案的“相对好坏”,而不需要单独训练一个价值模型。

其优势省去价值网络,占用资源少,同时训练稳定性较PPO要高。GRPO有个好处可以采用Rule Based的奖励函数,只要设计好的奖励函数,这样模型就能够被奖励引导,能力变强。笔着就Query生成这个任务设计了下方3个奖励函数。

Query生成奖励函数设计

为了生成语义上与原始查询相似但在字面表达上不同query, 设计了三个奖励函数,分别从语义相似性、文本多样性和输出格式有效性多个维度对模型输出进行评估。

1️.语义相似&表达多样性奖励函数

衡量生成的 rewritten_query 是否在语义上贴近原始查询(query),但又在字面表达上具有差异性。

reward = cosine_similarity - jaccard_similarity
  • cosine_similarity: 使用 m3e-small 编码器将原始响应和模型生成的 rewritten_query 编码为向量,通过余弦相似度衡量语义接近度。
  • jaccard_similarity: 用于衡量重写查询和原始查询在词汇上的重合程度(即字面相似度)。
  • 目标是 提高语义相似度 但 降低字面相似度,因此用差值作为奖励。

2. 格式合规性检查(硬约束)

确保模型输出严格符合 JSON 格式,便于结构化解析和后续使用。只要能解析出一个 JSON 对象,就给 0.2 的奖励,否则为 0.0。

3. 格式合规性检查(软约束)

当模型还没有学会完全生成正确 JSON 时,给予部分奖励,鼓励其逐步向正确格式靠拢。 如果文本中包含 { +0.2, 如果包含 } 且在 { 之后 +0.2, 最多奖励 0.4

实战部分:

初始化模型

这里采用unsloth 框架进行训练,模型选择Qwen3-0.6B,采用Lora 进行少量参数微调。

from unsloth import FastLanguageModel, is_bfloat16_supported
import torch

# 设置最大序列长度,可根据需要增加以处理更长的推理链
max_seq_length = 1024
# LoRA 的秩,更大的秩意味着更“智能”的模型,但训练速度会变慢
lora_rank = 64

# 从预训练模型加载 FastLanguageModel 和对应的分词器
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "/mnt/d/wsl/work/jupyter/model_hub/Qwen3-0.6B"# 指定模型名称
    max_seq_length = max_seq_length,      # 设置最大序列长度
    fast_inference = True, # use vLLM for fast inference!
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.5, # Reduce if out of memory
)


# 获取 PEFT (Parameter-Efficient Fine-Tuning) 模型,并应用 LoRA 配置
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,                         # LoRA 的秩,建议选择 8, 16, 32, 64, 128 等值
    target_modules = [
        "q_proj""k_proj""v_proj""o_proj",
        "gate_proj""up_proj""down_proj",
    ], # 指定要应用 LoRA 的目标模块,如果内存不足可移除 QKVO
    lora_alpha = lora_rank,                # LoRA 的 alpha 参数,通常设置为与 r 相同
    use_gradient_checkpointing = "unsloth"# 启用梯度检查点以进行长上下文微调
    random_state = 3407,                   # 设置随机种子以保证结果可复现
)

加载数据

定义好SYSTEM_PROMPT ,TASK_PROMPT ,将数据格式处理成如下格式,

from datasets import load_dataset
dataset = load_dataset("parquet", data_files="/mnt/d/wsl/work/jupyter/data_hub/Chinese-QA-Agriculture/Chinese-QA-AFAF-train-v2.parquet")
split_dataset = dataset["train"].train_test_split(test_size=0.95, seed=42)
train_dataset = split_dataset['train']


SYSTEM_PROMPT = """请扮演一个搜索大师。"""


TASK_PROMPT =  """
你的任务是接收一个原始查询,并对其进行改写,以提升搜索引擎的理解和匹配效果。改写后的查询应与原始查询表达相近的意图,但在用词、结构或描述上应尽量有所不同,以提高搜索多样性和命中率。请以如下 JSON 格式输出,仅包含字段 "
rewritten_query"。

例如:
原始查询: "
如何训练一个大型语言模型"
改写后的查询: "
大型语言模型训练教程"

现在原始查询是:{}

请输出如下格式的 JSON 结果:/no_think
"
""


train_dataset_format = train_dataset.map(lambda x: {
    "prompt" : [
        {"role""system""content": SYSTEM_PROMPT},
        {"role""user",   "content":TASK_PROMPT.format(x["prompt"])},
    ],
    "query": x["prompt"],
})

数据格式:

{'id''u1aUk83rJUWb',
 'prompt': [
  {'content''请扮演一个搜索大师。''role''system'},
  {'content''\n你的任务是接收一个原始查询,并对其进行改写,以提升搜索引擎的理解和匹配效果。改写后的查询应与原始查询表达相近的意图,但在用词、结构或描述上应尽量有所不同,以提高搜索多样性和命中率。请以如下 JSON 格式输出,仅包含字段 "rewritten_query"。\n\n例如:\n原始查询: "如何训练一个大型语言模型"\n改写后的查询: "大型语言模型训练教程"\n\n现在原始查询是:多肉植物每天至少需要多少小时的光照?\n\n请输出如下格式的 JSON 结果:/no_think\n','role''user'}],
 'response''大多数多肉植物每天至少需要6到8小时的光照。',
 'query''多肉植物每天至少需要多少小时的光照?'
}

设计奖励函数

实现上方设计的奖励函数:

from sentence_transformers import SentenceTransformer
import json
import re
from sentence_transformers import util

# 辅助函数:从文本中提取 JSON 答案
def extract_json_answer(text: str) -> dict | None:
    try:
        # 查找 JSON 对象的起始和结束位置
        json_start = text.find('{')
        json_end = text.rfind('}')
        if json_start != -1 and json_end != -1:
            json_str = text[json_start : json_end + 1]
            return json.loads(json_str)
    except json.JSONDecodeError:
        pass
    return None

def _calculate_jaccard_similarity(s1: str, s2: str) -> float:
    """
    Calculates the Jaccard similarity between two strings based on their word sets.
    "
""
    # Simple tokenization by splitting on whitespace and converting to lowercase
    words1 = set(s1.lower().split())
    words2 = set(s2.lower().split())

    if not words1 and not words2:
        return 1.0  # Both empty, considered perfectly similar literally
    if not words1 or not words2:
        return 0.0  # One empty, one not, considered dissimilar literally

    intersection = len(words1.intersection(words2))
    union = len(words1.union(words2))
    return intersection / union

class QueryRewriteReward:
    def __init__(self): 
        # 用m3e embeding模型做相似度计算
        self.embedding_model = SentenceTransformer("/mnt/d/wsl/work/jupyter/model_hub/m3e-small")

    def calculate_similarity_reward(self, query, response, completions, **kwargs) -> list[float]:
        rewritten_querys = [completion[0]['content'for completion in completions]
        rewards = []
        answer_embedding = self.embedding_model.encode(response[0], convert_to_tensor=True)
        for rewritten_query in rewritten_querys:
            # Compute embeddings for rewritten query
            rewritten_query_embedding = self.embedding_model.encode(rewritten_query, convert_to_tensor=True)
            # Compute cosine similarity
            cosine_similarity = util.cos_sim(answer_embedding, rewritten_query_embedding).item()
            jaccard_similarity = _calculate_jaccard_similarity(rewritten_query, query[0])
            rewards.append(cosine_similarity-jaccard_similarity)
        
        return rewards

    def json_format_reward_func(self, completions, **kwargs) -> list[float]:
        responses = [completion[0]['content'for completion in completions]
        rewards = []
        for response in responses:
            json_data = extract_json_answer(response)
            if json_data:
                reward = 0.2  # 有效的 JSON 格式
            else:
                reward = 0.0  # 无效的 JSON 格式
            rewards.append(reward)
        return rewards

    def soft_json_format_reward_func(self, completions, **kwargs) -> list[float]:
        responses = [completion[0]['content'for completion in completions]
        rewards = []
        for response in responses:
            reward = 0.0
            # 使用正则表达式检查是否包含 JSON 格式的内容
            json_start = response.find('{')
            json_end = response.rfind('}')
            if json_start != -1:
                reward += 0.2  # 包含 JSON 开始符
            if json_end != -1 and json_end > json_start:
                reward += 0.2
            rewards.append(reward)
        return rewards




query_rewrite_reward_func = QueryRewriteReward().calculate_similarity_reward
json_format_reward_func = QueryRewriteReward().json_format_reward_func
soft_json_format_reward_func = QueryRewriteReward().soft_json_format_reward_func

模型训练

模型训练:奖励稳定上升,证明模型正在学习。

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, 
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 10,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = 512,
    max_completion_length = 200,
    num_train_epochs = 1, # Set to 1 for a full training run
    save_steps = 500,
    max_steps = 1000,
    max_grad_norm = 0.1,
    report_to = "none",
    output_dir = "outputs_v1"
    # log_completions = True
)
new_train_dataset_format = train_dataset_format.train_test_split(test_size = 0.01)


trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        query_rewrite_reward_func,
        json_format_reward_func,
        soft_json_format_reward_func
    ],
    args = training_args,
    train_dataset = new_train_dataset_format["train"],
    eval_dataset = new_train_dataset_format["test"],
)
trainer.train()

奖励过程可视化结果:

模型推理

看看训练结束后模型的效果:确实重写的query具备多样性,且语义信息没有偏移。

messages = [
    {"role""system""content": SYSTEM_PROMPT},
    {"role""user",   "content":TASK_PROMPT.format("Qwen3是如何实现混合推理的")},
]

text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, # Must add for generation
    tokenize = False,
)
from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 1.0,
    top_k = 50,
    max_tokens = 256,
)
for i in range(3):
    output = model.fast_generate(
        text,
        sampling_params = sampling_params,
        #lora_request = model.load_lora("/mnt/d/wsl/work/jupyter/outputs_v1/checkpoint-1000"),
    )[0].outputs[0].text

    print(output)

效果:

####################################
<think>
</think>
json
{
"rewritten_query""Qwen3如何实现混合推理功能"
}
####################################
<think>
</think>
json
{
"rewritten_query""Qwen3混合推理技术实现方式"
}
####################################
<think>
</think>
{"rewritten_query""混合推理技术在知识图谱中的应用实例"}
####################################
 


关注我们!与InfraLink共赴智能未来



🔗 聚焦数据科学 | 深耕算法创新 | 赋能AI工程化

📌 技术干货持续更新,全球生态合作共建

✨ 点击关注@InfraLink,解锁更多前沿技术资讯与实践洞察


【声明】内容源于网络
0
0
InfraLink
链接技术基建,共筑智能未来。 在数据智能重塑产业格局的时代,InfraLink 以「构建技术基础设施的全球连接枢纽」为使命,聚焦 数据科学、算法创新、AI 工程化 三大核心领域,打造集技术资讯、实践经验、生态合作为一体的全球化社区平台。
内容 109
粉丝 0
InfraLink 链接技术基建,共筑智能未来。 在数据智能重塑产业格局的时代,InfraLink 以「构建技术基础设施的全球连接枢纽」为使命,聚焦 数据科学、算法创新、AI 工程化 三大核心领域,打造集技术资讯、实践经验、生态合作为一体的全球化社区平台。
总阅读18
粉丝0
内容109