Merge pull request #6309 from hpcaitech/grpo-eval-dev

[feat] Support evaluation during training
This commit is contained in:
YeAnbang 2025-05-16 16:11:23 +08:00 committed by GitHub
commit 3c42c0ce82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 312 additions and 87 deletions

1
.gitignore vendored
View File

@ -165,3 +165,4 @@ applications/ColossalChat/logs
applications/ColossalChat/tests/logs applications/ColossalChat/tests/logs
applications/ColossalChat/wandb applications/ColossalChat/wandb
applications/ColossalChat/model applications/ColossalChat/model
applications/ColossalChat/eval

View File

@ -93,7 +93,6 @@ class BaseConsumer:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
self.buffer = [] self.buffer = []
self.recv_cnt = 0 self.recv_cnt = 0
def state_dict(self) -> Dict[str, torch.Tensor]: def state_dict(self) -> Dict[str, torch.Tensor]:
@ -209,6 +208,8 @@ class SimpleConsumer(BaseConsumer):
model_config, model_config,
plugin_config, plugin_config,
minibatch_size, minibatch_size,
save_interval,
save_dir,
) )
path = model_config.pop("path") path = model_config.pop("path")
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)

View File

@ -33,12 +33,13 @@ class GRPOConsumer(BaseConsumer):
plugin_config, plugin_config,
minibatch_size=1, minibatch_size=1,
num_generations=8, num_generations=8,
use_wandb=True,
generate_config=None, generate_config=None,
grpo_config={}, grpo_config={},
project_name=None,
save_interval: int = 100, save_interval: int = 100,
save_dir="./model", save_dir="./model",
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
): ):
print(f"Using GRPO config: {grpo_config}") print(f"Using GRPO config: {grpo_config}")
if ( if (
@ -84,6 +85,9 @@ class GRPOConsumer(BaseConsumer):
self.effective_sample_count = 0 self.effective_sample_count = 0
self.effective_prompt_count = 0 self.effective_prompt_count = 0
self.total_sample_count = 0 self.total_sample_count = 0
self.project_name = project_name
self.run_name = run_name
self.wandb_group_name = wandb_group_name
self.policy_loss_fn = PolicyLoss( self.policy_loss_fn = PolicyLoss(
clip_eps_low=grpo_config.get("clip_eps_low", 0.2), clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
@ -134,7 +138,6 @@ class GRPOConsumer(BaseConsumer):
**reward_model_kwargs, **reward_model_kwargs,
) )
self.global_step = 0 self.global_step = 0
self.use_wandb = use_wandb
self.lr_scheduler = CosineAnnealingWarmupLR( self.lr_scheduler = CosineAnnealingWarmupLR(
optimizer=self.optimizer, optimizer=self.optimizer,
@ -145,13 +148,16 @@ class GRPOConsumer(BaseConsumer):
def setup(self): def setup(self):
super().setup() super().setup()
if self.use_wandb and ( if (not self.plugin.pp_size > 1 and self.rank == 0) or (
(not self.plugin.pp_size > 1 and self.rank == 0) self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
): ):
# Initialize wandb. self.wandb_run = wandb.init(
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" project=self.project_name,
self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name) sync_tensorboard=False,
dir="./wandb",
name=self.run_name,
group=self.wandb_group_name,
)
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
@ -265,7 +271,6 @@ class GRPOConsumer(BaseConsumer):
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
self.effective_sample_count += effective_samples.item() self.effective_sample_count += effective_samples.item()
self.total_sample_count += total_samples.item() self.total_sample_count += total_samples.item()
pbar.set_postfix( pbar.set_postfix(
{ {
"Global Step": self.global_step, "Global Step": self.global_step,
@ -506,7 +511,7 @@ class GRPOConsumer(BaseConsumer):
} }
if self.policy_loss_fn.beta > 0: if self.policy_loss_fn.beta > 0:
metrics["train/kl"] = self.accum_kl.item() / self.accum_count metrics["train/kl"] = self.accum_kl.item() / self.accum_count
if self.wandb_run is not None:
self.wandb_run.log(metrics) self.wandb_run.log(metrics)
self.accum_loss.zero_() self.accum_loss.zero_()
self.accum_reward.zero_() self.accum_reward.zero_()
@ -521,7 +526,6 @@ class GRPOConsumer(BaseConsumer):
# All gather excessive prompts index across DP ranks. # All gather excessive prompts index across DP ranks.
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx] excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin) excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
return loss_scalar, excessive_prompts_idx return loss_scalar, excessive_prompts_idx
else: else:
return None, excessive_prompts_idx return None, excessive_prompts_idx
@ -530,4 +534,5 @@ class GRPOConsumer(BaseConsumer):
self.policy_model._force_wait_all_gather() self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap() model = self.policy_model.unwrap()
state_dict = model.state_dict() state_dict = model.state_dict()
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
return state_dict return state_dict

View File

@ -205,7 +205,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
generate_config = generate_config.copy() generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG) generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations}) generate_config.update({"n": num_generations})
self.generate_config = SamplingParams(**generate_config) self.generate_config = generate_config
self.sample_params = SamplingParams(**generate_config)
self.model_config = model_config self.model_config = model_config
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.num_generations = num_generations self.num_generations = num_generations
@ -219,8 +220,9 @@ class VLLMInferenceBackend(BaseInferenceBackend):
micro_batch_input_ids_no_padding = [ micro_batch_input_ids_no_padding = [
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size) micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
] ]
sample_params = kwargs.get("sample_params", self.sample_params)
outputs = self.llm.generate( outputs = self.llm.generate(
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False
) )
out_tokens = [] out_tokens = []
out_len = [] out_len = []
@ -236,7 +238,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
log_probs.append(p) log_probs.append(p)
# pad them # pad them
max_len = self.generate_config.max_tokens max_len = self.sample_params.max_tokens
action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype) action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)
for i, new_token_ids in enumerate(out_tokens): for i, new_token_ids in enumerate(out_tokens):
@ -266,11 +268,11 @@ class VLLMInferenceBackend(BaseInferenceBackend):
"response_idx": response_idx, "response_idx": response_idx,
} }
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()} data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}
if "gt_answer" in kwargs: if "gt_answer" in kwargs:
# repeat gt_answer for each prompt. # repeat gt_answer for each prompt.
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1) data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
data = {k: v.to(get_current_device()) for k, v in data.items()} data = {k: v.to(get_current_device()) for k, v in data.items()}
return data return data

