作者 | 404dreamer 编辑 | 大模型之心Tech

原文链接:https://zhuanlan.zhihu.com/p/1921972845669492319

点击下方卡片,关注“大模型之心Tech”公众号

戳我-> 领取大模型巨卷干货

>>点击进入→大模型没那么大Tech技术交流群

本文只做学术分享,如有侵权,联系删文,自动驾驶课程学习与技术交流群事宜,也欢迎添加小助理微信AIDriver004做进一步咨询

前言

由于笔者最近准备做一些关于GRPO在多模态大模型上的先导试验,冷启动阶段可以使用Llama-factory直接微调Qwen2.5-VL,但便开始寻找多模态大模型强化学习开源训练框架。

笔者发现网上并没有对多模态强化学习训练框架EasyR1代码的走读,便顺着启动入口文件一路读下来。

由于EasyR1的代码还在频繁更新中,笔者预估了一下,大概这篇文章指向的是6月10号左右的代码版本:

https://github.com/hiyouga/EasyR1/tree/bae1fa073555d7b65a9d6003ecf70d5bcb4feab5

EasyR1代码仓:

https://github.com/hiyouga/EasyR1

EasyR1是基于纯语言强化学习训练框架verl修改而来,verl:

https://github.com/volcengine/verl

希望这篇不算很详细的代码走读可以对科研人员有所帮助。如有理解或表达错误,欢迎告知。

1 启动脚本及配置文件

1.1 启动脚本

启动脚本位于 examples 目录中,我们以 examples/qwen2_5_7b_math_grpo.sh为例

 #!/bin/bash  
 set -x  
 export PYTHONUNBUFFERED=1  
 MODEL_PATH=Qwen/Qwen2.5-7B-Instruct  # replace it with your local file path  
 python3 -m verl.trainer.main \  
     config=examples/config.yaml \  
     worker.actor.model.model_path=${MODEL_PATH}
  • 通过 verl/trainer/main.py拉起训练任务
  • config文件路径必须指定。其它config文件中的参数,例如 worker.actor.model.model_path可以通过在启动脚本中传参的方式覆盖config文件中该变量。

1.2 配置文件

配置文件位于:examples/config.yaml。配置文件中的配置信息分为4类:dataalgorithmworkertrainer 配置项python类对应:verl/trainer/config.py

由于EasyR1的基于verl构建的,因此可以参考verl的config:

https://verl.readthedocs.io/en/latest/examples/config.html#ppo-trainer-yaml-for-rl-fsdp-backend

1.2.1 data相关配置

1.2.1.1 默认值

verl/trainer/config.py -> DataConfig

 @dataclass
 class DataConfig:  
     train_files: str = ""  
     val_files: str = ""  
     prompt_key: str = "prompt"  
     answer_key: str = "answer"  
     image_key: str = "images"  
     image_dir: Optional[str] = None  
     max_prompt_length: int = 512  
     max_response_length: int = 512  
     rollout_batch_size: int = 512  
     mini_rollout_batch_size: Optional[int] = None  
     val_batch_size: int = -1  
     format_prompt: Optional[str] = None  
     override_chat_template: Optional[str] = None  
     shuffle: bool = True  
     seed: int = 1  
     min_pixels: Optional[int] = 262144  
     max_pixels: Optional[int] = 4194304  
     filter_overlong_prompts: bool = True

1.2.1.2 config文件

examples/config.yaml

 data:  
   train_files: hiyouga/math12k@train  
   val_files: hiyouga/math12k@test  
   prompt_key: problem  
   answer_key: answer  
   image_key: images  
   image_dir: null  
   max_prompt_length: 2048  
   max_response_length: 2048  
   rollout_batch_size: 512  # equivalent to verl's data.train_batch_size  
   mini_rollout_batch_size: null  # equivalent to verl's data.gen_batch_size  
   val_batch_size: 1024  
   format_prompt: ./examples/format_prompt/math_format.jinja  
   override_chat_template: null  
   shuffle: true  
   seed: 1  
   min_pixels: 262144  
   max_pixels: 4194304  
   filter_overlong_prompts: true

1.2.1.3 data相关配置参数解读


1.2.2 algorithm相关配置

1.2.2.1 默认值

verl/trainer/config.py -> AlgorithmConfig

 class AlgorithmConfig:  
     gamma: float = 1.0  
     """discount factor for ppo gae advantage estimator"""  
     lam: float = 1.0  
     """lambda value for ppo gae advantage estimator"""  
     adv_estimator: str = "grpo"  
     """advantage estimator, support `gae`, `grpo`, `reinforce_plus_plus`, `remax`, `rloo`"""    
     disable_kl: bool = False  
     """disable reference model"""  
     use_kl_loss: bool = False  
     """use kl loss instead of kl in reward"""  
     kl_penalty: str = "kl"  
     """kl penalty type, support `kl`, `abs`, `mse`, `low_var_kl`, `full`"""    
     kl_coef: float = 1e-3  
     """kl coefficient"""  
     kl_type: str = "fixed"  
     """kl controller type, support `fixed`, `adaptive`"""    
     kl_horizon: float = 10000.0  
     """kl horizon for adaptive kl controller"""  
     kl_target: float = 0.1  
     """target kl for adaptive kl controller"""

1.2.2.2 config文件

 algorithm:  
   adv_estimator: grpo  
   disable_kl: false  
   use_kl_loss: true  
   kl_penalty: low_var_kl  
   kl_coef: 1.0e-2
  • 很多多模态强化学习的工作中都不使用KL divergence来约束模型更新。可以考虑真正训练时把 disable_kl设置为 true

