add code for zero-bubble implementation

This commit is contained in:
YeAnbang 2025-07-09 11:21:43 +08:00
parent b1f646c7e7
commit 509274c47e
8 changed files with 2267 additions and 11 deletions

View File

@ -1,5 +1,6 @@
from typing import Any, Dict
import copy
import ray
import ray.util.collective as cc
import torch
import torch.distributed.distributed_c10d as c10d
@ -30,11 +31,18 @@ def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str =
obj = c10d._tensor_to_object(obj, size_tensor.item())
return obj
def ray_broadcast_tensor_dict(
tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default"
tensor_dict: Dict[str, torch.Tensor],
src: int = 0,
device=None,
group_name: str = "default",
backend: str = "nccl",
offload_to_cpu: bool = False,
pin_memory: bool = False,
) -> Dict[str, torch.Tensor]:
rank = cc.get_rank(group_name)
if tensor_dict is None:
tensor_dict = {}
if rank == src:
metadata = []
for k, v in tensor_dict.items():
@ -42,16 +50,103 @@ def ray_broadcast_tensor_dict(
else:
metadata = None
metadata = ray_broadcast_object(metadata, src, device, group_name)
if rank != src:
out_dict = {}
for k, shape, dtype in metadata:
if rank == src:
tensor = tensor_dict[k]
if offload_to_cpu:
tensor = tensor_dict[k].to(device)
else:
tensor = tensor_dict[k]
else:
tensor = torch.empty(shape, dtype=dtype, device=device)
tensor = tensor_dict.get(k, torch.zeros(shape, dtype=dtype, device=device, pin_memory=pin_memory))
if backend == "gloo" and dtype == torch.bfloat16:
# Gloo does not support bfloat16, convert to float16
tensor = tensor.view(torch.float16)
cc.broadcast(tensor, src, group_name)
if backend == "gloo" and dtype == torch.bfloat16:
# Convert back to bfloat16 if it was converted to float16
tensor = tensor.view(torch.bfloat16)
if rank != src:
out_dict[k] = tensor
if rank == src:
out_dict = tensor_dict
return out_dict
if offload_to_cpu:
tensor_dict[k] = tensor.cpu()
else:
tensor_dict[k] = tensor
return tensor_dict
@ray.remote
class SharedVariableActor:
def __init__(self, number_of_readers: int = 0, buffer_size_limit: int = 1000):
self.data_queue = []
self.data_uid = 0
self.number_of_readers = number_of_readers
self.queue_size = 0
self.signals = {}
self.process_locks = {}
self.signal_procs_meet_count = {}
self.buffer_size_limit = buffer_size_limit
def pickup_rollout_task(self, num_tasks: int):
"""
use queue size to control whether producers should generating new rollouts or wait
for consumer to consumer more data. if queue size is less than threshold,
it means consumer is consuming data fast enough, so producers can generate new rollouts.
if queue size is greater than threshold, it means consumer is consuming data slowly,
so producers should wait for consumer to consume more data.
Any free producer can pick up the task to generate rollout then increase the queued_data_size
to prevent other producer to pick up the task redundantly, Note it is not the real
queue length as data may still be generating
"""
ret = False
if self.queue_size < self.buffer_size_limit:
ret = True
self.queue_size += num_tasks
return ret
def append_data(self, data):
self.data_queue.append([self.data_uid, data, 0]) # [data_uid, data, access_count]
self.data_uid += 1
return True
def get_data(self, data_uid: int):
# for multi-process data reading
if not self.data_queue:
# no data in the queue, return None
return None
to_pop_index = None
ret = None
for i, (uid, data, access_count) in enumerate(self.data_queue):
if uid == data_uid:
# found the data with the given uid
self.data_queue[i][2] += 1
ret = copy.deepcopy(data)
if self.data_queue[i][2] == self.number_of_readers:
to_pop_index = i
break
if to_pop_index is not None:
# remove the data from the queue if it has been accessed by all readers
self.data_queue.pop(to_pop_index)
self.queue_size -= data["input_ids"].size(0)
return ret
def acquire_process_lock(self, key: str):
# atomic lock for process
if key not in self.process_locks:
self.process_locks[key] = 1 # locked
return 0
if self.process_locks[key] == 0:
self.process_locks[key] = 1 # lock the process
return 0
else:
return 1
def release_process_lock(self, key: str):
# atomic unlock for process
assert self.process_locks.get(key, 0) == 1, f"Releasing a process lock {key} that is not locked."
self.process_locks[key] = 0
def set_signal(self, key: str, signal: str):
self.signals[key] = signal
def get_signal(self):
return self.signals

View File

@ -0,0 +1,306 @@
import copy
import os
import uuid
from typing import Any, Dict, Optional
import ray
from .comm import SharedVariableActor
from .zero_bubble.distributor import Distributor
from .zero_bubble.grpo_consumer import GRPOConsumer
from .zero_bubble.producer import SimpleProducer
ALGO_MAP = {"GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
def get_jsonl_size_fast(path: str) -> int:
with open(path) as f:
lines = f.readlines()
lines = [line for line in lines if line.strip()]
return len(lines)
def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:
tp_size = plugin_config.get("tp_size", 1)
pp_size = plugin_config.get("pp_size", 1)
ep_size = plugin_config.get("ep_size", 1)
sp_size = plugin_config.get("sp_size", 1)
return n_procs // (tp_size * pp_size * ep_size * sp_size)
def launch_distributed(
num_producers: int,
num_proc_per_producer: int,
num_consumer_procs: int,
num_episodes: int,
inference_batch_size: int,
inference_microbatch_size: int,
train_batch_size: int,
train_minibatch_size: int,
train_dataset_config: Dict[str, Any],
inference_model_config: Dict[str, Any],
generate_config: Dict[str, Any],
train_model_config: Dict[str, Any],
grpo_config: Dict[str, Any],
plugin_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
inference_backend: str = "transformers",
num_generations: int = 8,
master_addr: str = "localhost",
master_port: int = 29500,
core_algo: str = "GRPO",
project_name: Optional[str] = None,
save_interval: int = 100,
save_dir: str = "./model",
eval_dataset_config: Optional[Dict[str, Any]] = None,
eval_interval: int = 100,
eval_save_dir: Optional[str] = None,
eval_generation_config: Optional[Dict[str, Any]] = None,
log_rollout_interval: int = 20,
rollout_save_dir: str = "./rollout",
enable_profiling: bool = False,
data_actor_buffer_size_limit: int = 0,
):
if core_algo not in ALGO_MAP:
raise NotImplementedError(f"{core_algo} is not supported yet.")
else:
core_consumer = ALGO_MAP.get(core_algo, GRPOConsumer)
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
if data_actor_buffer_size_limit <= 0:
# use 2 times the train_minibatch_size as the default buffer size limit
data_actor_buffer_size_limit = train_minibatch_size * train_dp_size * 2
dataset_path = train_dataset_config["path"]
train_dataset_size = get_jsonl_size_fast(dataset_path)
global_inference_batch_size = inference_batch_size * num_producers
train_dataset_size = (train_dataset_size // global_inference_batch_size) * global_inference_batch_size
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
wandb_group_name = str(uuid.uuid4())
rollout_log_file = os.path.join(
rollout_save_dir,
f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl",
)
# Attention: Ray use complex schedualing method that consider various factors including load-balancing.
# when requesting resources, it is not guaranteed that the resource comes from a node with lower node it
# this go against the design principle of our implementation, and we need to manually force the schedualing,
# allocating the producer to nodes with lower node id and the consumer to the resouces from nodes with higher
# node id. See the reference here: https://docs.ray.io/en/latest/ray-core/scheduling/index.html#nodeaffinityschedulingstrategy
nodes = ray.nodes()
# every producer is associated with a data worker, data worker is responsible for moving data from the producer to all consumer
shared_sync_data_actor = SharedVariableActor.remote(num_consumer_procs, data_actor_buffer_size_limit)
# all producer and the consumer 0 share the same model actor, model actor only provide signal for model synchronization
shared_signal_actor = SharedVariableActor.remote()
node_info = {
node["NodeID"]: {
"num_gpus": node["Resources"].get("GPU", 0),
"address": node["NodeManagerAddress"],
} # Default to 0 if no GPUs are available
for node in nodes
}
gpu_to_node_id = []
gpu_to_ip_address = []
for node_id in node_info:
for idx in range(int(node_info[node_id]["num_gpus"])):
gpu_to_node_id.append(node_id)
gpu_to_ip_address.append(node_info[node_id]["address"])
print(node_info)
producer_procs = []
for i in range(num_producers):
node_id = gpu_to_node_id[0]
producer_ip_address = gpu_to_ip_address[0]
for _ in range(num_proc_per_producer):
gpu_to_node_id.pop(0)
gpu_to_ip_address.pop(0)
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
producer = SimpleProducer.options(num_gpus=num_proc_per_producer, num_cpus=4).remote(
shared_sync_data_actor=shared_sync_data_actor,
shared_signal_actor=shared_signal_actor,
producer_idx=i,
num_producers=num_producers,
num_consumer_procs=num_consumer_procs,
num_episodes=num_episodes,
batch_size=inference_batch_size,
train_dataset_config=train_dataset_config,
model_config=inference_model_config,
generate_config=generate_config,
tokenizer_config=tokenizer_config,
microbatch_size=inference_microbatch_size,
backend=inference_backend,
num_generations=num_generations,
consumer_plugin_config=plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
grpo_config=grpo_config,
eval_save_dir=eval_save_dir,
eval_generation_config=eval_generation_config,
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
enable_profiling=enable_profiling,
)
producer_procs.append(producer)
# ray.get([p.setup.remote() for p in producer_procs])
generate_config_consumer = copy.deepcopy(generate_config)
generate_config_consumer.update(
dict(
backend=inference_backend,
)
)
consumer_master_ip_address = gpu_to_ip_address[0]
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
consumer_procs = []
if num_consumer_procs <= 1:
raise ValueError("Number of consumer processes should be greater than 1 for async rl training.")
for i in range(num_consumer_procs):
node_id = gpu_to_node_id[0]
consumer_ip_address = gpu_to_ip_address[0]
gpu_to_node_id.pop(0)
gpu_to_ip_address.pop(0)
print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}")
consumer = core_consumer.options(num_gpus=1, num_cpus=4).remote(
shared_sync_data_actor=shared_sync_data_actor,
shared_signal_actor=shared_signal_actor,
num_producers=num_producers,
num_episodes=num_episodes,
rank=i,
world_size=num_consumer_procs,
master_addr=consumer_master_ip_address,
master_port=master_port,
train_dataset_size=train_dataset_size,
batch_size=train_batch_size,
model_config=train_model_config,
plugin_config=plugin_config,
minibatch_size=train_minibatch_size,
generate_config=generate_config_consumer,
grpo_config=grpo_config,
num_generations=num_generations,
save_interval=save_interval,
save_dir=save_dir,
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
enable_profiling=enable_profiling,
)
consumer_procs.append(consumer)
distributor_procs = []
for i in range(num_producers):
distributor_procs.append(
Distributor.options(num_cpus=2).remote(
i,
plugin_config.get("pp_size", 1),
num_producers,
shared_signal_actor,
enable_profiling=enable_profiling,
)
)
print("=================== All processes are created, starting setup torch DDP ===================", flush=True)
ray.get([p.setup.remote() for p in consumer_procs])
print(
"=================== All processes are setup, starting initialize communication groups ===================",
flush=True,
)
remote_refs = []
# Initialize consumer communication group
for i, p in enumerate(consumer_procs):
remote_refs.append(p.init_collective_group.remote(num_consumer_procs, i, "gloo", f"consumer_pg"))
ray.get(remote_refs)
remote_refs = []
# Initialize producer communication group
for i, p in enumerate(producer_procs):
remote_refs.append(p.init_collective_group.remote(num_producers, i, "nccl", f"producer_pg"))
ray.get(remote_refs)
remote_refs = []
# Initialize distributor communication group
for i, p in enumerate(distributor_procs):
remote_refs.append(p.init_collective_group.remote(num_producers, i, "gloo", f"distributor_pg"))
ray.get(remote_refs)
remote_refs = []
# Initialize sync model communication group between consumer and sync model actor
# As per tested, gloo do not support nested initialization, so we need to initialize all participants in the same group in the same ray.get call.
consumer_pp = plugin_config.get("pp_size", 1)
for i, p in enumerate(consumer_procs):
consumer_ddp_config = ray.get(p.get_ddp_config.remote())
if consumer_pp > 1:
if consumer_ddp_config["tp_rank"] == 0 and consumer_ddp_config["dp_rank"] == 0:
pp_rank = consumer_ddp_config["pp_rank"]
remote_refs.append(
p.init_collective_group.remote(
num_producers + 1,
0,
backend="gloo",
group_name=f"sync_model_consumer_pp_{pp_rank}",
gloo_timeout=3000000,
)
)
for distributor_id, p_distributor in enumerate(distributor_procs):
remote_refs.append(
p_distributor.init_collective_group.remote(
num_producers + 1,
1 + distributor_id,
backend="gloo",
group_name=f"sync_model_consumer_pp_{pp_rank}",
gloo_timeout=3000000,
)
)
ray.get(remote_refs)
remote_refs = []
else:
if i == 0:
remote_refs.append(
p.init_collective_group.remote(
num_producers + 1, 0, backend="gloo", group_name=f"sync_model_consumer", gloo_timeout=3000000
)
)
for distributor_id, p_distributor in enumerate(distributor_procs):
remote_refs.append(
p_distributor.init_collective_group.remote(
num_producers + 1,
1 + distributor_id,
backend="gloo",
group_name=f"sync_model_consumer",
gloo_timeout=3000000,
)
)
ray.get(remote_refs)
remote_refs = []
# Initialize sync model communication group between producer and sync model actor
for i, p in enumerate(producer_procs):
if consumer_pp > 1:
for pp_rank in range(consumer_pp):
remote_refs.append(
p.init_collective_group.remote(
2, 0, backend="gloo", group_name=f"sync_model_producer_{i}_pp_{pp_rank}", gloo_timeout=3000000
)
)
remote_refs.append(
distributor_procs[i].init_collective_group.remote(
2, 1, backend="gloo", group_name=f"sync_model_producer_{i}_pp_{pp_rank}", gloo_timeout=3000000
)
)
ray.get(remote_refs)
remote_refs = []
else:
remote_refs.append(
p.init_collective_group.remote(
2, 0, backend="gloo", group_name=f"sync_model_producer_{i}", gloo_timeout=3000000
)
)
remote_refs.append(
distributor_procs[i].init_collective_group.remote(
2, 1, backend="gloo", group_name=f"sync_model_producer_{i}", gloo_timeout=3000000
)
)
ray.get(remote_refs)
remote_refs = []
print("=================== All processes are set up, starting loop ===================", flush=True)
ray.get([p.loop.remote() for p in (producer_procs + consumer_procs + distributor_procs)])

View File

@ -0,0 +1,347 @@
import os
import threading
import time
from typing import Any, Dict, Optional
import ray
import ray.util.collective as cc
import torch
import torch.distributed as dist
from coati.distributed.profiling_utils import CustomProfiler
from tqdm import tqdm
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.initialize import launch
from colossalai.utils import get_current_device
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
from coati.distributed.utils import bind_batch, post_recv, unbind_batch
class BaseConsumer:
def __init__(
self,
shared_sync_data_actor: SharedVariableActor,
shared_signal_actor: SharedVariableActor,
num_producers: int,
num_episodes: int,
rank: int,
world_size: int,
master_addr: str,
master_port: int,
train_dataset_size: int,
batch_size: int,
model_config: Dict[str, Any],
plugin_config: Dict[str, Any],
minibatch_size: int = 1,
save_interval: int = 100,
save_dir: str = "./model",
enable_profiling: bool = False,
):
self.num_producers = num_producers
self.num_episodes = num_episodes
self.rank = rank
self.world_size = world_size
self.master_addr = master_addr
self.master_port = master_port
self.train_dataset_size = train_dataset_size
self.received_prompts = 0
self.batch_size = batch_size
self.minibatch_size = minibatch_size
self.save_interval = save_interval
self.save_dir = save_dir
self.enable_profiling = enable_profiling
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // minibatch_size
self.data_uid = 0
self.sync_model_thread_started = False
self.model_config = model_config
self.plugin_config = plugin_config
self.device = get_current_device()
self.lr_scheduler = None
self.shared_sync_data_actor = shared_sync_data_actor
self.shared_signal_actor = shared_signal_actor
self.state_dict_cpu = {}
def setup(self) -> None:
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
if (
self.plugin_config.get("pp_size", 1) > 1
and "num_microbatches" not in self.plugin_config
and "microbatch_size" not in self.plugin_config
):
plugin_config["microbatch_size"] = max(1, self.minibatch_size // plugin_config.get("pp_size", 1))
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin)
self.dp_rank = dist.get_rank(self.plugin.dp_group)
self.tp_rank = dist.get_rank(self.plugin.tp_group)
self.pp_rank = dist.get_rank(self.plugin.pp_group)
self.dp_size = dist.get_world_size(self.plugin.dp_group)
self.tp_size = dist.get_world_size(self.plugin.tp_group)
self.pp_size = dist.get_world_size(self.plugin.pp_group)
self.buffer = []
self.recv_cnt = 0
self.profiler = CustomProfiler(f"C{self.rank}", disabled=not self.enable_profiling)
def get_ddp_config(self) -> Dict[str, Any]:
"""
Get the DDP configuration for the consumer.
This method is used to get the DDP configuration for the consumer.
"""
return {
"dp_size": self.dp_size,
"tp_size": self.tp_size,
"pp_size": self.pp_size,
"dp_rank": self.dp_rank,
"tp_rank": self.tp_rank,
"pp_rank": self.pp_rank,
"world_size": self.world_size,
"rank": self.rank,
}
def init_collective_group(
self,
world_size: int,
rank: int,
backend: str = "nccl",
group_name: str = "default",
gloo_timeout: int = 3000000,
):
cc.init_collective_group(
world_size=world_size, rank=rank, backend=backend, group_name=group_name, gloo_timeout=gloo_timeout
)
print(f"[C{self.rank}] Initialized {group_name} collective group", flush=True)
def state_dict(self) -> Dict[str, torch.Tensor]:
raise NotImplementedError
def step(self, **kwargs) -> Optional[float]:
raise NotImplementedError
def prepare_mini_batch(self, effective_group_to_raw_group_mapping: Dict[int, int]) -> Dict[str, torch.Tensor]:
"""
Prepare a mini-batch from the effective group to raw group mapping.
This method is used to create a mini-batch for training.
"""
batches = [
self.buffer[effective_group_to_raw_group_mapping[i]]
for i in range(self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size)
]
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
# each mini-batch use the first self.dp_size * minibatch_size effective samples
raw_mini_batches = self.buffer[
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
] # include the last effective sample
raw_mini_batches_metric_dict = {
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
}
batch = bind_batch([t[0] for t in batches])
batch = post_recv(batch)
return batch, raw_mini_batches_metric_dict
def calculate_effective_group_to_raw_group_mapping(self):
effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
return effective_group_to_raw_group_mapping
def loop(self) -> None:
print(f"Consumer{self.rank}, nmb: {self.num_microbatches}")
for episode in range(self.num_episodes):
with tqdm(
range(self.train_dataset_size),
desc=f"Episode {episode} with rollout step(s)",
disable=self.rank != 0,
) as pbar:
while self.received_prompts < self.train_dataset_size:
torch.cuda.reset_peak_memory_stats()
effective_group_to_raw_group_mapping = {}
self.profiler.enter(f"recv_data")
while len(effective_group_to_raw_group_mapping) < self.dp_size * self.minibatch_size:
# receive data from producers
raw_batch = ray.get(
self.shared_sync_data_actor.get_data.remote(self.data_uid)
) # get the first queued data
while raw_batch is None:
self.profiler.log(f"No data received by consumer {self.rank}, skipping")
print(
f"[T{dist.get_rank()}] No data received by consumer {self.rank}, skipping. Consider increasing the data actor buffer limit"
)
time.sleep(1)
raw_batch = ray.get(self.shared_sync_data_actor.get_data.remote(self.data_uid))
continue
self.data_uid += 1
raw_batch = {k: v.to(self.device) for k, v in raw_batch.items()}
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
# we need to calculate the metrics before filtering here for logging
# [batch_size, num_generations] -> [batch_size]
reward = raw_batch["reward"][:, :, 0]
format_acc = raw_batch["format_acc"][:, :, 0]
ans_acc = raw_batch["ans_acc"][:, :, 0]
response_len = (
raw_batch["response_idx"][:, :, 1] - raw_batch["response_idx"][:, :, 0] + 1
).type(torch.float32)
effective_group_mask = None
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
# filter the group based on the reward and accuracy
group_ans_acc_mean = ans_acc.mean(dim=1)
effective_group_mask = torch.logical_and(
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
)
raw_batch = unbind_batch(raw_batch) # List[Dict[str, torch.Tensor]]
self.received_prompts += len(raw_batch)
pbar.update(len(raw_batch))
for group_idx, group_with_reward in enumerate(raw_batch):
self.buffer.append(
[
(
group_with_reward
if effective_group_mask is None or effective_group_mask[group_idx]
else None
),
reward[group_idx],
format_acc[group_idx],
ans_acc[group_idx],
response_len[group_idx],
]
)
if effective_group_mask is not None:
print(
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
)
# mapping the effective group to the raw group for indexing
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
print(
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
)
self.profiler.exit(f"recv_data")
need_sync_model = False
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
# after we have enough effective groups, we can start training
# on each dp_rank, we use minibatch_size effective samples to form a batch
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
effective_group_to_raw_group_mapping
)
self.profiler.enter("step")
loss = self.step(pbar, **batch, **raw_mini_batches_metric_dict)
self.profiler.exit("step")
self.buffer = self.buffer[
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
]
# recalculate the effective group to raw group mapping
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
assert (
len(effective_group_to_raw_group_mapping)
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
)
# cc.barrier(group_name="consumer_pg")
if loss is not None:
pbar.set_postfix({"loss": loss})
need_sync_model = True
ray.get(self.shared_signal_actor.set_signal.remote("global_step", self.global_step + 1))
if need_sync_model and (
(self.global_step + 1) % self.save_interval == 0
or self.received_prompts >= self.train_dataset_size
):
if self.rank == 0:
print(f"Start saving policy model at step {self.global_step + 1}.")
save_path = os.path.join(
self.save_dir, f"modeling-episode-{episode}-step-{self.global_step + 1}"
)
self.booster.save_model(self.policy_model, save_path, shard=True)
if self.rank == 0:
print(f"Saved model checkpoint at step {self.global_step + 1} in folder {save_path}")
if need_sync_model and (
episode != self.num_episodes - 1 or self.received_prompts != self.train_dataset_size
):
def sync_model_thread():
# sync model weights to all producers, if no model update or it is the last training step, skip syncing
if self.pp_size > 1:
print(
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}"
)
else:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}")
torch.cuda.empty_cache()
if self.pp_size > 1:
if self.tp_rank == 0 and self.dp_rank == 0:
self.profiler.enter("sync_model")
ray.get(
self.shared_signal_actor.set_signal.remote(
f"consumer_pp_{self.pp_rank}", "ready_sync_model"
)
)
print(
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}"
)
ray_broadcast_tensor_dict(
self.state_dict_cpu,
src=0,
device=torch.device("cpu"),
group_name=f"sync_model_consumer_pp_{self.pp_rank}",
backend="gloo",
)
self.profiler.exit("sync_model")
else:
if self.rank == 0:
self.profiler.enter("sync_model")
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "ready_sync_model"))
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}")
ray_broadcast_tensor_dict(
self.state_dict_cpu,
src=0,
device=torch.device("cpu"),
group_name="sync_model_consumer",
backend="gloo",
)
self.profiler.exit("sync_model")
if not self.sync_model_thread_started:
# only sync model when the thread is not started and no other thread is broadcasting
self.sync_model_thread_started = True
state_dict_ = self.state_dict()
if (self.pp_size > 1 and self.tp_rank == 0 and self.dp_rank == 0) or (
self.pp_size == 1 and self.rank == 0
):
if len(self.state_dict_cpu) == 0:
# use pinned memory to speed up the transfer
self.state_dict_cpu = {k: v.cpu().pin_memory() for k, v in state_dict_.items()}
torch.cuda.synchronize()
for k, v in state_dict_.items():
self.state_dict_cpu[k].copy_(v, non_blocking=True)
torch.cuda.synchronize()
cc.barrier(
group_name="consumer_pg"
) # to make sure all ranks have state dict offloaded to CPU before starting the thread
time_before_starting_thread = time.time()
threading.Thread(target=sync_model_thread).start()
# sync_model_thread()
self.profiler.log(
f"Sync model, took {time.time() - time_before_starting_thread:.2f} seconds"
)
self.sync_model_thread_started = False
# ray.get(self.shared_signal_actor.release_process_lock.remote("broadcasting_lock"))
self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
self.received_prompts = 0
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "terminate"))
def __del__(self):
if hasattr(self, "profiler"):
self.profiler.close()