View File

@ -1,4 +1,5 @@
import copy import copy
import uuid
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import ray import ray
@ -34,7 +35,7 @@ def launch_distributed(
inference_microbatch_size: int, inference_microbatch_size: int,
train_batch_size: int, train_batch_size: int,
train_minibatch_size: int, train_minibatch_size: int,
dataset_config: Dict[str, Any], train_dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any], dataloaders_config: Dict[str, Any],
inference_model_config: Dict[str, Any], inference_model_config: Dict[str, Any],
generate_config: Dict[str, Any], generate_config: Dict[str, Any],
@ -50,8 +51,11 @@ def launch_distributed(
project_name: Optional[str] = None, project_name: Optional[str] = None,
save_interval: int = 100, save_interval: int = 100,
save_dir: str = "./model", save_dir: str = "./model",
eval_dataset_config: Optional[Dict[str, Any]] = None,
eval_interval: int = 100,
eval_save_dir: Optional[str] = None,
eval_generation_config: Optional[Dict[str, Any]] = None,
): ):
if core_algo not in ALGO_MAP: if core_algo not in ALGO_MAP:
raise NotImplementedError(f"{core_algo} is not supported yet.") raise NotImplementedError(f"{core_algo} is not supported yet.")
else: else:
@ -60,12 +64,15 @@ def launch_distributed(
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config) train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0 assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
dataset_path = dataset_config["path"] dataset_path = train_dataset_config["path"]
num_samples = get_jsonl_size_fast(dataset_path) num_samples = get_jsonl_size_fast(dataset_path)
global_inference_batch_size = inference_batch_size * num_producers global_inference_batch_size = inference_batch_size * num_producers
num_update_per_episode = num_samples // global_inference_batch_size num_update_per_episode = num_samples // global_inference_batch_size
num_recv_per_update = inference_batch_size // inference_microbatch_size num_recv_per_update = inference_batch_size // inference_microbatch_size
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
wandb_group_name = str(uuid.uuid4())
procs = [] procs = []
for i in range(num_producers): for i in range(num_producers):
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote( producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
@ -74,7 +81,7 @@ def launch_distributed(
num_consumer_procs=num_consumer_procs, num_consumer_procs=num_consumer_procs,
num_episodes=num_episodes, num_episodes=num_episodes,
batch_size=inference_batch_size, batch_size=inference_batch_size,
dataset_config=dataset_config, train_dataset_config=train_dataset_config,
dataloaders_config=dataloaders_config, dataloaders_config=dataloaders_config,
model_config=inference_model_config, model_config=inference_model_config,
generate_config=generate_config, generate_config=generate_config,
@ -83,6 +90,14 @@ def launch_distributed(
backend=inference_backend, backend=inference_backend,
num_generations=num_generations, num_generations=num_generations,
consumer_plugin_config=plugin_config, consumer_plugin_config=plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
evaluation_function_type=grpo_config["reward_fn_type"],
eval_save_dir=eval_save_dir,
eval_generation_config=eval_generation_config,
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
) )
procs.append(producer) procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config) generate_config_consumer = copy.deepcopy(generate_config)
@ -108,9 +123,11 @@ def launch_distributed(
generate_config=generate_config_consumer, generate_config=generate_config_consumer,
grpo_config=grpo_config, grpo_config=grpo_config,
num_generations=num_generations, num_generations=num_generations,
project_name=project_name,
save_interval=save_interval, save_interval=save_interval,
save_dir=save_dir, save_dir=save_dir,
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
) )
procs.append(consumer) procs.append(consumer)
ray.get([p.setup.remote() for p in procs]) ray.get([p.setup.remote() for p in procs])