1.2.3 worker相关配置

配置位置:verl/workers/config.py worker的配置又分为5个子配置:actorcriticrefrewardrollout

1.2.3.1 actor相关配置

位置:verl/workers/actor/config.py

 @dataclass  
 class ActorConfig:  
     strategy: str = "fsdp"  
     global_batch_size: int = 256  
     micro_batch_size_per_device_for_update: int = 4  
     micro_batch_size_per_device_for_experience: int = 16  
     max_grad_norm: float = 1.0  
     clip_ratio_low: float = 0.2  
     """clip ratio in PPO & DAPO"""  
     clip_ratio_high: float = 0.3  
     """clip ratio in PPO & DAPO"""  
     clip_ratio_dual: float = 3.0  
     """constant C in dual-clip PPO, clips when advantage < -C"""  
     loss_avg_mode: str = "token"  
     """loss average mode: `token`, `seq`"""    
     ppo_epochs: int = 1  
     """number of ppo epochs for each rollout batch"""  
     padding_free: bool = True  
     """use padding-free training"""  
     ulysses_size: int = 1  
     """ulysses sequence parallel size"""  
     use_torch_compile: bool = True  
     model: ModelConfig = field(default_factory=ModelConfig)  
     optim: OptimConfig = field(default_factory=OptimConfig)    
     fsdp: FSDPConfig = field(default_factory=FSDPConfig)    
     offload: OffloadConfig = field(default_factory=OffloadConfig)    
     """auto keys"""  
     global_batch_size_per_device: int = field(default=-1, init=False)  
     disable_kl: bool = field(default=False, init=False)  
     use_kl_loss: bool = field(default=False, init=False)  
     kl_penalty: str = field(default="kl", init=False)  
     kl_coef: float = field(default=0.0, init=False)
  • 在这个文件目录下还要关于actor model, optimizer相关的参数可以参考

config示例:

 actor:  
   global_batch_size: 128  # equivalent to verl's actor.ppo_mini_batch_size  
   micro_batch_size_per_device_for_update: 4  # equivalent to verl's actor.ppo_micro_batch_size_per_gpu  
   micro_batch_size_per_device_for_experience: 16  # equivalent to verl's rollout.log_prob_micro_batch_size_per_gpu  
   max_grad_norm: 1.0  
   padding_free: true  
   ulysses_size: 1  
   model:  
     model_path: Qwen/Qwen2.5-7B-Instruct  
     enable_gradient_checkpointing: true  
     trust_remote_code: false  
     freeze_vision_tower: false  
   optim:  
     lr: 1.0e-6  
     weight_decay: 1.0e-2  
     strategy: adamw  # {adamw, adamw_bf16}  
     lr_warmup_ratio: 0.0  
   fsdp:  
     enable_full_shard: true  
     enable_cpu_offload: false  
     enable_rank0_init: true  
   offload:  
     offload_params: true  # true: more CPU memory; false: more GPU memory  
     offload_optimizer: true  # true: more CPU memory; false: more GPU memory

1.2.3.2 critic相关配置

由于GRPO算法不涉及Value Model,所以无需配置Critic参数。  位置:verl/workers/critic/config.py

 @dataclass  
 class CriticConfig:  
     strategy: str = "fsdp"  
     global_batch_size: int = 256  
     micro_batch_size_per_device_for_update: int = 4  
     micro_batch_size_per_device_for_experience: int = 16  
     max_grad_norm: float = 1.0  
     cliprange_value: float = 0.5  
     """clip range for value loss"""  
     loss_avg_mode: str = "token"  
     """loss average mode: `token`, `seq`"""    
     ppo_epochs: int = 1  
     """number of ppo epochs for each rollout batch"""  
     padding_free: bool = False  
     """use padding-free training"""  
     ulysses_size: int = 1  
     """ulysses sequence parallel size"""  
     model: ModelConfig = field(default_factory=ModelConfig)  
     optim: OptimConfig = field(default_factory=OptimConfig)    
     fsdp: FSDPConfig = field(default_factory=FSDPConfig)    
     offload: OffloadConfig = field(default_factory=OffloadConfig)    """auto keys"""  
     global_batch_size_per_device: int = field(default=-1, init=False)
 

1.2.3.3 ref相关配置

位置:verl/workers/actor/config.py

@dataclass  
class RefConfig:  
    strategy: str = "fsdp"  
    fsdp: FSDPConfig = field(default_factory=FSDPConfig)  
    offload: OffloadConfig = field(default_factory=OffloadConfig)    """auto keys"""  
    micro_batch_size_per_device_for_experience: int = field(default=-1, init=False)  
    padding_free: bool = field(default=False, init=False)  
    ulysses_size: int = field(default=1, init=False)  
    use_torch_compile: bool = field(default=True, init=False)

1.2.3.4 reward相关配置

config设置:

reward:  
  reward_type: batch  
  reward_function: ./examples/reward_function/math.py:compute_score

位置:verl/workers/reward/config.py

@dataclass  
class RewardConfig:  
    reward_type: str = "batch"  
    reward_function: Optional[str] = None  
    reward_function_kwargs: dict = field(default_factory=dict)  
    skip_special_tokens: bool = True  
    num_cpus: int = 1  
    """auto keys"""  
    reward_function_name: Optional[str] = field(default=None, init=False)  
  
    def post_init(self):  
        if self.reward_function is not None:  # support custom reward function, e.g., ./math.py:main  
            if ":" not in self.reward_function:  
                self.reward_function_name = "main"  
            else:  
                self.reward_function, self.reward_function_name = self.reward_function.rsplit(":", maxsplit=1)  
  
            if os.path.exists(self.reward_function):  # ray job uses absolute path  
                self.reward_function = os.path.abspath(self.reward_function)  
            else:  
                print(f"Reward function {self.reward_function} not found.")  
                self.reward_function = None
  • post_init中根据传入的reward_function: ./examples/reward_function/math.py:compute_score,设置self.reward_function为reward文件位置,self.reward_function_name为对应的reward函数

