[feat][merge] Support one-behind to reduce bubble time. Add profiling code.

This commit is contained in:
xysheng-colossal 2025-06-26 16:39:05 +08:00
parent 9379a89677
commit d9b5f10d82
17 changed files with 1460 additions and 268 deletions

6
.gitignore vendored
View File

@ -167,3 +167,9 @@ applications/ColossalChat/wandb
applications/ColossalChat/model
applications/ColossalChat/eval
applications/ColossalChat/rollouts
applications/ColossalChat/*.txt
applications/ColossalChat/*.db
applications/ColossalChat/stdin
applications/ColossalChat/*.zip
applications/ColossalChat/*.prof
applications/ColossalChat/*.png

View File

@ -367,9 +367,9 @@ def apply_chat_template_and_mask(
}
# Format for RL.
gt_answer = None
if "messages" in chat and "gt_answer" in chat:
gt_answer = chat["gt_answer"]
if "messages" in chat:
gt_answer = chat.get("gt_answer", None)
test_cases = chat.get("test_cases", None)
chat = [chat["messages"]]
tokens = []
@ -402,12 +402,14 @@ def apply_chat_template_and_mask(
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
if gt_answer is not None:
gt_answer = tokenizer.encode(
gt_answer, padding="max_length", truncation=True, max_length=128, return_tensors="pt"
)
gt_answer = gt_answer.squeeze(1)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}
elif test_cases is not None:
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"test_cases": test_cases,
}
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
@ -440,3 +442,20 @@ class RawConversationDataset(Dataset):
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
self.tokenized_texts[index] = dict(tokens)
return self.tokenized_texts[index]
def collate_fn_grpo(batch):
input_ids = [item["input_ids"] for item in batch]
attention_mask = [item["attention_mask"] for item in batch]
labels = [item["labels"] for item in batch]
# Assume input_ids, attention_mask, labels are already of the same length,
# otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
input_ids = torch.stack(input_ids)
attention_mask = torch.stack(attention_mask)
labels = torch.stack(labels)
ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
if "test_cases" in batch[0]:
ret["test_cases"] = [item["test_cases"] for item in batch]
if "gt_answer" in batch[0]:
ret["gt_answer"] = [item["gt_answer"] for item in batch]
return ret

View File

@ -6,6 +6,7 @@ import ray
import ray.util.collective as cc
import torch
import torch.distributed as dist
from coati.distributed.profiling_utils import CustomProfiler
from tqdm import tqdm
from transformers import AutoModelForCausalLM
@ -36,6 +37,8 @@ class BaseConsumer:
minibatch_size: int = 1,
save_interval: int = 100,
save_dir: str = "./model",
enable_profiling: bool = False,
n_behind: int = 0,
):
self.num_producers = num_producers
self.num_episodes = num_episodes
@ -49,6 +52,7 @@ class BaseConsumer:
self.minibatch_size = minibatch_size
self.save_interval = save_interval
self.save_dir = save_dir
self.enable_profiling = enable_profiling
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // minibatch_size
@ -58,6 +62,7 @@ class BaseConsumer:
self.device = "npu"
self.lr_scheduler = None
self.generate_config = generate_config
self.n_behind = n_behind
def setup(self) -> None:
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
@ -99,6 +104,7 @@ class BaseConsumer:
self.buffer = []
self.recv_cnt = 0
self.profiler = CustomProfiler(f"C{self.rank}", disabled=not self.enable_profiling)
def state_dict(self) -> Dict[str, torch.Tensor]:
raise NotImplementedError
@ -106,6 +112,37 @@ class BaseConsumer:
def step(self, step_idx: int, **kwargs) -> Optional[float]:
raise NotImplementedError
def prepare_mini_batch(self, effective_group_to_raw_group_mapping: Dict[int, int]) -> Dict[str, torch.Tensor]:
"""
Prepare a mini-batch from the effective group to raw group mapping.
This method is used to create a mini-batch for training.
"""
batches = [
self.buffer[effective_group_to_raw_group_mapping[i]]
for i in range(self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size)
]
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
# each mini-batch use the first self.dp_size * minibatch_size effective samples
raw_mini_batches = self.buffer[
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
] # include the last effective sample
raw_mini_batches_metric_dict = {
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
}
batch = bind_batch([t[0] for t in batches])
batch = post_recv(batch)
return batch, raw_mini_batches_metric_dict
def calculate_effective_group_to_raw_group_mapping(self, step):
effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None and (self.buffer[buffer_idx][-1] <= step - self.n_behind):
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
return effective_group_to_raw_group_mapping
def loop(self) -> None:
print(
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
@ -117,32 +154,65 @@ class BaseConsumer:
disable=self.rank != 0,
) as pbar:
for step in pbar:
torch.npu.reset_peak_memory_stats()
i = 0
self.profiler.enter(f"rollout_episode_{episode}_step_{step}")
for _ in range(self.num_recv_per_update):
if self.n_behind > 0:
# after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
step=step
)
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
self.profiler.log(
f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
)
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
effective_group_to_raw_group_mapping
)
self.profiler.enter("step")
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
self.profiler.exit("step")
self.buffer = self.buffer[
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
]
# recalculate the effective group to raw group mapping
effective_group_to_raw_group_mapping_size_before = len(
effective_group_to_raw_group_mapping
)
effective_group_to_raw_group_mapping = (
self.calculate_effective_group_to_raw_group_mapping(step=step)
)
assert (
len(effective_group_to_raw_group_mapping)
== effective_group_to_raw_group_mapping_size_before
- self.dp_size * self.minibatch_size
)
if loss is not None:
pbar.set_postfix({"loss": loss})
i += 1
# receive data from producers
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.profiler.enter(f"recv_broadcast_data_P{r}")
raw_batch = ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_data_{r}"
)
self.profiler.exit(f"recv_broadcast_data_P{r}")
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
# we need to calculate the metrics before filtering here for logging
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
raw_batch_with_reward = self.calculate_reward(
{k: v.view(-1, v.size(-1)) if k != "temperature" else v for k, v in raw_batch.items()}
)
raw_batch_with_reward = {
raw_batch = {
k: v.view(-1, self.num_generations, v.size(-1)) if k != "temperature" else v
for k, v in raw_batch_with_reward.items()
for k, v in raw_batch.items()
}
# [batch_size, num_generations] -> [batch_size]
reward = raw_batch_with_reward["reward"][:, :, 0]
format_acc = raw_batch_with_reward["format_acc"][:, :, 0]
ans_acc = raw_batch_with_reward["ans_acc"][:, :, 0]
reward = raw_batch["reward"][:, :, 0]
format_acc = raw_batch["format_acc"][:, :, 0]
ans_acc = raw_batch["ans_acc"][:, :, 0]
response_len = (
raw_batch_with_reward["response_idx"][:, :, 1]
- raw_batch_with_reward["response_idx"][:, :, 0]
+ 1
raw_batch["response_idx"][:, :, 1] - raw_batch["response_idx"][:, :, 0] + 1
).type(torch.float32)
effective_group_mask = None
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
@ -151,8 +221,9 @@ class BaseConsumer:
effective_group_mask = torch.logical_and(
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
)
raw_batch_with_reward = unbind_batch(raw_batch_with_reward) # List[Dict[str, torch.Tensor]]
for group_idx, group_with_reward in enumerate(raw_batch_with_reward):
raw_batch = unbind_batch(raw_batch) # List[Dict[str, torch.Tensor]]
for group_idx, group_with_reward in enumerate(raw_batch):
self.buffer.append(
[
(
@ -164,56 +235,43 @@ class BaseConsumer:
format_acc[group_idx],
ans_acc[group_idx],
response_len[group_idx],
step,
]
)
if effective_group_mask is not None:
print(
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch_with_reward)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
)
# mapping the effective group to the raw group for indexing
effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
buffer_idx
)
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
step=step
)
print(
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
)
if self.n_behind == 0:
# If n_behind is 0, we start training after receiving data from producers.
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
step=step
)
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
# on each dp_rank, we use minibatch_size effective samples to form a batch
batches = [
self.buffer[effective_group_to_raw_group_mapping[i]]
for i in range(
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
)
]
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
# each mini-batch use the first self.dp_size * minibatch_size effective samples
raw_mini_batches = self.buffer[
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
] # include the last effective sample
raw_mini_batches_metric_dict = {
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
}
batch = bind_batch([t[0] for t in batches])
batch = post_recv(batch)
self.profiler.log(
f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
)
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
effective_group_to_raw_group_mapping
)
self.profiler.enter("step")
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
self.profiler.exit("step")
self.buffer = self.buffer[
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
]
# recalculate the effective group to raw group mapping
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
buffer_idx
)
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
step=step
)
assert (
len(effective_group_to_raw_group_mapping)
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
@ -231,14 +289,17 @@ class BaseConsumer:
if self.rank == 0:
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
episode != 0 or step >= self.n_behind
):
if self.pp_size > 1:
print(
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
)
else:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
torch.cuda.empty_cache()
self.profiler.enter("sync_model")
torch.npu.empty_cache()
state_dict = self.state_dict()
if self.pp_size > 1:
if self.tp_rank == 0 and self.dp_rank == 0:
@ -254,7 +315,14 @@ class BaseConsumer:
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
)
del state_dict
torch.cuda.empty_cache()
torch.npu.empty_cache()
self.profiler.exit("sync_model")
self.profiler.log(f"Peak memory usage: {torch.npu.max_memory_allocated() / 1024 ** 2:.2f} MB")
self.profiler.exit(f"rollout_episode_{episode}_step_{step}")
def __del__(self):
if hasattr(self, "profiler"):
self.profiler.close()
@ray.remote

View File

@ -1,14 +1,12 @@
from contextlib import nullcontext
from typing import Any, Dict, Optional
from typing import Any, Optional
import ray
import torch
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.distributed.utils import memory_efficient_logprob
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer
@ -40,6 +38,8 @@ class GRPOConsumer(BaseConsumer):
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
enable_profiling: bool = False,
n_behind: int = 0,
):
print(f"Using GRPO config: {grpo_config}")
if (
@ -65,6 +65,8 @@ class GRPOConsumer(BaseConsumer):
minibatch_size,
save_interval=save_interval,
save_dir=save_dir,
enable_profiling=enable_profiling,
n_behind=n_behind,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@ -119,20 +121,7 @@ class GRPOConsumer(BaseConsumer):
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
)
# Initialize verifiable reward.
response_format_tags = grpo_config.get("response_format_tags", None)
reward_model_kwargs = {
k: v
for k, v in grpo_config.items()
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
}
self.reward_model = VerifiableReward(
reward_fns=[
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
],
tokenizer=self.tokenizer,
tags=response_format_tags,
**reward_model_kwargs,
)
grpo_config.get("response_format_tags", None)
self.global_step = 0
self.lr_scheduler = CosineAnnealingWarmupLR(
@ -295,12 +284,11 @@ class GRPOConsumer(BaseConsumer):
)
if self.booster.plugin.stage_manager.is_last_stage():
reference_model_logits = reference_model_outputs["outputs"]["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
reference_action_log_probs = memory_efficient_logprob(
reference_model_outputs["outputs"]["logits"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
shard_config=self.plugin.shard_config,
)
else:
# Dummy reference logprobs for data iterator.
@ -323,11 +311,11 @@ class GRPOConsumer(BaseConsumer):
def _criterion(outputs, inputs):
action_logits = outputs.logits
action_log_probs = calc_action_log_probs(
action_logits / self.generate_config["temperature"],
action_log_probs = memory_efficient_logprob(
action_logits,
inputs["input_ids"],
num_action,
self.plugin.shard_config,
shard_config=self.plugin.shard_config,
)
if "reference_action_log_probs" in inputs:
per_token_kl = (
@ -370,16 +358,15 @@ class GRPOConsumer(BaseConsumer):
mean_kl.append(kl)
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
else:
policy_model_logits = self.policy_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
action_log_probs = calc_action_log_probs(
action_log_probs = memory_efficient_logprob(
policy_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
shard_config=self.plugin.shard_config,
)
if self.policy_loss_fn.beta > 0:
@ -388,11 +375,11 @@ class GRPOConsumer(BaseConsumer):
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
reference_action_log_probs = calc_action_log_probs(
reference_action_log_probs = memory_efficient_logprob(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
shard_config=self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
@ -498,40 +485,6 @@ class GRPOConsumer(BaseConsumer):
else:
return None
def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]:
"""
Calculate the group reward for the given rollout group.
Args:
rollout_group (Dict[str, Any]):
a group of samples generated by the model from the same prompt
contain the following keys:
"input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
"attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
"action_mask": torch.Tensor, [num_of_generation, response_length]
"action_log_probs": torch.Tensor, [num_of_generation, response_length]
"response_idx": int, torch.Tensor, [num_of_generation, 2]
"gt_answer": torch.Tensor, [num_of_generation, 128]
"temperature": torch.Tensor, [] (scalar)
Returns:
Dict[str, Any]: The new group data with calculated reward.
"""
reward_model_output = self.reward_model(
rollout["input_ids"],
gt_answer=rollout["gt_answer"],
response_idx=rollout["response_idx"],
)
# [num_of_generation]
reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device)
format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device)
ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device)
rollout["reward"] = reward.view((-1, 1))
rollout["format_acc"] = format_acc.view((-1, 1))
rollout["ans_acc"] = ans_acc.view((-1, 1))
return rollout
def state_dict(self):
self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap()

View File

@ -74,9 +74,8 @@ class TransformersInferenceBackend(BaseInferenceBackend):
micro_batch_size = input_ids.size(0)
input_ids = input_ids.to(get_current_device())
attention_mask = attention_mask.to(get_current_device())
gt_answer = None
if "gt_answer" in kwargs:
gt_answer = kwargs.pop("gt_answer")
gt_answer = kwargs.pop("gt_answer", None)
test_cases = kwargs.pop("test_cases", None)
if self.num_generations > 1:
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
@ -116,8 +115,9 @@ class TransformersInferenceBackend(BaseInferenceBackend):
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
if gt_answer is not None:
# repeat gt_answer for each prompt.
data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1)
data["gt_answer"] = gt_answer
if test_cases is not None:
data["test_cases"] = test_cases
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data
@ -270,11 +270,11 @@ class VLLMInferenceBackend(BaseInferenceBackend):
}
data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}
if "gt_answer" in kwargs:
# repeat gt_answer for each prompt.
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
data = {k: v.to(get_current_device()) for k, v in data.items()}
if "gt_answer" in kwargs:
data["gt_answer"] = kwargs["gt_answer"]
if "test_cases" in kwargs:
data["test_cases"] = kwargs["test_cases"]
return data
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:

View File

@ -16,7 +16,7 @@ def get_jsonl_size_fast(path: str) -> int:
with open(path) as f:
lines = f.readlines()
lines = [line for line in lines if line.strip()]
return len(lines) - 1
return len(lines)
def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:
@ -36,7 +36,6 @@ def launch_distributed(
train_batch_size: int,
train_minibatch_size: int,
train_dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
inference_model_config: Dict[str, Any],
generate_config: Dict[str, Any],
train_model_config: Dict[str, Any],
@ -57,6 +56,8 @@ def launch_distributed(
eval_generation_config: Optional[Dict[str, Any]] = None,
log_rollout_interval: int = 20,
rollout_save_dir: str = "./rollout",
enable_profiling: bool = False,
n_behind: int = 0,
):
if core_algo not in ALGO_MAP:
raise NotImplementedError(f"{core_algo} is not supported yet.")
@ -79,6 +80,11 @@ def launch_distributed(
f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl",
)
# Attention: Ray use complex schedualing method that consider various factors including load-balancing.
# when requesting resources, it is not guaranteed that the resource comes from a node with lower node it
# this go against the design principle of our implementation, and we need to manually force the schedualing,
# allocating the producer to nodes with lower node id and the consumer to the resouces from nodes with higher
# node id. See the reference here: https://docs.ray.io/en/latest/ray-core/scheduling/index.html#nodeaffinityschedulingstrategy
nodes = ray.nodes()
node_info = {
node["NodeID"]: {
@ -104,7 +110,6 @@ def launch_distributed(
gpu_to_node_id.pop(0)
gpu_to_ip_address.pop(0)
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
producer = SimpleProducer.options(
# num_cpus=1,
# num_cpus=num_proc_per_producer,
@ -121,7 +126,6 @@ def launch_distributed(
num_episodes=num_episodes,
batch_size=inference_batch_size,
train_dataset_config=train_dataset_config,
dataloaders_config=dataloaders_config,
model_config=inference_model_config,
generate_config=generate_config,
tokenizer_config=tokenizer_config,
@ -131,8 +135,7 @@ def launch_distributed(
consumer_plugin_config=plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
evaluation_function_type=grpo_config["reward_fn_type"],
response_format_tags=grpo_config["response_format_tags"],
grpo_config=grpo_config,
eval_save_dir=eval_save_dir,
eval_generation_config=eval_generation_config,
project_name=project_name,
@ -140,6 +143,8 @@ def launch_distributed(
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
enable_profiling=enable_profiling,
n_behind=n_behind,
)
producer_procs.append(producer)
ray.get([p.setup.remote() for p in producer_procs])
@ -185,6 +190,8 @@ def launch_distributed(
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
enable_profiling=enable_profiling,
n_behind=n_behind,
)
consumer_procs.append(consumer)
ray.get([p.setup.remote() for p in consumer_procs])

View File

@ -8,8 +8,10 @@ import ray.util.collective as cc
import torch
import tqdm
import wandb
from coati.dataset.loader import RawConversationDataset
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
from coati.distributed.profiling_utils import CustomProfiler
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from ray.util.collective import allreduce
from ray.util.collective.types import ReduceOp
from torch.utils.data import DataLoader, DistributedSampler
@ -34,7 +36,6 @@ class BaseProducer:
num_episodes: int,
batch_size: int,
train_dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
@ -43,14 +44,15 @@ class BaseProducer:
consumer_plugin_config: Dict[str, Any] = None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags",
response_format_tags=None,
grpo_config: Dict[str, Any] = None,
eval_save_dir: str = "./eval",
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
enable_profiling: bool = False,
n_behind: int = 0,
):
self.producer_idx = producer_idx
self.num_producers = num_producers
@ -61,6 +63,7 @@ class BaseProducer:
assert batch_size % microbatch_size == 0
self.num_microbatches = batch_size // microbatch_size
self.latest_eval_step = -1
self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling)
self.train_dataset_config = train_dataset_config
self.model_config = model_config
@ -73,6 +76,14 @@ class BaseProducer:
self.eval_mode = False
self.log_rollout_interval = log_rollout_interval
self.latest_rollout_log_step = -1
self.grpo_config = grpo_config
self.n_behind = n_behind
reward_model_kwargs = {
k: v
for k, v in grpo_config.items()
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
}
self.response_format_tags = grpo_config.get("response_format_tags", None)
if producer_idx == 0:
if os.path.exists(rollout_log_file):
raise ValueError(
@ -118,7 +129,16 @@ class BaseProducer:
),
num_workers=4,
drop_last=True,
collate_fn=collate_fn_grpo,
)
if grpo_config["reward_fn_type"] == "think_answer_tags":
self.evaluation_function = math_reward_fn
elif grpo_config["reward_fn_type"] == "boxed":
self.evaluation_function = boxed_math_reward_fn
elif grpo_config["reward_fn_type"] == "code":
self.evaluation_function = code_reward_fn
else:
raise ValueError(f"Unknown evaluation function type {grpo_config['reward_fn_type']}")
self.eval_dataset_config = eval_dataset_config
if self.eval_dataset_config is not None:
@ -140,18 +160,17 @@ class BaseProducer:
drop_last=False,
seed=42,
),
collate_fn=collate_fn_grpo,
)
if evaluation_function_type == "think_answer_tags":
self.evaluation_function = math_reward_fn
elif evaluation_function_type == "boxed":
self.evaluation_function = boxed_math_reward_fn
else:
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
self.response_format_tags = response_format_tags
else:
print("No eval dataset provided, skip eval")
self.device = "npu"
self.reward_model = VerifiableReward(
reward_fns=[self.evaluation_function], # multiple reward functions can be added here
tokenizer=self.tokenizer,
tags=self.response_format_tags,
**reward_model_kwargs,
)
# init backend
if backend in BACKEND_MAP:
@ -212,7 +231,13 @@ class BaseProducer:
eval_results = eval_results + [
self.evaluation_function(
eval_outputs["input_ids"][m][n],
eval_outputs["gt_answer"][m][n],
eval_outputs[
(
"test_cases"
if self.grpo_config["reward_fn_type"] == "code"
else "gt_answer"
)
][m],
eval_outputs["response_idx"][m][n],
tokenizer=self.tokenizer,
eval_mode=True,
@ -244,23 +269,71 @@ class BaseProducer:
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
self.eval_mode = False
self.latest_eval_step = self.consumer_global_step
self.profiler.enter("rollout")
outputs = self.rollout(**batch)
self.profiler.exit("rollout")
outputs["temperature"] = torch.tensor(
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device)
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
self.profiler.enter("calculate_reward")
if self.grpo_config["reward_fn_type"] == "code":
test_cases = []
for prompt_id in range(bs):
test_cases.extend([outputs["test_cases"][prompt_id]] * num_gen)
reward_model_output = self.reward_model(
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
test_cases=test_cases,
response_idx=outputs["response_idx"].view((-1, 2)),
)
else:
gt_answer = []
for prompt_id in range(bs):
gt_answer.extend([outputs["gt_answer"][prompt_id]] * num_gen)
reward_model_output = self.reward_model(
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
gt_answer=gt_answer,
response_idx=outputs["response_idx"].view((-1, 2)),
)
outputs["reward"] = (
torch.tensor([value[0] for value in reward_model_output])
.to(outputs["input_ids"].device)
.view((bs, num_gen, 1))
)
outputs["format_acc"] = (
torch.tensor([value[1] for value in reward_model_output])
.to(outputs["input_ids"].device)
.view((bs, num_gen, 1))
)
outputs["ans_acc"] = (
torch.tensor([value[2] for value in reward_model_output])
.to(outputs["input_ids"].device)
.view((bs, num_gen, 1))
)
if "gt_answer" in outputs:
outputs.pop("gt_answer")
if "test_cases" in outputs:
outputs.pop("test_cases")
self.profiler.exit("calculate_reward")
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
self.profiler.enter("send_broadcast_data")
ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
)
if (i + 1) % self.num_microbatches == 0 and (
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
self.profiler.exit("send_broadcast_data")
if (
(i + 1) % self.num_microbatches == 0
and (episode != self.num_episodes - 1 or i != num_valid_microbatches - 1)
and (episode != 0 or (i + 1) > self.n_behind * self.num_microbatches)
):
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False
):
self.model.llm.sleep() # revict KV_cache to avoid OOM
# don't sync model for last iteration
torch.cuda.empty_cache()
torch.npu.empty_cache()
self.profiler.enter("sync_model")
if self.consumer_pp_size > 1:
for pp_idx in range(self.consumer_pp_size):
@ -283,8 +356,9 @@ class BaseProducer:
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
self.profiler.exit("sync_model")
del state_dict
torch.cuda.empty_cache()
torch.npu.empty_cache()
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False
):
@ -300,6 +374,9 @@ class BaseProducer:
"temperature"
] + ratio * 0.9
def __del__(self):
self.profiler.close()
@ray.remote
class SimpleProducer(BaseProducer):
@ -311,7 +388,6 @@ class SimpleProducer(BaseProducer):
num_episodes,
batch_size,
train_dataset_config,
dataloaders_config,
model_config,
generate_config,
tokenizer_config=None,
@ -321,8 +397,7 @@ class SimpleProducer(BaseProducer):
consumer_plugin_config=None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags",
response_format_tags=None,
grpo_config: Dict[str, Any] = None,
eval_save_dir: str = "./eval",
eval_generation_config={},
project_name: str = None,
@ -330,6 +405,8 @@ class SimpleProducer(BaseProducer):
wandb_group_name: str = None,
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
enable_profiling: bool = False,
n_behind: int = 0,
):
super().__init__(
producer_idx,
@ -338,7 +415,6 @@ class SimpleProducer(BaseProducer):
num_episodes,
batch_size,
train_dataset_config,
dataloaders_config,
model_config,
generate_config,
tokenizer_config,
@ -347,14 +423,15 @@ class SimpleProducer(BaseProducer):
consumer_plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
evaluation_function_type=evaluation_function_type,
response_format_tags=response_format_tags,
grpo_config=grpo_config,
eval_save_dir=eval_save_dir,
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
enable_profiling=enable_profiling,
n_behind=n_behind,
)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)

View File

@ -0,0 +1,37 @@
import os
import time
class CustomProfiler:
def __init__(self, name, disabled=True):
self.disabled = disabled
if not disabled:
self.name = name
self.pid = os.getpid()
self.file = open(f"{name}.prof", "w")
def _log(self, message):
if self.disabled:
return
current_time = time.time()
self.file.write(f"{current_time} {self.name} {self.pid}:: {message}\n")
self.file.flush()
def log(self, message):
if self.disabled:
return
current_time = time.time()
self.file.write(f"[Log]: {current_time} {self.name} {self.pid}:: {message}\n")
self.file.flush()
def enter(self, event_name):
self._log(f"Enter {event_name}")
def exit(self, event_name):
self._log(f"Exit {event_name}")
def close(self):
if self.disabled:
return
self.file.close()
print(f"Profiler data written to {self.name}.prof")

View File

@ -0,0 +1,669 @@
# Code from the verl Project (https://github.com/agentica-project/rllm),
# which itself is adapted from Prime (https://github.com/PRIME-RL/PRIME)
#
# Copyright 2024 ByteDance Group
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import faulthandler
import json
import platform
# to run the solution files we're using a timing based approach
import signal
import sys
import traceback
# used for debugging to time steps
from datetime import datetime
from enum import Enum
# for capturing the stdout
from io import StringIO
# used for testing the code that reads from input
from unittest.mock import mock_open, patch
import numpy as np
from pyext import RuntimeModule
def truncatefn(s, length=300):
assert isinstance(s, str)
if len(s) <= length:
return s
return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :]
class CODE_TYPE(Enum):
call_based = 0
standard_input = 1
# used to capture stdout as a list
# from https://stackoverflow.com/a/16571630/6416660
# alternative use redirect_stdout() from contextlib
class Capturing(list):
def __enter__(self):
self._stdout = sys.stdout
sys.stdout = self._stringio = StringIO()
# Make closing the StringIO a no-op
self._stringio.close = lambda x: 1
return self
def __exit__(self, *args):
self.append(self._stringio.getvalue())
del self._stringio # free up some memory
sys.stdout = self._stdout
def only_int_check(val):
return isinstance(val, int)
def string_int_check(val):
return isinstance(val, str) and val.isdigit()
def combined_int_check(val):
return only_int_check(val) or string_int_check(val)
def clean_traceback(error_traceback):
file_start = error_traceback.find('File "<string>"')
# print(file_start)
error_traceback = "Traceback (most recent call last):\n " + error_traceback[file_start:]
return error_traceback
def run_test(in_outs, test=None, debug=False, timeout=15):
"""
if test(generated_code) is not None it'll try to run the code.
otherwise it'll just return an input and output pair.
"""
# Disable functionalities that can make destructive changes to the test.
reliability_guard()
if debug:
print(f"start = {datetime.now().time()}")
if in_outs:
if in_outs.get("fn_name") is None:
which_type = CODE_TYPE.standard_input # Standard input
method_name = None
else:
which_type = CODE_TYPE.call_based # Call-based
method_name = in_outs["fn_name"]
if debug:
print(f"loaded input_output = {datetime.now().time()}")
if test is None:
raise AssertionError("should not happen: test code is none")
elif test is not None:
results = []
sol = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(6*10**5)\n" # noqa: E501
if debug:
print(f"loading test code = {datetime.now().time()}")
if which_type == CODE_TYPE.call_based:
sol += test
if debug:
print(f"sol = {sol}")
signal.alarm(timeout)
try:
tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
tmp = tmp_sol if "class Solution" not in test else tmp_sol.Solution()
signal.alarm(0)
except Exception as e:
signal.alarm(0)
error_traceback = traceback.format_exc()
if debug:
print(f"type 0 compilation error = {e}")
results.append(-2)
return results, {
"error": repr(e),
# "error_code": -1,
# "error_message": "Compilation Error",
"traceback": clean_traceback(error_traceback),
}
signal.alarm(0)
elif which_type == CODE_TYPE.standard_input:
# sol
# if code has if __name__ == "__main__": then remove it
try:
astree = ast.parse(test)
last_block = astree.body[-1]
if isinstance(last_block, ast.If):
condition = last_block.test
if ast.unparse(condition).strip() == "__name__ == '__main__'":
test = ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body)
except Exception:
pass
tmp_test = test.split("\n")
new_test = []
for x in tmp_test:
if (not x.startswith("from ")) and (not x.startswith("import ")):
new_test.append("\t" + x + "\n")
else:
new_test.append(x + "\n")
tmp_test = new_test
new_test = ""
started = False
for i in tmp_test:
if i.startswith("\t") and not started:
new_test += "stdin = sys.stdin\nstdout = sys.stdout\n"
new_test += "def code():\n"
new_test += i
started = True
elif started and ((i.startswith("from ")) or (i.startswith("import "))):
new_test += "\t" + i
else:
new_test += i
tmp_test = new_test
sol += tmp_test
if debug:
print(f"sol = {sol}")
method_name = "code"
signal.alarm(timeout)
try:
tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
tmp = tmp_sol
signal.alarm(0)
except Exception as e:
signal.alarm(0)
error_traceback = traceback.format_exc()
if debug:
print(f"type 1 compilation error = {e}")
results.append(-2)
return results, {
"error": repr(e),
# "error_code": -1,
# "error_message": "Compilation Error",
"traceback": clean_traceback(error_traceback),
}
signal.alarm(0)
if debug:
print(f"get method = {datetime.now().time()}")
try:
method = getattr(tmp, method_name) # get_attr second arg must be str
except Exception:
signal.alarm(0)
error_traceback = traceback.format_exc()
error_info = sys.exc_info()
print(f"unable to get function error = {error_info}")
results.append(-2)
return results, {
"error": repr(error_info),
# "error_code": -1,
# "error_message": "Unable to extract code",
"traceback": clean_traceback(error_traceback),
}
for index, inputs in enumerate(in_outs["inputs"]):
raw_inputs = inputs
raw_outputs = in_outs["outputs"][index]
if which_type == CODE_TYPE.call_based:
inputs = [json.loads(line) for line in inputs.split("\n")]
in_outs["outputs"][index] = json.loads(in_outs["outputs"][index])
truncate_line_size = 300 // (raw_inputs.count("\n") + 1)
raw_inputs = "\n".join(
[truncatefn(line, truncate_line_size) for line in raw_inputs.strip().split("\n")]
)
raw_outputs = truncatefn(raw_outputs, 200)
else:
raw_inputs = truncatefn(raw_inputs)
raw_outputs = truncatefn(raw_outputs, 200)
# JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)
try:
if isinstance(inputs[0], dict):
inputs = [{int(k): v for k, v in inputs[0].items()}]
except Exception:
pass
try:
if isinstance(in_outs["outputs"][index], dict):
in_outs["outputs"][index] = [{int(k): v for k, v in in_outs["outputs"][index].items()}]
except Exception:
pass
try:
if isinstance(in_outs["outputs"][index][0], dict):
in_outs["outputs"][index] = [{int(k): v for k, v in in_outs["outputs"][index][0].items()}]
except Exception:
pass
if debug:
print(
f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}"
)
if which_type == CODE_TYPE.call_based: # Call-based
signal.alarm(timeout)
faulthandler.enable()
try:
output = method(*inputs)
raw_true_output = output
raw_true_output_copy = json.dumps(output)
raw_true_output_copy = truncatefn(raw_true_output_copy, 200)
# ground truth sequences are not tuples
if isinstance(output, tuple):
output = list(output)
tmp_result = output == in_outs["outputs"][index]
if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]:
tmp_result = tmp_result or (output == in_outs["outputs"][index][0])
# ground truth sequences are not tuples
try:
if isinstance(output[0], tuple):
tmp_result = tmp_result or ([list(x) for x in output] == in_outs["outputs"][index][0])
except Exception:
pass
results.append(tmp_result)
if tmp_result is not True:
return results, {
"output": raw_true_output_copy,
"expected": raw_outputs,
"inputs": raw_inputs,
# "error_code": -2,
"error_message": "Wrong Answer",
}
# reset the alarm
signal.alarm(0)
except Exception as e:
signal.alarm(0)
error_traceback = traceback.format_exc()
faulthandler.disable()
if debug:
print(f"Standard input runtime error or time limit exceeded error = {e}")
results.append(-1)
return results, {
"error": repr(e),
"traceback": clean_traceback(error_traceback),
}
faulthandler.disable()
signal.alarm(0)
if debug:
print(
f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
)
elif which_type == CODE_TYPE.standard_input: # Standard input
faulthandler.enable()
passed = False
if isinstance(inputs, list):
inputs = "\n".join(inputs)
if isinstance(in_outs["outputs"][index], list):
in_outs["outputs"][index] = "\n".join(in_outs["outputs"][index])
signal.alarm(timeout)
with Capturing() as output:
try:
call_method(method, inputs)
# reset the alarm
signal.alarm(0)
passed = True
except Exception as e:
# runtime error or took too long
signal.alarm(0)
error_traceback = traceback.format_exc()
print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
results.append(-1)
return results, {
"error": repr(e),
"traceback": clean_traceback(error_traceback),
}
signal.alarm(0)
raw_true_output = output[0]
raw_true_output_copy = truncatefn(raw_true_output, 200)
output = raw_true_output.splitlines()
if not passed:
if debug:
nl = "\n"
if not isinstance(inputs, list):
print(
f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
)
else:
print(
f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
)
continue
if passed and debug:
print(f"==> output = {output}, test outputs = {in_outs['outputs'][index]}")
if custom_compare_(output, in_outs["outputs"][index]):
tmp_result = True
results.append(tmp_result)
continue
# ground truth sequences are expressed as lists not tuples
if isinstance(output, tuple):
output = list(output)
tmp_result = False
try:
tmp_result = output == [in_outs["outputs"][index]]
if isinstance(in_outs["outputs"][index], list):
tmp_result = tmp_result or (output == in_outs["outputs"][index])
if isinstance(output[0], str):
tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index])
except Exception as e:
if debug:
print(f"Failed check1 exception = {e}")
if tmp_result is True:
results.append(tmp_result)
continue
# try one more time without \n
if isinstance(in_outs["outputs"][index], list):
for tmp_index, i in enumerate(in_outs["outputs"][index]):
in_outs["outputs"][index][tmp_index] = i.split("\n")
in_outs["outputs"][index][tmp_index] = [
x.strip() for x in in_outs["outputs"][index][tmp_index] if x
]
else:
in_outs["outputs"][index] = in_outs["outputs"][index].split("\n")
in_outs["outputs"][index] = list(filter(len, in_outs["outputs"][index]))
in_outs["outputs"][index] = list(map(lambda x: x.strip(), in_outs["outputs"][index]))
try:
tmp_result = output == [in_outs["outputs"][index]]
if isinstance(in_outs["outputs"][index], list):
tmp_result = tmp_result or (output == in_outs["outputs"][index])
except Exception as e:
if debug:
print(f"Failed check2 exception = {e}")
if tmp_result is True:
results.append(tmp_result)
continue
# try by converting the output into a split up list too
if isinstance(output, list):
output = list(filter(len, output))
if debug:
nl = "\n"
if not isinstance(inputs, list):
print(
f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}"
)
else:
print(
f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}"
)
if debug:
print(f"{tmp_result=} @a")
try:
tmp_result = output == [in_outs["outputs"][index]]
if isinstance(in_outs["outputs"][index], list):
tmp_result = tmp_result or (output == in_outs["outputs"][index])
except Exception as e:
if debug:
print(f"Failed check3 exception = {e}")
if debug:
print(f"{tmp_result=} @b")
try:
all_ints = all(
combined_int_check(e1) and combined_int_check(e2)
for e1, e2 in zip(output, in_outs["outputs"][index])
)
if not all_ints:
if debug:
print(
[
combined_int_check(e1) and combined_int_check(e2)
for e1, e2 in zip(output, in_outs["outputs"][index])
]
)
output_float = [float(e) for e in output]
gt_float = [float(e) for e in in_outs["outputs"][index]]
tmp_result = tmp_result or (
(len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)
)
except Exception:
pass
if debug:
print(f"{tmp_result=} @c")
try:
if isinstance(output[0], list):
all_ints = all(
combined_int_check(e1) and combined_int_check(e2)
for e1, e2 in zip(output[0], in_outs["outputs"][index])
)
if not all_ints:
output_float = [float(e) for e in output[0]]
gt_float = [float(e) for e in in_outs["outputs"][index][0]]
tmp_result = tmp_result or (
(len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)
)
except Exception:
pass
if tmp_result is True:
results.append(tmp_result)
continue
if debug:
print(f"{tmp_result=} @d")
# try by converting the stuff into split up list
if isinstance(in_outs["outputs"][index], list):
for tmp_index, i in enumerate(in_outs["outputs"][index]):
in_outs["outputs"][index][tmp_index] = set(i.split())
else:
in_outs["outputs"][index] = set(in_outs["outputs"][index].split())
if debug:
print(f"{tmp_result=} @e")
try:
tmp_result = output == in_outs["outputs"][index]
except Exception as e:
if debug:
print(f"Failed check4 exception = {e}")
continue
if tmp_result is True:
results.append(tmp_result)
continue
if debug:
print(f"{tmp_result=} @f")
# try by converting the output into a split up list too
if isinstance(output, list):
for tmp_index, i in enumerate(output):
output[tmp_index] = i.split()
output = list(filter(len, output))
for tmp_index, i in enumerate(output):
output[tmp_index] = set(i)
else:
output = output.split()
output = list(filter(len, output))
output = set(output)
if debug:
print(f"{tmp_result=} @g")
if tmp_result is True and debug:
print("PASSED")
results.append(tmp_result)
if tmp_result is not True:
return results, {
"output": raw_true_output_copy,
"expected": raw_outputs,
"inputs": raw_inputs,
# "error_code": -2,
"error_message": "Wrong Answer",
}
if debug:
nl = "\n"
if not isinstance(inputs, list):
print(
f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
)
else:
print(
f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
)
print(f"results = {results}")
return results, {}
def custom_compare_(output, ground_truth):
if isinstance(output, list):
output_1 = "\n".join(output)
if stripped_string_compare(output_1, ground_truth):
return True
if isinstance(output, list):
output_2 = [o.lstrip().rstrip() for o in output]
output_2 = "\n".join(output_2)
if stripped_string_compare(output_2, ground_truth):
return True
return False
def stripped_string_compare(s1, s2):
s1 = s1.lstrip().rstrip()
s2 = s2.lstrip().rstrip()
return s1 == s2
def call_method(method, inputs):
if isinstance(inputs, list):
inputs = "\n".join(inputs)
inputs_line_iterator = iter(inputs.split("\n"))
# sys.setrecursionlimit(10000)
# @patch('builtins.input', side_effect=inputs.split("\n"))
@patch("builtins.open", mock_open(read_data=inputs))
@patch("sys.stdin", StringIO(inputs))
@patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator))
@patch("sys.stdin.readlines", lambda *args: inputs.split("\n"))
@patch("sys.stdin.read", lambda *args: inputs)
# @patch('sys.stdout.write', print)
def _inner_call_method(_method):
try:
return _method()
except SystemExit:
pass
finally:
pass
return _inner_call_method(method)
def reliability_guard(maximum_memory_bytes=None):
"""
This disables various destructive functions and prevents the generated code
from interfering with the test (e.g. fork bomb, killing other processes,
removing filesystem files, etc.)
WARNING
This function is NOT a security sandbox. Untrusted code, including, model-
generated code, should not be blindly executed outside of one. See the
Codex paper for more information about OpenAI's code sandbox, and proceed
with caution.
"""
if maximum_memory_bytes is not None:
import resource
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
if platform.uname().system != "Darwin":
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
faulthandler.disable()
import builtins
builtins.exit = None
builtins.quit = None
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.kill = None
os.system = None # 防止干扰repl评测
os.putenv = None
os.remove = None
os.removedirs = None
os.rmdir = None
os.fchdir = None
os.setuid = None
os.fork = None
os.forkpty = None
os.killpg = None
os.rename = None
os.renames = None
os.truncate = None
os.replace = None
os.unlink = None
os.fchmod = None
os.fchown = None
os.chmod = None
os.chown = None
os.chroot = None
os.lchflags = None
os.lchmod = None
os.lchown = None
os.getcwd = None
os.chdir = None
import shutil
shutil.rmtree = None
shutil.move = None
shutil.chown = None
import subprocess
subprocess.Popen = None # type: ignore
__builtins__["help"] = None
import sys
sys.modules["ipdb"] = None
sys.modules["joblib"] = None
sys.modules["resource"] = None
sys.modules["psutil"] = None
sys.modules["tkinter"] = None

View File

@ -0,0 +1,61 @@
# Code from the verl Project (https://github.com/agentica-project/rllm),
# which itself is adapted from Prime (https://github.com/PRIME-RL/PRIME)
#
# Copyright 2024 ByteDance Group
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import os
import sys
import traceback
from typing import Optional
from .testing_util import run_test
def _temp_run(sample, generation, debug, result, metadata_list, timeout):
with open(os.devnull, "w") as devnull:
sys.stdout = devnull
sys.stderr = devnull
try:
res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout)
result.append(res)
metadata_list.append(metadata)
except Exception:
# print(e) # some tracebacks are extremely long.
traceback.print_exc(10)
result.append([-1 for i in range(len(sample["inputs"]))])
metadata_list.append({})
def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True):
"""Check correctness of code generation with a global timeout.
The global timeout is to catch some extreme/rare cases not handled by the timeouts
inside `run_test`"""
manager = multiprocessing.Manager()
result = manager.list()
metadata_list = manager.list()
p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout))
p.start()
p.join(timeout=timeout + 1)
if p.is_alive():
p.kill()
# p.terminate()
if not result:
# consider that all tests failed
result = [[-1 for i in range(len(in_outs["inputs"]))]]
if debug:
print("global timeout")
return result[0], metadata_list

View File

@ -1,7 +1,30 @@
# Copyright 2024 ByteDance Group
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Some functions in this file are adapted from the verl project
under the Apache License 2.0:
https://github.com/volcengine/verl
"""
import json
import torch
from latex2sympy2_extended import NormalizationConfig
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
from .code_reward.utils import check_correctness as check_correctness_code
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
CANNOT_PARSE_GT_ANSWER = -1
@ -98,15 +121,11 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
raise ValueError("no gt_answer is provided, please check your training dataset.")
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
final_answer, processed_str = extract_solution(decoded_final_answer)
format_valid = validate_response_structure(processed_str, kwargs["tags"])
# Check format accuracy
if format_valid:
format_acc += 1
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
if final_answer is not None:
if eval_mode or format_valid:
@ -114,6 +133,10 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
if not eval_mode:
reward = reward + length_reward
# Check format accuracy
if format_valid:
format_acc += 1
# Check if the sequence is over length
if not eval_mode and res_length >= max_new_tokens:
reward *= 0.0
@ -138,7 +161,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"]
eval_mode = kwargs.get("eval_mode", False)
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
format_score = 0.0
acc_score = 10.0
reward = torch.tensor(0.0)
format_acc = torch.tensor(0.0)
@ -161,7 +183,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
final_answer = extract_boxed_solution(decoded_final_answer)
format_valid = final_answer is not None
if "tags" in kwargs and kwargs["tags"]:
@ -169,10 +190,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
format_valid = format_valid and all(
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
)
# Check format accuracy
if format_valid:
format_acc += 1
reward += format_score
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
if final_answer is not None:
@ -181,6 +198,10 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
if not eval_mode:
reward = reward + length_reward
# Check format accuracy
if format_valid:
format_acc += 1
# Check if the sequence is over length
if not eval_mode and res_length >= max_new_tokens:
reward *= 0.0
@ -199,3 +220,78 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
"response_length": res_length,
"reward": reward.item(),
}
def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"]
eval_mode = kwargs.get("eval_mode", False)
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
acc_score = 10.0
reward = torch.tensor(0.0)
format_acc = torch.tensor(0.0)
ans_acc = torch.tensor(0.0)
s, e = response_idx[0], response_idx[1]
length_reward = 0.0
res_length = e.item() - s.item() + 1
if not eval_mode:
max_new_tokens = kwargs["max_new_tokens"]
else:
max_new_tokens = -1 # for eval mode, we don't need to check the length
if not eval_mode and soft_over_length_punishment:
cache_length = kwargs["cache_length"]
if max_new_tokens - cache_length < res_length < max_new_tokens:
length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score
# try to get code solution from completion. if the completion is pure code, this will not take effect.
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
solution = decoded_final_answer.split("```python")[-1].split("```")[0]
format_valid = False
if "```python" in decoded_final_answer:
format_valid = solution is not None
# Check format accuracy
if format_valid:
format_acc += 1
try:
try:
if not isinstance(test_cases, dict):
test_cases = json.loads(test_cases)
except Exception as e:
print(f"Error {e}: Cannot parse test cases.")
raise e
# Complete check on all in-out pairs first. If there is no failure, per-sample test can be skipped.
try:
res, metadata = check_correctness_code(in_outs=test_cases, generation=solution, timeout=10, debug=True)
metadata = dict(enumerate(metadata))[0]
success = all(map(lambda x: x is True, res))
if success:
ans_acc += 1
if eval_mode or format_valid:
reward += acc_score
if not eval_mode:
reward = reward + length_reward
except Exception:
pass
# Check if the sequence is over length
if not eval_mode and res_length >= max_new_tokens:
reward *= 0.0
except Exception:
pass
if not eval_mode:
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
else:
prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True)
return {
"prompt": prompt,
"prediction": decoded_final_answer,
"gold": test_cases["outputs"],
"parsed": solution,
"format_valid": format_acc.item(),
"ans_valid": ans_acc.item(),
"response_length": res_length,
"reward": reward.item(),
}

View File

@ -2,6 +2,7 @@
Function-based reward verification module.
"""
import inspect
from typing import Any, Dict, List
import torch
@ -15,7 +16,8 @@ class VerifiableReward:
def __call__(
self,
input_ids: torch.LongTensor,
gt_answer: List[torch.Tensor] = None,
gt_answer: List[str] = None,
test_cases: List[str] = None,
response_idx: List[torch.Tensor] = None,
) -> torch.Tensor:
# Get batch size
@ -26,18 +28,44 @@ class VerifiableReward:
# Loop through reward functions
for reward_fn in self.reward_fns:
# Apply the reward function to the entire batch at once
reward_batch = torch.stack(
[
reward_fn(
input_ids[i],
gt_answer=gt_answer[i],
response_idx=response_idx[i],
**self.kwargs,
)
for i in range(bs)
],
dim=0,
)
if "gt_answer" in inspect.getfullargspec(reward_fn).args:
reward_batch = torch.stack(
[
reward_fn(
input_ids[i],
gt_answer=gt_answer[i],
response_idx=response_idx[i],
**self.kwargs,
)
for i in range(bs)
],
dim=0,
)
elif "test_cases" in inspect.getfullargspec(reward_fn).args:
reward_batch = torch.stack(
[
reward_fn(
input_ids[i],
test_cases=test_cases[i],
response_idx=response_idx[i],
**self.kwargs,
)
for i in range(bs)
],
dim=0,
)
else:
reward_batch = torch.stack(
[
reward_fn(
input_ids[i],
response_idx=response_idx[i],
**self.kwargs,
)
for i in range(bs)
],
dim=0,
)
rewards += reward_batch
return rewards

View File

@ -71,31 +71,43 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
return per_label_logps.squeeze(-1)
def calc_action_log_probs(
def memory_efficient_logprob(
logits: torch.Tensor,
sequences: torch.LongTensor,
num_actions: int,
shard_config,
inputs: torch.Tensor,
num_action: int,
chunk_size: int = 2048,
shard_config: Any = None,
vocab_size: int = None,
) -> torch.Tensor:
"""Calculate action log probs.
"""
Calculate action log probs in a memory-efficient way by processing in chunks.
Args:
logits (torch.Tensor): Output tensor of Actor.forward.logits.
sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions.
shard_config
vocab_size
inputs (torch.LongTensor): Input sequences.
num_action (int): Number of actions.
chunk_size (int, optional): Size of each chunk to process. Default is 2048.
shard_config: Shard configuration for distributed computation.
vocab_size (int, optional): Vocabulary size. Default is None.
Returns:
torch.Tensor: Action log probs.
"""
# labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
# logits: torch.Tensor, # [B, S, Vocab_size]
log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype)
log_probs = log_probs.squeeze(-1)
return log_probs[:, -num_actions:]
action_log_probs = torch.zeros((logits.size(0), num_action), device=logits.device, dtype=logits.dtype)
context_length = logits.size(1) - num_action
for i in range(action_log_probs.size(0)):
# loop over each sample in the micro-batch
for start in range(context_length, logits.size(1), chunk_size):
end = min(start + chunk_size, logits.size(1))
# calculate log probs in chunks to save memory
log_probs = dist_log_prob(
inputs[i : i + 1, start - 1 : end],
logits[i : i + 1, start - 1 : end],
shard_config,
vocab_size,
logits.dtype,
) # [1, chunk_size, 1]
log_probs = log_probs.squeeze(-1)
action_log_probs[i, start - context_length : end - context_length] += log_probs[0]
return action_log_probs
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:

View File

@ -0,0 +1,13 @@
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
# 8K context length
# rm -rf *.prof
# MAX_NEW_TOKENS=$((8192-512))
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 16 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.txt
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.png
# 4K context length
rm -rf *.prof
MAX_NEW_TOKENS=$((4096-512))
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 8 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 4 -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.txt
python visualization.py --visualization actor_timelines_ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.png

View File

@ -20,6 +20,7 @@ sentencepiece==0.2.0
tiktoken
jsonlines
math-verify==0.7.0
pyext
# The following packages be built into the image.
# torch_npu==2.5.1

View File

@ -6,6 +6,12 @@ import ray
import torch
from coati.distributed.launch import launch_distributed
DEFAUT_SYSTEM_PROMPT = {
"think_answer_tags": "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n",
"boxed": "Please reason step by step, and put your final answer within \\boxed{}.",
"code": "You are a helpful assistant.",
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
@ -61,66 +67,6 @@ if __name__ == "__main__":
default=2,
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
)
parser.add_argument(
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
)
parser.add_argument(
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
)
parser.add_argument(
"--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional"
)
# Sampling parameters
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
parser.add_argument("-temp", "--temperature", type=float, default=1.0, help="Temperature for sampling.")
parser.add_argument(
"-topk",
"--top-k",
type=int,
default=None,
help="Top k for sampling. Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-topp",
"--top-p",
type=float,
default=1.0,
help="Top p for sampling. Please check the generation arguments documentation for your backend.",
)
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.")
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
# GRPO parameters
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
parser.add_argument(
"-rt",
"--reward-type",
type=str,
default="think_answer_tags",
choices=["think_answer_tags", "boxed"],
help="Reward type for GRPO.",
)
parser.add_argument(
"-ei",
"--eval-interval",
type=int,
default=100,
help="Interval for evaluation. Evaluate every ei training steps.",
)
# Logging/Checkpointing parameters
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.")
parser.add_argument(
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
)
parser.add_argument(
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
)
parser.add_argument(
"-tp",
"--tensor-parallel-size",
@ -142,6 +88,37 @@ if __name__ == "__main__":
default=0,
help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
)
parser.add_argument(
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
)
parser.add_argument(
"--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional"
)
# Sampling parameters
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
parser.add_argument("-gmu", "--gpu-memory-utilization", type=float, default=0.7, help="Vllm gpu memory utilization")
parser.add_argument("-temp", "--temperature", type=float, default=1.0, help="Temperature for sampling.")
parser.add_argument(
"-topk",
"--top-k",
type=int,
default=None,
help="Top k for sampling. Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-topp",
"--top-p",
type=float,
default=1.0,
help="Top p for sampling. Please check the generation arguments documentation for your backend.",
)
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.")
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
parser.add_argument(
"-ptp",
"--producer-tensor-parallel-size",
@ -149,6 +126,46 @@ if __name__ == "__main__":
default=1,
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
)
# GRPO parameters
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
parser.add_argument(
"-rt",
"--reward-type",
type=str,
default="think_answer_tags",
choices=["think_answer_tags", "boxed", "code"],
help="Reward type for GRPO.",
)
parser.add_argument(
"-ei",
"--eval-interval",
type=int,
default=100,
help="Interval for evaluation. Evaluate every ei training steps.",
)
parser.add_argument(
"-nb",
"--n-behind",
type=int,
default=0,
help="Number of producer batches to rollout to fill the data buffer before trainer starts to decrease bubble time",
)
# Logging/Checkpointing parameters
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.")
parser.add_argument(
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
)
parser.add_argument(
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
)
parser.add_argument(
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
)
args = parser.parse_args()
if args.train_minibatch_size is None:
@ -167,10 +184,29 @@ if __name__ == "__main__":
if args.master_address is None:
# Default settings: Using single machine
ray.init(address="local", namespace="ray-example")
ray.init(
address="local",
namespace="ray-example",
runtime_env={
"env_vars": {
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
"TOKENIZERS_PARALLELISM": "false"
},
},
)
else:
# For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node
ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir)
ray.init(
_node_ip_address=args.master_address,
namespace="ray-example",
_temp_dir=args.ray_dir,
runtime_env={
"env_vars": {
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
"TOKENIZERS_PARALLELISM": "false"
},
},
)
if args.top_k is None:
if args.backend == "transformers":
@ -178,6 +214,8 @@ if __name__ == "__main__":
elif args.backend == "vllm":
args.top_k = -1
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
inference_model_config = dict(path=args.model)
train_model_config = dict(
path=args.model, use_flash_attention_2=False, use_cache=False, attn_implementation="eager"
@ -204,7 +242,7 @@ if __name__ == "__main__":
elif args.backend == "vllm":
inference_model_config.update(
dict(
gpu_memory_utilization=0.7,
gpu_memory_utilization=args.gpu_memory_utilization,
enforce_eager=True,
enable_chunked_prefill=True,
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
@ -213,10 +251,10 @@ if __name__ == "__main__":
)
generate_config.update(
dict(
max_tokens=args.max_new_tokens + args.max_prompt_tokens, # max new tokens
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
max_tokens=args.max_new_tokens, # max new tokens
include_stop_str_in_output=True,
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
)
)
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
@ -276,6 +314,10 @@ if __name__ == "__main__":
else:
raise ValueError(f"Unsupported algorithm: {args.algo}")
if args.system_prompt is None:
# Default system prompt
args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type]
launch_distributed(
num_producers=args.num_inferencer,
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size),
@ -290,7 +332,6 @@ if __name__ == "__main__":
"max_length": args.max_prompt_tokens,
"system_prompt": args.system_prompt,
},
dataloaders_config={},
inference_model_config=inference_model_config,
generate_config=generate_config,
num_generations=args.num_generations,
@ -329,4 +370,6 @@ if __name__ == "__main__":
eval_generation_config=eval_generation_config,
log_rollout_interval=20,
rollout_save_dir=args.rollout_save_dir,
enable_profiling=args.enable_profiling,
n_behind=args.n_behind,
)

View File

@ -0,0 +1,102 @@
# Re-import required libraries due to kernel reset
import argparse
from collections import defaultdict
import matplotlib.cm as cm
import matplotlib.pyplot as plt
# Argument parser for command line arguments
parser = argparse.ArgumentParser(description="Process profiling logs and generate a timeline plot.")
parser.add_argument("--visualization", type=str, default="actor_timelines.png", help="Path to the visualization file.")
parser.add_argument("--dir", type=str, default="./", help="Dir to the prof file.")
args = parser.parse_args()
# Raw log lines
log_lines = []
import glob
files = glob.glob(f"{args.dir}*.prof")
for file in files:
with open(file, "r") as f:
log_lines += f.readlines()
# Parse logs and collect function intervals grouped by actor
actors = defaultdict(lambda: defaultdict(list))
current_entries = {}
# First, collect all timestamps to find the minimum
all_timestamps = []
parsed_lines = []
for line in log_lines:
if line.startswith("[Log]"):
continue
parts = line.split()
timestamp = float(parts[0])
actor = parts[1]
action = parts[3]
func_name = parts[4]
parsed_lines.append((timestamp, actor, action, func_name))
all_timestamps.append(timestamp)
if not all_timestamps:
raise ValueError("No valid log entries found.")
min_timestamp = min(all_timestamps)
for timestamp, actor, action, func_name in parsed_lines:
rel_timestamp = timestamp - min_timestamp
key = (actor, func_name)
if action == "Enter":
current_entries[key] = rel_timestamp
elif action == "Exit":
start_time = current_entries.pop(key, None)
if start_time is not None:
actors[actor][func_name].append((start_time, rel_timestamp))
# Plotting setup
fig, ax = plt.subplots(figsize=(12, 6))
colors = cm.get_cmap("tab10", len(actors))
actor_offsets = {}
base_offset = 0
function_spacing = 0.9
yticks = []
yticklabels = []
for idx, (actor, func_dict) in enumerate(actors.items()):
actor_offsets[actor] = base_offset
color = colors(idx)
for j, (func, intervals) in enumerate(func_dict.items()):
differences = [y - x for x, y in intervals]
print(actor, func, intervals, differences)
y_val = base_offset + j * function_spacing
yticks.append(y_val)
yticklabels.append(f"{actor}:{func}")
for start, end in intervals:
if end - start < 1:
end = start + 1 # Ensure all lines are at least 3 units long
ax.plot(
[start, end],
[y_val, y_val],
color=color,
linewidth=2,
label=actor if j == 0 else "",
)
base_offset += len(func_dict) * function_spacing + 1
# Formatting
ax.set_yticks(yticks)
ax.set_yticklabels(yticklabels)
ax.set_xlabel("Time")
ax.set_title("Timeline per Actor")
# Remove duplicate labels in legend
handles, labels = ax.get_legend_handles_labels()
unique = dict(zip(labels, handles))
ax.legend(unique.values(), unique.keys())
plt.tight_layout()
plt.grid(True)
plt.savefig(args.visualization, dpi=600) # Increase dpi for higher resolution
print(f"Plot saved as {args.visualization}")