View File

@ -0,0 +1,108 @@
import time
import ray
import ray.util.collective as cc
import torch
from coati.distributed.profiling_utils import CustomProfiler
from colossalai.utils import get_current_device
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
@ray.remote
class Distributor:
def __init__(
self,
distributor_id,
consumer_pp_size,
num_producers,
shared_signal_actor: SharedVariableActor,
enable_profiling: bool = True,
):
self.distributor_id = distributor_id
self.consumer_pp_size = consumer_pp_size
self.state_dict_cpu = {}
self.num_producers = num_producers
self.shared_signal_actor = shared_signal_actor
self.device = get_current_device()
self.profiler = CustomProfiler(f"D{self.distributor_id}", disabled=not enable_profiling)
def init_collective_group(
self,
world_size: int,
rank: int,
backend: str = "nccl",
group_name: str = "default",
gloo_timeout: int = 3000000,
):
cc.init_collective_group(
world_size=world_size, rank=rank, backend=backend, group_name=group_name, gloo_timeout=gloo_timeout
)
print(f"[D] Initialized {group_name} collective group", flush=True)
def loop(self):
while True:
time.sleep(1)
signal = ray.get(self.shared_signal_actor.get_signal.remote())
if self.consumer_pp_size > 1:
for i in range(self.consumer_pp_size):
if signal.get(f"consumer_pp_{i}", None) == "ready_sync_model":
self.profiler.enter(f"sync_model_consumer_pp_{i}")
cc.barrier(group_name="distributor_pg")
ray.get(self.shared_signal_actor.set_signal.remote(f"consumer_pp_{i}", "not_ready_sync_model"))
# Broadcast the model state dict from consumer to shared variable actor
self.state_dict_cpu[i] = ray_broadcast_tensor_dict(
None,
0,
device=torch.device("cpu"),
group_name=f"sync_model_consumer_pp_{i}",
backend="gloo",
)
self.profiler.exit(f"sync_model_consumer_pp_{i}")
for i in range(self.consumer_pp_size):
if signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model":
self.profiler.enter(f"sync_model_producer_{self.distributor_id}_pp_{i}")
# Broadcast the model state dict to all producers
ray.get(
self.shared_signal_actor.set_signal.remote(
f"producer_{self.distributor_id}_pp_{i}", "not_ready_sync_model"
)
)
ray_broadcast_tensor_dict(
self.state_dict_cpu[i],
1,
device=torch.device("cpu"),
group_name=f"sync_model_producer_{self.distributor_id}_pp_{i}",
backend="gloo",
)
self.profiler.exit(f"sync_model_producer_{self.distributor_id}_pp_{i}")
else:
if signal.get("consumer", None) == "ready_sync_model":
self.profiler.enter("sync_model_consumer")
cc.barrier(group_name="distributor_pg")
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "not_ready_sync_model"))
# Broadcast the model state dict from consumer to shared variable actor
self.state_dict_cpu = ray_broadcast_tensor_dict(
None, 0, device=torch.device("cpu"), group_name="sync_model_consumer", backend="gloo"
)
self.profiler.exit("sync_model_consumer")
if signal.get(f"producer_{self.distributor_id}", None) == "ready_sync_model":
self.profiler.enter(f"sync_model_producer_{self.distributor_id}")
# Broadcast the model state dict to all producers
ray.get(
self.shared_signal_actor.set_signal.remote(
f"producer_{self.distributor_id}", "not_ready_sync_model"
)
)
ray_broadcast_tensor_dict(
self.state_dict_cpu,
1,
device=torch.device("cpu"),
group_name=f"sync_model_producer_{self.distributor_id}",
backend="gloo",
)
self.profiler.exit(f"sync_model_producer_{self.distributor_id}")
if signal.get("consumer", None) == "terminate":
self.profiler.log("terminate sync model worker")
break