1.2.3.5 rollout相关config

config设置:

rollout:  
  n: 5 # 对同一条 prompt 并行采样 n 条序列  
  temperature: 1.0  
  top_p: 0.99  
  gpu_memory_utilization: 0.6  
  enforce_eager: false  
  enable_chunked_prefill: false  
  tensor_parallel_size: 2  
  limit_images: 0  
  val_override_config:  
    temperature: 0.5  
    n: 1

位置:verl/workers/rollout/config.py

@dataclass  
class RolloutConfig:  
    name: str = "vllm"  
    n: int = 1  
    temperature: float = 1.0  
    top_p: float = 1.0  
    top_k: int = -1  
    seed: int = 1  
    limit_images: int = 0  
    dtype: str = "bf16"  
    gpu_memory_utilization: float = 0.6  
    ignore_eos: bool = False  
    enforce_eager: bool = False  
    enable_chunked_prefill: bool = False  # only for v0 engine  
    tensor_parallel_size: int = 2  
    max_model_len: Optional[int] = None  
    max_num_batched_tokens: int = 8192  
    disable_log_stats: bool = True  
    val_override_config: Dict[str, Any] = field(default_factory=dict)  
    """auto keys"""  
    prompt_length: int = field(default=-1, init=False)  
    response_length: int = field(default=-1, init=False)  
    trust_remote_code: bool = field(default=False, init=False)  
  
    def to_dict(self):  
        return asdict(self)

2 数据集相关

2.1 数据集相关flow

主程序,创建dataloader:

# verl/trainer/main.py
train_dataloader, val_dataloader = create_dataloader(config.data, tokenizer, processor)

dataset: verl/utils/dataset.py -> _build_messages__getitem__

# dataset.py ->RLHFDataset._build_messages
prompt_str: str = example[self.prompt_key]  
if self.format_prompt:  
    format_prompt = Template(self.format_prompt.strip())    
    prompt_str = format_prompt.render(content=prompt_str)

template:

{{ content | trim }} You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE put in \boxed{}.
  • jinjia模板

  • content: 渲染进模板的变量。

  • |: 把前面的值交给后面的过滤器处理

  • trim: 等价于 Python 的 str.strip():去掉字符串首尾的空格、制表符、换行符等空白字符。

  • 看起来只支持单论问答,将单论的prompt放进Template中进行render

# dataset.py ->RLHFDataset._build_messages
if self.image_key in example:  
    # https://huggingface.co/docs/transformers/en/tasks/image_text_to_text  
    content_list = []  
    for i, content in enumerate(prompt_str.split("<image>")):  
        if i != 0:  
            content_list.append({"type""image"})  
  
        if content:  
            content_list.append({"type""text""text": content})
  • <image>为图像占位符

2.2 数据集格式

if os.path.isdir(data_path):  
    # when we use dataset builder, we should always refer to the train split  
    file_type = os.path.splitext(os.listdir(data_path)[0])[-1][1:].replace("jsonl""json")  
    self.dataset = load_dataset(file_type, data_dir=data_path, split=data_split)elif os.path.isfile(data_path):  
    file_type = os.path.splitext(data_path)[-1][1:].replace("jsonl""json")  
    self.dataset = load_dataset(file_type, data_files=data_path, split=data_split)else:  
    # load remote dataset from huggingface hub  
    self.dataset = load_dataset(data_path, split=data_split)
  • 数据集由 datasets.load_dataset 函数加载:https://huggingface.co/docs/datasets/loading

  • 支持csv、json、parquet等格式。

数据中必须包含包含三个key:problemanswerimages,以json为例:

[
  {
    "problem""请描述这张图片的内容:<image>",
    "answer""图片中是一只戴着墨镜的金毛犬。",
    "images": ["dog.jpg"]
  },
  {
    "problem""<image> <image>\n请比较这两张照片的不同之处。",
    "answer""左图为雪山,右图为沙漠。",
    "images": [
      "images/mountain.png",
      "images/desert.png"
    ],
    "sample_id"42              // 额外字段也会被保留
  }
]

3 训练workflow

训练入口:verl/trainer/main.py

        trainer = RayPPOTrainer(
            config=config,
            tokenizer=tokenizer,
            processor=processor,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            role_worker_mapping=role_worker_mapping,
            resource_pool_manager=resource_pool_manager,
            ray_worker_group_cls=ray_worker_group_cls,
            reward_fn=reward_fn,
            val_reward_fn=val_reward_fn,
        )
        trainer.init_workers()
        trainer.fit()
  • 先 init 再 fit

Trainer: verl/trainer/ray_trainer.py

3.1 trainer.init

if config.algorithm.adv_estimator == AdvantageEstimator.GAE:
    self.use_critic = True
else:
    self.use_critic = False
  • 使用GRPO时,不涉及Value Model,从而:self.use_critic = False
if config.data.rollout_batch_size % config.worker.actor.global_batch_size != 0:
    raise ValueError("Rollout batch size must be divisible by actor global batch size.")
  • 保证  一次环境采样(batch)  拆分到  分布式 Actor 进程  时整除。
