This commit is contained in:
Tong Li 2025-03-06 16:26:14 +08:00
parent 0f566cc2d4
commit 0cc0c843ed
3 changed files with 18 additions and 4 deletions

View File

@ -357,7 +357,9 @@ def apply_chat_template_and_mask(
ignore_idx: int = -100, ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning.\n" system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"
system_element = { system_element = {
"role": "system", "role": "system",

View File

@ -1,6 +1,6 @@
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import os
import ray import ray
import ray.util.collective as cc import ray.util.collective as cc
import torch import torch
@ -33,6 +33,8 @@ class BaseConsumer:
model_config: Dict[str, Any], model_config: Dict[str, Any],
plugin_config: Dict[str, Any], plugin_config: Dict[str, Any],
microbatch_size: int = 1, microbatch_size: int = 1,
save_interval: int = 100,
save_dir: str = "./model"
): ):
self.num_producers = num_producers self.num_producers = num_producers
self.num_episodes = num_episodes self.num_episodes = num_episodes
@ -44,6 +46,8 @@ class BaseConsumer:
self.num_recv_per_update = num_recv_per_update self.num_recv_per_update = num_recv_per_update
self.batch_size = batch_size self.batch_size = batch_size
self.microbatch_size = microbatch_size self.microbatch_size = microbatch_size
self.save_interval = save_interval
self.save_dir = save_dir
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size" assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // microbatch_size self.num_microbatches = batch_size // microbatch_size
@ -116,6 +120,14 @@ class BaseConsumer:
pbar.set_postfix({"loss": loss}) pbar.set_postfix({"loss": loss})
i += 1 i += 1
assert len(self.buffer) == 0 assert len(self.buffer) == 0
if (step + 1) % self.save_interval == 0:
if self.rank == 0:
print(f"Start saving policy model at step {step + 1}.")
save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}")
self.booster.save_model(self.policy_model, save_path, shard=True)
if self.rank == 0:
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
state_dict = self.state_dict() state_dict = self.state_dict()

View File

@ -32,7 +32,7 @@ class GRPOConsumer(BaseConsumer):
plugin_config, plugin_config,
microbatch_size=1, microbatch_size=1,
num_generations=4, num_generations=4,
use_wandb=False, use_wandb=True,
): ):
super().__init__( super().__init__(
num_producers, num_producers,
@ -79,7 +79,7 @@ class GRPOConsumer(BaseConsumer):
self.policy_loss_fn = PolicyLoss() self.policy_loss_fn = PolicyLoss()
self.global_step = 0 self.global_step = 0
if self.rank == 0: if use_wandb and self.rank == 0:
self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True) self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True)
def setup(self): def setup(self):