fix schedualing for multi-node training

This commit is contained in:
YeAnbang 2025-05-02 19:45:07 +08:00
parent d06042b434
commit 7d658402da
7 changed files with 124 additions and 38 deletions

View File

@ -358,9 +358,6 @@ def apply_chat_template_and_mask(
ignore_idx: int = -100, ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
if system_prompt is None:
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"
system_element = { system_element = {
"role": "system", "role": "system",
"content": system_prompt, "content": system_prompt,
@ -375,7 +372,9 @@ def apply_chat_template_and_mask(
tokens = [] tokens = []
assistant_mask = [] assistant_mask = []
for i, msg in enumerate(chat): for i, msg in enumerate(chat):
msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True) msg_tokens = tokenizer.apply_chat_template(
[system_element, msg] if system_prompt else [msg], tokenize=True, add_generation_prompt=True
)
# remove unexpected bos token # remove unexpected bos token
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id: if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
msg_tokens = msg_tokens[1:] msg_tokens = msg_tokens[1:]

View File

@ -115,10 +115,10 @@ class BaseConsumer:
eval_statistics = None eval_statistics = None
eval_global_step = None eval_global_step = None
for r in range(self.num_producers): for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
local_eval_result = ray_broadcast_tensor_dict( local_eval_result = ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_data_{r}" None, src=0, device=self.device, group_name=f"sync_data_{r}"
) )
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
assert "consumer_global_step" in local_eval_result assert "consumer_global_step" in local_eval_result
eval_global_step = local_eval_result.pop("consumer_global_step").item() eval_global_step = local_eval_result.pop("consumer_global_step").item()
if eval_statistics is None: if eval_statistics is None:

View File

@ -41,6 +41,7 @@ class GRPOConsumer(BaseConsumer):
save_interval: int = 100, save_interval: int = 100,
save_dir="./model", save_dir="./model",
eval_interval: int = -1, eval_interval: int = -1,
response_format_tags: dict = None,
): ):
print(f"Using GRPO config: {grpo_config}") print(f"Using GRPO config: {grpo_config}")
if grpo_config.get("loss_variation", "sample_level") == "token_level": if grpo_config.get("loss_variation", "sample_level") == "token_level":
@ -125,12 +126,6 @@ class GRPOConsumer(BaseConsumer):
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config." "either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
) )
# Initialize verifiable reward. # Initialize verifiable reward.
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
reward_model_kwargs = { reward_model_kwargs = {
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"] k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
} }
@ -377,7 +372,7 @@ class GRPOConsumer(BaseConsumer):
kl.append(appox_kl.mean()) kl.append(appox_kl.mean())
else: else:
per_token_kl = 0.0 per_token_kl = 0.0
kl.append(0.0) kl.append(torch.zeros(1, device=action_log_probs.device))
loss, _ = self.policy_loss_fn( loss, _ = self.policy_loss_fn(
action_log_probs, action_log_probs,

View File

@ -53,6 +53,7 @@ def launch_distributed(
eval_dataset_config: Optional[Dict[str, Any]] = None, eval_dataset_config: Optional[Dict[str, Any]] = None,
eval_interval: int = 100, eval_interval: int = 100,
eval_save_dir: Optional[str] = None, eval_save_dir: Optional[str] = None,
response_format_tags: Dict[Any] = None,
): ):
if core_algo not in ALGO_MAP: if core_algo not in ALGO_MAP:
@ -65,13 +66,46 @@ def launch_distributed(
dataset_path = train_dataset_config["path"] dataset_path = train_dataset_config["path"]
num_samples = get_jsonl_size_fast(dataset_path) num_samples = get_jsonl_size_fast(dataset_path)
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer global_inference_batch_size = inference_batch_size * num_producers
num_update_per_episode = num_samples // global_inference_batch_size num_update_per_episode = num_samples // global_inference_batch_size
num_recv_per_update = inference_batch_size // inference_microbatch_size num_recv_per_update = inference_batch_size // inference_microbatch_size
procs = [] # Attention: Ray use complex schedualing method that consider various factors including load-balancing.
# when requesting resources, it is not guaranteed that the resource comes from a node with lower node it
# this go against the design principle of our implementation, and we need to manually force the schedualing,
# allocating the producer to nodes with lower node id and the consumer to the resouces from nodes with higher
# node id. See the reference here: https://docs.ray.io/en/latest/ray-core/scheduling/index.html#nodeaffinityschedulingstrategy
nodes = ray.nodes()
node_info = {
node["NodeID"]: {
"num_gpus": node["Resources"].get("GPU", 0),
"address": node["NodeManagerAddress"],
} # Default to 0 if no GPUs are available
for node in nodes
}
gpu_to_node_id = []
gpu_to_ip_address = []
for node_id in node_info:
for idx in range(int(node_info[node_id]["num_gpus"])):
gpu_to_node_id.append(node_id)
gpu_to_ip_address.append(node_info[node_id]["address"])
print(node_info)
producer_procs = []
for i in range(num_producers): for i in range(num_producers):
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote( node_id = gpu_to_node_id[0]
producer_ip_address = gpu_to_ip_address[0]
for _ in range(num_proc_per_producer):
gpu_to_node_id.pop(0)
gpu_to_ip_address.pop(0)
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
producer = SimpleProducer.options(
num_gpus=num_proc_per_producer,
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
),
).remote(
producer_idx=i, producer_idx=i,
num_producers=num_producers, num_producers=num_producers,
num_consumer_procs=num_consumer_procs, num_consumer_procs=num_consumer_procs,
@ -91,20 +125,35 @@ def launch_distributed(
evaluation_function_type=grpo_config["reward_fn_type"], evaluation_function_type=grpo_config["reward_fn_type"],
eval_save_dir=eval_save_dir, eval_save_dir=eval_save_dir,
) )
procs.append(producer) producer_procs.append(producer)
ray.get([p.setup.remote() for p in producer_procs])
generate_config_consumer = copy.deepcopy(generate_config) generate_config_consumer = copy.deepcopy(generate_config)
generate_config_consumer.update( generate_config_consumer.update(
dict( dict(
backend=inference_backend, backend=inference_backend,
) )
) )
consumer_master_ip_address = gpu_to_ip_address[0]
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
consumer_procs = []
for i in range(num_consumer_procs): for i in range(num_consumer_procs):
consumer = core_consumer.options(num_gpus=1).remote( node_id = gpu_to_node_id[0]
consumer_ip_address = gpu_to_ip_address[0]
gpu_to_node_id.pop(0)
gpu_to_ip_address.pop(0)
print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}")
consumer = core_consumer.options(
num_gpus=1,
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
),
).remote(
num_producers=num_producers, num_producers=num_producers,
num_episodes=num_episodes, num_episodes=num_episodes,
rank=i, rank=i,
world_size=num_consumer_procs, world_size=num_consumer_procs,
master_addr=master_addr, master_addr=consumer_master_ip_address,
master_port=master_port, master_port=master_port,
num_update_per_episode=num_update_per_episode, num_update_per_episode=num_update_per_episode,
num_recv_per_update=num_recv_per_update, num_recv_per_update=num_recv_per_update,
@ -119,7 +168,8 @@ def launch_distributed(
save_interval=save_interval, save_interval=save_interval,
save_dir=save_dir, save_dir=save_dir,
eval_interval=eval_interval, eval_interval=eval_interval,
response_format_tags=response_format_tags,
) )
procs.append(consumer) consumer_procs.append(consumer)
ray.get([p.setup.remote() for p in procs]) ray.get([p.setup.remote() for p in consumer_procs])
ray.get([p.loop.remote() for p in procs]) ray.get([p.loop.remote() for p in (producer_procs + consumer_procs)])

View File

@ -112,6 +112,8 @@ class BaseProducer:
drop_last=False, drop_last=False,
seed=42, seed=42,
), ),
num_workers=4,
drop_last=False,
) )
if evaluation_function_type == "think_answer_tags": if evaluation_function_type == "think_answer_tags":
self.evaluation_function = math_reward_fn self.evaluation_function = math_reward_fn
@ -166,8 +168,8 @@ class BaseProducer:
) )
eval_results = [] eval_results = []
eval_statistics[eval_task_name] = torch.zeros(2, device=self.device) eval_statistics[eval_task_name] = torch.zeros(2, device=self.device)
for eval_batch in tqdm.tqdm( for eval_batch_id, eval_batch in tqdm.tqdm(
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0 enumerate(self.eval_dataloaders[eval_task_name]), desc=f"Evaluating: {eval_task_name}"
): ):
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params) eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
eval_results = eval_results + [ eval_results = eval_results + [
@ -190,9 +192,7 @@ class BaseProducer:
self.eval_save_dir, self.eval_save_dir,
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl", f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
) )
# delete the file if it exists
safe_write_jsonl(result_file_name, eval_results) safe_write_jsonl(result_file_name, eval_results)
print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
eval_statistics["consumer_global_step"] = torch.tensor( eval_statistics["consumer_global_step"] = torch.tensor(
[self.consumer_global_step], device=self.device [self.consumer_global_step], device=self.device
) )
@ -230,6 +230,8 @@ class BaseProducer:
state_dict = ray_broadcast_tensor_dict( state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}" None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
) )
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict) self.load_state_dict(state_dict)
else: else:
print( print(
@ -308,7 +310,10 @@ class SimpleProducer(BaseProducer):
def rollout(self, input_ids, attention_mask, **kwargs): def rollout(self, input_ids, attention_mask, **kwargs):
rollouts = self.model.generate(input_ids, attention_mask, **kwargs) rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
if self.producer_idx == 1: if self.producer_idx == 1:
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) print(
"Truncated rollout example:\n",
self.tokenizer.decode(rollouts["input_ids"][0][0][-5000:], skip_special_tokens=True),
)
return rollouts return rollouts

View File

@ -14,6 +14,10 @@ def verify_math_representation(completion, gt_answer):
""" """
Verify if the completion is a valid math representation of the gt_answer. Verify if the completion is a valid math representation of the gt_answer.
""" """
if not completion.startswith("\\boxed{"):
completion = "\\boxed{" + completion + "}"
if not gt_answer.startswith("\\boxed{"):
gt_answer = "\\boxed{" + gt_answer + "}"
target = ( target = (
ExprExtractionConfig(), ExprExtractionConfig(),
LatexExtractionConfig( LatexExtractionConfig(
@ -59,7 +63,7 @@ def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, rew
if math_verify_result == CANNOT_PARSE_GT_ANSWER: if math_verify_result == CANNOT_PARSE_GT_ANSWER:
# plain text answer cannot be parsed, but is correct # plain text answer cannot be parsed, but is correct
reward += acc_score reward += acc_score
else: elif math_verify_result == CANNOT_PARSE_PREDICTION:
reward += ( reward += (
acc_score / 2 acc_score / 2
) # not a valid latex math representation, but the answer is correct, receive half of the score ) # not a valid latex math representation, but the answer is correct, receive half of the score
@ -140,9 +144,15 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
final_answer = extract_boxed_solution(decoded_final_answer) final_answer = extract_boxed_solution(decoded_final_answer)
format_valid = final_answer is not None format_valid = final_answer is not None
if "tags" in kwargs:
tags = kwargs["tags"]
format_valid = format_valid and all(
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
)
# Check format accuracy # Check format accuracy
if format_valid: if format_valid:
format_acc += 1 format_acc += 1

View File

@ -6,6 +6,13 @@ import ray
import torch import torch
from coati.distributed.launch import launch_distributed from coati.distributed.launch import launch_distributed
DEFAULT_RESPONSE_FORMAT_TAGS = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
@ -65,10 +72,22 @@ if __name__ == "__main__":
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional" "--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
) )
parser.add_argument( parser.add_argument(
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional" "--master_address",
type=str,
default=None,
help="Ray master address for multi-node distributed training, Optional",
) )
parser.add_argument( parser.add_argument(
"--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional" "--torch_ddp_master_address",
type=str,
default=None,
help="Torch DDP master address for multi-node distributed training, Optional",
)
parser.add_argument(
"--torch_ddp_master_port",
type=int,
default=29505,
help="Torch DDP master port for multi-node distributed training, Optional",
) )
# Sampling parameters # Sampling parameters
@ -105,6 +124,9 @@ if __name__ == "__main__":
help="Reward type for GRPO.", help="Reward type for GRPO.",
) )
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.") parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
parser.add_argument(
"-rft", "--reponse_format-tags", type=str, default=None, help="Optional json string of the response format tag"
)
# Logging/Checkpointing parameters # Logging/Checkpointing parameters
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.") parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
@ -236,17 +258,17 @@ if __name__ == "__main__":
"zero_stage": 2, "zero_stage": 2,
}, # for zero }, # for zero
# plugin_config={ # plugin_config={
# "tp_size": 1, # "tp_size": 4,
# "pp_size": 2, # "pp_size": 2,
# "microbatch_size": max( # "microbatch_size": max(
# 1, args.train_microbatch_size // 2 # 1, args.train_microbatch_size // 2
# ), # microbatch size should be set to train_microbatch_size // pp_size # ), # microbatch size should be set to train_microbatch_size // pp_size
# "zero_stage": 0, # "zero_stage": 1,
# "max_norm": 1.0, # "max_norm": 1.0,
# }, # for pp, tp # }, # for pp, tp
inference_backend=args.backend, inference_backend=args.backend,
master_addr="localhost", master_addr=args.torch_ddp_master_address,
master_port=args.master_port, master_port=args.torch_ddp_master_port,
core_algo=args.algo, core_algo=args.algo,
project_name=args.project, project_name=args.project,
save_interval=args.save_interval, save_interval=args.save_interval,
@ -257,4 +279,9 @@ if __name__ == "__main__":
}, },
eval_interval=args.eval_interval, eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
response_format_tags=(
json.loads(args.response_format_tags)
if args.response_format_tags is not None
else DEFAULT_RESPONSE_FORMAT_TAGS
),
) )