if (
    config.data.rollout_batch_size * config.worker.rollout.n
) % config.worker.actor.micro_batch_size_per_device_for_experience != 0:
    raise ValueError(
        "Rollout batch size * rollout.n must be divisible by actor micro batch size for experience."
    )
  • 参数

  • rollout_batch_size : 训练循环  每一步  要从数据集中取多少  prompt  进入环境(例如 128 条)。它决定了“这一步我们要收集多少经验”。

  • rollout.n : 每个 prompt  重复采样多少条响应

  • n = 1: 标准 PPO / GAE

  • n >= 2: GRPO

  • micro_batch_size_per_device_for_experience: Actor 端  每块 GPU(或 CPU)一次 forward+backward  能接受的“经验样本”条数(比如 32)。 因为显存有限,往往用梯度累积把一个大 batch 切成多个 micro-batch。

  • 目的:采样得到的experience条目再聚合到  梯度累积 (micro-batch)  时也要整除,避免最后一个微批量尺寸不对引起 NCCL / FP16 确认错误。

3.2 trainer.fit

由于GRPO算法中不涉及Value model,但是训练框架中由于兼容其它强化学习算法,例如PPO,部分步骤涉及value model。我们忽略value model相关的步骤。

3.2.1 batch data

# verl/trainer/ray_trainer.py ->RayPPOTrainer.fit
with timer("gen", timing_raw):
 self.actor_rollout_ref_wg.prepare_rollout_engine()
 batch = self._make_batch_data(metrics=metrics)
 self.actor_rollout_ref_wg.release_rollout_engine()
# balance the number of valid tokens on each dp rank.
NOTE: this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
self._balance_batch(batch, metrics=metrics)

从dataloader中得到一个batch,这个流程包含两个步骤:_make_batch_data 和 _balance_batch

3.2.1.1 _make_batch_data

_make_batch_data 主函数

# verl/trainer/ray_trainer.py -> RayPPOTrainer._make_batch_data
def _make_batch_data(self, metrics: Dict[str, Any]) -> DataProto:
    batch = None
    while True:
        try:
            batch_dict = next(self.data_iterator)
        except StopIteration:
            self.data_iterator = iter(self.train_dataloader)
            batch_dict = next(self.data_iterator)

            meta_info = {"min_pixels": self.config.data.min_pixels, "max_pixels": self.config.data.max_pixels}
            new_batch: DataProto = DataProto.from_single_dict(batch_dict, meta_info=meta_info)

            # pop those keys for generation
            gen_batch = new_batch.pop(
                batch_keys=["input_ids""attention_mask""position_ids"],
                non_tensor_batch_keys=["raw_prompt_ids""multi_modal_data"],
                meta_info_keys=["min_pixels""max_pixels"],
            )

            # generate a batch
            gen_batch_output = self.actor_rollout_ref_wg.generate_sequences(gen_batch)

            new_batch.non_tensor_batch["uid"] = np.array(
                [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object
            )
            # repeat to align with repeated responses in rollout
            new_batch = new_batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True)
            new_batch = new_batch.union(gen_batch_output)

            batch = DataProto.concat([batch, new_batch]) if batch is not None else new_batch
            if len(batch) < self.config.data.rollout_batch_size * self.config.worker.rollout.n:
                continue
            else:
                return batch[: self.config.data.rollout_batch_size * self.config.worker.rollout.n]
  • Line 21: gen_batch_output = self.actor_rollout_ref_wg.generate_sequences(gen_batch)

  • 为每条 prompt  一次性采样 rollout.n 条 response

  • Line 23 - 25:

  • 给每条样本分配唯一 id,便于后面日志或回放.

  • Line 27: new_batch = new_batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True)

  • 此时 new_batch 里仍只有  prompt_len  条元信息;如果直接 union(),行数对不上,DataProto 会报 shape 错误。

  • 需要去根据 rollout.n 去repeat,使prompt数量与生成的response数量对齐。

self.actor_rollout_ref_wg.generate_sequences: response采样

# verl/workers/fsdp_workers.py -> FSDPWorker.generate_sequences
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
    assert self._has_rollout

    meta_info = {
        "eos_token_id": self.generation_config.eos_token_id
        if self.generation_config is not None
        else self.tokenizer.eos_token_id,
        "pad_token_id": self.generation_config.pad_token_id
        if self.generation_config is not None
        else self.tokenizer.pad_token_id,
    }
    prompts.meta_info.update(meta_info)

    prompts = self.rollout_sharding_manager.preprocess_data(prompts)
    output = self.rollout.generate_sequences(prompts=prompts)
    output = self.rollout_sharding_manager.postprocess_data(output)

    output = output.to("cpu")
    return output
  • 采样response包含三个步骤:(1)前处理prompts。(2)对prompt进行采样得到response。(3)后处理response。

  • 前处理后处理都是和tp相关的一些简单操作,不做展开。

对prompt进行采样得到response