View File

@ -0,0 +1,498 @@
from contextlib import nullcontext
from typing import Any, Optional
import ray
import torch
import wandb
from coati.distributed.comm import SharedVariableActor
from coati.distributed.zero_bubble.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.utils import memory_efficient_logprob
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
@ray.remote
class GRPOConsumer(BaseConsumer):
def __init__(
self,
shared_sync_data_actor: SharedVariableActor,
shared_signal_actor: SharedVariableActor,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
train_dataset_size,
batch_size,
model_config,
plugin_config,
minibatch_size=1,
num_generations=8,
generate_config=None,
grpo_config={},
save_interval: int = 100,
save_dir="./model",
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
enable_profiling: bool = False,
):
print(f"Using GRPO config: {grpo_config}")
if (
plugin_config.get("pp_size", 1) > 1
and "num_microbatches" not in plugin_config
and "microbatch_size" not in plugin_config
):
plugin_config["microbatch_size"] = max(
1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1)
)
super().__init__(
shared_sync_data_actor,
shared_signal_actor,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
train_dataset_size,
batch_size,
model_config,
plugin_config,
minibatch_size,
save_interval=save_interval,
save_dir=save_dir,
enable_profiling=enable_profiling,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
self.accum_advantages = torch.zeros(1, device=self.device)
self.raw_train_batch_reward = []
self.raw_train_batch_format_acc = []
self.raw_train_batch_ans_acc = []
self.raw_train_batch_response_len = []
self.accum_count = 0
self.generate_config = generate_config
self.grpo_config = grpo_config
self.project_name = project_name
self.effective_sample_count = 0
self.effective_prompt_count = 0
self.project_name = project_name
self.run_name = run_name
self.wandb_group_name = wandb_group_name
self.policy_loss_fn = PolicyLoss(
clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
clip_eps_high=grpo_config.get("clip_eps_high", 0.2),
beta=grpo_config.get("beta", 0.01),
loss_variation=grpo_config.get("loss_variation", "sample_level"),
)
# Reference model is initialized from policy model.
if self.policy_loss_fn.beta > 0:
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.reference_model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations
self.filter_range = grpo_config.get("filter_range", None)
if self.filter_range is not None:
assert len(self.filter_range) == 2, "Filter range should have 2 values."
self.filter_truncated_response = grpo_config.get("filter_truncated_response", False)
if self.filter_truncated_response:
self.max_length = 0
if "max_tokens" in self.generate_config:
self.max_length = self.generate_config["max_tokens"]
elif "max_new_tokens" in self.generate_config:
self.max_length = self.generate_config["max_new_tokens"]
else:
raise ValueError(
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
)
# Initialize verifiable reward.
grpo_config.get("response_format_tags", None)
self.global_step = 0
def setup(self):
super().setup()
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1
and self.booster.plugin.stage_manager.is_last_stage()
and self.tp_rank == 0
and self.dp_rank == 0
):
self.wandb_run = wandb.init(
project=self.project_name,
sync_tensorboard=False,
dir="./wandb",
name=self.run_name,
group=self.wandb_group_name,
)
self.lr_scheduler = CosineAnnealingWarmupLR(
optimizer=self.optimizer,
total_steps=min(self.num_episodes, 4) * self.train_dataset_size // (self.batch_size * self.dp_size),
warmup_steps=0,
eta_min=0.1 * self.grpo_config.get("lr", 1e-6),
)
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
)
if self.policy_loss_fn.beta > 0:
self.reference_model, *_ = self.booster.boost(self.reference_model)
self.plugin.logger.set_level("ERROR")
def step(self, pbar: Any, **kwargs) -> Optional[float]:
"""
Step data from policy model:
[{
"input_ids": torch.Tensor,
"attention_mask": torch.Tensor,
"action_mask": torch.Tensor,
"action_log_probs": torch.Tensor,
},
...]
Format:
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
"""
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items() if "raw_train_mini_batch_" not in k}
self.raw_train_batch_reward.extend(kwargs["raw_train_mini_batch_reward"])
self.raw_train_batch_format_acc.extend(kwargs["raw_train_mini_batch_format_acc"])
self.raw_train_batch_ans_acc.extend(kwargs["raw_train_mini_batch_ans_acc"])
self.raw_train_batch_response_len.extend(kwargs["raw_train_mini_batch_response_len"])
action_mask = data["action_mask"]
num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"]
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
reward = data["reward"].view((-1))
format_acc = data["format_acc"].view((-1))
ans_acc = data["ans_acc"].view((-1))
# [minibatch_size, num_generations]
group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# [minibatch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [minibatch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
# [minibatch_size x num_of_generation]
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
# filter out overlength samples
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
loss_mask = torch.logical_and(
loss_mask,
action_mask[:, -1] == False,
)
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False) == False:
# filter out samples with reward outside the range
# if dynamic batching is enabled, we filter out out of range groups before training
group_ans_acc_mean = (
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1)
)
loss_mask = torch.logical_and(
loss_mask,
torch.logical_and(
group_ans_acc_mean > self.filter_range[0],
group_ans_acc_mean < self.filter_range[1],
),
)
self.effective_prompt_count += (
group_reward.size(0) * self.dp_size
) # all prompts in the batch are effective as we filtered out the bad ones before step.
mean_kl, mean_loss = [], []
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
self.effective_sample_count += effective_samples.item()
pbar.set_postfix(
{
"Global Step": self.global_step,
"Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples",
}
)
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
ctx = (
nullcontext()
if need_update or self.booster.plugin.zero_stage == 2
else self.booster.no_sync(self.policy_model, self.optimizer)
)
with ctx:
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
input_ids_forward_micro_batch = data["input_ids"][
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
]
attention_mask_forward_micro_batch = data["attention_mask"][
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
]
action_mask_forward_micro_batch = action_mask[
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
]
loss_mask_forward_micro_batch = (
loss_mask[forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size]
if loss_mask is not None
else None
)
advantages_forward_micro_batch = advantages[
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
]
if self.plugin.pp_size > 1:
# Support training with PP.
if self.policy_loss_fn.beta > 0:
with torch.no_grad():
reference_model_outputs = self.booster.execute_pipeline(
iter(
[
{
"input_ids": input_ids_forward_micro_batch,
"attention_mask": attention_mask_forward_micro_batch,
}
]
),
self.reference_model,
criterion=lambda outputs, inputs: torch.tensor(
[0.0], device=action_mask.device
), # dummy criterion
optimizer=None,
return_loss=False,
return_outputs=True,
)
if self.booster.plugin.stage_manager.is_last_stage():
reference_action_log_probs = memory_efficient_logprob(
reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
shard_config=self.plugin.shard_config,
)
else:
# Dummy reference logprobs for data iterator.
reference_action_log_probs = None
else:
reference_action_log_probs = None
data_policy_forward = {
"input_ids": input_ids_forward_micro_batch,
"attention_mask": attention_mask_forward_micro_batch,
"action_mask": action_mask_forward_micro_batch,
"advantages": advantages_forward_micro_batch,
"loss_mask": loss_mask_forward_micro_batch,
"source": self.rank,
}
if reference_action_log_probs is not None:
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
kl = []
def _criterion(outputs, inputs):
action_logits = outputs.logits
action_log_probs = memory_efficient_logprob(
action_logits / self.generate_config["temperature"],
inputs["input_ids"],
num_action,
shard_config=self.plugin.shard_config,
)
if "reference_action_log_probs" in inputs:
per_token_kl = (
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
- (inputs["reference_action_log_probs"] - action_log_probs)
- 1
)
appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
inputs["action_mask"], dim=-1
)
kl.append(appox_kl.mean())
else:
per_token_kl = 0.0
kl.append(torch.tensor(0.0))
loss, _ = self.policy_loss_fn(
action_log_probs,
action_log_probs,
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
inputs["action_mask"],
loss_mask=inputs["loss_mask"],
total_effective_tokens_in_batch=total_effective_tokens_count,
)
return loss
policy_model_outputs = self.booster.execute_pipeline(
iter([data_policy_forward]),
self.policy_model,
criterion=_criterion,
optimizer=self.optimizer,
return_loss=True,
return_outputs=False,
)
loss = policy_model_outputs["loss"]
if self.booster.plugin.stage_manager.is_last_stage():
if len(kl) > 0:
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
mean_kl.append(kl)
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
else:
policy_model_logits = self.policy_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
action_log_probs = memory_efficient_logprob(
policy_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
shard_config=self.plugin.shard_config,
)
if self.policy_loss_fn.beta > 0:
with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
reference_action_log_probs = memory_efficient_logprob(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
shard_config=self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
- (reference_action_log_probs - action_log_probs)
- 1
)
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
action_mask_forward_micro_batch, dim=-1
)
else:
per_token_kl = 0.0
kl = None
loss, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
action_mask_forward_micro_batch,
loss_mask=loss_mask_forward_micro_batch,
total_effective_tokens_in_batch=total_effective_tokens_count,
)
self.booster.backward(loss, self.optimizer)
loss = all_reduce_mean(loss, self.plugin)
# Calculate accumulate value.
if kl is not None:
kl = all_reduce_mean(kl.mean(), self.plugin)
mean_kl.append(kl.data)
mean_loss.append(loss.data)
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1
and self.booster.plugin.stage_manager.is_last_stage()
and self.tp_rank == 0
and self.dp_rank == 0
):
reward = all_reduce_mean(reward.mean(), self.plugin)
format_acc = all_reduce_mean(format_acc.mean(), self.plugin)
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
advantages = all_reduce_mean(advantages.mean(), self.plugin)
response_length = all_reduce_mean(response_length.mean(), self.plugin)
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
if self.policy_loss_fn.beta > 0:
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
self.accum_advantages.add_(advantages.data)
self.accum_count += 1
if need_update:
self.optimizer.step()
self.optimizer.zero_grad()
self.global_step += 1
if self.lr_scheduler is not None:
self.lr_scheduler.step()
# no need to run all reduce as raw_train_batch_* are not splited across dp rank
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
self.effective_prompt_count = 0
self.effective_sample_count = 0
loss_scalar = self.accum_loss.item()
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
raw_batch_reward_mean = torch.cat(self.raw_train_batch_reward, dim=0).mean().cpu().item()
raw_batch_format_acc_mean = torch.cat(self.raw_train_batch_format_acc, dim=0).mean().cpu().item()
raw_batch_ans_acc_mean = torch.cat(self.raw_train_batch_ans_acc, dim=0).mean().cpu().item()
raw_batch_response_len = torch.cat(self.raw_train_batch_response_len, dim=0)
raw_batch_response_len_mean = raw_batch_response_len.mean().cpu().item()
overlength_samples_ratio = (
(raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item()
) # not an exact figure, but a close estimate
self.raw_train_batch_reward = []
self.raw_train_batch_format_acc = []
self.raw_train_batch_ans_acc = []
self.raw_train_batch_response_len = []
to_log_msg = [
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
f"Reward: {raw_batch_reward_mean:.4f}",
f"format Reward: {raw_batch_format_acc_mean:.4f}",
f"Acc Reward: {raw_batch_ans_acc_mean:.4f}",
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
f"Response Length: {raw_batch_response_len_mean:.4f}",
f"Sample_utilization: {sample_utilization:.4f}",
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
print("\n".join(to_log_msg))
metrics = {
"metrics/reward": raw_batch_reward_mean,
"metrics/format_acc": raw_batch_format_acc_mean,
"metrics/ans_acc": raw_batch_ans_acc_mean,
"metrics/response_length": raw_batch_response_len_mean,
"train/loss": self.accum_loss.item() / self.accum_count,
"train/advantages": self.accum_advantages.item() / self.accum_count,
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
"train/sample_utilization": sample_utilization,
"train/overlength_samples_ratio": overlength_samples_ratio,
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
}
if self.policy_loss_fn.beta > 0:
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
if self.wandb_run is not None:
self.wandb_run.log(metrics)
self.accum_loss.zero_()
self.accum_kl.zero_()
self.accum_advantages.zero_()
self.accum_count = 0
return loss_scalar
else:
return None
def state_dict(self):
self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap()
state_dict = model.state_dict()
return state_dict

View File

@ -0,0 +1,533 @@
import copy
import json
import os
import threading
import time
from typing import Any, Dict, Optional
import ray
import ray.util.collective as cc
import torch
import tqdm
import wandb
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
from coati.distributed.profiling_utils import CustomProfiler
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from ray.util.collective import allreduce
from ray.util.collective.types import ReduceOp
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer
from colossalai.utils import get_current_device
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
from coati.distributed.inference_backend import BACKEND_MAP
from coati.distributed.utils import pre_send, safe_append_to_jsonl_file
try:
from vllm import SamplingParams
except ImportError:
LLM = None
class BaseProducer:
def __init__(
self,
shared_sync_data_actor: SharedVariableActor,
shared_signal_actor: SharedVariableActor,
producer_idx: int,
num_producers: int,
num_consumer_procs: int,
num_episodes: int,
batch_size: int,
train_dataset_config: Dict[str, Any],
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
microbatch_size: int = 1,
backend: str = "transformers",
consumer_plugin_config: Dict[str, Any] = None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
grpo_config: Dict[str, Any] = None,
eval_save_dir: str = "./eval",
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
enable_profiling: bool = False,
):
self.producer_idx = producer_idx
self.num_producers = num_producers
self.num_consumer_procs = num_consumer_procs
self.num_episodes = num_episodes
self.batch_size = batch_size
self.microbatch_size = microbatch_size
assert batch_size % microbatch_size == 0
self.num_microbatches = batch_size // microbatch_size
self.latest_eval_step = -1
self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling)
# for async data and model sync
self.shared_sync_data_actor = shared_sync_data_actor
self.shared_signal_actor = shared_signal_actor
self.sync_model_thread_started = False
self.train_dataset_config = train_dataset_config
self.model_config = model_config
self.generate_config = generate_config
self.tokenizer_config = tokenizer_config
self.consumer_plugin_config = consumer_plugin_config
self.eval_interval = eval_interval
self.eval_save_dir = eval_save_dir
self.consumer_global_step = 0
self.producer_weight_version = 0
self.eval_mode = False
self.log_rollout_interval = log_rollout_interval
self.latest_rollout_log_step = -1
self.grpo_config = grpo_config
reward_model_kwargs = {
k: v
for k, v in grpo_config.items()
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
}
self.response_format_tags = grpo_config.get("response_format_tags", None)
if producer_idx == 0:
if os.path.exists(rollout_log_file):
raise ValueError(
f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name."
)
else:
os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True)
self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8")
if self.producer_idx == 0:
self.wandb_run = wandb.init(
project=project_name,
sync_tensorboard=False,
dir="./wandb",
name=run_name + "_eval",
group=wandb_group_name,
)
if os.path.exists(self.eval_save_dir) and self.eval_interval > 0:
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
# init tokenizer
if tokenizer_config is None:
tokenizer_path = model_config["path"]
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
else:
tokenizer_path = tokenizer_config.pop("path")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config)
self.tokenizer.padding_side = "left"
# init dataloader
train_dataset_path = train_dataset_config.pop("path")
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
self.train_dataloader = DataLoader(
self.train_dataset,
batch_size=microbatch_size,
sampler=DistributedSampler(
self.train_dataset,
num_replicas=num_producers,
rank=producer_idx,
shuffle=True,
drop_last=True,
seed=42,
),
num_workers=4,
drop_last=True,
collate_fn=collate_fn_grpo,
)
if grpo_config["reward_fn_type"] == "think_answer_tags":
self.evaluation_function = math_reward_fn
elif grpo_config["reward_fn_type"] == "boxed":
self.evaluation_function = boxed_math_reward_fn
elif grpo_config["reward_fn_type"] == "code":
self.evaluation_function = code_reward_fn
else:
raise ValueError(f"Unknown evaluation function type {grpo_config['reward_fn_type']}")
self.eval_dataset_config = eval_dataset_config
if self.eval_dataset_config is not None:
self.eval_dataloaders = {}
for eval_task_name in self.eval_dataset_config:
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")
eval_dataset = RawConversationDataset(
self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name]
)
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}")
self.eval_dataloaders[eval_task_name] = DataLoader(
eval_dataset,
batch_size=microbatch_size,
sampler=DistributedSampler(
eval_dataset,
num_replicas=num_producers,
rank=producer_idx,
shuffle=False,
drop_last=False,
seed=42,
),
collate_fn=collate_fn_grpo,
)
else:
print("No eval dataset provided, skip eval")
self.device = get_current_device()
self.reward_model = VerifiableReward(
reward_fns=[self.evaluation_function], # multiple reward functions can be added here
tokenizer=self.tokenizer,
tags=self.response_format_tags,
**reward_model_kwargs,
)
# init backend
if backend in BACKEND_MAP:
self.backend_cls = BACKEND_MAP[backend]
else:
raise ValueError(f"Unexpected backend {backend}")
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
self.state_dict_cpu = {i: None for i in range(self.consumer_pp_size)}
def init_collective_group(
self,
world_size: int,
rank: int,
backend: str = "nccl",
group_name: str = "default",
gloo_timeout: int = 3000000,
):
cc.init_collective_group(
world_size=world_size, rank=rank, backend=backend, group_name=group_name, gloo_timeout=gloo_timeout
)
print(f"[P{self.producer_idx}] Initialized {group_name} collective group", flush=True)
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
raise NotImplementedError
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
raise NotImplementedError
def loop(self) -> None:
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
num_valid_microbatches = num_update_per_episode * self.num_microbatches
print(
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}"
)
for episode in range(self.num_episodes):
self.train_dataloader.sampler.set_epoch(episode)
for i, batch in enumerate(self.train_dataloader):
self.profiler.log(f"train episode {episode} batch {i}")
if i >= num_valid_microbatches:
break
self.consumer_global_step = ray.get(self.shared_signal_actor.get_signal.remote()).get("global_step", 0)
# sync model first, as the model syncing runs in a separate thread, will not block the main thread
# sync model during inference, which takes less than 10s, so that the model can be updated immediately after inference
if episode != self.num_episodes - 1 or i != num_valid_microbatches - 1:
# don't sync model for last iteration
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False
):
self.model.llm.sleep() # revict KV_cache to avoid OOM
torch.cuda.empty_cache()
# sync model thread function
def sync_model_thread():
if self.consumer_pp_size > 1:
self.profiler.enter("sync_model")
for pp_idx in range(self.consumer_pp_size):
ray.get(
self.shared_signal_actor.set_signal.remote(
f"producer_{self.producer_idx}_pp_{pp_idx}", "ready_sync_model"
)
)
print(
f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
self.state_dict_cpu[pp_idx] = ray_broadcast_tensor_dict(
self.state_dict_cpu[pp_idx],
1,
device=torch.device("cpu"),
group_name=f"sync_model_producer_{self.producer_idx}_pp_{pp_idx}",
backend="gloo", # use gloo for CPU communication
pin_memory=True,
)
self.profiler.exit("sync_model")
else:
self.profiler.enter("sync_model")
ray.get(
self.shared_signal_actor.set_signal.remote(
f"producer_{self.producer_idx}", "ready_sync_model"
)
)
print(
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
time0 = time.time()
self.state_dict_cpu[0] = ray_broadcast_tensor_dict(
self.state_dict_cpu[0],
1,
device=torch.device("cpu"),
group_name=f"sync_model_producer_{self.producer_idx}",
backend="gloo", # use gloo for CPU communication
pin_memory=True,
)
self.profiler.log(f"Broadcast model state dict took {time.time() - time0:.2f} seconds")
self.profiler.exit("sync_model")
self.sync_model_thread_started = False
if not self.sync_model_thread_started and self.consumer_global_step != self.producer_weight_version:
# only sync model when the thread is not started and global step is changed
self.sync_model_thread_started = True
self.sync_model_thread = threading.Thread(target=sync_model_thread)
self.producer_weight_version = self.consumer_global_step
self.sync_model_thread.start()
torch.cuda.empty_cache()
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False
):
self.model.llm.wake_up()
if self.eval_interval > 0 and self.eval_dataset_config is not None:
if (
self.consumer_global_step - self.latest_eval_step >= self.eval_interval
and self.consumer_global_step > self.latest_eval_step
) or self.latest_eval_step == -1:
to_log_msg = {}
self.eval_mode = True
for eval_task_name in self.eval_dataloaders:
if self.producer_idx == 0:
print(
f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}"
)
eval_results = []
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
for eval_batch in tqdm.tqdm(
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
):
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
eval_results = eval_results + [
self.evaluation_function(
eval_outputs["input_ids"][m][n],
eval_outputs[
(
"test_cases"
if self.grpo_config["reward_fn_type"] == "code"
else "gt_answer"
)
][m],
eval_outputs["response_idx"][m][n],
tokenizer=self.tokenizer,
eval_mode=True,
tags=self.response_format_tags,
)
for m in range(eval_outputs["input_ids"].size(0))
for n in range(eval_outputs["input_ids"].size(1))
]
eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1])
eval_statistics_tensor[1] += len(eval_results)
allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_pg")
to_log_msg[f"eval/{eval_task_name}"] = (
eval_statistics_tensor[0].item() / eval_statistics_tensor[1].item()
)
if self.producer_idx == 0:
print(
f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}"
)
# save eval results
safe_append_to_jsonl_file(
os.path.join(
self.eval_save_dir,
f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl",
),
eval_results,
)
if self.producer_idx == 0:
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
self.eval_mode = False
self.latest_eval_step = self.consumer_global_step
self.profiler.enter("sleep")
while not (ray.get(self.shared_sync_data_actor.pickup_rollout_task.remote(self.microbatch_size))):
time.sleep(1)
self.profiler.exit("sleep")
self.profiler.enter("rollout")
self.profiler.log(f"rollout batch {i} episode {episode}")
# time.sleep(30) # simulate long inference time
outputs = self.rollout(**batch)
self.profiler.exit("rollout")
outputs["temperature"] = torch.tensor(
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device)
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
self.profiler.enter("calculate_reward")
if self.grpo_config["reward_fn_type"] == "code":
test_cases = []
for prompt_id in range(bs):
test_cases.extend([outputs["test_cases"][prompt_id]] * num_gen)
reward_model_output = self.reward_model(
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
test_cases=test_cases,
response_idx=outputs["response_idx"].view((-1, 2)),
)
else:
gt_answer = []
for prompt_id in range(bs):
gt_answer.extend([outputs["gt_answer"][prompt_id]] * num_gen)
reward_model_output = self.reward_model(
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
gt_answer=gt_answer,
response_idx=outputs["response_idx"].view((-1, 2)),
)
outputs["reward"] = (
torch.tensor([value[0] for value in reward_model_output])
.to(outputs["input_ids"].device)
.view((bs, num_gen, 1))
)
outputs["format_acc"] = (
torch.tensor([value[1] for value in reward_model_output])
.to(outputs["input_ids"].device)
.view((bs, num_gen, 1))
)
outputs["ans_acc"] = (
torch.tensor([value[2] for value in reward_model_output])
.to(outputs["input_ids"].device)
.view((bs, num_gen, 1))
)
if "gt_answer" in outputs:
outputs.pop("gt_answer")
if "test_cases" in outputs:
outputs.pop("test_cases")
self.profiler.exit("calculate_reward")
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs = pre_send(outputs)
outputs = {k: v.cpu() for k, v in outputs.items()}
self.profiler.enter("send_data")
ray.get(self.shared_sync_data_actor.append_data.remote(outputs))
self.profiler.exit("send_data")
if (i + 1) % self.num_microbatches == 0 and (
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
):
if not self.sync_model_thread_started:
# load state dict, note this should be done in the main thread to avoid race condition
for pp_idx in range(self.consumer_pp_size):
if self.state_dict_cpu[pp_idx] is not None and self.state_dict_cpu[pp_idx] != {}:
self.load_state_dict(self.state_dict_cpu[pp_idx])
# linear annealing for 1 episode, temperature from initial to 0.9
if episode <= 0:
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
"temperature"
] + ratio * 0.9
if isinstance(self.model, BACKEND_MAP["vllm"]):
self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
"temperature"
] + ratio * 0.9
def __del__(self):
self.profiler.close()
@ray.remote
class SimpleProducer(BaseProducer):
def __init__(
self,
shared_sync_data_actor: SharedVariableActor,
shared_signal_actor: SharedVariableActor,
producer_idx,
num_producers,
num_consumer_procs,
num_episodes,
batch_size,
train_dataset_config,
model_config,
generate_config,
tokenizer_config=None,
microbatch_size=1,
backend="transformers",
num_generations: int = 8,
consumer_plugin_config=None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
grpo_config: Dict[str, Any] = None,
eval_save_dir: str = "./eval",
eval_generation_config={},
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
enable_profiling: bool = False,
):
super().__init__(
shared_sync_data_actor,
shared_signal_actor,
producer_idx,
num_producers,
num_consumer_procs,
num_episodes,
batch_size,
train_dataset_config,
model_config,
generate_config,
tokenizer_config,
microbatch_size,
backend,
consumer_plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
grpo_config=grpo_config,
eval_save_dir=eval_save_dir,
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
enable_profiling=enable_profiling,
)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
self.eval_generation_config.update(eval_generation_config)
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
if self.producer_idx == 0 and not self.eval_mode:
if (
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
or self.latest_rollout_log_step == -1
):
new_record = (
json.dumps(
{
"train_step": self.consumer_global_step,
"rollout": self.tokenizer.batch_decode(
rollouts["input_ids"][:, 0], skip_special_tokens=True
),
}
)
+ "\n"
)
self.rollout_log_file.write(new_record)
self.rollout_log_file.flush()
self.latest_rollout_log_step = self.consumer_global_step
return rollouts
def __del__(self):
if self.producer_idx == 0:
self.wandb_run.finish()
if hasattr(self, "rollout_log_file"):
self.rollout_log_file.close()
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)