View File

@ -1,9 +1,16 @@
import copy
import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import ray import ray
import ray.util.collective as cc import ray.util.collective as cc
import torch import torch
import tqdm
import wandb
from coati.dataset.loader import RawConversationDataset from coati.dataset.loader import RawConversationDataset
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from ray.util.collective import allreduce
from ray.util.collective.types import Backend, ReduceOp
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -11,7 +18,12 @@ from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict from .comm import ray_broadcast_tensor_dict
from .inference_backend import BACKEND_MAP from .inference_backend import BACKEND_MAP
from .utils import pre_send from .utils import pre_send, safe_append_to_jsonl_file
try:
from vllm import SamplingParams
except ImportError:
LLM = None
class BaseProducer: class BaseProducer:
@ -22,7 +34,7 @@ class BaseProducer:
num_consumer_procs: int, num_consumer_procs: int,
num_episodes: int, num_episodes: int,
batch_size: int, batch_size: int,
dataset_config: Dict[str, Any], train_dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any], dataloaders_config: Dict[str, Any],
model_config: Dict[str, Any], model_config: Dict[str, Any],
generate_config: Dict[str, Any], generate_config: Dict[str, Any],
@ -30,6 +42,14 @@ class BaseProducer:
microbatch_size: int = 1, microbatch_size: int = 1,
backend: str = "transformers", backend: str = "transformers",
consumer_plugin_config: Dict[str, Any] = None, consumer_plugin_config: Dict[str, Any] = None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags",
eval_save_dir: str = "./eval",
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
wandb_log_rollout_interval: int = 20,
): ):
self.producer_idx = producer_idx self.producer_idx = producer_idx
self.num_producers = num_producers self.num_producers = num_producers
@ -39,11 +59,31 @@ class BaseProducer:
self.microbatch_size = microbatch_size self.microbatch_size = microbatch_size
assert batch_size % microbatch_size == 0 assert batch_size % microbatch_size == 0
self.num_microbatches = batch_size // microbatch_size self.num_microbatches = batch_size // microbatch_size
self.latest_eval_step = -1
self.dataset_config = dataset_config self.train_dataset_config = train_dataset_config
self.model_config = model_config self.model_config = model_config
self.generate_config = generate_config self.generate_config = generate_config
self.tokenizer_config = tokenizer_config self.tokenizer_config = tokenizer_config
self.consumer_plugin_config = consumer_plugin_config
self.eval_interval = eval_interval
self.eval_save_dir = eval_save_dir
self.consumer_global_step = 0
self.eval_mode = False
self.wandb_rollout_data = []
self.wandb_log_rollout_interval = wandb_log_rollout_interval
self.latest_rollout_log_step = -1
if self.producer_idx == 0:
self.wandb_run = wandb.init(
project=project_name,
sync_tensorboard=False,
dir="./wandb",
name=run_name + "_eval",
group=wandb_group_name,
)
if os.path.exists(self.eval_save_dir) and self.eval_interval > 0:
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
# init tokenizer # init tokenizer
if tokenizer_config is None: if tokenizer_config is None:
@ -55,13 +95,13 @@ class BaseProducer:
self.tokenizer.padding_side = "left" self.tokenizer.padding_side = "left"
# init dataloader # init dataloader
dataset_path = dataset_config.pop("path") train_dataset_path = train_dataset_config.pop("path")
self.dataset = RawConversationDataset(self.tokenizer, dataset_path, **dataset_config) self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
self.dataloader = DataLoader( self.train_dataloader = DataLoader(
self.dataset, self.train_dataset,
batch_size=microbatch_size, batch_size=microbatch_size,
sampler=DistributedSampler( sampler=DistributedSampler(
self.dataset, self.train_dataset,
num_replicas=num_producers, num_replicas=num_producers,
rank=producer_idx, rank=producer_idx,
shuffle=True, shuffle=True,
@ -71,6 +111,36 @@ class BaseProducer:
num_workers=4, num_workers=4,
drop_last=True, drop_last=True,
) )
self.eval_dataset_config = eval_dataset_config
if self.eval_dataset_config is not None:
self.eval_dataloaders = {}
for eval_task_name in self.eval_dataset_config:
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")
eval_dataset = RawConversationDataset(
self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name]
)
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}")
self.eval_dataloaders[eval_task_name] = DataLoader(
eval_dataset,
batch_size=microbatch_size,
sampler=DistributedSampler(
eval_dataset,
num_replicas=num_producers,
rank=producer_idx,
shuffle=False,
drop_last=False,
seed=42,
),
)
if evaluation_function_type == "think_answer_tags":
self.evaluation_function = math_reward_fn
elif evaluation_function_type == "boxed":
self.evaluation_function = boxed_math_reward_fn
else:
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
else:
raise ValueError("eval_dataset_config is not defined")
self.device = get_current_device() self.device = get_current_device()
# init backend # init backend
@ -82,6 +152,12 @@ class BaseProducer:
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
def setup(self) -> None: def setup(self) -> None:
cc.init_collective_group(
world_size=self.num_producers,
rank=self.producer_idx,
backend=Backend.NCCL,
group_name="producer_group",
)
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
if self.consumer_pp_size > 1: if self.consumer_pp_size > 1:
for i in range(self.consumer_pp_size): for i in range(self.consumer_pp_size):
@ -96,29 +172,74 @@ class BaseProducer:
raise NotImplementedError raise NotImplementedError
def loop(self) -> None: def loop(self) -> None:
num_update_per_episode = len(self.dataloader) // self.num_microbatches num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
num_valid_microbatches = num_update_per_episode * self.num_microbatches num_valid_microbatches = num_update_per_episode * self.num_microbatches
print( print(
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}" f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}"
) )
for episode in range(self.num_episodes): for episode in range(self.num_episodes):
self.dataloader.sampler.set_epoch(episode) self.train_dataloader.sampler.set_epoch(episode)
for i, batch in enumerate(self.dataloader): for i, batch in enumerate(self.train_dataloader):
if i >= num_valid_microbatches: if i >= num_valid_microbatches:
break break
if self.eval_interval > 0 and self.eval_dataset_config is not None:
if (
self.consumer_global_step - self.latest_eval_step >= self.eval_interval
and self.consumer_global_step > self.latest_eval_step
) or self.latest_eval_step == -1:
to_log_msg = {}
self.eval_mode = True
for eval_task_name in self.eval_dataloaders:
if self.producer_idx == 0:
print(
f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}"
)
eval_results = []
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
for eval_batch in tqdm.tqdm(
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
):
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
eval_results = eval_results + [
self.evaluation_function(
eval_outputs["input_ids"][m][n],
eval_outputs["gt_answer"][m][n],
eval_outputs["response_idx"][m][n],
tokenizer=self.tokenizer,
eval_mode=True,
)
for m in range(eval_outputs["input_ids"].size(0))
for n in range(eval_outputs["input_ids"].size(1))
]
eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1])
eval_statistics_tensor[1] += len(eval_results)
allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group")
to_log_msg[f"eval/{eval_task_name}"] = (
eval_statistics_tensor[0].item() / eval_statistics_tensor[1].item()
)
if self.producer_idx == 0:
print(
f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}"
)
# save eval results
safe_append_to_jsonl_file(
os.path.join(
self.eval_save_dir,
f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl",
),
eval_results,
)
if self.producer_idx == 0:
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
self.eval_mode = False
self.latest_eval_step = self.consumer_global_step
outputs = self.rollout(**batch) outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs["temperature"] = torch.tensor( outputs["temperature"] = torch.tensor(
[ [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
(
self.model.generate_config["temperature"]
if isinstance(self.model.generate_config.temperature, dict)
else self.model.generate_config.temperature
)
]
* outputs["input_ids"].size(0)
).to(outputs["input_ids"].device) ).to(outputs["input_ids"].device)
outputs = pre_send(outputs) outputs = pre_send(outputs)
ray_broadcast_tensor_dict( ray_broadcast_tensor_dict(
@ -142,6 +263,8 @@ class BaseProducer:
state_dict = ray_broadcast_tensor_dict( state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}" None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
) )
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict) self.load_state_dict(state_dict)
else: else:
print( print(
@ -150,6 +273,8 @@ class BaseProducer:
state_dict = ray_broadcast_tensor_dict( state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model" None, self.num_producers, device=self.device, group_name="sync_model"
) )
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict) self.load_state_dict(state_dict)
del state_dict del state_dict
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -159,13 +284,12 @@ class BaseProducer:
self.model.llm.wake_up() self.model.llm.wake_up()
# linear annealing for 1 episode, temperature from initial to 0.9 # linear annealing for 1 episode, temperature from initial to 0.9
if episode <= 0: if episode <= 0:
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
if isinstance(self.model.generate_config.temperature, dict):
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
"temperature" "temperature"
] + ratio * 0.9 ] + ratio * 0.9
else: if isinstance(self.model, BACKEND_MAP["vllm"]):
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
"temperature" "temperature"
] + ratio * 0.9 ] + ratio * 0.9
@ -179,7 +303,7 @@ class SimpleProducer(BaseProducer):
num_consumer_procs, num_consumer_procs,
num_episodes, num_episodes,
batch_size, batch_size,
dataset_config, train_dataset_config,
dataloaders_config, dataloaders_config,
model_config, model_config,
generate_config, generate_config,
@ -188,6 +312,14 @@ class SimpleProducer(BaseProducer):
backend="transformers", backend="transformers",
num_generations: int = 8, num_generations: int = 8,
consumer_plugin_config=None, consumer_plugin_config=None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags",
eval_save_dir: str = "./eval",
eval_generation_config={},
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
): ):
super().__init__( super().__init__(
producer_idx, producer_idx,
@ -195,7 +327,7 @@ class SimpleProducer(BaseProducer):
num_consumer_procs, num_consumer_procs,
num_episodes, num_episodes,
batch_size, batch_size,
dataset_config, train_dataset_config,
dataloaders_config, dataloaders_config,
model_config, model_config,
generate_config, generate_config,
@ -203,15 +335,43 @@ class SimpleProducer(BaseProducer):
microbatch_size, microbatch_size,
backend, backend,
consumer_plugin_config, consumer_plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
evaluation_function_type=evaluation_function_type,
eval_save_dir=eval_save_dir,
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
) )
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
self.eval_generation_config.update(eval_generation_config)
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
@torch.no_grad() @torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs): def rollout(self, input_ids, attention_mask, **kwargs):
rollouts = self.model.generate(input_ids, attention_mask, **kwargs) rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
if self.producer_idx == 1: if self.producer_idx == 0 and not self.eval_mode:
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) wandb_rollout_data = self.wandb_rollout_data + [
[
str(self.consumer_global_step),
str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)),
]
]
if (
self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval
or self.latest_rollout_log_step == -1
):
self.wandb_rollout_data = wandb_rollout_data
self.latest_rollout_log_step = self.consumer_global_step
self.wandb_run.log(
{
"rollout/rollout_examples": wandb.Table(
columns=["train_step", "rollout_examples"], data=wandb_rollout_data
)
}
)
return rollouts return rollouts
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):