# verl/workers/rollout/vllm_rollout_spmd.py -> vLLMRollout.generate_sequences
@torch.no_grad()
def generate_sequences(self, prompts: DataProto) -> DataProto:
    if self.rank == 0:
        print("[Rollout] Start generating sequences.")

    # left-padded attention_mask
    input_ids: torch.Tensor = prompts.batch["input_ids"]  # (bs, prompt_length)
    attention_mask: torch.Tensor = prompts.batch["attention_mask"]
    position_ids: torch.Tensor = prompts.batch["position_ids"]
    eos_token_id: int = prompts.meta_info["eos_token_id"]
    batch_size = input_ids.size(0)

    non_tensor_batch = prompts.non_tensor_batch
    batch_raw_prompt_ids = non_tensor_batch.pop("raw_prompt_ids")
    batch_multi_modal_data = non_tensor_batch.get("multi_modal_data")  # do not pop it
    if batch_size != len(batch_raw_prompt_ids):
        raise RuntimeError("vllm sharding manager is not work properly.")

    if batch_multi_modal_data is not None:
        min_pixels, max_pixels = prompts.meta_info["min_pixels"], prompts.meta_info["max_pixels"]
        vllm_inputs = []
        for raw_prompt_ids, multi_modal_data in zip(batch_raw_prompt_ids, batch_multi_modal_data):
            vllm_inputs.append(
                {
                    "prompt_token_ids": list(raw_prompt_ids),
                    "multi_modal_data": _process_multi_modal_data(multi_modal_data, min_pixels, max_pixels),
                }
            )
    else:
        vllm_inputs = [
            {"prompt_token_ids": list(raw_prompt_ids)} for raw_prompt_ids in batch_raw_prompt_ids
        ]

    # users can customize different sampling_params at different run
    with self.update_sampling_params(**prompts.meta_info):
        completions: List[RequestOutput] = self.inference_engine.generate(
            prompts=vllm_inputs, sampling_params=self.sampling_params, use_tqdm=False
        )
        response_ids = [output.token_ids for completion in completions for output in completion.outputs]
        response_ids = VF.pad_2d_list_to_length(
            response_ids, self.pad_token_id, max_length=self.config.response_length
        ).to(input_ids.device)

        if self.sampling_params.n > 1:
            batch_size = batch_size * self.sampling_params.n
            input_ids = _repeat_interleave(input_ids, self.sampling_params.n)
            attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n)
            position_ids = _repeat_interleave(position_ids, self.sampling_params.n)
            if "multi_modal_data" in non_tensor_batch:
                non_tensor_batch["multi_modal_data"] = _repeat_interleave(
                    non_tensor_batch["multi_modal_data"], self.sampling_params.n
                )

    sequence_ids = torch.cat([input_ids, response_ids], dim=-1)
    response_length = response_ids.size(1)
    delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
    delta_position_id = delta_position_id.view(1-1).expand(batch_size, -1)
    if position_ids.dim() == 3:  # qwen2vl mrope
        delta_position_id = delta_position_id.view(batch_size, 1-1).expand(batch_size, 3-1)

    # prompt: left pad + response: right pad
    # attention_mask: [0,0,0,0,1,1,1,1 | 1,1,1,0,0,0,0,0]
    # position_ids:   [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11]
    response_position_ids = position_ids[..., -1:] + delta_position_id
    position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
    response_mask = VF.get_response_mask(
        response_ids=response_ids, eos_token_id=eos_token_id, dtype=attention_mask.dtype
    )
    attention_mask = torch.cat((attention_mask, response_mask), dim=-1)

    # all the tp ranks should contain the same data here. data in all ranks are valid
    batch = TensorDict(
        {
            "prompts": input_ids,
            "responses": response_ids,
            "input_ids": sequence_ids,  # here input_ids become the whole sentences
            "attention_mask": attention_mask,
            "response_mask": response_mask,
            "position_ids": position_ids,
        },
        batch_size=batch_size,
    )
    if self.rank == 0:
        print("[Rollout] Finish generating sequences.")

    return DataProto(batch=batch, non_tensor_batch=non_tensor_batch, meta_info=prompts.meta_info)
  • Line 20 - 29: _process_multi_modal_data 处理多模态(图像)数据。
  • Line 46 - 53: 需要通过repeat的方式对齐prompt相关信息和采样出来的response。

_balance_batch

# # verl/trainer/ray_trainer.py -> RayPPOTrainer._balance_batch
def _balance_batch(self, batch: DataProto, metrics: Dict[str, Any], logging_prefix: str = "global_seqlen") -> None:
    """Reorder the data on single controller such that each dp rank gets similar total tokens"""
    attention_mask = batch.batch["attention_mask"]
    batch_size = attention_mask.shape[0]
    global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist()  # (train_batch_size,)
    world_size = self.actor_rollout_ref_wg.world_size
    global_partition_lst = get_seqlen_balanced_partitions(
        global_seqlen_lst, k_partitions=world_size, equal_size=True
    )
    # reorder based on index. The data will be automatically equally partitioned by dispatch function
    global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
    batch.reorder(global_idx)
    global_balance_stats = log_seqlen_unbalance(
        seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix
    )
    metrics.update(global_balance_stats)
  • Reorder the data on single controller such that each dp rank gets similar total tokens

3.2.2 compute reward

# verl/trainer/ray_trainer.py ->RayPPOTrainer.fit
if "token_level_scores" not in batch.batch:
 with timer("reward", timing_raw):
  reward_ref = self.reward_fn.compute_reward.remote(batch)
  • 通过提前设置好的verifiable reward的计算函数计算batch的reward

3.2.3 recompute actor log_probs

# verl/trainer/ray_trainer.py -> RayPPOTrainer.fit
with timer("old", timing_raw):
    old_log_probs = self.actor_rollout_ref_wg.compute_log_probs(batch)
    batch = batch.union(old_log_probs)

