This commit is contained in:
Tong Li 2025-03-06 16:26:14 +08:00
parent 0f566cc2d4
commit 0cc0c843ed
3 changed files with 18 additions and 4 deletions

View File

@ -357,7 +357,9 @@ def apply_chat_template_and_mask(
ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]:
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning.\n"
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"
system_element = {
"role": "system",

View File

@ -1,6 +1,6 @@
from contextlib import nullcontext
from typing import Any, Dict, Optional
import os
import ray
import ray.util.collective as cc
import torch
@ -33,6 +33,8 @@ class BaseConsumer:
model_config: Dict[str, Any],
plugin_config: Dict[str, Any],
microbatch_size: int = 1,
save_interval: int = 100,
save_dir: str = "./model"
):
self.num_producers = num_producers
self.num_episodes = num_episodes
@ -44,6 +46,8 @@ class BaseConsumer:
self.num_recv_per_update = num_recv_per_update
self.batch_size = batch_size
self.microbatch_size = microbatch_size
self.save_interval = save_interval
self.save_dir = save_dir
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // microbatch_size
@ -116,6 +120,14 @@ class BaseConsumer:
pbar.set_postfix({"loss": loss})
i += 1
assert len(self.buffer) == 0
if (step + 1) % self.save_interval == 0:
if self.rank == 0:
print(f"Start saving policy model at step {step + 1}.")
save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}")
self.booster.save_model(self.policy_model, save_path, shard=True)
if self.rank == 0:
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
state_dict = self.state_dict()

View File

@ -32,7 +32,7 @@ class GRPOConsumer(BaseConsumer):
plugin_config,
microbatch_size=1,
num_generations=4,
use_wandb=False,
use_wandb=True,
):
super().__init__(
num_producers,
@ -79,7 +79,7 @@ class GRPOConsumer(BaseConsumer):
self.policy_loss_fn = PolicyLoss()
self.global_step = 0
if self.rank == 0:
if use_wandb and self.rank == 0:
self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True)
def setup(self):