fix schedualing for multi-node training

This commit is contained in:
YeAnbang 2025-05-02 19:45:07 +08:00
parent d06042b434
commit 7d658402da
7 changed files with 124 additions and 38 deletions

View File

@ -358,9 +358,6 @@ def apply_chat_template_and_mask(
ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]:
if system_prompt is None:
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"
system_element = {
"role": "system",
"content": system_prompt,
@ -375,7 +372,9 @@ def apply_chat_template_and_mask(
tokens = []
assistant_mask = []
for i, msg in enumerate(chat):
msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True)
msg_tokens = tokenizer.apply_chat_template(
[system_element, msg] if system_prompt else [msg], tokenize=True, add_generation_prompt=True
)
# remove unexpected bos token
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
msg_tokens = msg_tokens[1:]

View File

@ -115,10 +115,10 @@ class BaseConsumer:
eval_statistics = None
eval_global_step = None
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
local_eval_result = ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_data_{r}"
)
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
assert "consumer_global_step" in local_eval_result
eval_global_step = local_eval_result.pop("consumer_global_step").item()
if eval_statistics is None:

View File

@ -41,6 +41,7 @@ class GRPOConsumer(BaseConsumer):
save_interval: int = 100,
save_dir="./model",
eval_interval: int = -1,
response_format_tags: dict = None,
):
print(f"Using GRPO config: {grpo_config}")
if grpo_config.get("loss_variation", "sample_level") == "token_level":
@ -125,12 +126,6 @@ class GRPOConsumer(BaseConsumer):
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
)
# Initialize verifiable reward.
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
reward_model_kwargs = {
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
}
@ -377,7 +372,7 @@ class GRPOConsumer(BaseConsumer):
kl.append(appox_kl.mean())
else:
per_token_kl = 0.0
kl.append(0.0)
kl.append(torch.zeros(1, device=action_log_probs.device))
loss, _ = self.policy_loss_fn(
action_log_probs,

View File

@ -53,6 +53,7 @@ def launch_distributed(
eval_dataset_config: Optional[Dict[str, Any]] = None,
eval_interval: int = 100,
eval_save_dir: Optional[str] = None,
response_format_tags: Dict[Any] = None,
):
if core_algo not in ALGO_MAP:
@ -65,13 +66,46 @@ def launch_distributed(
dataset_path = train_dataset_config["path"]
num_samples = get_jsonl_size_fast(dataset_path)
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
global_inference_batch_size = inference_batch_size * num_producers
num_update_per_episode = num_samples // global_inference_batch_size
num_recv_per_update = inference_batch_size // inference_microbatch_size
procs = []
# Attention: Ray use complex schedualing method that consider various factors including load-balancing.
# when requesting resources, it is not guaranteed that the resource comes from a node with lower node it
# this go against the design principle of our implementation, and we need to manually force the schedualing,
# allocating the producer to nodes with lower node id and the consumer to the resouces from nodes with higher
# node id. See the reference here: https://docs.ray.io/en/latest/ray-core/scheduling/index.html#nodeaffinityschedulingstrategy
nodes = ray.nodes()
node_info = {
node["NodeID"]: {
"num_gpus": node["Resources"].get("GPU", 0),
"address": node["NodeManagerAddress"],
} # Default to 0 if no GPUs are available
for node in nodes
}
gpu_to_node_id = []
gpu_to_ip_address = []
for node_id in node_info:
for idx in range(int(node_info[node_id]["num_gpus"])):
gpu_to_node_id.append(node_id)
gpu_to_ip_address.append(node_info[node_id]["address"])
print(node_info)
producer_procs = []
for i in range(num_producers):
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
node_id = gpu_to_node_id[0]
producer_ip_address = gpu_to_ip_address[0]
for _ in range(num_proc_per_producer):
gpu_to_node_id.pop(0)
gpu_to_ip_address.pop(0)
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
producer = SimpleProducer.options(
num_gpus=num_proc_per_producer,
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
),
).remote(
producer_idx=i,
num_producers=num_producers,
num_consumer_procs=num_consumer_procs,
@ -91,20 +125,35 @@ def launch_distributed(
evaluation_function_type=grpo_config["reward_fn_type"],
eval_save_dir=eval_save_dir,
)
procs.append(producer)
producer_procs.append(producer)
ray.get([p.setup.remote() for p in producer_procs])
generate_config_consumer = copy.deepcopy(generate_config)
generate_config_consumer.update(
dict(
backend=inference_backend,
)
)
consumer_master_ip_address = gpu_to_ip_address[0]
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
consumer_procs = []
for i in range(num_consumer_procs):
consumer = core_consumer.options(num_gpus=1).remote(
node_id = gpu_to_node_id[0]
consumer_ip_address = gpu_to_ip_address[0]
gpu_to_node_id.pop(0)
gpu_to_ip_address.pop(0)
print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}")
consumer = core_consumer.options(
num_gpus=1,
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
),
).remote(
num_producers=num_producers,
num_episodes=num_episodes,
rank=i,
world_size=num_consumer_procs,
master_addr=master_addr,
master_addr=consumer_master_ip_address,
master_port=master_port,
num_update_per_episode=num_update_per_episode,
num_recv_per_update=num_recv_per_update,
@ -119,7 +168,8 @@ def launch_distributed(
save_interval=save_interval,
save_dir=save_dir,
eval_interval=eval_interval,
response_format_tags=response_format_tags,
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])
ray.get([p.loop.remote() for p in procs])
consumer_procs.append(consumer)
ray.get([p.setup.remote() for p in consumer_procs])
ray.get([p.loop.remote() for p in (producer_procs + consumer_procs)])

View File

@ -112,6 +112,8 @@ class BaseProducer:
drop_last=False,
seed=42,
),
num_workers=4,
drop_last=False,
)
if evaluation_function_type == "think_answer_tags":
self.evaluation_function = math_reward_fn
@ -166,8 +168,8 @@ class BaseProducer:
)
eval_results = []
eval_statistics[eval_task_name] = torch.zeros(2, device=self.device)
for eval_batch in tqdm.tqdm(
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
for eval_batch_id, eval_batch in tqdm.tqdm(
enumerate(self.eval_dataloaders[eval_task_name]), desc=f"Evaluating: {eval_task_name}"
):
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
eval_results = eval_results + [
@ -190,9 +192,7 @@ class BaseProducer:
self.eval_save_dir,
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
)
# delete the file if it exists
safe_write_jsonl(result_file_name, eval_results)
print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
eval_statistics["consumer_global_step"] = torch.tensor(
[self.consumer_global_step], device=self.device
)
@ -230,6 +230,8 @@ class BaseProducer:
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
)
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
else:
print(
@ -308,7 +310,10 @@ class SimpleProducer(BaseProducer):
def rollout(self, input_ids, attention_mask, **kwargs):
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
if self.producer_idx == 1:
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
print(
"Truncated rollout example:\n",
self.tokenizer.decode(rollouts["input_ids"][0][0][-5000:], skip_special_tokens=True),
)
return rollouts

View File

@ -14,6 +14,10 @@ def verify_math_representation(completion, gt_answer):
"""
Verify if the completion is a valid math representation of the gt_answer.
"""
if not completion.startswith("\\boxed{"):
completion = "\\boxed{" + completion + "}"
if not gt_answer.startswith("\\boxed{"):
gt_answer = "\\boxed{" + gt_answer + "}"
target = (
ExprExtractionConfig(),
LatexExtractionConfig(
@ -59,7 +63,7 @@ def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, rew
if math_verify_result == CANNOT_PARSE_GT_ANSWER:
# plain text answer cannot be parsed, but is correct
reward += acc_score
else:
elif math_verify_result == CANNOT_PARSE_PREDICTION:
reward += (
acc_score / 2
) # not a valid latex math representation, but the answer is correct, receive half of the score
@ -140,9 +144,15 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
final_answer = extract_boxed_solution(decoded_final_answer)
format_valid = final_answer is not None
if "tags" in kwargs:
tags = kwargs["tags"]
format_valid = format_valid and all(
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
)
# Check format accuracy
if format_valid:
format_acc += 1

View File

@ -6,6 +6,13 @@ import ray
import torch
from coati.distributed.launch import launch_distributed
DEFAULT_RESPONSE_FORMAT_TAGS = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
@ -65,10 +72,22 @@ if __name__ == "__main__":
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
)
parser.add_argument(
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
"--master_address",
type=str,
default=None,
help="Ray master address for multi-node distributed training, Optional",
)
parser.add_argument(
"--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional"
"--torch_ddp_master_address",
type=str,
default=None,
help="Torch DDP master address for multi-node distributed training, Optional",
)
parser.add_argument(
"--torch_ddp_master_port",
type=int,
default=29505,
help="Torch DDP master port for multi-node distributed training, Optional",
)
# Sampling parameters
@ -105,6 +124,9 @@ if __name__ == "__main__":
help="Reward type for GRPO.",
)
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
parser.add_argument(
"-rft", "--reponse_format-tags", type=str, default=None, help="Optional json string of the response format tag"
)
# Logging/Checkpointing parameters
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
@ -236,17 +258,17 @@ if __name__ == "__main__":
"zero_stage": 2,
}, # for zero
# plugin_config={
# "tp_size": 1,
# "pp_size": 2,
# "microbatch_size": max(
# 1, args.train_microbatch_size // 2
# ), # microbatch size should be set to train_microbatch_size // pp_size
# "zero_stage": 0,
# "max_norm": 1.0,
# "tp_size": 4,
# "pp_size": 2,
# "microbatch_size": max(
# 1, args.train_microbatch_size // 2
# ), # microbatch size should be set to train_microbatch_size // pp_size
# "zero_stage": 1,
# "max_norm": 1.0,
# }, # for pp, tp
inference_backend=args.backend,
master_addr="localhost",
master_port=args.master_port,
master_addr=args.torch_ddp_master_address,
master_port=args.torch_ddp_master_port,
core_algo=args.algo,
project_name=args.project,
save_interval=args.save_interval,
@ -257,4 +279,9 @@ if __name__ == "__main__":
},
eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
response_format_tags=(
json.loads(args.response_format_tags)
if args.response_format_tags is not None
else DEFAULT_RESPONSE_FORMAT_TAGS
),
)