# verl/workers/fsdp_workers.py -> FSDPWorker.compute_log_probs
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_probs(self, data: DataProto):
    assert self._has_actor

    self._process_multi_modal_inputs(data)
    data = data.to(torch.cuda.current_device())

    if self._use_param_offload:
        load_fsdp_model(self.fsdp_module)

    # we should always recompute old_log_probs when it is HybridEngine
    data.meta_info["temperature"] = self.config.rollout.temperature
    # perform recompute log_prob
    with self.ulysses_sharding_manager:
        data = self.ulysses_sharding_manager.preprocess_data(data)
        output = self.actor.compute_log_prob(data=data)
        output = DataProto.from_dict(
            tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature}
        )
        output = self.ulysses_sharding_manager.postprocess_data(output)

    # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
    # unshard the root FSDP module
    if self.world_size > 1:
        self.fsdp_module._handle.reshard(True)

    if self._use_param_offload:
        offload_fsdp_model(self.fsdp_module)

    output = output.to("cpu")
    return output
  • Compute the  log probability of the responses  given  input_ids ,  attention_mask  and  position_ids

  • verl/workers/actor/dp_actor.py -> DataParallelPPOActor.compute_log_prob:

  • select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]

3.2.4 compute ref_log_probs

与compute actor log_probs逻辑一致,只不过换成了ref model

# verl/trainer/ray_trainer.py -> RayPPOTrainer.fit
if self.use_reference_policy:
    with timer("ref", timing_raw):
        ref_log_probs = self.actor_rollout_ref_wg.compute_ref_log_probs(batch)
        batch = batch.union(ref_log_probs)

# verl/workers/fsdp_workers.py -> FSDPWorker.compute_ref_log_probs
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_probs(self, data: DataProto):
    assert self._has_ref

    self._process_multi_modal_inputs(data)
    data = data.to(torch.cuda.current_device())

    if self._use_ref_param_offload:
        load_fsdp_model(self.ref_fsdp_module)

    data.meta_info["temperature"] = self.config.rollout.temperature
    with self.ulysses_sharding_manager:
        data = self.ulysses_sharding_manager.preprocess_data(data)
        output = self.ref_policy.compute_log_prob(data=data)
        output = DataProto.from_dict(tensors={"ref_log_probs": output})
        output = self.ulysses_sharding_manager.postprocess_data(output)

    # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
    # unshard the root FSDP module
    if self.world_size > 1:
        self.ref_fsdp_module._handle.reshard(True)

    if self._use_ref_param_offload:
        offload_fsdp_model(self.ref_fsdp_module)

    output = output.to("cpu")
    return output

3.2.5 compute_advantage

首先计算KL-divergence:

# verl/trainer/ray_trainer.py -> RayPPOTrainer.fit
if not self.config.algorithm.use_kl_loss and self.use_reference_policy:
    # apply kl penalty to reward
    batch, kl_metrics = apply_kl_penalty(batch, self.kl_ctrl, self.config.algorithm.kl_penalty)
    metrics.update(kl_metrics)
else:
    batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]

# verl/trainer/ray_trainer.py -> apply_kl_penalty
def apply_kl_penalty(data: DataProto, kl_ctrl: KLController, kl_penalty="kl"):
    token_level_scores = data.batch["token_level_scores"]
    batch_size = data.batch.batch_size[0]
    response_mask = data.batch["response_mask"]

    # compute kl between ref_policy and current policy
    kld = compute_kl(data.batch["old_log_probs"], data.batch["ref_log_probs"], kl_penalty=kl_penalty)
    kld = kld * response_mask  # (batch_size, response_length)

    data.batch["token_level_rewards"] = token_level_scores - kl_ctrl.kl_coef * kld

    current_kl = VF.masked_mean(kld, mask=response_mask, dim=-1)  # average over sequence
    current_kl = torch.mean(current_kl, dim=0).item()
    metrics = {"critic/kl": current_kl, "critic/kl_coef": kl_ctrl.kl_coef}

    # According to https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L880
    kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
    return data, metrics

# verl/trainer/core_algos.py -> compute_kl
def compute_kl(log_probs: torch.FloatTensor, ref_log_probs: torch.FloatTensor, kl_penalty: str) -> torch.Tensor:
    """Compute KL divergence given log_probs and ref_log_probs.

    Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L1150

    Args:
        log_probs: torch.Tensor
        ref_log_probs: torch.Tensor
        kl_penalty: str

    Returns:
        kl_div: torch.Tensor

    """

    log_probs, ref_log_probs = log_probs.float(), ref_log_probs.float()
    if kl_penalty == "kl":
        return log_probs - ref_log_probs
    
    if kl_penalty == "abs":
        return (log_probs - ref_log_probs).abs()

    if kl_penalty == "mse":
        return 0.5 * (log_probs - ref_log_probs).square()

    # J. Schulman. Approximating kl divergence, 2020.
    # URL http://joschu.net/blog/kl-approx.html
    if kl_penalty == "low_var_kl":
        # For numerical stability
        kl = (ref_log_probs - log_probs).clamp(-20.020.0)
        kld = (kl.exp() - kl - 1).contiguous()
        return torch.clamp(kld, min=-10.0, max=10.0)

    if kl_penalty == "full":
        return F.kl_div(ref_log_probs, log_probs, log_target=True, reduction="none").sum(-1)

    raise NotImplementedError(f"Unknown KL penalty: {kl_penalty}.")
  • 优先使用low_var_kl

  • KL-devergence约束

  • 不使用:batch.batch["token_level_scores"] = reward_tensor

  • 使用:data.batch["token_level_rewards"] = token_level_scores - kl_ctrl.kl_coef * kld

计算advantage:

# verl/trainer/ray_trainer.py -> RayPPOTrainer.fit
batch = compute_advantage(
    batch,
    adv_estimator=self.config.algorithm.adv_estimator,
    gamma=self.config.algorithm.gamma,
    lam=self.config.algorithm.lam,
)

