mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-29 14:30:40 +00:00
add code for zero-bubble implementation
This commit is contained in:
parent
b1f646c7e7
commit
509274c47e
@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import copy
|
||||
import ray
|
||||
import ray.util.collective as cc
|
||||
import torch
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
@ -30,11 +31,18 @@ def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str =
|
||||
obj = c10d._tensor_to_object(obj, size_tensor.item())
|
||||
return obj
|
||||
|
||||
|
||||
def ray_broadcast_tensor_dict(
|
||||
tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default"
|
||||
tensor_dict: Dict[str, torch.Tensor],
|
||||
src: int = 0,
|
||||
device=None,
|
||||
group_name: str = "default",
|
||||
backend: str = "nccl",
|
||||
offload_to_cpu: bool = False,
|
||||
pin_memory: bool = False,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
rank = cc.get_rank(group_name)
|
||||
if tensor_dict is None:
|
||||
tensor_dict = {}
|
||||
if rank == src:
|
||||
metadata = []
|
||||
for k, v in tensor_dict.items():
|
||||
@ -42,16 +50,103 @@ def ray_broadcast_tensor_dict(
|
||||
else:
|
||||
metadata = None
|
||||
metadata = ray_broadcast_object(metadata, src, device, group_name)
|
||||
if rank != src:
|
||||
out_dict = {}
|
||||
for k, shape, dtype in metadata:
|
||||
if rank == src:
|
||||
tensor = tensor_dict[k]
|
||||
if offload_to_cpu:
|
||||
tensor = tensor_dict[k].to(device)
|
||||
else:
|
||||
tensor = tensor_dict[k]
|
||||
else:
|
||||
tensor = torch.empty(shape, dtype=dtype, device=device)
|
||||
tensor = tensor_dict.get(k, torch.zeros(shape, dtype=dtype, device=device, pin_memory=pin_memory))
|
||||
if backend == "gloo" and dtype == torch.bfloat16:
|
||||
# Gloo does not support bfloat16, convert to float16
|
||||
tensor = tensor.view(torch.float16)
|
||||
cc.broadcast(tensor, src, group_name)
|
||||
if backend == "gloo" and dtype == torch.bfloat16:
|
||||
# Convert back to bfloat16 if it was converted to float16
|
||||
tensor = tensor.view(torch.bfloat16)
|
||||
if rank != src:
|
||||
out_dict[k] = tensor
|
||||
if rank == src:
|
||||
out_dict = tensor_dict
|
||||
return out_dict
|
||||
if offload_to_cpu:
|
||||
tensor_dict[k] = tensor.cpu()
|
||||
else:
|
||||
tensor_dict[k] = tensor
|
||||
return tensor_dict
|
||||
|
||||
|
||||
@ray.remote
|
||||
class SharedVariableActor:
|
||||
def __init__(self, number_of_readers: int = 0, buffer_size_limit: int = 1000):
|
||||
self.data_queue = []
|
||||
self.data_uid = 0
|
||||
self.number_of_readers = number_of_readers
|
||||
self.queue_size = 0
|
||||
self.signals = {}
|
||||
self.process_locks = {}
|
||||
self.signal_procs_meet_count = {}
|
||||
self.buffer_size_limit = buffer_size_limit
|
||||
|
||||
def pickup_rollout_task(self, num_tasks: int):
|
||||
"""
|
||||
use queue size to control whether producers should generating new rollouts or wait
|
||||
for consumer to consumer more data. if queue size is less than threshold,
|
||||
it means consumer is consuming data fast enough, so producers can generate new rollouts.
|
||||
if queue size is greater than threshold, it means consumer is consuming data slowly,
|
||||
so producers should wait for consumer to consume more data.
|
||||
|
||||
Any free producer can pick up the task to generate rollout then increase the queued_data_size
|
||||
to prevent other producer to pick up the task redundantly, Note it is not the real
|
||||
queue length as data may still be generating
|
||||
"""
|
||||
ret = False
|
||||
if self.queue_size < self.buffer_size_limit:
|
||||
ret = True
|
||||
self.queue_size += num_tasks
|
||||
return ret
|
||||
|
||||
def append_data(self, data):
|
||||
self.data_queue.append([self.data_uid, data, 0]) # [data_uid, data, access_count]
|
||||
self.data_uid += 1
|
||||
return True
|
||||
|
||||
def get_data(self, data_uid: int):
|
||||
# for multi-process data reading
|
||||
if not self.data_queue:
|
||||
# no data in the queue, return None
|
||||
return None
|
||||
to_pop_index = None
|
||||
ret = None
|
||||
for i, (uid, data, access_count) in enumerate(self.data_queue):
|
||||
if uid == data_uid:
|
||||
# found the data with the given uid
|
||||
self.data_queue[i][2] += 1
|
||||
ret = copy.deepcopy(data)
|
||||
if self.data_queue[i][2] == self.number_of_readers:
|
||||
to_pop_index = i
|
||||
break
|
||||
if to_pop_index is not None:
|
||||
# remove the data from the queue if it has been accessed by all readers
|
||||
self.data_queue.pop(to_pop_index)
|
||||
self.queue_size -= data["input_ids"].size(0)
|
||||
return ret
|
||||
|
||||
def acquire_process_lock(self, key: str):
|
||||
# atomic lock for process
|
||||
if key not in self.process_locks:
|
||||
self.process_locks[key] = 1 # locked
|
||||
return 0
|
||||
if self.process_locks[key] == 0:
|
||||
self.process_locks[key] = 1 # lock the process
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
def release_process_lock(self, key: str):
|
||||
# atomic unlock for process
|
||||
assert self.process_locks.get(key, 0) == 1, f"Releasing a process lock {key} that is not locked."
|
||||
self.process_locks[key] = 0
|
||||
|
||||
def set_signal(self, key: str, signal: str):
|
||||
self.signals[key] = signal
|
||||
|
||||
def get_signal(self):
|
||||
return self.signals
|
||||
|
@ -0,0 +1,306 @@
|
||||
import copy
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import ray
|
||||
|
||||
from .comm import SharedVariableActor
|
||||
from .zero_bubble.distributor import Distributor
|
||||
from .zero_bubble.grpo_consumer import GRPOConsumer
|
||||
from .zero_bubble.producer import SimpleProducer
|
||||
|
||||
ALGO_MAP = {"GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
|
||||
|
||||
|
||||
def get_jsonl_size_fast(path: str) -> int:
|
||||
with open(path) as f:
|
||||
lines = f.readlines()
|
||||
lines = [line for line in lines if line.strip()]
|
||||
return len(lines)
|
||||
|
||||
|
||||
def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:
|
||||
tp_size = plugin_config.get("tp_size", 1)
|
||||
pp_size = plugin_config.get("pp_size", 1)
|
||||
ep_size = plugin_config.get("ep_size", 1)
|
||||
sp_size = plugin_config.get("sp_size", 1)
|
||||
return n_procs // (tp_size * pp_size * ep_size * sp_size)
|
||||
|
||||
|
||||
def launch_distributed(
|
||||
num_producers: int,
|
||||
num_proc_per_producer: int,
|
||||
num_consumer_procs: int,
|
||||
num_episodes: int,
|
||||
inference_batch_size: int,
|
||||
inference_microbatch_size: int,
|
||||
train_batch_size: int,
|
||||
train_minibatch_size: int,
|
||||
train_dataset_config: Dict[str, Any],
|
||||
inference_model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
train_model_config: Dict[str, Any],
|
||||
grpo_config: Dict[str, Any],
|
||||
plugin_config: Dict[str, Any],
|
||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||
inference_backend: str = "transformers",
|
||||
num_generations: int = 8,
|
||||
master_addr: str = "localhost",
|
||||
master_port: int = 29500,
|
||||
core_algo: str = "GRPO",
|
||||
project_name: Optional[str] = None,
|
||||
save_interval: int = 100,
|
||||
save_dir: str = "./model",
|
||||
eval_dataset_config: Optional[Dict[str, Any]] = None,
|
||||
eval_interval: int = 100,
|
||||
eval_save_dir: Optional[str] = None,
|
||||
eval_generation_config: Optional[Dict[str, Any]] = None,
|
||||
log_rollout_interval: int = 20,
|
||||
rollout_save_dir: str = "./rollout",
|
||||
enable_profiling: bool = False,
|
||||
data_actor_buffer_size_limit: int = 0,
|
||||
):
|
||||
if core_algo not in ALGO_MAP:
|
||||
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
||||
else:
|
||||
core_consumer = ALGO_MAP.get(core_algo, GRPOConsumer)
|
||||
|
||||
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
|
||||
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
||||
if data_actor_buffer_size_limit <= 0:
|
||||
# use 2 times the train_minibatch_size as the default buffer size limit
|
||||
data_actor_buffer_size_limit = train_minibatch_size * train_dp_size * 2
|
||||
|
||||
dataset_path = train_dataset_config["path"]
|
||||
train_dataset_size = get_jsonl_size_fast(dataset_path)
|
||||
global_inference_batch_size = inference_batch_size * num_producers
|
||||
train_dataset_size = (train_dataset_size // global_inference_batch_size) * global_inference_batch_size
|
||||
|
||||
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
|
||||
wandb_group_name = str(uuid.uuid4())
|
||||
rollout_log_file = os.path.join(
|
||||
rollout_save_dir,
|
||||
f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl",
|
||||
)
|
||||
|
||||
# Attention: Ray use complex schedualing method that consider various factors including load-balancing.
|
||||
# when requesting resources, it is not guaranteed that the resource comes from a node with lower node it
|
||||
# this go against the design principle of our implementation, and we need to manually force the schedualing,
|
||||
# allocating the producer to nodes with lower node id and the consumer to the resouces from nodes with higher
|
||||
# node id. See the reference here: https://docs.ray.io/en/latest/ray-core/scheduling/index.html#nodeaffinityschedulingstrategy
|
||||
nodes = ray.nodes()
|
||||
|
||||
# every producer is associated with a data worker, data worker is responsible for moving data from the producer to all consumer
|
||||
shared_sync_data_actor = SharedVariableActor.remote(num_consumer_procs, data_actor_buffer_size_limit)
|
||||
# all producer and the consumer 0 share the same model actor, model actor only provide signal for model synchronization
|
||||
shared_signal_actor = SharedVariableActor.remote()
|
||||
|
||||
node_info = {
|
||||
node["NodeID"]: {
|
||||
"num_gpus": node["Resources"].get("GPU", 0),
|
||||
"address": node["NodeManagerAddress"],
|
||||
} # Default to 0 if no GPUs are available
|
||||
for node in nodes
|
||||
}
|
||||
gpu_to_node_id = []
|
||||
gpu_to_ip_address = []
|
||||
for node_id in node_info:
|
||||
for idx in range(int(node_info[node_id]["num_gpus"])):
|
||||
gpu_to_node_id.append(node_id)
|
||||
gpu_to_ip_address.append(node_info[node_id]["address"])
|
||||
print(node_info)
|
||||
|
||||
producer_procs = []
|
||||
for i in range(num_producers):
|
||||
node_id = gpu_to_node_id[0]
|
||||
producer_ip_address = gpu_to_ip_address[0]
|
||||
for _ in range(num_proc_per_producer):
|
||||
gpu_to_node_id.pop(0)
|
||||
gpu_to_ip_address.pop(0)
|
||||
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
|
||||
producer = SimpleProducer.options(num_gpus=num_proc_per_producer, num_cpus=4).remote(
|
||||
shared_sync_data_actor=shared_sync_data_actor,
|
||||
shared_signal_actor=shared_signal_actor,
|
||||
producer_idx=i,
|
||||
num_producers=num_producers,
|
||||
num_consumer_procs=num_consumer_procs,
|
||||
num_episodes=num_episodes,
|
||||
batch_size=inference_batch_size,
|
||||
train_dataset_config=train_dataset_config,
|
||||
model_config=inference_model_config,
|
||||
generate_config=generate_config,
|
||||
tokenizer_config=tokenizer_config,
|
||||
microbatch_size=inference_microbatch_size,
|
||||
backend=inference_backend,
|
||||
num_generations=num_generations,
|
||||
consumer_plugin_config=plugin_config,
|
||||
eval_dataset_config=eval_dataset_config,
|
||||
eval_interval=eval_interval,
|
||||
grpo_config=grpo_config,
|
||||
eval_save_dir=eval_save_dir,
|
||||
eval_generation_config=eval_generation_config,
|
||||
project_name=project_name,
|
||||
run_name=run_name,
|
||||
wandb_group_name=wandb_group_name,
|
||||
log_rollout_interval=log_rollout_interval,
|
||||
rollout_log_file=rollout_log_file,
|
||||
enable_profiling=enable_profiling,
|
||||
)
|
||||
producer_procs.append(producer)
|
||||
# ray.get([p.setup.remote() for p in producer_procs])
|
||||
generate_config_consumer = copy.deepcopy(generate_config)
|
||||
generate_config_consumer.update(
|
||||
dict(
|
||||
backend=inference_backend,
|
||||
)
|
||||
)
|
||||
consumer_master_ip_address = gpu_to_ip_address[0]
|
||||
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
|
||||
consumer_procs = []
|
||||
if num_consumer_procs <= 1:
|
||||
raise ValueError("Number of consumer processes should be greater than 1 for async rl training.")
|
||||
for i in range(num_consumer_procs):
|
||||
node_id = gpu_to_node_id[0]
|
||||
consumer_ip_address = gpu_to_ip_address[0]
|
||||
gpu_to_node_id.pop(0)
|
||||
gpu_to_ip_address.pop(0)
|
||||
print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}")
|
||||
consumer = core_consumer.options(num_gpus=1, num_cpus=4).remote(
|
||||
shared_sync_data_actor=shared_sync_data_actor,
|
||||
shared_signal_actor=shared_signal_actor,
|
||||
num_producers=num_producers,
|
||||
num_episodes=num_episodes,
|
||||
rank=i,
|
||||
world_size=num_consumer_procs,
|
||||
master_addr=consumer_master_ip_address,
|
||||
master_port=master_port,
|
||||
train_dataset_size=train_dataset_size,
|
||||
batch_size=train_batch_size,
|
||||
model_config=train_model_config,
|
||||
plugin_config=plugin_config,
|
||||
minibatch_size=train_minibatch_size,
|
||||
generate_config=generate_config_consumer,
|
||||
grpo_config=grpo_config,
|
||||
num_generations=num_generations,
|
||||
save_interval=save_interval,
|
||||
save_dir=save_dir,
|
||||
project_name=project_name,
|
||||
run_name=run_name,
|
||||
wandb_group_name=wandb_group_name,
|
||||
enable_profiling=enable_profiling,
|
||||
)
|
||||
consumer_procs.append(consumer)
|
||||
|
||||
distributor_procs = []
|
||||
for i in range(num_producers):
|
||||
distributor_procs.append(
|
||||
Distributor.options(num_cpus=2).remote(
|
||||
i,
|
||||
plugin_config.get("pp_size", 1),
|
||||
num_producers,
|
||||
shared_signal_actor,
|
||||
enable_profiling=enable_profiling,
|
||||
)
|
||||
)
|
||||
print("=================== All processes are created, starting setup torch DDP ===================", flush=True)
|
||||
ray.get([p.setup.remote() for p in consumer_procs])
|
||||
print(
|
||||
"=================== All processes are setup, starting initialize communication groups ===================",
|
||||
flush=True,
|
||||
)
|
||||
remote_refs = []
|
||||
# Initialize consumer communication group
|
||||
for i, p in enumerate(consumer_procs):
|
||||
remote_refs.append(p.init_collective_group.remote(num_consumer_procs, i, "gloo", f"consumer_pg"))
|
||||
ray.get(remote_refs)
|
||||
remote_refs = []
|
||||
# Initialize producer communication group
|
||||
for i, p in enumerate(producer_procs):
|
||||
remote_refs.append(p.init_collective_group.remote(num_producers, i, "nccl", f"producer_pg"))
|
||||
ray.get(remote_refs)
|
||||
remote_refs = []
|
||||
# Initialize distributor communication group
|
||||
for i, p in enumerate(distributor_procs):
|
||||
remote_refs.append(p.init_collective_group.remote(num_producers, i, "gloo", f"distributor_pg"))
|
||||
ray.get(remote_refs)
|
||||
remote_refs = []
|
||||
# Initialize sync model communication group between consumer and sync model actor
|
||||
# As per tested, gloo do not support nested initialization, so we need to initialize all participants in the same group in the same ray.get call.
|
||||
consumer_pp = plugin_config.get("pp_size", 1)
|
||||
for i, p in enumerate(consumer_procs):
|
||||
consumer_ddp_config = ray.get(p.get_ddp_config.remote())
|
||||
if consumer_pp > 1:
|
||||
if consumer_ddp_config["tp_rank"] == 0 and consumer_ddp_config["dp_rank"] == 0:
|
||||
pp_rank = consumer_ddp_config["pp_rank"]
|
||||
remote_refs.append(
|
||||
p.init_collective_group.remote(
|
||||
num_producers + 1,
|
||||
0,
|
||||
backend="gloo",
|
||||
group_name=f"sync_model_consumer_pp_{pp_rank}",
|
||||
gloo_timeout=3000000,
|
||||
)
|
||||
)
|
||||
for distributor_id, p_distributor in enumerate(distributor_procs):
|
||||
remote_refs.append(
|
||||
p_distributor.init_collective_group.remote(
|
||||
num_producers + 1,
|
||||
1 + distributor_id,
|
||||
backend="gloo",
|
||||
group_name=f"sync_model_consumer_pp_{pp_rank}",
|
||||
gloo_timeout=3000000,
|
||||
)
|
||||
)
|
||||
ray.get(remote_refs)
|
||||
remote_refs = []
|
||||
else:
|
||||
if i == 0:
|
||||
remote_refs.append(
|
||||
p.init_collective_group.remote(
|
||||
num_producers + 1, 0, backend="gloo", group_name=f"sync_model_consumer", gloo_timeout=3000000
|
||||
)
|
||||
)
|
||||
for distributor_id, p_distributor in enumerate(distributor_procs):
|
||||
remote_refs.append(
|
||||
p_distributor.init_collective_group.remote(
|
||||
num_producers + 1,
|
||||
1 + distributor_id,
|
||||
backend="gloo",
|
||||
group_name=f"sync_model_consumer",
|
||||
gloo_timeout=3000000,
|
||||
)
|
||||
)
|
||||
ray.get(remote_refs)
|
||||
remote_refs = []
|
||||
# Initialize sync model communication group between producer and sync model actor
|
||||
for i, p in enumerate(producer_procs):
|
||||
if consumer_pp > 1:
|
||||
for pp_rank in range(consumer_pp):
|
||||
remote_refs.append(
|
||||
p.init_collective_group.remote(
|
||||
2, 0, backend="gloo", group_name=f"sync_model_producer_{i}_pp_{pp_rank}", gloo_timeout=3000000
|
||||
)
|
||||
)
|
||||
remote_refs.append(
|
||||
distributor_procs[i].init_collective_group.remote(
|
||||
2, 1, backend="gloo", group_name=f"sync_model_producer_{i}_pp_{pp_rank}", gloo_timeout=3000000
|
||||
)
|
||||
)
|
||||
ray.get(remote_refs)
|
||||
remote_refs = []
|
||||
else:
|
||||
remote_refs.append(
|
||||
p.init_collective_group.remote(
|
||||
2, 0, backend="gloo", group_name=f"sync_model_producer_{i}", gloo_timeout=3000000
|
||||
)
|
||||
)
|
||||
remote_refs.append(
|
||||
distributor_procs[i].init_collective_group.remote(
|
||||
2, 1, backend="gloo", group_name=f"sync_model_producer_{i}", gloo_timeout=3000000
|
||||
)
|
||||
)
|
||||
ray.get(remote_refs)
|
||||
remote_refs = []
|
||||
print("=================== All processes are set up, starting loop ===================", flush=True)
|
||||
ray.get([p.loop.remote() for p in (producer_procs + consumer_procs + distributor_procs)])
|
@ -0,0 +1,347 @@
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import ray
|
||||
import ray.util.collective as cc
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.distributed.profiling_utils import CustomProfiler
|
||||
from tqdm import tqdm
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
|
||||
from coati.distributed.utils import bind_batch, post_recv, unbind_batch
|
||||
|
||||
|
||||
class BaseConsumer:
|
||||
def __init__(
|
||||
self,
|
||||
shared_sync_data_actor: SharedVariableActor,
|
||||
shared_signal_actor: SharedVariableActor,
|
||||
num_producers: int,
|
||||
num_episodes: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
master_addr: str,
|
||||
master_port: int,
|
||||
train_dataset_size: int,
|
||||
batch_size: int,
|
||||
model_config: Dict[str, Any],
|
||||
plugin_config: Dict[str, Any],
|
||||
minibatch_size: int = 1,
|
||||
save_interval: int = 100,
|
||||
save_dir: str = "./model",
|
||||
enable_profiling: bool = False,
|
||||
):
|
||||
self.num_producers = num_producers
|
||||
self.num_episodes = num_episodes
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.master_addr = master_addr
|
||||
self.master_port = master_port
|
||||
self.train_dataset_size = train_dataset_size
|
||||
self.received_prompts = 0
|
||||
self.batch_size = batch_size
|
||||
self.minibatch_size = minibatch_size
|
||||
self.save_interval = save_interval
|
||||
self.save_dir = save_dir
|
||||
self.enable_profiling = enable_profiling
|
||||
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
|
||||
self.num_microbatches = batch_size // minibatch_size
|
||||
self.data_uid = 0
|
||||
self.sync_model_thread_started = False
|
||||
|
||||
self.model_config = model_config
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
self.device = get_current_device()
|
||||
self.lr_scheduler = None
|
||||
|
||||
self.shared_sync_data_actor = shared_sync_data_actor
|
||||
self.shared_signal_actor = shared_signal_actor
|
||||
self.state_dict_cpu = {}
|
||||
|
||||
def setup(self) -> None:
|
||||
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
||||
|
||||
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
|
||||
if (
|
||||
self.plugin_config.get("pp_size", 1) > 1
|
||||
and "num_microbatches" not in self.plugin_config
|
||||
and "microbatch_size" not in self.plugin_config
|
||||
):
|
||||
plugin_config["microbatch_size"] = max(1, self.minibatch_size // plugin_config.get("pp_size", 1))
|
||||
plugin_config.update(self.plugin_config)
|
||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||
self.booster = Booster(plugin=self.plugin)
|
||||
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
||||
self.tp_rank = dist.get_rank(self.plugin.tp_group)
|
||||
self.pp_rank = dist.get_rank(self.plugin.pp_group)
|
||||
|
||||
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
||||
self.tp_size = dist.get_world_size(self.plugin.tp_group)
|
||||
self.pp_size = dist.get_world_size(self.plugin.pp_group)
|
||||
|
||||
self.buffer = []
|
||||
self.recv_cnt = 0
|
||||
self.profiler = CustomProfiler(f"C{self.rank}", disabled=not self.enable_profiling)
|
||||
|
||||
def get_ddp_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the DDP configuration for the consumer.
|
||||
This method is used to get the DDP configuration for the consumer.
|
||||
"""
|
||||
return {
|
||||
"dp_size": self.dp_size,
|
||||
"tp_size": self.tp_size,
|
||||
"pp_size": self.pp_size,
|
||||
"dp_rank": self.dp_rank,
|
||||
"tp_rank": self.tp_rank,
|
||||
"pp_rank": self.pp_rank,
|
||||
"world_size": self.world_size,
|
||||
"rank": self.rank,
|
||||
}
|
||||
|
||||
def init_collective_group(
|
||||
self,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
backend: str = "nccl",
|
||||
group_name: str = "default",
|
||||
gloo_timeout: int = 3000000,
|
||||
):
|
||||
cc.init_collective_group(
|
||||
world_size=world_size, rank=rank, backend=backend, group_name=group_name, gloo_timeout=gloo_timeout
|
||||
)
|
||||
print(f"[C{self.rank}] Initialized {group_name} collective group", flush=True)
|
||||
|
||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def step(self, **kwargs) -> Optional[float]:
|
||||
raise NotImplementedError
|
||||
|
||||
def prepare_mini_batch(self, effective_group_to_raw_group_mapping: Dict[int, int]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Prepare a mini-batch from the effective group to raw group mapping.
|
||||
This method is used to create a mini-batch for training.
|
||||
"""
|
||||
batches = [
|
||||
self.buffer[effective_group_to_raw_group_mapping[i]]
|
||||
for i in range(self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size)
|
||||
]
|
||||
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
|
||||
# each mini-batch use the first self.dp_size * minibatch_size effective samples
|
||||
raw_mini_batches = self.buffer[
|
||||
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
|
||||
] # include the last effective sample
|
||||
raw_mini_batches_metric_dict = {
|
||||
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
|
||||
}
|
||||
batch = bind_batch([t[0] for t in batches])
|
||||
batch = post_recv(batch)
|
||||
return batch, raw_mini_batches_metric_dict
|
||||
|
||||
def calculate_effective_group_to_raw_group_mapping(self):
|
||||
effective_group_to_raw_group_mapping = {}
|
||||
for buffer_idx in range(len(self.buffer)):
|
||||
if self.buffer[buffer_idx][0] is not None:
|
||||
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
|
||||
return effective_group_to_raw_group_mapping
|
||||
|
||||
def loop(self) -> None:
|
||||
print(f"Consumer{self.rank}, nmb: {self.num_microbatches}")
|
||||
for episode in range(self.num_episodes):
|
||||
with tqdm(
|
||||
range(self.train_dataset_size),
|
||||
desc=f"Episode {episode} with rollout step(s)",
|
||||
disable=self.rank != 0,
|
||||
) as pbar:
|
||||
while self.received_prompts < self.train_dataset_size:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
effective_group_to_raw_group_mapping = {}
|
||||
self.profiler.enter(f"recv_data")
|
||||
while len(effective_group_to_raw_group_mapping) < self.dp_size * self.minibatch_size:
|
||||
# receive data from producers
|
||||
raw_batch = ray.get(
|
||||
self.shared_sync_data_actor.get_data.remote(self.data_uid)
|
||||
) # get the first queued data
|
||||
while raw_batch is None:
|
||||
self.profiler.log(f"No data received by consumer {self.rank}, skipping")
|
||||
print(
|
||||
f"[T{dist.get_rank()}] No data received by consumer {self.rank}, skipping. Consider increasing the data actor buffer limit"
|
||||
)
|
||||
time.sleep(1)
|
||||
raw_batch = ray.get(self.shared_sync_data_actor.get_data.remote(self.data_uid))
|
||||
continue
|
||||
self.data_uid += 1
|
||||
raw_batch = {k: v.to(self.device) for k, v in raw_batch.items()}
|
||||
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
|
||||
# we need to calculate the metrics before filtering here for logging
|
||||
# [batch_size, num_generations] -> [batch_size]
|
||||
reward = raw_batch["reward"][:, :, 0]
|
||||
format_acc = raw_batch["format_acc"][:, :, 0]
|
||||
ans_acc = raw_batch["ans_acc"][:, :, 0]
|
||||
response_len = (
|
||||
raw_batch["response_idx"][:, :, 1] - raw_batch["response_idx"][:, :, 0] + 1
|
||||
).type(torch.float32)
|
||||
effective_group_mask = None
|
||||
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
|
||||
# filter the group based on the reward and accuracy
|
||||
group_ans_acc_mean = ans_acc.mean(dim=1)
|
||||
effective_group_mask = torch.logical_and(
|
||||
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
|
||||
)
|
||||
|
||||
raw_batch = unbind_batch(raw_batch) # List[Dict[str, torch.Tensor]]
|
||||
self.received_prompts += len(raw_batch)
|
||||
pbar.update(len(raw_batch))
|
||||
for group_idx, group_with_reward in enumerate(raw_batch):
|
||||
self.buffer.append(
|
||||
[
|
||||
(
|
||||
group_with_reward
|
||||
if effective_group_mask is None or effective_group_mask[group_idx]
|
||||
else None
|
||||
),
|
||||
reward[group_idx],
|
||||
format_acc[group_idx],
|
||||
ans_acc[group_idx],
|
||||
response_len[group_idx],
|
||||
]
|
||||
)
|
||||
if effective_group_mask is not None:
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
|
||||
)
|
||||
# mapping the effective group to the raw group for indexing
|
||||
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
|
||||
)
|
||||
self.profiler.exit(f"recv_data")
|
||||
need_sync_model = False
|
||||
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
|
||||
# after we have enough effective groups, we can start training
|
||||
# on each dp_rank, we use minibatch_size effective samples to form a batch
|
||||
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
|
||||
effective_group_to_raw_group_mapping
|
||||
)
|
||||
self.profiler.enter("step")
|
||||
loss = self.step(pbar, **batch, **raw_mini_batches_metric_dict)
|
||||
self.profiler.exit("step")
|
||||
self.buffer = self.buffer[
|
||||
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||
]
|
||||
# recalculate the effective group to raw group mapping
|
||||
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
|
||||
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
|
||||
assert (
|
||||
len(effective_group_to_raw_group_mapping)
|
||||
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
|
||||
)
|
||||
# cc.barrier(group_name="consumer_pg")
|
||||
if loss is not None:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
need_sync_model = True
|
||||
ray.get(self.shared_signal_actor.set_signal.remote("global_step", self.global_step + 1))
|
||||
if need_sync_model and (
|
||||
(self.global_step + 1) % self.save_interval == 0
|
||||
or self.received_prompts >= self.train_dataset_size
|
||||
):
|
||||
if self.rank == 0:
|
||||
print(f"Start saving policy model at step {self.global_step + 1}.")
|
||||
save_path = os.path.join(
|
||||
self.save_dir, f"modeling-episode-{episode}-step-{self.global_step + 1}"
|
||||
)
|
||||
self.booster.save_model(self.policy_model, save_path, shard=True)
|
||||
if self.rank == 0:
|
||||
print(f"Saved model checkpoint at step {self.global_step + 1} in folder {save_path}")
|
||||
|
||||
if need_sync_model and (
|
||||
episode != self.num_episodes - 1 or self.received_prompts != self.train_dataset_size
|
||||
):
|
||||
|
||||
def sync_model_thread():
|
||||
# sync model weights to all producers, if no model update or it is the last training step, skip syncing
|
||||
if self.pp_size > 1:
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}"
|
||||
)
|
||||
else:
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}")
|
||||
torch.cuda.empty_cache()
|
||||
if self.pp_size > 1:
|
||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||
self.profiler.enter("sync_model")
|
||||
ray.get(
|
||||
self.shared_signal_actor.set_signal.remote(
|
||||
f"consumer_pp_{self.pp_rank}", "ready_sync_model"
|
||||
)
|
||||
)
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}"
|
||||
)
|
||||
ray_broadcast_tensor_dict(
|
||||
self.state_dict_cpu,
|
||||
src=0,
|
||||
device=torch.device("cpu"),
|
||||
group_name=f"sync_model_consumer_pp_{self.pp_rank}",
|
||||
backend="gloo",
|
||||
)
|
||||
self.profiler.exit("sync_model")
|
||||
else:
|
||||
if self.rank == 0:
|
||||
self.profiler.enter("sync_model")
|
||||
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "ready_sync_model"))
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}")
|
||||
ray_broadcast_tensor_dict(
|
||||
self.state_dict_cpu,
|
||||
src=0,
|
||||
device=torch.device("cpu"),
|
||||
group_name="sync_model_consumer",
|
||||
backend="gloo",
|
||||
)
|
||||
self.profiler.exit("sync_model")
|
||||
|
||||
if not self.sync_model_thread_started:
|
||||
# only sync model when the thread is not started and no other thread is broadcasting
|
||||
self.sync_model_thread_started = True
|
||||
state_dict_ = self.state_dict()
|
||||
if (self.pp_size > 1 and self.tp_rank == 0 and self.dp_rank == 0) or (
|
||||
self.pp_size == 1 and self.rank == 0
|
||||
):
|
||||
if len(self.state_dict_cpu) == 0:
|
||||
# use pinned memory to speed up the transfer
|
||||
self.state_dict_cpu = {k: v.cpu().pin_memory() for k, v in state_dict_.items()}
|
||||
torch.cuda.synchronize()
|
||||
for k, v in state_dict_.items():
|
||||
self.state_dict_cpu[k].copy_(v, non_blocking=True)
|
||||
torch.cuda.synchronize()
|
||||
cc.barrier(
|
||||
group_name="consumer_pg"
|
||||
) # to make sure all ranks have state dict offloaded to CPU before starting the thread
|
||||
time_before_starting_thread = time.time()
|
||||
threading.Thread(target=sync_model_thread).start()
|
||||
# sync_model_thread()
|
||||
self.profiler.log(
|
||||
f"Sync model, took {time.time() - time_before_starting_thread:.2f} seconds"
|
||||
)
|
||||
self.sync_model_thread_started = False
|
||||
# ray.get(self.shared_signal_actor.release_process_lock.remote("broadcasting_lock"))
|
||||
self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
self.received_prompts = 0
|
||||
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "terminate"))
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "profiler"):
|
||||
self.profiler.close()
|
@ -0,0 +1,108 @@
|
||||
import time
|
||||
|
||||
import ray
|
||||
import ray.util.collective as cc
|
||||
import torch
|
||||
from coati.distributed.profiling_utils import CustomProfiler
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
|
||||
|
||||
|
||||
@ray.remote
|
||||
class Distributor:
|
||||
def __init__(
|
||||
self,
|
||||
distributor_id,
|
||||
consumer_pp_size,
|
||||
num_producers,
|
||||
shared_signal_actor: SharedVariableActor,
|
||||
enable_profiling: bool = True,
|
||||
):
|
||||
self.distributor_id = distributor_id
|
||||
self.consumer_pp_size = consumer_pp_size
|
||||
self.state_dict_cpu = {}
|
||||
self.num_producers = num_producers
|
||||
self.shared_signal_actor = shared_signal_actor
|
||||
self.device = get_current_device()
|
||||
self.profiler = CustomProfiler(f"D{self.distributor_id}", disabled=not enable_profiling)
|
||||
|
||||
def init_collective_group(
|
||||
self,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
backend: str = "nccl",
|
||||
group_name: str = "default",
|
||||
gloo_timeout: int = 3000000,
|
||||
):
|
||||
cc.init_collective_group(
|
||||
world_size=world_size, rank=rank, backend=backend, group_name=group_name, gloo_timeout=gloo_timeout
|
||||
)
|
||||
print(f"[D] Initialized {group_name} collective group", flush=True)
|
||||
|
||||
def loop(self):
|
||||
while True:
|
||||
time.sleep(1)
|
||||
signal = ray.get(self.shared_signal_actor.get_signal.remote())
|
||||
if self.consumer_pp_size > 1:
|
||||
for i in range(self.consumer_pp_size):
|
||||
if signal.get(f"consumer_pp_{i}", None) == "ready_sync_model":
|
||||
self.profiler.enter(f"sync_model_consumer_pp_{i}")
|
||||
cc.barrier(group_name="distributor_pg")
|
||||
ray.get(self.shared_signal_actor.set_signal.remote(f"consumer_pp_{i}", "not_ready_sync_model"))
|
||||
# Broadcast the model state dict from consumer to shared variable actor
|
||||
self.state_dict_cpu[i] = ray_broadcast_tensor_dict(
|
||||
None,
|
||||
0,
|
||||
device=torch.device("cpu"),
|
||||
group_name=f"sync_model_consumer_pp_{i}",
|
||||
backend="gloo",
|
||||
)
|
||||
self.profiler.exit(f"sync_model_consumer_pp_{i}")
|
||||
for i in range(self.consumer_pp_size):
|
||||
if signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model":
|
||||
self.profiler.enter(f"sync_model_producer_{self.distributor_id}_pp_{i}")
|
||||
# Broadcast the model state dict to all producers
|
||||
ray.get(
|
||||
self.shared_signal_actor.set_signal.remote(
|
||||
f"producer_{self.distributor_id}_pp_{i}", "not_ready_sync_model"
|
||||
)
|
||||
)
|
||||
ray_broadcast_tensor_dict(
|
||||
self.state_dict_cpu[i],
|
||||
1,
|
||||
device=torch.device("cpu"),
|
||||
group_name=f"sync_model_producer_{self.distributor_id}_pp_{i}",
|
||||
backend="gloo",
|
||||
)
|
||||
self.profiler.exit(f"sync_model_producer_{self.distributor_id}_pp_{i}")
|
||||
else:
|
||||
if signal.get("consumer", None) == "ready_sync_model":
|
||||
self.profiler.enter("sync_model_consumer")
|
||||
cc.barrier(group_name="distributor_pg")
|
||||
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "not_ready_sync_model"))
|
||||
# Broadcast the model state dict from consumer to shared variable actor
|
||||
self.state_dict_cpu = ray_broadcast_tensor_dict(
|
||||
None, 0, device=torch.device("cpu"), group_name="sync_model_consumer", backend="gloo"
|
||||
)
|
||||
self.profiler.exit("sync_model_consumer")
|
||||
if signal.get(f"producer_{self.distributor_id}", None) == "ready_sync_model":
|
||||
self.profiler.enter(f"sync_model_producer_{self.distributor_id}")
|
||||
# Broadcast the model state dict to all producers
|
||||
ray.get(
|
||||
self.shared_signal_actor.set_signal.remote(
|
||||
f"producer_{self.distributor_id}", "not_ready_sync_model"
|
||||
)
|
||||
)
|
||||
ray_broadcast_tensor_dict(
|
||||
self.state_dict_cpu,
|
||||
1,
|
||||
device=torch.device("cpu"),
|
||||
group_name=f"sync_model_producer_{self.distributor_id}",
|
||||
backend="gloo",
|
||||
)
|
||||
self.profiler.exit(f"sync_model_producer_{self.distributor_id}")
|
||||
if signal.get("consumer", None) == "terminate":
|
||||
self.profiler.log("terminate sync model worker")
|
||||
break
|
@ -0,0 +1,498 @@
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Optional
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import wandb
|
||||
from coati.distributed.comm import SharedVariableActor
|
||||
from coati.distributed.zero_bubble.consumer import BaseConsumer
|
||||
from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.utils import memory_efficient_logprob
|
||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
@ray.remote
|
||||
class GRPOConsumer(BaseConsumer):
|
||||
def __init__(
|
||||
self,
|
||||
shared_sync_data_actor: SharedVariableActor,
|
||||
shared_signal_actor: SharedVariableActor,
|
||||
num_producers,
|
||||
num_episodes,
|
||||
rank,
|
||||
world_size,
|
||||
master_addr,
|
||||
master_port,
|
||||
train_dataset_size,
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
minibatch_size=1,
|
||||
num_generations=8,
|
||||
generate_config=None,
|
||||
grpo_config={},
|
||||
save_interval: int = 100,
|
||||
save_dir="./model",
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
wandb_group_name: str = None,
|
||||
enable_profiling: bool = False,
|
||||
):
|
||||
print(f"Using GRPO config: {grpo_config}")
|
||||
if (
|
||||
plugin_config.get("pp_size", 1) > 1
|
||||
and "num_microbatches" not in plugin_config
|
||||
and "microbatch_size" not in plugin_config
|
||||
):
|
||||
plugin_config["microbatch_size"] = max(
|
||||
1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1)
|
||||
)
|
||||
super().__init__(
|
||||
shared_sync_data_actor,
|
||||
shared_signal_actor,
|
||||
num_producers,
|
||||
num_episodes,
|
||||
rank,
|
||||
world_size,
|
||||
master_addr,
|
||||
master_port,
|
||||
train_dataset_size,
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
minibatch_size,
|
||||
save_interval=save_interval,
|
||||
save_dir=save_dir,
|
||||
enable_profiling=enable_profiling,
|
||||
)
|
||||
path = model_config.pop("path")
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.policy_model.train()
|
||||
self.policy_model.gradient_checkpointing_enable()
|
||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
||||
self.accum_loss = torch.zeros(1, device=self.device)
|
||||
self.accum_kl = torch.zeros(1, device=self.device)
|
||||
self.accum_advantages = torch.zeros(1, device=self.device)
|
||||
self.raw_train_batch_reward = []
|
||||
self.raw_train_batch_format_acc = []
|
||||
self.raw_train_batch_ans_acc = []
|
||||
self.raw_train_batch_response_len = []
|
||||
self.accum_count = 0
|
||||
self.generate_config = generate_config
|
||||
self.grpo_config = grpo_config
|
||||
self.project_name = project_name
|
||||
self.effective_sample_count = 0
|
||||
self.effective_prompt_count = 0
|
||||
self.project_name = project_name
|
||||
self.run_name = run_name
|
||||
self.wandb_group_name = wandb_group_name
|
||||
|
||||
self.policy_loss_fn = PolicyLoss(
|
||||
clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
|
||||
clip_eps_high=grpo_config.get("clip_eps_high", 0.2),
|
||||
beta=grpo_config.get("beta", 0.01),
|
||||
loss_variation=grpo_config.get("loss_variation", "sample_level"),
|
||||
)
|
||||
|
||||
# Reference model is initialized from policy model.
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.reference_model.eval()
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.num_generations = num_generations
|
||||
self.filter_range = grpo_config.get("filter_range", None)
|
||||
if self.filter_range is not None:
|
||||
assert len(self.filter_range) == 2, "Filter range should have 2 values."
|
||||
|
||||
self.filter_truncated_response = grpo_config.get("filter_truncated_response", False)
|
||||
if self.filter_truncated_response:
|
||||
self.max_length = 0
|
||||
if "max_tokens" in self.generate_config:
|
||||
self.max_length = self.generate_config["max_tokens"]
|
||||
elif "max_new_tokens" in self.generate_config:
|
||||
self.max_length = self.generate_config["max_new_tokens"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
||||
)
|
||||
# Initialize verifiable reward.
|
||||
grpo_config.get("response_format_tags", None)
|
||||
self.global_step = 0
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||
self.plugin.pp_size > 1
|
||||
and self.booster.plugin.stage_manager.is_last_stage()
|
||||
and self.tp_rank == 0
|
||||
and self.dp_rank == 0
|
||||
):
|
||||
self.wandb_run = wandb.init(
|
||||
project=self.project_name,
|
||||
sync_tensorboard=False,
|
||||
dir="./wandb",
|
||||
name=self.run_name,
|
||||
group=self.wandb_group_name,
|
||||
)
|
||||
|
||||
self.lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=self.optimizer,
|
||||
total_steps=min(self.num_episodes, 4) * self.train_dataset_size // (self.batch_size * self.dp_size),
|
||||
warmup_steps=0,
|
||||
eta_min=0.1 * self.grpo_config.get("lr", 1e-6),
|
||||
)
|
||||
|
||||
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
||||
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
||||
)
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
||||
self.plugin.logger.set_level("ERROR")
|
||||
|
||||
def step(self, pbar: Any, **kwargs) -> Optional[float]:
|
||||
"""
|
||||
Step data from policy model:
|
||||
[{
|
||||
"input_ids": torch.Tensor,
|
||||
"attention_mask": torch.Tensor,
|
||||
"action_mask": torch.Tensor,
|
||||
"action_log_probs": torch.Tensor,
|
||||
},
|
||||
...]
|
||||
Format:
|
||||
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
||||
"""
|
||||
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
|
||||
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items() if "raw_train_mini_batch_" not in k}
|
||||
self.raw_train_batch_reward.extend(kwargs["raw_train_mini_batch_reward"])
|
||||
self.raw_train_batch_format_acc.extend(kwargs["raw_train_mini_batch_format_acc"])
|
||||
self.raw_train_batch_ans_acc.extend(kwargs["raw_train_mini_batch_ans_acc"])
|
||||
self.raw_train_batch_response_len.extend(kwargs["raw_train_mini_batch_response_len"])
|
||||
action_mask = data["action_mask"]
|
||||
num_action = action_mask.shape[1]
|
||||
old_action_log_probs = data["action_log_probs"]
|
||||
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
||||
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
|
||||
|
||||
reward = data["reward"].view((-1))
|
||||
format_acc = data["format_acc"].view((-1))
|
||||
ans_acc = data["ans_acc"].view((-1))
|
||||
|
||||
# [minibatch_size, num_generations]
|
||||
|
||||
group_reward = reward.view(-1, self.num_generations)
|
||||
reward_mean = group_reward.mean(dim=1)
|
||||
# [minibatch_size x num_generations]
|
||||
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
|
||||
|
||||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||
# [minibatch_size x num_generations]
|
||||
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
||||
|
||||
# [minibatch_size x num_of_generation]
|
||||
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
||||
|
||||
# filter out overlength samples
|
||||
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
||||
loss_mask = torch.logical_and(
|
||||
loss_mask,
|
||||
action_mask[:, -1] == False,
|
||||
)
|
||||
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False) == False:
|
||||
# filter out samples with reward outside the range
|
||||
# if dynamic batching is enabled, we filter out out of range groups before training
|
||||
group_ans_acc_mean = (
|
||||
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1)
|
||||
)
|
||||
loss_mask = torch.logical_and(
|
||||
loss_mask,
|
||||
torch.logical_and(
|
||||
group_ans_acc_mean > self.filter_range[0],
|
||||
group_ans_acc_mean < self.filter_range[1],
|
||||
),
|
||||
)
|
||||
self.effective_prompt_count += (
|
||||
group_reward.size(0) * self.dp_size
|
||||
) # all prompts in the batch are effective as we filtered out the bad ones before step.
|
||||
|
||||
mean_kl, mean_loss = [], []
|
||||
|
||||
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
|
||||
|
||||
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
|
||||
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
|
||||
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
|
||||
self.effective_sample_count += effective_samples.item()
|
||||
pbar.set_postfix(
|
||||
{
|
||||
"Global Step": self.global_step,
|
||||
"Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples",
|
||||
}
|
||||
)
|
||||
|
||||
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
|
||||
ctx = (
|
||||
nullcontext()
|
||||
if need_update or self.booster.plugin.zero_stage == 2
|
||||
else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||
)
|
||||
with ctx:
|
||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
|
||||
input_ids_forward_micro_batch = data["input_ids"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
attention_mask_forward_micro_batch = data["attention_mask"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
action_mask_forward_micro_batch = action_mask[
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
loss_mask_forward_micro_batch = (
|
||||
loss_mask[forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size]
|
||||
if loss_mask is not None
|
||||
else None
|
||||
)
|
||||
advantages_forward_micro_batch = advantages[
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
|
||||
if self.plugin.pp_size > 1:
|
||||
# Support training with PP.
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
with torch.no_grad():
|
||||
reference_model_outputs = self.booster.execute_pipeline(
|
||||
iter(
|
||||
[
|
||||
{
|
||||
"input_ids": input_ids_forward_micro_batch,
|
||||
"attention_mask": attention_mask_forward_micro_batch,
|
||||
}
|
||||
]
|
||||
),
|
||||
self.reference_model,
|
||||
criterion=lambda outputs, inputs: torch.tensor(
|
||||
[0.0], device=action_mask.device
|
||||
), # dummy criterion
|
||||
optimizer=None,
|
||||
return_loss=False,
|
||||
return_outputs=True,
|
||||
)
|
||||
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
reference_action_log_probs = memory_efficient_logprob(
|
||||
reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
shard_config=self.plugin.shard_config,
|
||||
)
|
||||
else:
|
||||
# Dummy reference logprobs for data iterator.
|
||||
reference_action_log_probs = None
|
||||
else:
|
||||
reference_action_log_probs = None
|
||||
|
||||
data_policy_forward = {
|
||||
"input_ids": input_ids_forward_micro_batch,
|
||||
"attention_mask": attention_mask_forward_micro_batch,
|
||||
"action_mask": action_mask_forward_micro_batch,
|
||||
"advantages": advantages_forward_micro_batch,
|
||||
"loss_mask": loss_mask_forward_micro_batch,
|
||||
"source": self.rank,
|
||||
}
|
||||
if reference_action_log_probs is not None:
|
||||
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
|
||||
|
||||
kl = []
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
action_logits = outputs.logits
|
||||
action_log_probs = memory_efficient_logprob(
|
||||
action_logits / self.generate_config["temperature"],
|
||||
inputs["input_ids"],
|
||||
num_action,
|
||||
shard_config=self.plugin.shard_config,
|
||||
)
|
||||
if "reference_action_log_probs" in inputs:
|
||||
per_token_kl = (
|
||||
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
||||
- (inputs["reference_action_log_probs"] - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
|
||||
inputs["action_mask"], dim=-1
|
||||
)
|
||||
kl.append(appox_kl.mean())
|
||||
else:
|
||||
per_token_kl = 0.0
|
||||
kl.append(torch.tensor(0.0))
|
||||
|
||||
loss, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
action_log_probs,
|
||||
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
inputs["action_mask"],
|
||||
loss_mask=inputs["loss_mask"],
|
||||
total_effective_tokens_in_batch=total_effective_tokens_count,
|
||||
)
|
||||
return loss
|
||||
|
||||
policy_model_outputs = self.booster.execute_pipeline(
|
||||
iter([data_policy_forward]),
|
||||
self.policy_model,
|
||||
criterion=_criterion,
|
||||
optimizer=self.optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=False,
|
||||
)
|
||||
loss = policy_model_outputs["loss"]
|
||||
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
if len(kl) > 0:
|
||||
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
|
||||
mean_kl.append(kl)
|
||||
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
|
||||
else:
|
||||
policy_model_logits = self.policy_model(
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
attention_mask=attention_mask_forward_micro_batch,
|
||||
).logits
|
||||
action_log_probs = memory_efficient_logprob(
|
||||
policy_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
shard_config=self.plugin.shard_config,
|
||||
)
|
||||
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
with torch.no_grad():
|
||||
reference_model_logits = self.reference_model(
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
attention_mask=attention_mask_forward_micro_batch,
|
||||
).logits
|
||||
reference_action_log_probs = memory_efficient_logprob(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
shard_config=self.plugin.shard_config,
|
||||
)
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
- (reference_action_log_probs - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
|
||||
action_mask_forward_micro_batch, dim=-1
|
||||
)
|
||||
else:
|
||||
per_token_kl = 0.0
|
||||
kl = None
|
||||
|
||||
loss, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
old_action_log_probs,
|
||||
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
action_mask_forward_micro_batch,
|
||||
loss_mask=loss_mask_forward_micro_batch,
|
||||
total_effective_tokens_in_batch=total_effective_tokens_count,
|
||||
)
|
||||
|
||||
self.booster.backward(loss, self.optimizer)
|
||||
loss = all_reduce_mean(loss, self.plugin)
|
||||
# Calculate accumulate value.
|
||||
if kl is not None:
|
||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||
mean_kl.append(kl.data)
|
||||
mean_loss.append(loss.data)
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1
|
||||
and self.booster.plugin.stage_manager.is_last_stage()
|
||||
and self.tp_rank == 0
|
||||
and self.dp_rank == 0
|
||||
):
|
||||
reward = all_reduce_mean(reward.mean(), self.plugin)
|
||||
format_acc = all_reduce_mean(format_acc.mean(), self.plugin)
|
||||
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
|
||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||
self.accum_advantages.add_(advantages.data)
|
||||
self.accum_count += 1
|
||||
if need_update:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.global_step += 1
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
# no need to run all reduce as raw_train_batch_* are not splited across dp rank
|
||||
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
|
||||
self.effective_prompt_count = 0
|
||||
self.effective_sample_count = 0
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
raw_batch_reward_mean = torch.cat(self.raw_train_batch_reward, dim=0).mean().cpu().item()
|
||||
raw_batch_format_acc_mean = torch.cat(self.raw_train_batch_format_acc, dim=0).mean().cpu().item()
|
||||
raw_batch_ans_acc_mean = torch.cat(self.raw_train_batch_ans_acc, dim=0).mean().cpu().item()
|
||||
raw_batch_response_len = torch.cat(self.raw_train_batch_response_len, dim=0)
|
||||
raw_batch_response_len_mean = raw_batch_response_len.mean().cpu().item()
|
||||
overlength_samples_ratio = (
|
||||
(raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item()
|
||||
) # not an exact figure, but a close estimate
|
||||
self.raw_train_batch_reward = []
|
||||
self.raw_train_batch_format_acc = []
|
||||
self.raw_train_batch_ans_acc = []
|
||||
self.raw_train_batch_response_len = []
|
||||
to_log_msg = [
|
||||
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
|
||||
f"Reward: {raw_batch_reward_mean:.4f}",
|
||||
f"format Reward: {raw_batch_format_acc_mean:.4f}",
|
||||
f"Acc Reward: {raw_batch_ans_acc_mean:.4f}",
|
||||
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
||||
f"Response Length: {raw_batch_response_len_mean:.4f}",
|
||||
f"Sample_utilization: {sample_utilization:.4f}",
|
||||
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
|
||||
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
||||
print("\n".join(to_log_msg))
|
||||
metrics = {
|
||||
"metrics/reward": raw_batch_reward_mean,
|
||||
"metrics/format_acc": raw_batch_format_acc_mean,
|
||||
"metrics/ans_acc": raw_batch_ans_acc_mean,
|
||||
"metrics/response_length": raw_batch_response_len_mean,
|
||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||
"train/sample_utilization": sample_utilization,
|
||||
"train/overlength_samples_ratio": overlength_samples_ratio,
|
||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||
}
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
|
||||
if self.wandb_run is not None:
|
||||
self.wandb_run.log(metrics)
|
||||
self.accum_loss.zero_()
|
||||
self.accum_kl.zero_()
|
||||
self.accum_advantages.zero_()
|
||||
self.accum_count = 0
|
||||
return loss_scalar
|
||||
else:
|
||||
return None
|
||||
|
||||
def state_dict(self):
|
||||
self.policy_model._force_wait_all_gather()
|
||||
model = self.policy_model.unwrap()
|
||||
state_dict = model.state_dict()
|
||||
return state_dict
|
@ -0,0 +1,533 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import ray
|
||||
import ray.util.collective as cc
|
||||
import torch
|
||||
import tqdm
|
||||
import wandb
|
||||
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
|
||||
from coati.distributed.profiling_utils import CustomProfiler
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from ray.util.collective import allreduce
|
||||
from ray.util.collective.types import ReduceOp
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
|
||||
from coati.distributed.inference_backend import BACKEND_MAP
|
||||
from coati.distributed.utils import pre_send, safe_append_to_jsonl_file
|
||||
|
||||
try:
|
||||
from vllm import SamplingParams
|
||||
except ImportError:
|
||||
LLM = None
|
||||
|
||||
|
||||
class BaseProducer:
|
||||
def __init__(
|
||||
self,
|
||||
shared_sync_data_actor: SharedVariableActor,
|
||||
shared_signal_actor: SharedVariableActor,
|
||||
producer_idx: int,
|
||||
num_producers: int,
|
||||
num_consumer_procs: int,
|
||||
num_episodes: int,
|
||||
batch_size: int,
|
||||
train_dataset_config: Dict[str, Any],
|
||||
model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||
microbatch_size: int = 1,
|
||||
backend: str = "transformers",
|
||||
consumer_plugin_config: Dict[str, Any] = None,
|
||||
eval_dataset_config=None,
|
||||
eval_interval=-1, # disable evaluation
|
||||
grpo_config: Dict[str, Any] = None,
|
||||
eval_save_dir: str = "./eval",
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
wandb_group_name: str = None,
|
||||
log_rollout_interval: int = 20,
|
||||
rollout_log_file: str = "./rollout_log.jsonl",
|
||||
enable_profiling: bool = False,
|
||||
):
|
||||
self.producer_idx = producer_idx
|
||||
self.num_producers = num_producers
|
||||
self.num_consumer_procs = num_consumer_procs
|
||||
self.num_episodes = num_episodes
|
||||
self.batch_size = batch_size
|
||||
self.microbatch_size = microbatch_size
|
||||
assert batch_size % microbatch_size == 0
|
||||
self.num_microbatches = batch_size // microbatch_size
|
||||
self.latest_eval_step = -1
|
||||
self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling)
|
||||
|
||||
# for async data and model sync
|
||||
self.shared_sync_data_actor = shared_sync_data_actor
|
||||
self.shared_signal_actor = shared_signal_actor
|
||||
self.sync_model_thread_started = False
|
||||
|
||||
self.train_dataset_config = train_dataset_config
|
||||
self.model_config = model_config
|
||||
self.generate_config = generate_config
|
||||
self.tokenizer_config = tokenizer_config
|
||||
self.consumer_plugin_config = consumer_plugin_config
|
||||
self.eval_interval = eval_interval
|
||||
self.eval_save_dir = eval_save_dir
|
||||
self.consumer_global_step = 0
|
||||
self.producer_weight_version = 0
|
||||
self.eval_mode = False
|
||||
self.log_rollout_interval = log_rollout_interval
|
||||
self.latest_rollout_log_step = -1
|
||||
self.grpo_config = grpo_config
|
||||
reward_model_kwargs = {
|
||||
k: v
|
||||
for k, v in grpo_config.items()
|
||||
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
|
||||
}
|
||||
self.response_format_tags = grpo_config.get("response_format_tags", None)
|
||||
if producer_idx == 0:
|
||||
if os.path.exists(rollout_log_file):
|
||||
raise ValueError(
|
||||
f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name."
|
||||
)
|
||||
else:
|
||||
os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True)
|
||||
self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8")
|
||||
if self.producer_idx == 0:
|
||||
self.wandb_run = wandb.init(
|
||||
project=project_name,
|
||||
sync_tensorboard=False,
|
||||
dir="./wandb",
|
||||
name=run_name + "_eval",
|
||||
group=wandb_group_name,
|
||||
)
|
||||
|
||||
if os.path.exists(self.eval_save_dir) and self.eval_interval > 0:
|
||||
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
|
||||
|
||||
# init tokenizer
|
||||
if tokenizer_config is None:
|
||||
tokenizer_path = model_config["path"]
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
else:
|
||||
tokenizer_path = tokenizer_config.pop("path")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config)
|
||||
self.tokenizer.padding_side = "left"
|
||||
|
||||
# init dataloader
|
||||
train_dataset_path = train_dataset_config.pop("path")
|
||||
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
|
||||
self.train_dataloader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=microbatch_size,
|
||||
sampler=DistributedSampler(
|
||||
self.train_dataset,
|
||||
num_replicas=num_producers,
|
||||
rank=producer_idx,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
seed=42,
|
||||
),
|
||||
num_workers=4,
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn_grpo,
|
||||
)
|
||||
if grpo_config["reward_fn_type"] == "think_answer_tags":
|
||||
self.evaluation_function = math_reward_fn
|
||||
elif grpo_config["reward_fn_type"] == "boxed":
|
||||
self.evaluation_function = boxed_math_reward_fn
|
||||
elif grpo_config["reward_fn_type"] == "code":
|
||||
self.evaluation_function = code_reward_fn
|
||||
else:
|
||||
raise ValueError(f"Unknown evaluation function type {grpo_config['reward_fn_type']}")
|
||||
|
||||
self.eval_dataset_config = eval_dataset_config
|
||||
if self.eval_dataset_config is not None:
|
||||
self.eval_dataloaders = {}
|
||||
for eval_task_name in self.eval_dataset_config:
|
||||
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")
|
||||
eval_dataset = RawConversationDataset(
|
||||
self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name]
|
||||
)
|
||||
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}")
|
||||
self.eval_dataloaders[eval_task_name] = DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=microbatch_size,
|
||||
sampler=DistributedSampler(
|
||||
eval_dataset,
|
||||
num_replicas=num_producers,
|
||||
rank=producer_idx,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
seed=42,
|
||||
),
|
||||
collate_fn=collate_fn_grpo,
|
||||
)
|
||||
else:
|
||||
print("No eval dataset provided, skip eval")
|
||||
self.device = get_current_device()
|
||||
self.reward_model = VerifiableReward(
|
||||
reward_fns=[self.evaluation_function], # multiple reward functions can be added here
|
||||
tokenizer=self.tokenizer,
|
||||
tags=self.response_format_tags,
|
||||
**reward_model_kwargs,
|
||||
)
|
||||
|
||||
# init backend
|
||||
if backend in BACKEND_MAP:
|
||||
self.backend_cls = BACKEND_MAP[backend]
|
||||
else:
|
||||
raise ValueError(f"Unexpected backend {backend}")
|
||||
|
||||
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
|
||||
self.state_dict_cpu = {i: None for i in range(self.consumer_pp_size)}
|
||||
|
||||
def init_collective_group(
|
||||
self,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
backend: str = "nccl",
|
||||
group_name: str = "default",
|
||||
gloo_timeout: int = 3000000,
|
||||
):
|
||||
cc.init_collective_group(
|
||||
world_size=world_size, rank=rank, backend=backend, group_name=group_name, gloo_timeout=gloo_timeout
|
||||
)
|
||||
print(f"[P{self.producer_idx}] Initialized {group_name} collective group", flush=True)
|
||||
|
||||
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def loop(self) -> None:
|
||||
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
|
||||
num_valid_microbatches = num_update_per_episode * self.num_microbatches
|
||||
|
||||
print(
|
||||
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}"
|
||||
)
|
||||
for episode in range(self.num_episodes):
|
||||
self.train_dataloader.sampler.set_epoch(episode)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
self.profiler.log(f"train episode {episode} batch {i}")
|
||||
if i >= num_valid_microbatches:
|
||||
break
|
||||
|
||||
self.consumer_global_step = ray.get(self.shared_signal_actor.get_signal.remote()).get("global_step", 0)
|
||||
# sync model first, as the model syncing runs in a separate thread, will not block the main thread
|
||||
# sync model during inference, which takes less than 10s, so that the model can be updated immediately after inference
|
||||
if episode != self.num_episodes - 1 or i != num_valid_microbatches - 1:
|
||||
# don't sync model for last iteration
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||
"enable_sleep_mode", False
|
||||
):
|
||||
self.model.llm.sleep() # revict KV_cache to avoid OOM
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# sync model thread function
|
||||
def sync_model_thread():
|
||||
if self.consumer_pp_size > 1:
|
||||
self.profiler.enter("sync_model")
|
||||
for pp_idx in range(self.consumer_pp_size):
|
||||
ray.get(
|
||||
self.shared_signal_actor.set_signal.remote(
|
||||
f"producer_{self.producer_idx}_pp_{pp_idx}", "ready_sync_model"
|
||||
)
|
||||
)
|
||||
print(
|
||||
f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||
)
|
||||
self.state_dict_cpu[pp_idx] = ray_broadcast_tensor_dict(
|
||||
self.state_dict_cpu[pp_idx],
|
||||
1,
|
||||
device=torch.device("cpu"),
|
||||
group_name=f"sync_model_producer_{self.producer_idx}_pp_{pp_idx}",
|
||||
backend="gloo", # use gloo for CPU communication
|
||||
pin_memory=True,
|
||||
)
|
||||
self.profiler.exit("sync_model")
|
||||
else:
|
||||
self.profiler.enter("sync_model")
|
||||
ray.get(
|
||||
self.shared_signal_actor.set_signal.remote(
|
||||
f"producer_{self.producer_idx}", "ready_sync_model"
|
||||
)
|
||||
)
|
||||
print(
|
||||
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||
)
|
||||
time0 = time.time()
|
||||
self.state_dict_cpu[0] = ray_broadcast_tensor_dict(
|
||||
self.state_dict_cpu[0],
|
||||
1,
|
||||
device=torch.device("cpu"),
|
||||
group_name=f"sync_model_producer_{self.producer_idx}",
|
||||
backend="gloo", # use gloo for CPU communication
|
||||
pin_memory=True,
|
||||
)
|
||||
self.profiler.log(f"Broadcast model state dict took {time.time() - time0:.2f} seconds")
|
||||
self.profiler.exit("sync_model")
|
||||
self.sync_model_thread_started = False
|
||||
|
||||
if not self.sync_model_thread_started and self.consumer_global_step != self.producer_weight_version:
|
||||
# only sync model when the thread is not started and global step is changed
|
||||
self.sync_model_thread_started = True
|
||||
self.sync_model_thread = threading.Thread(target=sync_model_thread)
|
||||
self.producer_weight_version = self.consumer_global_step
|
||||
self.sync_model_thread.start()
|
||||
torch.cuda.empty_cache()
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||
"enable_sleep_mode", False
|
||||
):
|
||||
self.model.llm.wake_up()
|
||||
|
||||
if self.eval_interval > 0 and self.eval_dataset_config is not None:
|
||||
if (
|
||||
self.consumer_global_step - self.latest_eval_step >= self.eval_interval
|
||||
and self.consumer_global_step > self.latest_eval_step
|
||||
) or self.latest_eval_step == -1:
|
||||
to_log_msg = {}
|
||||
self.eval_mode = True
|
||||
for eval_task_name in self.eval_dataloaders:
|
||||
if self.producer_idx == 0:
|
||||
print(
|
||||
f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}"
|
||||
)
|
||||
eval_results = []
|
||||
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
|
||||
for eval_batch in tqdm.tqdm(
|
||||
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
|
||||
):
|
||||
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
|
||||
eval_results = eval_results + [
|
||||
self.evaluation_function(
|
||||
eval_outputs["input_ids"][m][n],
|
||||
eval_outputs[
|
||||
(
|
||||
"test_cases"
|
||||
if self.grpo_config["reward_fn_type"] == "code"
|
||||
else "gt_answer"
|
||||
)
|
||||
][m],
|
||||
eval_outputs["response_idx"][m][n],
|
||||
tokenizer=self.tokenizer,
|
||||
eval_mode=True,
|
||||
tags=self.response_format_tags,
|
||||
)
|
||||
for m in range(eval_outputs["input_ids"].size(0))
|
||||
for n in range(eval_outputs["input_ids"].size(1))
|
||||
]
|
||||
eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1])
|
||||
eval_statistics_tensor[1] += len(eval_results)
|
||||
allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_pg")
|
||||
to_log_msg[f"eval/{eval_task_name}"] = (
|
||||
eval_statistics_tensor[0].item() / eval_statistics_tensor[1].item()
|
||||
)
|
||||
if self.producer_idx == 0:
|
||||
print(
|
||||
f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}"
|
||||
)
|
||||
# save eval results
|
||||
safe_append_to_jsonl_file(
|
||||
os.path.join(
|
||||
self.eval_save_dir,
|
||||
f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl",
|
||||
),
|
||||
eval_results,
|
||||
)
|
||||
|
||||
if self.producer_idx == 0:
|
||||
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
|
||||
self.eval_mode = False
|
||||
self.latest_eval_step = self.consumer_global_step
|
||||
self.profiler.enter("sleep")
|
||||
while not (ray.get(self.shared_sync_data_actor.pickup_rollout_task.remote(self.microbatch_size))):
|
||||
time.sleep(1)
|
||||
self.profiler.exit("sleep")
|
||||
self.profiler.enter("rollout")
|
||||
self.profiler.log(f"rollout batch {i} episode {episode}")
|
||||
# time.sleep(30) # simulate long inference time
|
||||
outputs = self.rollout(**batch)
|
||||
self.profiler.exit("rollout")
|
||||
outputs["temperature"] = torch.tensor(
|
||||
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
||||
).to(outputs["input_ids"].device)
|
||||
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
|
||||
self.profiler.enter("calculate_reward")
|
||||
if self.grpo_config["reward_fn_type"] == "code":
|
||||
test_cases = []
|
||||
for prompt_id in range(bs):
|
||||
test_cases.extend([outputs["test_cases"][prompt_id]] * num_gen)
|
||||
reward_model_output = self.reward_model(
|
||||
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
|
||||
test_cases=test_cases,
|
||||
response_idx=outputs["response_idx"].view((-1, 2)),
|
||||
)
|
||||
else:
|
||||
gt_answer = []
|
||||
for prompt_id in range(bs):
|
||||
gt_answer.extend([outputs["gt_answer"][prompt_id]] * num_gen)
|
||||
reward_model_output = self.reward_model(
|
||||
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
|
||||
gt_answer=gt_answer,
|
||||
response_idx=outputs["response_idx"].view((-1, 2)),
|
||||
)
|
||||
outputs["reward"] = (
|
||||
torch.tensor([value[0] for value in reward_model_output])
|
||||
.to(outputs["input_ids"].device)
|
||||
.view((bs, num_gen, 1))
|
||||
)
|
||||
outputs["format_acc"] = (
|
||||
torch.tensor([value[1] for value in reward_model_output])
|
||||
.to(outputs["input_ids"].device)
|
||||
.view((bs, num_gen, 1))
|
||||
)
|
||||
outputs["ans_acc"] = (
|
||||
torch.tensor([value[2] for value in reward_model_output])
|
||||
.to(outputs["input_ids"].device)
|
||||
.view((bs, num_gen, 1))
|
||||
)
|
||||
if "gt_answer" in outputs:
|
||||
outputs.pop("gt_answer")
|
||||
if "test_cases" in outputs:
|
||||
outputs.pop("test_cases")
|
||||
self.profiler.exit("calculate_reward")
|
||||
|
||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||
outputs = pre_send(outputs)
|
||||
outputs = {k: v.cpu() for k, v in outputs.items()}
|
||||
self.profiler.enter("send_data")
|
||||
|
||||
ray.get(self.shared_sync_data_actor.append_data.remote(outputs))
|
||||
self.profiler.exit("send_data")
|
||||
|
||||
if (i + 1) % self.num_microbatches == 0 and (
|
||||
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
|
||||
):
|
||||
if not self.sync_model_thread_started:
|
||||
# load state dict, note this should be done in the main thread to avoid race condition
|
||||
for pp_idx in range(self.consumer_pp_size):
|
||||
if self.state_dict_cpu[pp_idx] is not None and self.state_dict_cpu[pp_idx] != {}:
|
||||
self.load_state_dict(self.state_dict_cpu[pp_idx])
|
||||
|
||||
# linear annealing for 1 episode, temperature from initial to 0.9
|
||||
if episode <= 0:
|
||||
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
|
||||
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
] + ratio * 0.9
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]):
|
||||
self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
] + ratio * 0.9
|
||||
|
||||
def __del__(self):
|
||||
self.profiler.close()
|
||||
|
||||
|
||||
@ray.remote
|
||||
class SimpleProducer(BaseProducer):
|
||||
def __init__(
|
||||
self,
|
||||
shared_sync_data_actor: SharedVariableActor,
|
||||
shared_signal_actor: SharedVariableActor,
|
||||
producer_idx,
|
||||
num_producers,
|
||||
num_consumer_procs,
|
||||
num_episodes,
|
||||
batch_size,
|
||||
train_dataset_config,
|
||||
model_config,
|
||||
generate_config,
|
||||
tokenizer_config=None,
|
||||
microbatch_size=1,
|
||||
backend="transformers",
|
||||
num_generations: int = 8,
|
||||
consumer_plugin_config=None,
|
||||
eval_dataset_config=None,
|
||||
eval_interval=-1, # disable evaluation
|
||||
grpo_config: Dict[str, Any] = None,
|
||||
eval_save_dir: str = "./eval",
|
||||
eval_generation_config={},
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
wandb_group_name: str = None,
|
||||
log_rollout_interval: int = 20,
|
||||
rollout_log_file: str = "./rollout_log.jsonl",
|
||||
enable_profiling: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
shared_sync_data_actor,
|
||||
shared_signal_actor,
|
||||
producer_idx,
|
||||
num_producers,
|
||||
num_consumer_procs,
|
||||
num_episodes,
|
||||
batch_size,
|
||||
train_dataset_config,
|
||||
model_config,
|
||||
generate_config,
|
||||
tokenizer_config,
|
||||
microbatch_size,
|
||||
backend,
|
||||
consumer_plugin_config,
|
||||
eval_dataset_config=eval_dataset_config,
|
||||
eval_interval=eval_interval,
|
||||
grpo_config=grpo_config,
|
||||
eval_save_dir=eval_save_dir,
|
||||
project_name=project_name,
|
||||
run_name=run_name,
|
||||
wandb_group_name=wandb_group_name,
|
||||
log_rollout_interval=log_rollout_interval,
|
||||
rollout_log_file=rollout_log_file,
|
||||
enable_profiling=enable_profiling,
|
||||
)
|
||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
||||
self.eval_generation_config.update(eval_generation_config)
|
||||
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
|
||||
|
||||
@torch.no_grad()
|
||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
||||
if self.producer_idx == 0 and not self.eval_mode:
|
||||
if (
|
||||
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
|
||||
or self.latest_rollout_log_step == -1
|
||||
):
|
||||
new_record = (
|
||||
json.dumps(
|
||||
{
|
||||
"train_step": self.consumer_global_step,
|
||||
"rollout": self.tokenizer.batch_decode(
|
||||
rollouts["input_ids"][:, 0], skip_special_tokens=True
|
||||
),
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
self.rollout_log_file.write(new_record)
|
||||
self.rollout_log_file.flush()
|
||||
self.latest_rollout_log_step = self.consumer_global_step
|
||||
return rollouts
|
||||
|
||||
def __del__(self):
|
||||
if self.producer_idx == 0:
|
||||
self.wandb_run.finish()
|
||||
if hasattr(self, "rollout_log_file"):
|
||||
self.rollout_log_file.close()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.model.load_state_dict(state_dict)
|
369
applications/ColossalChat/rl_example_zero_bubble.py
Normal file
369
applications/ColossalChat/rl_example_zero_bubble.py
Normal file
@ -0,0 +1,369 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.distributed.launch_zero_bubble import launch_distributed
|
||||
|
||||
DEFAUT_SYSTEM_PROMPT = {
|
||||
"think_answer_tags": "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n",
|
||||
"boxed": "Please reason step by step, and put your final answer within \\boxed{}.",
|
||||
"code": "You are a helpful assistant.",
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||
parser.add_argument(
|
||||
"-ed",
|
||||
"--eval-dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \
|
||||
For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \
|
||||
The key is the task name, and the value is the path to the jsonl file",
|
||||
)
|
||||
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
|
||||
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
|
||||
|
||||
# Distributed training parameters
|
||||
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
||||
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
||||
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
|
||||
parser.add_argument(
|
||||
"-ibs",
|
||||
"--inference-batch-size",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-imbs",
|
||||
"--inference-microbatch-size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tbs",
|
||||
"--train-batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tMbs",
|
||||
"--train-minibatch-size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tmbs",
|
||||
"--train-microbatch-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tp",
|
||||
"--tensor-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pp",
|
||||
"--pipeline-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-zero",
|
||||
"--zero-stage",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional"
|
||||
)
|
||||
|
||||
# Sampling parameters
|
||||
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
|
||||
parser.add_argument("-temp", "--temperature", type=float, default=1.0, help="Temperature for sampling.")
|
||||
parser.add_argument(
|
||||
"-topk",
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Top k for sampling. Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-topp",
|
||||
"--top-p",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Top p for sampling. Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
|
||||
parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.")
|
||||
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
|
||||
parser.add_argument(
|
||||
"-ptp",
|
||||
"--producer-tensor-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
|
||||
# GRPO parameters
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
|
||||
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
|
||||
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
|
||||
parser.add_argument(
|
||||
"-rt",
|
||||
"--reward-type",
|
||||
type=str,
|
||||
default="think_answer_tags",
|
||||
choices=["think_answer_tags", "boxed", "code"],
|
||||
help="Reward type for GRPO.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-ei",
|
||||
"--eval-interval",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Interval for evaluation. Evaluate every ei training steps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-cbsl",
|
||||
"--data_actor_buffer_size_limit",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="The approximate number of samples to keep in the consumer buffer. After this limit is reached, the producer will stop generating new samples and prioritize model sync until the consumer has processed some samples",
|
||||
)
|
||||
|
||||
# Logging/Checkpointing parameters
|
||||
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
||||
parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.")
|
||||
parser.add_argument(
|
||||
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.train_minibatch_size is None:
|
||||
# Default settings: Using train batch size as mini batch size
|
||||
args.train_minibatch_size = args.train_batch_size
|
||||
if args.inference_batch_size is None:
|
||||
# Default settings: Using train batch size as inference batch size, sync every inference model every train step
|
||||
args.inference_batch_size = args.train_batch_size
|
||||
assert (
|
||||
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
|
||||
and args.train_microbatch_size > 0
|
||||
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
||||
assert (
|
||||
args.train_minibatch_size <= args.train_batch_size and args.train_batch_size % args.train_minibatch_size == 0
|
||||
), "Train mini batch size must be less than or equals to train batch size and train batch size must be divisible by train mini batch size"
|
||||
|
||||
if args.master_address is None:
|
||||
# Default settings: Using single machine
|
||||
ray.init(
|
||||
address="local",
|
||||
namespace="ray-example",
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
|
||||
"TOKENIZERS_PARALLELISM": "false"
|
||||
},
|
||||
},
|
||||
)
|
||||
else:
|
||||
# For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node
|
||||
ray.init(
|
||||
_node_ip_address=args.master_address,
|
||||
namespace="ray-example",
|
||||
_temp_dir=args.ray_dir,
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
|
||||
"TOKENIZERS_PARALLELISM": "false"
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if args.top_k is None:
|
||||
if args.backend == "transformers":
|
||||
args.top_k = 50
|
||||
elif args.backend == "vllm":
|
||||
args.top_k = -1
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
|
||||
|
||||
inference_model_config = dict(path=args.model)
|
||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
||||
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
|
||||
|
||||
if args.backend == "transformers":
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
use_flash_attention_2=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_length=args.max_new_tokens + args.max_prompt_tokens,
|
||||
do_sample=True,
|
||||
max_new_tokens=None,
|
||||
early_stopping=False if args.reward_type == "think_answer_tags" else True,
|
||||
stop_strings=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||
)
|
||||
)
|
||||
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
|
||||
elif args.backend == "vllm":
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
gpu_memory_utilization=0.7,
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
|
||||
tensor_parallel_size=args.producer_tensor_parallel_size,
|
||||
)
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_tokens=args.max_new_tokens, # max new tokens
|
||||
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
|
||||
include_stop_str_in_output=True,
|
||||
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||
)
|
||||
)
|
||||
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
if args.algo == "GRPO":
|
||||
# Default Settings
|
||||
grpo_config = {
|
||||
"lr": args.learning_rate,
|
||||
"train_microbatch_size": args.train_microbatch_size,
|
||||
"num_minibatch_during_rollout": 1, # number of mini batches to pop out from buffer and used for training during rollout of the producer after it syncs the model. Hint, set to a proper value close to the number of mini batches for training that takes roughly the same time as the rollout of the producer. A value that is too large or too small will cause bubble time on the trainer or the producer.
|
||||
"beta": args.kl_coeff, # KL penalty coefficient
|
||||
"loss_variation": "sample_level",
|
||||
"reward_fn_type": args.reward_type,
|
||||
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"response_format_tags": (
|
||||
{
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
if args.reward_type == "think_answer_tags"
|
||||
else None
|
||||
),
|
||||
}
|
||||
elif args.algo == "DAPO":
|
||||
# DAPO variant settings
|
||||
grpo_config = {
|
||||
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
|
||||
"lr": args.learning_rate,
|
||||
"train_microbatch_size": args.train_microbatch_size,
|
||||
"dynamic_batching": True,
|
||||
"clip_eps_low": 0.2,
|
||||
"clip_eps_high": 0.28,
|
||||
"skip_threshold": 20.0,
|
||||
"beta": 0, # no KL penalty for DAPO
|
||||
"loss_variation": "token_level",
|
||||
"soft_over_length_punishment": True,
|
||||
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"cache_length": min(1024, int(args.max_new_tokens / 4)),
|
||||
"filter_truncated_response": True,
|
||||
"reward_fn_type": args.reward_type,
|
||||
"response_format_tags": (
|
||||
{
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
if args.reward_type == "think_answer_tags"
|
||||
else None
|
||||
),
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
||||
|
||||
if args.system_prompt is None:
|
||||
# Default system prompt
|
||||
args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type]
|
||||
|
||||
launch_distributed(
|
||||
num_producers=args.num_inferencer,
|
||||
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size),
|
||||
num_consumer_procs=args.num_trainers,
|
||||
num_episodes=args.num_episodes,
|
||||
inference_batch_size=args.inference_batch_size,
|
||||
inference_microbatch_size=args.inference_microbatch_size,
|
||||
train_batch_size=args.train_batch_size,
|
||||
train_minibatch_size=args.train_minibatch_size,
|
||||
train_dataset_config={
|
||||
"path": args.dataset,
|
||||
"max_length": args.max_prompt_tokens,
|
||||
"system_prompt": args.system_prompt,
|
||||
},
|
||||
inference_model_config=inference_model_config,
|
||||
generate_config=generate_config,
|
||||
num_generations=args.num_generations,
|
||||
train_model_config=train_model_config,
|
||||
grpo_config=grpo_config,
|
||||
plugin_config={
|
||||
"tp_size": args.tensor_parallel_size,
|
||||
"pp_size": args.pipeline_parallel_size,
|
||||
"microbatch_size": max(
|
||||
1, args.train_microbatch_size // args.pipeline_parallel_size
|
||||
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
"zero_stage": args.zero_stage,
|
||||
"max_norm": 1.0,
|
||||
}, # for pp, tp
|
||||
inference_backend=args.backend,
|
||||
master_addr="localhost",
|
||||
master_port=args.master_port,
|
||||
core_algo=args.algo,
|
||||
project_name=args.project,
|
||||
save_interval=args.save_interval,
|
||||
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
|
||||
eval_dataset_config=(
|
||||
{
|
||||
k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt}
|
||||
for k, v in json.loads(args.eval_dataset).items()
|
||||
}
|
||||
if args.eval_dataset
|
||||
else None
|
||||
),
|
||||
eval_interval=args.eval_interval,
|
||||
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
||||
eval_generation_config=eval_generation_config,
|
||||
log_rollout_interval=20,
|
||||
rollout_save_dir=args.rollout_save_dir,
|
||||
enable_profiling=args.enable_profiling,
|
||||
data_actor_buffer_size_limit=args.data_actor_buffer_size_limit,
|
||||
)
|
Loading…
Reference in New Issue
Block a user