View File

@ -5,6 +5,7 @@ from .reward_utils import extract_boxed_solution, extract_solution, validate_res
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"] tokenizer = kwargs["tokenizer"]
eval_mode = kwargs.get("eval_mode", False)
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
acc_score = 10.0 acc_score = 10.0
reward = torch.tensor(0.0) reward = torch.tensor(0.0)
@ -44,36 +45,23 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
reward = reward + length_reward reward = reward + length_reward
if not eval_mode:
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
def gsm8k_reward_fn(input_ids, **kwargs):
gt_answer = kwargs["gt_answer"]
tokenizer = kwargs["tokenizer"]
s, e = kwargs["response_start"], kwargs["response_end"]
reward = torch.tensor(0.0).to(input_ids.device)
if gt_answer is None:
return reward
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
final_answer, processed_str = extract_solution(decoded_final_answer)
is_valid = True
try:
int(final_answer.strip())
except Exception:
is_valid = False
format_valid = validate_response_structure(processed_str, kwargs["tags"])
if not is_valid or not format_valid:
return reward
else: else:
reward += 1.0 prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True)
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): return {
reward = reward + 9.0 "prompt": prompt,
return reward "prediction": decoded_final_answer,
"gold": gt_answer,
"parsed": final_answer,
"format_valid": format_acc.item(),
"ans_valid": ans_acc.item(),
}
def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"] tokenizer = kwargs["tokenizer"]
eval_mode = kwargs.get("eval_mode", False)
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
format_score = 0.0 format_score = 0.0
acc_score = 10.0 acc_score = 10.0
@ -91,7 +79,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
if gt_answer is None: if gt_answer is None:
return reward return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
@ -108,5 +96,15 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
reward += acc_score reward += acc_score
reward = reward + length_reward reward = reward + length_reward
if not eval_mode:
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
else:
prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True)
return {
"prompt": prompt,
"prediction": decoded_final_answer,
"gold": gt_answer,
"parsed": final_answer,
"format_valid": format_acc.item(),
"ans_valid": ans_acc.item(),
}