# verl/trainer/ray_trainer.py -> compute_advantage
def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma: float = 1.0, lam: float = 1.0):
    token_level_rewards = data.batch["token_level_rewards"]
    response_mask = data.batch["response_mask"]
    index = data.non_tensor_batch["uid"]
    if adv_estimator == AdvantageEstimator.GAE:
        values = data.batch["values"]
        advantages, returns = core_algos.compute_gae_advantage_return(
            token_level_rewards, values, response_mask, gamma, lam
        )
    elif adv_estimator == AdvantageEstimator.GRPO:
        advantages, returns = core_algos.compute_grpo_outcome_advantage(token_level_rewards, response_mask, index)
    elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS:
        advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
            token_level_rewards, response_mask, gamma
        )
    elif adv_estimator == AdvantageEstimator.REMAX:
        reward_baselines = data.batch["reward_baselines"]
        advantages, returns = core_algos.compute_remax_outcome_advantage(
            token_level_rewards, reward_baselines, response_mask
        )
    elif adv_estimator == AdvantageEstimator.RLOO:
        advantages, returns = core_algos.compute_rloo_outcome_advantage(token_level_rewards, response_mask, index)
    else:
        raise NotImplementedError

    data.batch["advantages"] = advantages
    data.batch["returns"] = returns
    return data

# verl/trainer/core_algos.py -> compute_grpo_outcome_advantage
@torch.no_grad()
def compute_grpo_outcome_advantage(
    token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6
)
 -> Tuple[torch.Tensor, torch.Tensor]:


    scores = token_level_rewards.sum(dim=-1)
    id2score = defaultdict(list)
    id2mean, id2std = {}, {}

    bsz = scores.shape[0]
    for i in range(bsz):
        id2score[index[i]].append(scores[i])

    for idx in id2score:
        assert len(id2score[idx]) > 1"GRPO needs rollout.n > 1."
        id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
        id2std[idx] = torch.std(torch.tensor(id2score[idx]))

    for i in range(bsz):
        scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps)

    returns = scores.unsqueeze(-1) * response_mask
    return returns, returns

3.2.6 update_actor

更新policy model:

# verl/trainer/ray_trainer.py -> RayPPOTrainer.fit
if self.config.trainer.critic_warmup <= self.global_step:
    with timer("update_actor", timing_raw):
        actor_output = self.actor_rollout_ref_wg.update_actor(batch)

    actor_metrics = reduce_metrics(actor_output.non_tensor_batch)
    metrics.update(actor_metrics)

# verl/workers/fsdp_workers.py -> FSDPWorker.update_actor 
metrics = self.actor.update_policy(data=data)
# verl/workers/actor/dp_actor.py -> DataParallelPPOActor.update_policy
def update_policy(self, data: DataProto) -> Dict[str, Any]:
    self.actor_module.train()

    temperature = data.meta_info["temperature"]  # temperature must be in the data.meta_info to avoid slient error
    select_keys = ["responses""input_ids""attention_mask""position_ids"]
    select_keys.extend(["old_log_probs""ref_log_probs""advantages"])
    non_tensor_select_keys = ["multi_modal_inputs"]

    # Split to make minibatch iterator for updating the actor
    # See PPO paper for details. https://arxiv.org/abs/1707.06347
    mini_batches = data.select(select_keys, non_tensor_select_keys).split(self.config.global_batch_size_per_device)

    metrics = defaultdict(list)
    for _ in range(self.config.ppo_epochs):
        if self.rank == 0:
            mini_batches = tqdm(mini_batches, desc="Train mini-batches", position=1)

        for mini_batch in mini_batches:
            gradient_accumulation = (
                self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update
            )
            micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update)
            if self.rank == 0:
                micro_batches = tqdm(micro_batches, desc="Update policy", position=2)

            for micro_batch in micro_batches:
                model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
                responses = model_inputs["responses"]
                response_length = responses.size(1)
                attention_mask = model_inputs["attention_mask"]
                response_mask = attention_mask[:, -response_length:]
                old_log_probs = model_inputs["old_log_probs"]
                advantages = model_inputs["advantages"]

                # all return: (bsz, response_length)
                log_probs = self._forward_micro_batch(model_inputs, temperature=temperature)

                pg_loss, pg_metrics = core_algos.compute_policy_loss(
                    old_log_probs=old_log_probs,
                    log_probs=log_probs,
                    advantages=advantages,
                    response_mask=response_mask,
                    clip_ratio_low=self.config.clip_ratio_low,
                    clip_ratio_high=self.config.clip_ratio_high,
                    clip_ratio_dual=self.config.clip_ratio_dual,
                    loss_avg_mode=self.config.loss_avg_mode,
                )
                if self.config.use_kl_loss and "ref_log_probs" in model_inputs:
                    ref_log_probs = model_inputs["ref_log_probs"]
                    # compute kl loss
                    kld = core_algos.compute_kl(
                        log_probs=log_probs,
                        ref_log_probs=ref_log_probs,
                        kl_penalty=self.config.kl_penalty,
                    )
                    kl_loss = VF.masked_mean(kld, response_mask)
                    pg_loss = pg_loss + kl_loss * self.config.kl_coef
                    metrics["actor/kl_loss"] = kl_loss.detach().item()
                    metrics["actor/kl_coef"] = self.config.kl_coef

                loss = pg_loss / gradient_accumulation
                loss.backward()

                batch_metrics = {
                    "actor/pg_loss": pg_loss.detach().item(),
                    "actor/pg_clipfrac_higher": pg_metrics["pg_clipfrac_higher"],
                    "actor/pg_clipfrac_lower": pg_metrics["pg_clipfrac_lower"],
                    "actor/entropy_loss": pg_metrics["entropy_loss"],
                    "actor/ppo_kl": pg_metrics["ppo_kl"],
                }
                append_to_dict(metrics, batch_metrics)

            grad_norm = self._optimizer_step()
            append_to_dict(metrics, {"actor/grad_norm": grad_norm.detach().item()})

    return metrics
  • 外层 for loop: for _ in range(self.config.ppo_epochs)

  • PPO/GRPO Epoch 循环 。一次采样完毕后,可以把同一批数据反复训练  K 步 (通常 1-8 次)以充分提取信号,正是论文里的“multiple epochs”。

  • 中层 for loop: for mini_batch in mini_batches

  • 遍历重排后的 mini-batch 对应 PPO 公式里“把大 batch 打乱、分批做 backward

  • 内层 for loop: for micro_batch in micro_batches

  • 把 mini-batch 再切小,循环中 loss.backward() 之后  不立刻 step() ,而是靠计数控制 optimizer.step(),等整个mini_batch遍历完毕后,再step;这样即使显存只能放 1/4 mini-batch 也能保持大 effective batch。