View File

@ -0,0 +1,369 @@
import argparse
import json
import os
import ray
import torch
from coati.distributed.launch_zero_bubble import launch_distributed
DEFAUT_SYSTEM_PROMPT = {
"think_answer_tags": "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n",
"boxed": "Please reason step by step, and put your final answer within \\boxed{}.",
"code": "You are a helpful assistant.",
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument(
"-ed",
"--eval-dataset",
type=str,
default=None,
help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \
For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \
The key is the task name, and the value is the path to the jsonl file",
)
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
# Distributed training parameters
parser.add_argument("-t", "--num-trainers", type=int, default=2)
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
parser.add_argument(
"-ibs",
"--inference-batch-size",
type=int,
default=64,
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
)
parser.add_argument(
"-imbs",
"--inference-microbatch-size",
type=int,
default=8,
help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.",
)
parser.add_argument(
"-tbs",
"--train-batch-size",
type=int,
default=32,
help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples",
)
parser.add_argument(
"-tMbs",
"--train-minibatch-size",
type=int,
default=8,
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
)
parser.add_argument(
"-tmbs",
"--train-microbatch-size",
type=int,
default=2,
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
)
parser.add_argument(
"-tp",
"--tensor-parallel-size",
type=int,
default=1,
help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-pp",
"--pipeline-parallel-size",
type=int,
default=1,
help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-zero",
"--zero-stage",
type=int,
default=0,
help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
)
parser.add_argument(
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
)
parser.add_argument(
"--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional"
)
# Sampling parameters
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
parser.add_argument("-temp", "--temperature", type=float, default=1.0, help="Temperature for sampling.")
parser.add_argument(
"-topk",
"--top-k",
type=int,
default=None,
help="Top k for sampling. Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-topp",
"--top-p",
type=float,
default=1.0,
help="Top p for sampling. Please check the generation arguments documentation for your backend.",
)
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.")
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
parser.add_argument(
"-ptp",
"--producer-tensor-parallel-size",
type=int,
default=1,
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
)
# GRPO parameters
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
parser.add_argument(
"-rt",
"--reward-type",
type=str,
default="think_answer_tags",
choices=["think_answer_tags", "boxed", "code"],
help="Reward type for GRPO.",
)
parser.add_argument(
"-ei",
"--eval-interval",
type=int,
default=100,
help="Interval for evaluation. Evaluate every ei training steps.",
)
parser.add_argument(
"-cbsl",
"--data_actor_buffer_size_limit",
type=int,
default=-1,
help="The approximate number of samples to keep in the consumer buffer. After this limit is reached, the producer will stop generating new samples and prioritize model sync until the consumer has processed some samples",
)
# Logging/Checkpointing parameters
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.")
parser.add_argument(
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
)
parser.add_argument(
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
)
parser.add_argument(
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
)
args = parser.parse_args()
if args.train_minibatch_size is None:
# Default settings: Using train batch size as mini batch size
args.train_minibatch_size = args.train_batch_size
if args.inference_batch_size is None:
# Default settings: Using train batch size as inference batch size, sync every inference model every train step
args.inference_batch_size = args.train_batch_size
assert (
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
and args.train_microbatch_size > 0
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
assert (
args.train_minibatch_size <= args.train_batch_size and args.train_batch_size % args.train_minibatch_size == 0
), "Train mini batch size must be less than or equals to train batch size and train batch size must be divisible by train mini batch size"
if args.master_address is None:
# Default settings: Using single machine
ray.init(
address="local",
namespace="ray-example",
runtime_env={
"env_vars": {
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
"TOKENIZERS_PARALLELISM": "false"
},
},
)
else:
# For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node
ray.init(
_node_ip_address=args.master_address,
namespace="ray-example",
_temp_dir=args.ray_dir,
runtime_env={
"env_vars": {
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
"TOKENIZERS_PARALLELISM": "false"
},
},
)
if args.top_k is None:
if args.backend == "transformers":
args.top_k = 50
elif args.backend == "vllm":
args.top_k = -1
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
if args.backend == "transformers":
inference_model_config.update(
dict(
use_flash_attention_2=True,
torch_dtype=torch.bfloat16,
)
)
generate_config.update(
dict(
max_length=args.max_new_tokens + args.max_prompt_tokens,
do_sample=True,
max_new_tokens=None,
early_stopping=False if args.reward_type == "think_answer_tags" else True,
stop_strings=["</answer>"] if args.reward_type == "think_answer_tags" else None,
)
)
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
elif args.backend == "vllm":
inference_model_config.update(
dict(
gpu_memory_utilization=0.7,
enforce_eager=True,
enable_chunked_prefill=True,
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
tensor_parallel_size=args.producer_tensor_parallel_size,
)
)
generate_config.update(
dict(
max_tokens=args.max_new_tokens, # max new tokens
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
include_stop_str_in_output=True,
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
)
)
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
else:
raise ValueError(f"Unsupported backend: {args.backend}")
if args.algo == "GRPO":
# Default Settings
grpo_config = {
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
"num_minibatch_during_rollout": 1, # number of mini batches to pop out from buffer and used for training during rollout of the producer after it syncs the model. Hint, set to a proper value close to the number of mini batches for training that takes roughly the same time as the rollout of the producer. A value that is too large or too small will cause bubble time on the trainer or the producer.
"beta": args.kl_coeff, # KL penalty coefficient
"loss_variation": "sample_level",
"reward_fn_type": args.reward_type,
"max_length": args.max_new_tokens + args.max_prompt_tokens,
"max_new_tokens": args.max_new_tokens,
"response_format_tags": (
{
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if args.reward_type == "think_answer_tags"
else None
),
}
elif args.algo == "DAPO":
# DAPO variant settings
grpo_config = {
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
"dynamic_batching": True,
"clip_eps_low": 0.2,
"clip_eps_high": 0.28,
"skip_threshold": 20.0,
"beta": 0, # no KL penalty for DAPO
"loss_variation": "token_level",
"soft_over_length_punishment": True,
"max_length": args.max_new_tokens + args.max_prompt_tokens,
"max_new_tokens": args.max_new_tokens,
"cache_length": min(1024, int(args.max_new_tokens / 4)),
"filter_truncated_response": True,
"reward_fn_type": args.reward_type,
"response_format_tags": (
{
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if args.reward_type == "think_answer_tags"
else None
),
}
else:
raise ValueError(f"Unsupported algorithm: {args.algo}")
if args.system_prompt is None:
# Default system prompt
args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type]
launch_distributed(
num_producers=args.num_inferencer,
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size),
num_consumer_procs=args.num_trainers,
num_episodes=args.num_episodes,
inference_batch_size=args.inference_batch_size,
inference_microbatch_size=args.inference_microbatch_size,
train_batch_size=args.train_batch_size,
train_minibatch_size=args.train_minibatch_size,
train_dataset_config={
"path": args.dataset,
"max_length": args.max_prompt_tokens,
"system_prompt": args.system_prompt,
},
inference_model_config=inference_model_config,
generate_config=generate_config,
num_generations=args.num_generations,
train_model_config=train_model_config,
grpo_config=grpo_config,
plugin_config={
"tp_size": args.tensor_parallel_size,
"pp_size": args.pipeline_parallel_size,
"microbatch_size": max(
1, args.train_microbatch_size // args.pipeline_parallel_size
), # microbatch size should be set to train_microbatch_size // pp_size
"zero_stage": args.zero_stage,
"max_norm": 1.0,
}, # for pp, tp
inference_backend=args.backend,
master_addr="localhost",
master_port=args.master_port,
core_algo=args.algo,
project_name=args.project,
save_interval=args.save_interval,
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
eval_dataset_config=(
{
k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt}
for k, v in json.loads(args.eval_dataset).items()
}
if args.eval_dataset
else None
),
eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
eval_generation_config=eval_generation_config,
log_rollout_interval=20,
rollout_save_dir=args.rollout_save_dir,
enable_profiling=args.enable_profiling,
data_actor_buffer_size_limit=args.data_actor_buffer_size_limit,
)