View File

@ -1,6 +1,9 @@
import json
import os
from typing import Any, Dict, List from typing import Any, Dict, List
import torch import torch
from filelock import FileLock
from colossalai.shardformer.layer.loss import dist_log_prob from colossalai.shardformer.layer.loss import dist_log_prob
@ -130,3 +133,13 @@ def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.
""" """
tensor = tensor * mask tensor = tensor * mask
return tensor.sum(dim=dim) return tensor.sum(dim=dim)
def safe_append_to_jsonl_file(file_path, data):
with FileLock(file_path + ".lock"):
# Ensure file exists
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "a", encoding="utf8") as f:
for entry in data:
json_line = json.dumps(entry, ensure_ascii=False)
f.write(json_line + "\n")

View File

@ -1,4 +1,5 @@
import argparse import argparse
import json
import os import os
import ray import ray
@ -9,7 +10,16 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument("-p", "--project", type=str, default="GRPO-V3", help="Project name.") parser.add_argument(
"-ed",
"--eval-dataset",
type=str,
default='{"eval task name":"data_eval.jsonl"}',
help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \
For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \
The key is the task name, and the value is the path to the jsonl file",
)
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.") parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
# Distributed training parameters # Distributed training parameters
@ -94,11 +104,20 @@ if __name__ == "__main__":
choices=["think_answer_tags", "boxed"], choices=["think_answer_tags", "boxed"],
help="Reward type for GRPO.", help="Reward type for GRPO.",
) )
parser.add_argument(
"-ei",
"--eval-interval",
type=int,
default=100,
help="Interval for evaluation. Evaluate every ei training steps.",
)
# Logging/Checkpointing parameters # Logging/Checkpointing parameters
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.") parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.") parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.")
parser.add_argument(
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
)
args = parser.parse_args() args = parser.parse_args()
if args.train_minibatch_size is None: if args.train_minibatch_size is None:
@ -148,6 +167,7 @@ if __name__ == "__main__":
stop_strings=["</answer>"] if args.reward_type == "think_answer_tags" else None, stop_strings=["</answer>"] if args.reward_type == "think_answer_tags" else None,
) )
) )
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
elif args.backend == "vllm": elif args.backend == "vllm":
inference_model_config.update( inference_model_config.update(
dict( dict(
@ -166,6 +186,7 @@ if __name__ == "__main__":
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None, stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
) )
) )
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
else: else:
raise ValueError(f"Unsupported backend: {args.backend}") raise ValueError(f"Unsupported backend: {args.backend}")
@ -208,7 +229,7 @@ if __name__ == "__main__":
inference_microbatch_size=args.inference_microbatch_size, inference_microbatch_size=args.inference_microbatch_size,
train_batch_size=args.train_batch_size, train_batch_size=args.train_batch_size,
train_minibatch_size=args.train_minibatch_size, train_minibatch_size=args.train_minibatch_size,
dataset_config={ train_dataset_config={
"path": args.dataset, "path": args.dataset,
"max_length": args.max_prompt_tokens, "max_length": args.max_prompt_tokens,
"system_prompt": args.system_prompt, "system_prompt": args.system_prompt,
@ -238,4 +259,11 @@ if __name__ == "__main__":
project_name=args.project, project_name=args.project,
save_interval=args.save_interval, save_interval=args.save_interval,
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")), save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
eval_dataset_config={
k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt}
for k, v in json.loads(args.eval_dataset).items()
},
eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
eval_generation_config=eval_generation_config,
) )