computer policy loss:

 # verl/trainer/core_algos.py -> compute_policy_loss
 def compute_policy_loss(
     old_log_probs: torch.Tensor,
     log_probs: torch.Tensor,
     advantages: torch.Tensor,
     response_mask: torch.Tensor,
     clip_ratio_low: float,
     clip_ratio_high: float,
     clip_ratio_dual: float,
     loss_avg_mode: Literal["token""seq"],
 )
 -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:

     """Compute the clipped policy objective and related metrics for PPO.
 
     Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L568
 
     Args:
         old_log_prob: `(torch.Tensor)`
             shape: (bs, response_length)
         log_prob: `(torch.Tensor)`
             shape: (bs, response_length)
         advantages: `(torch.Tensor)`
             shape: (bs, response_length)
         response_mask: `(torch.Tensor)`
             shape: (bs, response_length)
         clip_ratio_low: (float)
             The lower clip range used in PPO. See https://arxiv.org/abs/1707.06347
         clip_ratio_high: (float)
             The higher clip range used in DAPO. See https://arxiv.org/pdf/2503.14476
         clip_ratio_dual: (float)
             The dual clip range used in Dual-clip PPO. See https://arxiv.org/pdf/1912.09729
         loss_avg_mode: (Literal["token", "seq"])
             "token": average the loss in the whole batch
             "seq": average the loss in each sequence then average the mean of the means
 
     Returns:
         pg_loss: `a scalar torch.Tensor`
             policy gradient loss computed via PPO
         pg_clipfrac_higher: (float)
             a float number indicating the fraction of policy gradient loss being clipped to a higher value
         pg_clipfrac_lower: (float)
             a float number indicating the fraction of policy gradient loss being clipped to a lower value
         ppo_kl: (float)
             a float number indicating the mean KL divergence between the old policy and the new policy
         entropy_loss: (float)
             a float number indicating the mean entropy loss
 
     """

     negative_approx_kl = log_probs - old_log_probs
     # clamp negative_approx_kl to avoid nan kld
     negative_approx_kl = torch.clamp(negative_approx_kl, -20.020.0)
     ratio = torch.exp(negative_approx_kl)
     # clamp the ratio before exp to avoid nan grad
     # see: https://github.com/pytorch/pytorch/issues/10729
     clipped_ratio = torch.exp(
         torch.clamp(negative_approx_kl, np.log(1.0 - clip_ratio_low), np.log(1.0 + clip_ratio_high))
     )
 
     # pg metrics
     metrics = {"ppo_kl": -negative_approx_kl}
     # use negative log probs as an estimator of entropy loss
     metrics["entropy_loss"] = average_loss(-log_probs, response_mask, mode=loss_avg_mode)
 
     pg_loss = -advantages * ratio  # -ratio * A
     pg_loss2 = -advantages * clipped_ratio  # -clip(ratio, 1-clip_low, 1+clip_high) * A
     pg_loss3 = -advantages * clip_ratio_dual  # -clip_dual * A
 
     clipped_pg_loss_higher = torch.max(pg_loss, pg_loss2)  # clip if pg_loss < pg_loss2
     metrics["pg_clipfrac_higher"] = (pg_loss < pg_loss2).float()
     clipped_pg_loss_lower = torch.min(clipped_pg_loss_higher, pg_loss3)  # clip if pg_loss > pg_loss3 and adv < 0
     # 根据advantage的符号选择是使用lower bound还是upper bound
     final_pg_loss = torch.where(advantages < 0, clipped_pg_loss_lower, clipped_pg_loss_higher)
     metrics["pg_clipfrac_lower"] = (clipped_pg_loss_higher > pg_loss3).float() * (advantages < 0).float()
 
     final_pg_loss = average_loss(final_pg_loss, response_mask, mode=loss_avg_mode)
     metrics = {k: VF.masked_mean(v, response_mask).detach().item() for k, v in metrics.items()}
     return final_pg_loss, metrics

大模型之心Tech知识星球交流社区

我们创建了一个全新的学习社区 —— “大模型之心Tech”知识星球,希望能够帮你把复杂的东西拆开,揉碎,整合,帮你快速打通从0到1的技术路径。 

星球内容包含:每日大模型相关论文/技术报告更新、分类汇总(开源repo、大模型预训练后训练知识蒸馏量化推理模型MoE强化学习RAG提示工程等多个版块)、科研/办公助手AI创作工具/产品测评、升学&求职&岗位推荐,等等。

星球成员平均每天花费不到0.3元,加入后3天内不满意可随时退款,欢迎扫码加入一起学习一起卷!