mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-23 02:06:35 +00:00
184 lines
7.3 KiB
Python
184 lines
7.3 KiB
Python
import copy
|
|
import os
|
|
import uuid
|
|
from typing import Any, Dict, Optional
|
|
|
|
import ray
|
|
|
|
from .consumer import SimpleConsumer
|
|
from .grpo_consumer import GRPOConsumer
|
|
from .producer import SimpleProducer
|
|
|
|
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
|
|
|
|
|
|
def get_jsonl_size_fast(path: str) -> int:
|
|
with open(path) as f:
|
|
lines = f.readlines()
|
|
lines = [line for line in lines if line.strip()]
|
|
return len(lines)
|
|
|
|
|
|
def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:
|
|
tp_size = plugin_config.get("tp_size", 1)
|
|
pp_size = plugin_config.get("pp_size", 1)
|
|
ep_size = plugin_config.get("ep_size", 1)
|
|
sp_size = plugin_config.get("sp_size", 1)
|
|
return n_procs // (tp_size * pp_size * ep_size * sp_size)
|
|
|
|
|
|
def launch_distributed(
|
|
num_producers: int,
|
|
num_proc_per_producer: int,
|
|
num_consumer_procs: int,
|
|
num_episodes: int,
|
|
inference_batch_size: int,
|
|
inference_microbatch_size: int,
|
|
train_batch_size: int,
|
|
train_minibatch_size: int,
|
|
train_dataset_config: Dict[str, Any],
|
|
inference_model_config: Dict[str, Any],
|
|
generate_config: Dict[str, Any],
|
|
train_model_config: Dict[str, Any],
|
|
grpo_config: Dict[str, Any],
|
|
plugin_config: Dict[str, Any],
|
|
tokenizer_config: Optional[Dict[str, Any]] = None,
|
|
inference_backend: str = "transformers",
|
|
num_generations: int = 8,
|
|
master_addr: str = "localhost",
|
|
master_port: int = 29500,
|
|
core_algo: str = "GRPO",
|
|
project_name: Optional[str] = None,
|
|
save_interval: int = 100,
|
|
save_dir: str = "./model",
|
|
eval_dataset_config: Optional[Dict[str, Any]] = None,
|
|
eval_interval: int = 100,
|
|
eval_save_dir: Optional[str] = None,
|
|
eval_generation_config: Optional[Dict[str, Any]] = None,
|
|
log_rollout_interval: int = 20,
|
|
rollout_save_dir: str = "./rollout",
|
|
enable_profiling: bool = False,
|
|
n_behind: int = 0,
|
|
):
|
|
if core_algo not in ALGO_MAP:
|
|
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
|
else:
|
|
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
|
|
|
|
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
|
|
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
|
|
|
dataset_path = train_dataset_config["path"]
|
|
num_samples = get_jsonl_size_fast(dataset_path)
|
|
global_inference_batch_size = inference_batch_size * num_producers
|
|
num_update_per_episode = num_samples // global_inference_batch_size
|
|
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
|
|
|
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
|
|
wandb_group_name = str(uuid.uuid4())
|
|
rollout_log_file = os.path.join(
|
|
rollout_save_dir,
|
|
f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl",
|
|
)
|
|
|
|
# Attention: Ray use complex schedualing method that consider various factors including load-balancing.
|
|
# when requesting resources, it is not guaranteed that the resource comes from a node with lower node it
|
|
# this go against the design principle of our implementation, and we need to manually force the schedualing,
|
|
# allocating the producer to nodes with lower node id and the consumer to the resouces from nodes with higher
|
|
# node id. See the reference here: https://docs.ray.io/en/latest/ray-core/scheduling/index.html#nodeaffinityschedulingstrategy
|
|
nodes = ray.nodes()
|
|
node_info = {
|
|
node["NodeID"]: {
|
|
"num_gpus": node["Resources"].get("GPU", 0),
|
|
"address": node["NodeManagerAddress"],
|
|
} # Default to 0 if no GPUs are available
|
|
for node in nodes
|
|
}
|
|
gpu_to_node_id = []
|
|
gpu_to_ip_address = []
|
|
for node_id in node_info:
|
|
for idx in range(int(node_info[node_id]["num_gpus"])):
|
|
gpu_to_node_id.append(node_id)
|
|
gpu_to_ip_address.append(node_info[node_id]["address"])
|
|
print(node_info)
|
|
|
|
producer_procs = []
|
|
for i in range(num_producers):
|
|
node_id = gpu_to_node_id[0]
|
|
producer_ip_address = gpu_to_ip_address[0]
|
|
for _ in range(num_proc_per_producer):
|
|
gpu_to_node_id.pop(0)
|
|
gpu_to_ip_address.pop(0)
|
|
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
|
|
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
|
|
producer_idx=i,
|
|
num_producers=num_producers,
|
|
num_consumer_procs=num_consumer_procs,
|
|
num_episodes=num_episodes,
|
|
batch_size=inference_batch_size,
|
|
train_dataset_config=train_dataset_config,
|
|
model_config=inference_model_config,
|
|
generate_config=generate_config,
|
|
tokenizer_config=tokenizer_config,
|
|
microbatch_size=inference_microbatch_size,
|
|
backend=inference_backend,
|
|
num_generations=num_generations,
|
|
consumer_plugin_config=plugin_config,
|
|
eval_dataset_config=eval_dataset_config,
|
|
eval_interval=eval_interval,
|
|
grpo_config=grpo_config,
|
|
eval_save_dir=eval_save_dir,
|
|
eval_generation_config=eval_generation_config,
|
|
project_name=project_name,
|
|
run_name=run_name,
|
|
wandb_group_name=wandb_group_name,
|
|
log_rollout_interval=log_rollout_interval,
|
|
rollout_log_file=rollout_log_file,
|
|
enable_profiling=enable_profiling,
|
|
n_behind=n_behind,
|
|
)
|
|
producer_procs.append(producer)
|
|
ray.get([p.setup.remote() for p in producer_procs])
|
|
generate_config_consumer = copy.deepcopy(generate_config)
|
|
generate_config_consumer.update(
|
|
dict(
|
|
backend=inference_backend,
|
|
)
|
|
)
|
|
consumer_master_ip_address = gpu_to_ip_address[0]
|
|
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
|
|
consumer_procs = []
|
|
for i in range(num_consumer_procs):
|
|
node_id = gpu_to_node_id[0]
|
|
consumer_ip_address = gpu_to_ip_address[0]
|
|
gpu_to_node_id.pop(0)
|
|
gpu_to_ip_address.pop(0)
|
|
print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}")
|
|
consumer = core_consumer.options(num_gpus=1).remote(
|
|
num_producers=num_producers,
|
|
num_episodes=num_episodes,
|
|
rank=i,
|
|
world_size=num_consumer_procs,
|
|
master_addr=consumer_master_ip_address,
|
|
master_port=master_port,
|
|
num_update_per_episode=num_update_per_episode,
|
|
num_recv_per_update=num_recv_per_update,
|
|
batch_size=train_batch_size,
|
|
model_config=train_model_config,
|
|
plugin_config=plugin_config,
|
|
minibatch_size=train_minibatch_size,
|
|
generate_config=generate_config_consumer,
|
|
grpo_config=grpo_config,
|
|
num_generations=num_generations,
|
|
save_interval=save_interval,
|
|
save_dir=save_dir,
|
|
project_name=project_name,
|
|
run_name=run_name,
|
|
wandb_group_name=wandb_group_name,
|
|
enable_profiling=enable_profiling,
|
|
n_behind=n_behind,
|
|
)
|
|
consumer_procs.append(consumer)
|
|
ray.get([p.setup.remote() for p in consumer_procs])
|
|
ray.get([p.loop.remote() for p in (producer_procs + consumer_procs)])
|