多模态大模型强化学习训练框架 - EasyR1代码走读(GRPO)
- 2025-07-16 08:00:00
作者 | 404dreamer 编辑 | 大模型之心Tech
原文链接:https://zhuanlan.zhihu.com/p/1921972845669492319
点击下方卡片,关注“大模型之心Tech”公众号
>>点击进入→大模型没那么大Tech技术交流群
本文只做学术分享,如有侵权,联系删文,自动驾驶课程学习与技术交流群事宜,也欢迎添加小助理微信AIDriver004做进一步咨询
前言
由于笔者最近准备做一些关于GRPO在多模态大模型上的先导试验,冷启动阶段可以使用Llama-factory直接微调Qwen2.5-VL,但便开始寻找多模态大模型强化学习开源训练框架。
笔者发现网上并没有对多模态强化学习训练框架EasyR1代码的走读,便顺着启动入口文件一路读下来。
由于EasyR1的代码还在频繁更新中,笔者预估了一下,大概这篇文章指向的是6月10号左右的代码版本:
https://github.com/hiyouga/EasyR1/tree/bae1fa073555d7b65a9d6003ecf70d5bcb4feab5
EasyR1代码仓:
https://github.com/hiyouga/EasyR1
EasyR1是基于纯语言强化学习训练框架verl修改而来,verl:
https://github.com/volcengine/verl
希望这篇不算很详细的代码走读可以对科研人员有所帮助。如有理解或表达错误,欢迎告知。
1 启动脚本及配置文件
1.1 启动脚本
启动脚本位于 examples
目录中,我们以 examples/qwen2_5_7b_math_grpo.sh
为例
#!/bin/bash
set -x
export PYTHONUNBUFFERED=1
MODEL_PATH=Qwen/Qwen2.5-7B-Instruct # replace it with your local file path
python3 -m verl.trainer.main \
config=examples/config.yaml \
worker.actor.model.model_path=${MODEL_PATH}
通过 verl/trainer/main.py
拉起训练任务config
文件路径必须指定。其它config文件中的参数,例如worker.actor.model.model_path
可以通过在启动脚本中传参的方式覆盖config文件中该变量。
1.2 配置文件
配置文件位于:examples/config.yaml
。配置文件中的配置信息分为4类:data
, algorithm
, worker
, trainer
配置项python类对应:verl/trainer/config.py
由于EasyR1的基于verl构建的,因此可以参考verl的config:
https://verl.readthedocs.io/en/latest/examples/config.html#ppo-trainer-yaml-for-rl-fsdp-backend
1.2.1 data相关配置
1.2.1.1 默认值
verl/trainer/config.py -> DataConfig
@dataclass
class DataConfig:
train_files: str = ""
val_files: str = ""
prompt_key: str = "prompt"
answer_key: str = "answer"
image_key: str = "images"
image_dir: Optional[str] = None
max_prompt_length: int = 512
max_response_length: int = 512
rollout_batch_size: int = 512
mini_rollout_batch_size: Optional[int] = None
val_batch_size: int = -1
format_prompt: Optional[str] = None
override_chat_template: Optional[str] = None
shuffle: bool = True
seed: int = 1
min_pixels: Optional[int] = 262144
max_pixels: Optional[int] = 4194304
filter_overlong_prompts: bool = True
1.2.1.2 config文件
examples/config.yaml
data:
train_files: hiyouga/math12k@train
val_files: hiyouga/math12k@test
prompt_key: problem
answer_key: answer
image_key: images
image_dir: null
max_prompt_length: 2048
max_response_length: 2048
rollout_batch_size: 512 # equivalent to verl's data.train_batch_size
mini_rollout_batch_size: null # equivalent to verl's data.gen_batch_size
val_batch_size: 1024
format_prompt: ./examples/format_prompt/math_format.jinja
override_chat_template: null
shuffle: true
seed: 1
min_pixels: 262144
max_pixels: 4194304
filter_overlong_prompts: true
1.2.1.3 data相关配置参数解读
1.2.2 algorithm相关配置
1.2.2.1 默认值
verl/trainer/config.py -> AlgorithmConfig
class AlgorithmConfig:
gamma: float = 1.0
"""discount factor for ppo gae advantage estimator"""
lam: float = 1.0
"""lambda value for ppo gae advantage estimator"""
adv_estimator: str = "grpo"
"""advantage estimator, support `gae`, `grpo`, `reinforce_plus_plus`, `remax`, `rloo`"""
disable_kl: bool = False
"""disable reference model"""
use_kl_loss: bool = False
"""use kl loss instead of kl in reward"""
kl_penalty: str = "kl"
"""kl penalty type, support `kl`, `abs`, `mse`, `low_var_kl`, `full`"""
kl_coef: float = 1e-3
"""kl coefficient"""
kl_type: str = "fixed"
"""kl controller type, support `fixed`, `adaptive`"""
kl_horizon: float = 10000.0
"""kl horizon for adaptive kl controller"""
kl_target: float = 0.1
"""target kl for adaptive kl controller"""
1.2.2.2 config文件
algorithm:
adv_estimator: grpo
disable_kl: false
use_kl_loss: true
kl_penalty: low_var_kl
kl_coef: 1.0e-2
很多多模态强化学习的工作中都不使用KL divergence来约束模型更新。可以考虑真正训练时把 disable_kl
设置为true
1.2.3 worker相关配置
配置位置:verl/workers/config.py
worker的配置又分为5个子配置:actor
, critic
, ref
, reward
, rollout
1.2.3.1 actor相关配置
位置:verl/workers/actor/config.py
@dataclass
class ActorConfig:
strategy: str = "fsdp"
global_batch_size: int = 256
micro_batch_size_per_device_for_update: int = 4
micro_batch_size_per_device_for_experience: int = 16
max_grad_norm: float = 1.0
clip_ratio_low: float = 0.2
"""clip ratio in PPO & DAPO"""
clip_ratio_high: float = 0.3
"""clip ratio in PPO & DAPO"""
clip_ratio_dual: float = 3.0
"""constant C in dual-clip PPO, clips when advantage < -C"""
loss_avg_mode: str = "token"
"""loss average mode: `token`, `seq`"""
ppo_epochs: int = 1
"""number of ppo epochs for each rollout batch"""
padding_free: bool = True
"""use padding-free training"""
ulysses_size: int = 1
"""ulysses sequence parallel size"""
use_torch_compile: bool = True
model: ModelConfig = field(default_factory=ModelConfig)
optim: OptimConfig = field(default_factory=OptimConfig)
fsdp: FSDPConfig = field(default_factory=FSDPConfig)
offload: OffloadConfig = field(default_factory=OffloadConfig)
"""auto keys"""
global_batch_size_per_device: int = field(default=-1, init=False)
disable_kl: bool = field(default=False, init=False)
use_kl_loss: bool = field(default=False, init=False)
kl_penalty: str = field(default="kl", init=False)
kl_coef: float = field(default=0.0, init=False)
在这个文件目录下还要关于actor model, optimizer相关的参数可以参考
config示例:
actor:
global_batch_size: 128 # equivalent to verl's actor.ppo_mini_batch_size
micro_batch_size_per_device_for_update: 4 # equivalent to verl's actor.ppo_micro_batch_size_per_gpu
micro_batch_size_per_device_for_experience: 16 # equivalent to verl's rollout.log_prob_micro_batch_size_per_gpu
max_grad_norm: 1.0
padding_free: true
ulysses_size: 1
model:
model_path: Qwen/Qwen2.5-7B-Instruct
enable_gradient_checkpointing: true
trust_remote_code: false
freeze_vision_tower: false
optim:
lr: 1.0e-6
weight_decay: 1.0e-2
strategy: adamw # {adamw, adamw_bf16}
lr_warmup_ratio: 0.0
fsdp:
enable_full_shard: true
enable_cpu_offload: false
enable_rank0_init: true
offload:
offload_params: true # true: more CPU memory; false: more GPU memory
offload_optimizer: true # true: more CPU memory; false: more GPU memory
1.2.3.2 critic相关配置
由于GRPO算法不涉及Value Model,所以无需配置Critic参数。 位置:verl/workers/critic/config.py
@dataclass
class CriticConfig:
strategy: str = "fsdp"
global_batch_size: int = 256
micro_batch_size_per_device_for_update: int = 4
micro_batch_size_per_device_for_experience: int = 16
max_grad_norm: float = 1.0
cliprange_value: float = 0.5
"""clip range for value loss"""
loss_avg_mode: str = "token"
"""loss average mode: `token`, `seq`"""
ppo_epochs: int = 1
"""number of ppo epochs for each rollout batch"""
padding_free: bool = False
"""use padding-free training"""
ulysses_size: int = 1
"""ulysses sequence parallel size"""
model: ModelConfig = field(default_factory=ModelConfig)
optim: OptimConfig = field(default_factory=OptimConfig)
fsdp: FSDPConfig = field(default_factory=FSDPConfig)
offload: OffloadConfig = field(default_factory=OffloadConfig) """auto keys"""
global_batch_size_per_device: int = field(default=-1, init=False)
1.2.3.3 ref相关配置
位置:verl/workers/actor/config.py
@dataclass
class RefConfig:
strategy: str = "fsdp"
fsdp: FSDPConfig = field(default_factory=FSDPConfig)
offload: OffloadConfig = field(default_factory=OffloadConfig) """auto keys"""
micro_batch_size_per_device_for_experience: int = field(default=-1, init=False)
padding_free: bool = field(default=False, init=False)
ulysses_size: int = field(default=1, init=False)
use_torch_compile: bool = field(default=True, init=False)
1.2.3.4 reward相关配置
config设置:
reward:
reward_type: batch
reward_function: ./examples/reward_function/math.py:compute_score
位置:verl/workers/reward/config.py
@dataclass
class RewardConfig:
reward_type: str = "batch"
reward_function: Optional[str] = None
reward_function_kwargs: dict = field(default_factory=dict)
skip_special_tokens: bool = True
num_cpus: int = 1
"""auto keys"""
reward_function_name: Optional[str] = field(default=None, init=False)
def post_init(self):
if self.reward_function is not None: # support custom reward function, e.g., ./math.py:main
if ":" not in self.reward_function:
self.reward_function_name = "main"
else:
self.reward_function, self.reward_function_name = self.reward_function.rsplit(":", maxsplit=1)
if os.path.exists(self.reward_function): # ray job uses absolute path
self.reward_function = os.path.abspath(self.reward_function)
else:
print(f"Reward function {self.reward_function} not found.")
self.reward_function = None
post_init
中根据传入的reward_function: ./examples/reward_function/math.py:compute_score
,设置self.reward_function
为reward文件位置,self.reward_function_name
为对应的reward函数
1.2.3.5 rollout相关config
config设置:
rollout:
n: 5 # 对同一条 prompt 并行采样 n 条序列
temperature: 1.0
top_p: 0.99
gpu_memory_utilization: 0.6
enforce_eager: false
enable_chunked_prefill: false
tensor_parallel_size: 2
limit_images: 0
val_override_config:
temperature: 0.5
n: 1
位置:verl/workers/rollout/config.py
@dataclass
class RolloutConfig:
name: str = "vllm"
n: int = 1
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
seed: int = 1
limit_images: int = 0
dtype: str = "bf16"
gpu_memory_utilization: float = 0.6
ignore_eos: bool = False
enforce_eager: bool = False
enable_chunked_prefill: bool = False # only for v0 engine
tensor_parallel_size: int = 2
max_model_len: Optional[int] = None
max_num_batched_tokens: int = 8192
disable_log_stats: bool = True
val_override_config: Dict[str, Any] = field(default_factory=dict)
"""auto keys"""
prompt_length: int = field(default=-1, init=False)
response_length: int = field(default=-1, init=False)
trust_remote_code: bool = field(default=False, init=False)
def to_dict(self):
return asdict(self)
2 数据集相关
2.1 数据集相关flow
主程序,创建dataloader:
# verl/trainer/main.py
train_dataloader, val_dataloader = create_dataloader(config.data, tokenizer, processor)
dataset: verl/utils/dataset.py
-> _build_messages
, __getitem__
# dataset.py ->RLHFDataset._build_messages
prompt_str: str = example[self.prompt_key]
if self.format_prompt:
format_prompt = Template(self.format_prompt.strip())
prompt_str = format_prompt.render(content=prompt_str)
template:
{{ content | trim }} You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE put in \boxed{}.
jinjia模板
content
: 渲染进模板的变量。|
: 把前面的值交给后面的过滤器处理trim
: 等价于 Python 的str.strip()
:去掉字符串首尾的空格、制表符、换行符等空白字符。看起来只支持单论问答,将单论的prompt放进Template中进行render
# dataset.py ->RLHFDataset._build_messages
if self.image_key in example:
# https://huggingface.co/docs/transformers/en/tasks/image_text_to_text
content_list = []
for i, content in enumerate(prompt_str.split("<image>")):
if i != 0:
content_list.append({"type": "image"})
if content:
content_list.append({"type": "text", "text": content})
<image>
为图像占位符
2.2 数据集格式
if os.path.isdir(data_path):
# when we use dataset builder, we should always refer to the train split
file_type = os.path.splitext(os.listdir(data_path)[0])[-1][1:].replace("jsonl", "json")
self.dataset = load_dataset(file_type, data_dir=data_path, split=data_split)elif os.path.isfile(data_path):
file_type = os.path.splitext(data_path)[-1][1:].replace("jsonl", "json")
self.dataset = load_dataset(file_type, data_files=data_path, split=data_split)else:
# load remote dataset from huggingface hub
self.dataset = load_dataset(data_path, split=data_split)
数据集由
datasets.load_dataset
函数加载:https://huggingface.co/docs/datasets/loading支持csv、json、parquet等格式。
数据中必须包含包含三个key:problem
, answer
, images
,以json为例:
[
{
"problem": "请描述这张图片的内容:<image>",
"answer": "图片中是一只戴着墨镜的金毛犬。",
"images": ["dog.jpg"]
},
{
"problem": "<image> <image>\n请比较这两张照片的不同之处。",
"answer": "左图为雪山,右图为沙漠。",
"images": [
"images/mountain.png",
"images/desert.png"
],
"sample_id": 42 // 额外字段也会被保留
}
]
3 训练workflow
训练入口:verl/trainer/main.py
trainer = RayPPOTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
先 init
再fit
Trainer: verl/trainer/ray_trainer.py
3.1 trainer.init
if config.algorithm.adv_estimator == AdvantageEstimator.GAE:
self.use_critic = True
else:
self.use_critic = False
使用GRPO时,不涉及Value Model,从而: self.use_critic = False
if config.data.rollout_batch_size % config.worker.actor.global_batch_size != 0:
raise ValueError("Rollout batch size must be divisible by actor global batch size.")
保证 一次环境采样(batch) 拆分到 分布式 Actor 进程 时整除。
if (
config.data.rollout_batch_size * config.worker.rollout.n
) % config.worker.actor.micro_batch_size_per_device_for_experience != 0:
raise ValueError(
"Rollout batch size * rollout.n must be divisible by actor micro batch size for experience."
)
参数
rollout_batch_size
: 训练循环 每一步 要从数据集中取多少 prompt 进入环境(例如 128 条)。它决定了“这一步我们要收集多少经验”。rollout.n
: 每个 prompt 重复采样多少条响应n = 1
: 标准 PPO / GAEn >= 2
: GRPOmicro_batch_size_per_device_for_experience
: Actor 端 每块 GPU(或 CPU)一次 forward+backward 能接受的“经验样本”条数(比如 32)。 因为显存有限,往往用梯度累积把一个大 batch 切成多个 micro-batch。目的:采样得到的experience条目再聚合到 梯度累积 (micro-batch) 时也要整除,避免最后一个微批量尺寸不对引起 NCCL / FP16 确认错误。
3.2 trainer.fit
由于GRPO算法中不涉及Value model,但是训练框架中由于兼容其它强化学习算法,例如PPO,部分步骤涉及value model。我们忽略value model相关的步骤。
3.2.1 batch data
# verl/trainer/ray_trainer.py ->RayPPOTrainer.fit
with timer("gen", timing_raw):
self.actor_rollout_ref_wg.prepare_rollout_engine()
batch = self._make_batch_data(metrics=metrics)
self.actor_rollout_ref_wg.release_rollout_engine()
# balance the number of valid tokens on each dp rank.
# NOTE: this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
self._balance_batch(batch, metrics=metrics)
从dataloader中得到一个batch,这个流程包含两个步骤:_make_batch_data
和 _balance_batch
3.2.1.1 _make_batch_data
_make_batch_data
主函数
# verl/trainer/ray_trainer.py -> RayPPOTrainer._make_batch_data
def _make_batch_data(self, metrics: Dict[str, Any]) -> DataProto:
batch = None
while True:
try:
batch_dict = next(self.data_iterator)
except StopIteration:
self.data_iterator = iter(self.train_dataloader)
batch_dict = next(self.data_iterator)
meta_info = {"min_pixels": self.config.data.min_pixels, "max_pixels": self.config.data.max_pixels}
new_batch: DataProto = DataProto.from_single_dict(batch_dict, meta_info=meta_info)
# pop those keys for generation
gen_batch = new_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
meta_info_keys=["min_pixels", "max_pixels"],
)
# generate a batch
gen_batch_output = self.actor_rollout_ref_wg.generate_sequences(gen_batch)
new_batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object
)
# repeat to align with repeated responses in rollout
new_batch = new_batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True)
new_batch = new_batch.union(gen_batch_output)
batch = DataProto.concat([batch, new_batch]) if batch is not None else new_batch
if len(batch) < self.config.data.rollout_batch_size * self.config.worker.rollout.n:
continue
else:
return batch[: self.config.data.rollout_batch_size * self.config.worker.rollout.n]
Line 21:
gen_batch_output = self.actor_rollout_ref_wg.generate_sequences(gen_batch)
为每条 prompt 一次性采样
rollout.n
条 responseLine 23 - 25:
给每条样本分配唯一 id,便于后面日志或回放.
Line 27:
new_batch = new_batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True)
此时
new_batch
里仍只有 prompt_len 条元信息;如果直接union()
,行数对不上,DataProto 会报 shape 错误。需要去根据
rollout.n
去repeat,使prompt数量与生成的response数量对齐。
self.actor_rollout_ref_wg.generate_sequences
: response采样
# verl/workers/fsdp_workers.py -> FSDPWorker.generate_sequences
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
assert self._has_rollout
meta_info = {
"eos_token_id": self.generation_config.eos_token_id
if self.generation_config is not None
else self.tokenizer.eos_token_id,
"pad_token_id": self.generation_config.pad_token_id
if self.generation_config is not None
else self.tokenizer.pad_token_id,
}
prompts.meta_info.update(meta_info)
prompts = self.rollout_sharding_manager.preprocess_data(prompts)
output = self.rollout.generate_sequences(prompts=prompts)
output = self.rollout_sharding_manager.postprocess_data(output)
output = output.to("cpu")
return output
采样response包含三个步骤:(1)前处理prompts。(2)对prompt进行采样得到response。(3)后处理response。
前处理后处理都是和tp相关的一些简单操作,不做展开。
对prompt进行采样得到response
# verl/workers/rollout/vllm_rollout_spmd.py -> vLLMRollout.generate_sequences
@torch.no_grad()
def generate_sequences(self, prompts: DataProto) -> DataProto:
if self.rank == 0:
print("[Rollout] Start generating sequences.")
# left-padded attention_mask
input_ids: torch.Tensor = prompts.batch["input_ids"] # (bs, prompt_length)
attention_mask: torch.Tensor = prompts.batch["attention_mask"]
position_ids: torch.Tensor = prompts.batch["position_ids"]
eos_token_id: int = prompts.meta_info["eos_token_id"]
batch_size = input_ids.size(0)
non_tensor_batch = prompts.non_tensor_batch
batch_raw_prompt_ids = non_tensor_batch.pop("raw_prompt_ids")
batch_multi_modal_data = non_tensor_batch.get("multi_modal_data") # do not pop it
if batch_size != len(batch_raw_prompt_ids):
raise RuntimeError("vllm sharding manager is not work properly.")
if batch_multi_modal_data is not None:
min_pixels, max_pixels = prompts.meta_info["min_pixels"], prompts.meta_info["max_pixels"]
vllm_inputs = []
for raw_prompt_ids, multi_modal_data in zip(batch_raw_prompt_ids, batch_multi_modal_data):
vllm_inputs.append(
{
"prompt_token_ids": list(raw_prompt_ids),
"multi_modal_data": _process_multi_modal_data(multi_modal_data, min_pixels, max_pixels),
}
)
else:
vllm_inputs = [
{"prompt_token_ids": list(raw_prompt_ids)} for raw_prompt_ids in batch_raw_prompt_ids
]
# users can customize different sampling_params at different run
with self.update_sampling_params(**prompts.meta_info):
completions: List[RequestOutput] = self.inference_engine.generate(
prompts=vllm_inputs, sampling_params=self.sampling_params, use_tqdm=False
)
response_ids = [output.token_ids for completion in completions for output in completion.outputs]
response_ids = VF.pad_2d_list_to_length(
response_ids, self.pad_token_id, max_length=self.config.response_length
).to(input_ids.device)
if self.sampling_params.n > 1:
batch_size = batch_size * self.sampling_params.n
input_ids = _repeat_interleave(input_ids, self.sampling_params.n)
attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n)
position_ids = _repeat_interleave(position_ids, self.sampling_params.n)
if "multi_modal_data" in non_tensor_batch:
non_tensor_batch["multi_modal_data"] = _repeat_interleave(
non_tensor_batch["multi_modal_data"], self.sampling_params.n
)
sequence_ids = torch.cat([input_ids, response_ids], dim=-1)
response_length = response_ids.size(1)
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
delta_position_id = delta_position_id.view(1, -1).expand(batch_size, -1)
if position_ids.dim() == 3: # qwen2vl mrope
delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)
# prompt: left pad + response: right pad
# attention_mask: [0,0,0,0,1,1,1,1 | 1,1,1,0,0,0,0,0]
# position_ids: [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11]
response_position_ids = position_ids[..., -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
response_mask = VF.get_response_mask(
response_ids=response_ids, eos_token_id=eos_token_id, dtype=attention_mask.dtype
)
attention_mask = torch.cat((attention_mask, response_mask), dim=-1)
# all the tp ranks should contain the same data here. data in all ranks are valid
batch = TensorDict(
{
"prompts": input_ids,
"responses": response_ids,
"input_ids": sequence_ids, # here input_ids become the whole sentences
"attention_mask": attention_mask,
"response_mask": response_mask,
"position_ids": position_ids,
},
batch_size=batch_size,
)
if self.rank == 0:
print("[Rollout] Finish generating sequences.")
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch, meta_info=prompts.meta_info)
Line 20 - 29: _process_multi_modal_data
处理多模态(图像)数据。Line 46 - 53: 需要通过repeat的方式对齐prompt相关信息和采样出来的response。
_balance_batch
# # verl/trainer/ray_trainer.py -> RayPPOTrainer._balance_batch
def _balance_batch(self, batch: DataProto, metrics: Dict[str, Any], logging_prefix: str = "global_seqlen") -> None:
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""
attention_mask = batch.batch["attention_mask"]
batch_size = attention_mask.shape[0]
global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,)
world_size = self.actor_rollout_ref_wg.world_size
global_partition_lst = get_seqlen_balanced_partitions(
global_seqlen_lst, k_partitions=world_size, equal_size=True
)
# reorder based on index. The data will be automatically equally partitioned by dispatch function
global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
batch.reorder(global_idx)
global_balance_stats = log_seqlen_unbalance(
seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix
)
metrics.update(global_balance_stats)
Reorder the data on single controller such that each dp rank gets similar total tokens
3.2.2 compute reward
# verl/trainer/ray_trainer.py ->RayPPOTrainer.fit
if "token_level_scores" not in batch.batch:
with timer("reward", timing_raw):
reward_ref = self.reward_fn.compute_reward.remote(batch)
通过提前设置好的verifiable reward的计算函数计算batch的reward
3.2.3 recompute actor log_probs
# verl/trainer/ray_trainer.py -> RayPPOTrainer.fit
with timer("old", timing_raw):
old_log_probs = self.actor_rollout_ref_wg.compute_log_probs(batch)
batch = batch.union(old_log_probs)
# verl/workers/fsdp_workers.py -> FSDPWorker.compute_log_probs
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_probs(self, data: DataProto):
assert self._has_actor
self._process_multi_modal_inputs(data)
data = data.to(torch.cuda.current_device())
if self._use_param_offload:
load_fsdp_model(self.fsdp_module)
# we should always recompute old_log_probs when it is HybridEngine
data.meta_info["temperature"] = self.config.rollout.temperature
# perform recompute log_prob
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output = self.actor.compute_log_prob(data=data)
output = DataProto.from_dict(
tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature}
)
output = self.ulysses_sharding_manager.postprocess_data(output)
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1:
self.fsdp_module._handle.reshard(True)
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)
output = output.to("cpu")
return output
Compute the log probability of the responses given input_ids , attention_mask and position_ids
verl/workers/actor/dp_actor.py -> DataParallelPPOActor.compute_log_prob
:select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
3.2.4 compute ref_log_probs
与compute actor log_probs逻辑一致,只不过换成了ref model
# verl/trainer/ray_trainer.py -> RayPPOTrainer.fit
if self.use_reference_policy:
with timer("ref", timing_raw):
ref_log_probs = self.actor_rollout_ref_wg.compute_ref_log_probs(batch)
batch = batch.union(ref_log_probs)
# verl/workers/fsdp_workers.py -> FSDPWorker.compute_ref_log_probs
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_probs(self, data: DataProto):
assert self._has_ref
self._process_multi_modal_inputs(data)
data = data.to(torch.cuda.current_device())
if self._use_ref_param_offload:
load_fsdp_model(self.ref_fsdp_module)
data.meta_info["temperature"] = self.config.rollout.temperature
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output = self.ref_policy.compute_log_prob(data=data)
output = DataProto.from_dict(tensors={"ref_log_probs": output})
output = self.ulysses_sharding_manager.postprocess_data(output)
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1:
self.ref_fsdp_module._handle.reshard(True)
if self._use_ref_param_offload:
offload_fsdp_model(self.ref_fsdp_module)
output = output.to("cpu")
return output
3.2.5 compute_advantage
首先计算KL-divergence:
# verl/trainer/ray_trainer.py -> RayPPOTrainer.fit
if not self.config.algorithm.use_kl_loss and self.use_reference_policy:
# apply kl penalty to reward
batch, kl_metrics = apply_kl_penalty(batch, self.kl_ctrl, self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
# verl/trainer/ray_trainer.py -> apply_kl_penalty
def apply_kl_penalty(data: DataProto, kl_ctrl: KLController, kl_penalty="kl"):
token_level_scores = data.batch["token_level_scores"]
batch_size = data.batch.batch_size[0]
response_mask = data.batch["response_mask"]
# compute kl between ref_policy and current policy
kld = compute_kl(data.batch["old_log_probs"], data.batch["ref_log_probs"], kl_penalty=kl_penalty)
kld = kld * response_mask # (batch_size, response_length)
data.batch["token_level_rewards"] = token_level_scores - kl_ctrl.kl_coef * kld
current_kl = VF.masked_mean(kld, mask=response_mask, dim=-1) # average over sequence
current_kl = torch.mean(current_kl, dim=0).item()
metrics = {"critic/kl": current_kl, "critic/kl_coef": kl_ctrl.kl_coef}
# According to https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L880
kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
return data, metrics
# verl/trainer/core_algos.py -> compute_kl
def compute_kl(log_probs: torch.FloatTensor, ref_log_probs: torch.FloatTensor, kl_penalty: str) -> torch.Tensor:
"""Compute KL divergence given log_probs and ref_log_probs.
Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L1150
Args:
log_probs: torch.Tensor
ref_log_probs: torch.Tensor
kl_penalty: str
Returns:
kl_div: torch.Tensor
"""
log_probs, ref_log_probs = log_probs.float(), ref_log_probs.float()
if kl_penalty == "kl":
return log_probs - ref_log_probs
if kl_penalty == "abs":
return (log_probs - ref_log_probs).abs()
if kl_penalty == "mse":
return 0.5 * (log_probs - ref_log_probs).square()
# J. Schulman. Approximating kl divergence, 2020.
# URL http://joschu.net/blog/kl-approx.html
if kl_penalty == "low_var_kl":
# For numerical stability
kl = (ref_log_probs - log_probs).clamp(-20.0, 20.0)
kld = (kl.exp() - kl - 1).contiguous()
return torch.clamp(kld, min=-10.0, max=10.0)
if kl_penalty == "full":
return F.kl_div(ref_log_probs, log_probs, log_target=True, reduction="none").sum(-1)
raise NotImplementedError(f"Unknown KL penalty: {kl_penalty}.")
优先使用
low_var_kl
KL-devergence约束
不使用:
batch.batch["token_level_scores"] = reward_tensor
使用:
data.batch["token_level_rewards"] = token_level_scores - kl_ctrl.kl_coef * kld
计算advantage:
# verl/trainer/ray_trainer.py -> RayPPOTrainer.fit
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
)
# verl/trainer/ray_trainer.py -> compute_advantage
def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma: float = 1.0, lam: float = 1.0):
token_level_rewards = data.batch["token_level_rewards"]
response_mask = data.batch["response_mask"]
index = data.non_tensor_batch["uid"]
if adv_estimator == AdvantageEstimator.GAE:
values = data.batch["values"]
advantages, returns = core_algos.compute_gae_advantage_return(
token_level_rewards, values, response_mask, gamma, lam
)
elif adv_estimator == AdvantageEstimator.GRPO:
advantages, returns = core_algos.compute_grpo_outcome_advantage(token_level_rewards, response_mask, index)
elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS:
advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards, response_mask, gamma
)
elif adv_estimator == AdvantageEstimator.REMAX:
reward_baselines = data.batch["reward_baselines"]
advantages, returns = core_algos.compute_remax_outcome_advantage(
token_level_rewards, reward_baselines, response_mask
)
elif adv_estimator == AdvantageEstimator.RLOO:
advantages, returns = core_algos.compute_rloo_outcome_advantage(token_level_rewards, response_mask, index)
else:
raise NotImplementedError
data.batch["advantages"] = advantages
data.batch["returns"] = returns
return data
# verl/trainer/core_algos.py -> compute_grpo_outcome_advantage
@torch.no_grad()
def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6
) -> Tuple[torch.Tensor, torch.Tensor]:
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean, id2std = {}, {}
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
assert len(id2score[idx]) > 1, "GRPO needs rollout.n > 1."
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor(id2score[idx]))
for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps)
returns = scores.unsqueeze(-1) * response_mask
return returns, returns
3.2.6 update_actor
更新policy model:
# verl/trainer/ray_trainer.py -> RayPPOTrainer.fit
if self.config.trainer.critic_warmup <= self.global_step:
with timer("update_actor", timing_raw):
actor_output = self.actor_rollout_ref_wg.update_actor(batch)
actor_metrics = reduce_metrics(actor_output.non_tensor_batch)
metrics.update(actor_metrics)
# verl/workers/fsdp_workers.py -> FSDPWorker.update_actor
metrics = self.actor.update_policy(data=data)
# verl/workers/actor/dp_actor.py -> DataParallelPPOActor.update_policy
def update_policy(self, data: DataProto) -> Dict[str, Any]:
self.actor_module.train()
temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
select_keys.extend(["old_log_probs", "ref_log_probs", "advantages"])
non_tensor_select_keys = ["multi_modal_inputs"]
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
mini_batches = data.select(select_keys, non_tensor_select_keys).split(self.config.global_batch_size_per_device)
metrics = defaultdict(list)
for _ in range(self.config.ppo_epochs):
if self.rank == 0:
mini_batches = tqdm(mini_batches, desc="Train mini-batches", position=1)
for mini_batch in mini_batches:
gradient_accumulation = (
self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update
)
micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update)
if self.rank == 0:
micro_batches = tqdm(micro_batches, desc="Update policy", position=2)
for micro_batch in micro_batches:
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
responses = model_inputs["responses"]
response_length = responses.size(1)
attention_mask = model_inputs["attention_mask"]
response_mask = attention_mask[:, -response_length:]
old_log_probs = model_inputs["old_log_probs"]
advantages = model_inputs["advantages"]
# all return: (bsz, response_length)
log_probs = self._forward_micro_batch(model_inputs, temperature=temperature)
pg_loss, pg_metrics = core_algos.compute_policy_loss(
old_log_probs=old_log_probs,
log_probs=log_probs,
advantages=advantages,
response_mask=response_mask,
clip_ratio_low=self.config.clip_ratio_low,
clip_ratio_high=self.config.clip_ratio_high,
clip_ratio_dual=self.config.clip_ratio_dual,
loss_avg_mode=self.config.loss_avg_mode,
)
if self.config.use_kl_loss and "ref_log_probs" in model_inputs:
ref_log_probs = model_inputs["ref_log_probs"]
# compute kl loss
kld = core_algos.compute_kl(
log_probs=log_probs,
ref_log_probs=ref_log_probs,
kl_penalty=self.config.kl_penalty,
)
kl_loss = VF.masked_mean(kld, response_mask)
pg_loss = pg_loss + kl_loss * self.config.kl_coef
metrics["actor/kl_loss"] = kl_loss.detach().item()
metrics["actor/kl_coef"] = self.config.kl_coef
loss = pg_loss / gradient_accumulation
loss.backward()
batch_metrics = {
"actor/pg_loss": pg_loss.detach().item(),
"actor/pg_clipfrac_higher": pg_metrics["pg_clipfrac_higher"],
"actor/pg_clipfrac_lower": pg_metrics["pg_clipfrac_lower"],
"actor/entropy_loss": pg_metrics["entropy_loss"],
"actor/ppo_kl": pg_metrics["ppo_kl"],
}
append_to_dict(metrics, batch_metrics)
grad_norm = self._optimizer_step()
append_to_dict(metrics, {"actor/grad_norm": grad_norm.detach().item()})
return metrics
外层 for loop:
for _ in range(self.config.ppo_epochs)
PPO/GRPO Epoch 循环 。一次采样完毕后,可以把同一批数据反复训练 K 步 (通常 1-8 次)以充分提取信号,正是论文里的“multiple epochs”。
中层 for loop:
for mini_batch in mini_batches
遍历重排后的 mini-batch 对应 PPO 公式里“把大 batch 打乱、分批做 backward
内层 for loop:
for micro_batch in micro_batches
把 mini-batch 再切小,循环中
loss.backward()
之后 不立刻step()
,而是靠计数控制optimizer.step()
,等整个mini_batch
遍历完毕后,再step
;这样即使显存只能放 1/4 mini-batch 也能保持大 effective batch。
computer policy loss:
# verl/trainer/core_algos.py -> compute_policy_loss
def compute_policy_loss(
old_log_probs: torch.Tensor,
log_probs: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
clip_ratio_low: float,
clip_ratio_high: float,
clip_ratio_dual: float,
loss_avg_mode: Literal["token", "seq"],
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Compute the clipped policy objective and related metrics for PPO.
Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L568
Args:
old_log_prob: `(torch.Tensor)`
shape: (bs, response_length)
log_prob: `(torch.Tensor)`
shape: (bs, response_length)
advantages: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
clip_ratio_low: (float)
The lower clip range used in PPO. See https://arxiv.org/abs/1707.06347
clip_ratio_high: (float)
The higher clip range used in DAPO. See https://arxiv.org/pdf/2503.14476
clip_ratio_dual: (float)
The dual clip range used in Dual-clip PPO. See https://arxiv.org/pdf/1912.09729
loss_avg_mode: (Literal["token", "seq"])
"token": average the loss in the whole batch
"seq": average the loss in each sequence then average the mean of the means
Returns:
pg_loss: `a scalar torch.Tensor`
policy gradient loss computed via PPO
pg_clipfrac_higher: (float)
a float number indicating the fraction of policy gradient loss being clipped to a higher value
pg_clipfrac_lower: (float)
a float number indicating the fraction of policy gradient loss being clipped to a lower value
ppo_kl: (float)
a float number indicating the mean KL divergence between the old policy and the new policy
entropy_loss: (float)
a float number indicating the mean entropy loss
"""
negative_approx_kl = log_probs - old_log_probs
# clamp negative_approx_kl to avoid nan kld
negative_approx_kl = torch.clamp(negative_approx_kl, -20.0, 20.0)
ratio = torch.exp(negative_approx_kl)
# clamp the ratio before exp to avoid nan grad
# see: https://github.com/pytorch/pytorch/issues/10729
clipped_ratio = torch.exp(
torch.clamp(negative_approx_kl, np.log(1.0 - clip_ratio_low), np.log(1.0 + clip_ratio_high))
)
# pg metrics
metrics = {"ppo_kl": -negative_approx_kl}
# use negative log probs as an estimator of entropy loss
metrics["entropy_loss"] = average_loss(-log_probs, response_mask, mode=loss_avg_mode)
pg_loss = -advantages * ratio # -ratio * A
pg_loss2 = -advantages * clipped_ratio # -clip(ratio, 1-clip_low, 1+clip_high) * A
pg_loss3 = -advantages * clip_ratio_dual # -clip_dual * A
clipped_pg_loss_higher = torch.max(pg_loss, pg_loss2) # clip if pg_loss < pg_loss2
metrics["pg_clipfrac_higher"] = (pg_loss < pg_loss2).float()
clipped_pg_loss_lower = torch.min(clipped_pg_loss_higher, pg_loss3) # clip if pg_loss > pg_loss3 and adv < 0
# 根据advantage的符号选择是使用lower bound还是upper bound
final_pg_loss = torch.where(advantages < 0, clipped_pg_loss_lower, clipped_pg_loss_higher)
metrics["pg_clipfrac_lower"] = (clipped_pg_loss_higher > pg_loss3).float() * (advantages < 0).float()
final_pg_loss = average_loss(final_pg_loss, response_mask, mode=loss_avg_mode)
metrics = {k: VF.masked_mean(v, response_mask).detach().item() for k, v in metrics.items()}
return final_pg_loss, metrics
大模型之心Tech知识星球交流社区
我们创建了一个全新的学习社区 —— “大模型之心Tech”知识星球,希望能够帮你把复杂的东西拆开,揉碎,整合,帮你快速打通从0到1的技术路径。
星球内容包含:每日大模型相关论文/技术报告更新、分类汇总(开源repo、大模型预训练、后训练、知识蒸馏、量化、推理模型、MoE、强化学习、RAG、提示工程等多个版块)、科研/办公助手、AI创作工具/产品测评、升学&求职&岗位推荐,等等。
星球成员平均每天花费不到0.3元,加入后3天内不满意可随时退款,欢迎扫码加入一起学习一起卷!

- 点赞 0
-
分享
微信扫一扫
-
加入群聊
扫码加入群聊