mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-16 08:27:44 +00:00
fix schedualing for multi-node training
This commit is contained in:
parent
d06042b434
commit
7d658402da
@ -358,9 +358,6 @@ def apply_chat_template_and_mask(
|
||||
ignore_idx: int = -100,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
|
||||
if system_prompt is None:
|
||||
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"
|
||||
|
||||
system_element = {
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
@ -375,7 +372,9 @@ def apply_chat_template_and_mask(
|
||||
tokens = []
|
||||
assistant_mask = []
|
||||
for i, msg in enumerate(chat):
|
||||
msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True)
|
||||
msg_tokens = tokenizer.apply_chat_template(
|
||||
[system_element, msg] if system_prompt else [msg], tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
# remove unexpected bos token
|
||||
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
|
||||
msg_tokens = msg_tokens[1:]
|
||||
|
@ -115,10 +115,10 @@ class BaseConsumer:
|
||||
eval_statistics = None
|
||||
eval_global_step = None
|
||||
for r in range(self.num_producers):
|
||||
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
|
||||
local_eval_result = ray_broadcast_tensor_dict(
|
||||
None, src=0, device=self.device, group_name=f"sync_data_{r}"
|
||||
)
|
||||
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
|
||||
assert "consumer_global_step" in local_eval_result
|
||||
eval_global_step = local_eval_result.pop("consumer_global_step").item()
|
||||
if eval_statistics is None:
|
||||
|
@ -41,6 +41,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
save_interval: int = 100,
|
||||
save_dir="./model",
|
||||
eval_interval: int = -1,
|
||||
response_format_tags: dict = None,
|
||||
):
|
||||
print(f"Using GRPO config: {grpo_config}")
|
||||
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
||||
@ -125,12 +126,6 @@ class GRPOConsumer(BaseConsumer):
|
||||
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
||||
)
|
||||
# Initialize verifiable reward.
|
||||
response_format_tags = {
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
reward_model_kwargs = {
|
||||
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
|
||||
}
|
||||
@ -377,7 +372,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
kl.append(appox_kl.mean())
|
||||
else:
|
||||
per_token_kl = 0.0
|
||||
kl.append(0.0)
|
||||
kl.append(torch.zeros(1, device=action_log_probs.device))
|
||||
|
||||
loss, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
|
@ -53,6 +53,7 @@ def launch_distributed(
|
||||
eval_dataset_config: Optional[Dict[str, Any]] = None,
|
||||
eval_interval: int = 100,
|
||||
eval_save_dir: Optional[str] = None,
|
||||
response_format_tags: Dict[Any] = None,
|
||||
):
|
||||
|
||||
if core_algo not in ALGO_MAP:
|
||||
@ -65,13 +66,46 @@ def launch_distributed(
|
||||
|
||||
dataset_path = train_dataset_config["path"]
|
||||
num_samples = get_jsonl_size_fast(dataset_path)
|
||||
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
|
||||
global_inference_batch_size = inference_batch_size * num_producers
|
||||
num_update_per_episode = num_samples // global_inference_batch_size
|
||||
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
||||
|
||||
procs = []
|
||||
# Attention: Ray use complex schedualing method that consider various factors including load-balancing.
|
||||
# when requesting resources, it is not guaranteed that the resource comes from a node with lower node it
|
||||
# this go against the design principle of our implementation, and we need to manually force the schedualing,
|
||||
# allocating the producer to nodes with lower node id and the consumer to the resouces from nodes with higher
|
||||
# node id. See the reference here: https://docs.ray.io/en/latest/ray-core/scheduling/index.html#nodeaffinityschedulingstrategy
|
||||
nodes = ray.nodes()
|
||||
node_info = {
|
||||
node["NodeID"]: {
|
||||
"num_gpus": node["Resources"].get("GPU", 0),
|
||||
"address": node["NodeManagerAddress"],
|
||||
} # Default to 0 if no GPUs are available
|
||||
for node in nodes
|
||||
}
|
||||
gpu_to_node_id = []
|
||||
gpu_to_ip_address = []
|
||||
for node_id in node_info:
|
||||
for idx in range(int(node_info[node_id]["num_gpus"])):
|
||||
gpu_to_node_id.append(node_id)
|
||||
gpu_to_ip_address.append(node_info[node_id]["address"])
|
||||
print(node_info)
|
||||
|
||||
producer_procs = []
|
||||
for i in range(num_producers):
|
||||
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
|
||||
node_id = gpu_to_node_id[0]
|
||||
producer_ip_address = gpu_to_ip_address[0]
|
||||
for _ in range(num_proc_per_producer):
|
||||
gpu_to_node_id.pop(0)
|
||||
gpu_to_ip_address.pop(0)
|
||||
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
|
||||
producer = SimpleProducer.options(
|
||||
num_gpus=num_proc_per_producer,
|
||||
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
|
||||
node_id=node_id,
|
||||
soft=False,
|
||||
),
|
||||
).remote(
|
||||
producer_idx=i,
|
||||
num_producers=num_producers,
|
||||
num_consumer_procs=num_consumer_procs,
|
||||
@ -91,20 +125,35 @@ def launch_distributed(
|
||||
evaluation_function_type=grpo_config["reward_fn_type"],
|
||||
eval_save_dir=eval_save_dir,
|
||||
)
|
||||
procs.append(producer)
|
||||
producer_procs.append(producer)
|
||||
ray.get([p.setup.remote() for p in producer_procs])
|
||||
generate_config_consumer = copy.deepcopy(generate_config)
|
||||
generate_config_consumer.update(
|
||||
dict(
|
||||
backend=inference_backend,
|
||||
)
|
||||
)
|
||||
consumer_master_ip_address = gpu_to_ip_address[0]
|
||||
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
|
||||
consumer_procs = []
|
||||
for i in range(num_consumer_procs):
|
||||
consumer = core_consumer.options(num_gpus=1).remote(
|
||||
node_id = gpu_to_node_id[0]
|
||||
consumer_ip_address = gpu_to_ip_address[0]
|
||||
gpu_to_node_id.pop(0)
|
||||
gpu_to_ip_address.pop(0)
|
||||
print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}")
|
||||
consumer = core_consumer.options(
|
||||
num_gpus=1,
|
||||
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
|
||||
node_id=node_id,
|
||||
soft=False,
|
||||
),
|
||||
).remote(
|
||||
num_producers=num_producers,
|
||||
num_episodes=num_episodes,
|
||||
rank=i,
|
||||
world_size=num_consumer_procs,
|
||||
master_addr=master_addr,
|
||||
master_addr=consumer_master_ip_address,
|
||||
master_port=master_port,
|
||||
num_update_per_episode=num_update_per_episode,
|
||||
num_recv_per_update=num_recv_per_update,
|
||||
@ -119,7 +168,8 @@ def launch_distributed(
|
||||
save_interval=save_interval,
|
||||
save_dir=save_dir,
|
||||
eval_interval=eval_interval,
|
||||
response_format_tags=response_format_tags,
|
||||
)
|
||||
procs.append(consumer)
|
||||
ray.get([p.setup.remote() for p in procs])
|
||||
ray.get([p.loop.remote() for p in procs])
|
||||
consumer_procs.append(consumer)
|
||||
ray.get([p.setup.remote() for p in consumer_procs])
|
||||
ray.get([p.loop.remote() for p in (producer_procs + consumer_procs)])
|
||||
|
@ -112,6 +112,8 @@ class BaseProducer:
|
||||
drop_last=False,
|
||||
seed=42,
|
||||
),
|
||||
num_workers=4,
|
||||
drop_last=False,
|
||||
)
|
||||
if evaluation_function_type == "think_answer_tags":
|
||||
self.evaluation_function = math_reward_fn
|
||||
@ -166,8 +168,8 @@ class BaseProducer:
|
||||
)
|
||||
eval_results = []
|
||||
eval_statistics[eval_task_name] = torch.zeros(2, device=self.device)
|
||||
for eval_batch in tqdm.tqdm(
|
||||
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
|
||||
for eval_batch_id, eval_batch in tqdm.tqdm(
|
||||
enumerate(self.eval_dataloaders[eval_task_name]), desc=f"Evaluating: {eval_task_name}"
|
||||
):
|
||||
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
|
||||
eval_results = eval_results + [
|
||||
@ -190,9 +192,7 @@ class BaseProducer:
|
||||
self.eval_save_dir,
|
||||
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
|
||||
)
|
||||
# delete the file if it exists
|
||||
safe_write_jsonl(result_file_name, eval_results)
|
||||
print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
|
||||
eval_statistics["consumer_global_step"] = torch.tensor(
|
||||
[self.consumer_global_step], device=self.device
|
||||
)
|
||||
@ -230,6 +230,8 @@ class BaseProducer:
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
|
||||
)
|
||||
if "consumer_global_step" in state_dict:
|
||||
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||
self.load_state_dict(state_dict)
|
||||
else:
|
||||
print(
|
||||
@ -308,7 +310,10 @@ class SimpleProducer(BaseProducer):
|
||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
||||
if self.producer_idx == 1:
|
||||
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
|
||||
print(
|
||||
"Truncated rollout example:\n",
|
||||
self.tokenizer.decode(rollouts["input_ids"][0][0][-5000:], skip_special_tokens=True),
|
||||
)
|
||||
|
||||
return rollouts
|
||||
|
||||
|
@ -14,6 +14,10 @@ def verify_math_representation(completion, gt_answer):
|
||||
"""
|
||||
Verify if the completion is a valid math representation of the gt_answer.
|
||||
"""
|
||||
if not completion.startswith("\\boxed{"):
|
||||
completion = "\\boxed{" + completion + "}"
|
||||
if not gt_answer.startswith("\\boxed{"):
|
||||
gt_answer = "\\boxed{" + gt_answer + "}"
|
||||
target = (
|
||||
ExprExtractionConfig(),
|
||||
LatexExtractionConfig(
|
||||
@ -59,7 +63,7 @@ def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, rew
|
||||
if math_verify_result == CANNOT_PARSE_GT_ANSWER:
|
||||
# plain text answer cannot be parsed, but is correct
|
||||
reward += acc_score
|
||||
else:
|
||||
elif math_verify_result == CANNOT_PARSE_PREDICTION:
|
||||
reward += (
|
||||
acc_score / 2
|
||||
) # not a valid latex math representation, but the answer is correct, receive half of the score
|
||||
@ -140,9 +144,15 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||
|
||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||
final_answer = extract_boxed_solution(decoded_final_answer)
|
||||
format_valid = final_answer is not None
|
||||
if "tags" in kwargs:
|
||||
tags = kwargs["tags"]
|
||||
format_valid = format_valid and all(
|
||||
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
|
||||
)
|
||||
# Check format accuracy
|
||||
if format_valid:
|
||||
format_acc += 1
|
||||
|
@ -6,6 +6,13 @@ import ray
|
||||
import torch
|
||||
from coati.distributed.launch import launch_distributed
|
||||
|
||||
DEFAULT_RESPONSE_FORMAT_TAGS = {
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||
@ -65,10 +72,22 @@ if __name__ == "__main__":
|
||||
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
|
||||
"--master_address",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Ray master address for multi-node distributed training, Optional",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional"
|
||||
"--torch_ddp_master_address",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Torch DDP master address for multi-node distributed training, Optional",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--torch_ddp_master_port",
|
||||
type=int,
|
||||
default=29505,
|
||||
help="Torch DDP master port for multi-node distributed training, Optional",
|
||||
)
|
||||
|
||||
# Sampling parameters
|
||||
@ -105,6 +124,9 @@ if __name__ == "__main__":
|
||||
help="Reward type for GRPO.",
|
||||
)
|
||||
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
|
||||
parser.add_argument(
|
||||
"-rft", "--reponse_format-tags", type=str, default=None, help="Optional json string of the response format tag"
|
||||
)
|
||||
|
||||
# Logging/Checkpointing parameters
|
||||
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
||||
@ -236,17 +258,17 @@ if __name__ == "__main__":
|
||||
"zero_stage": 2,
|
||||
}, # for zero
|
||||
# plugin_config={
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 2,
|
||||
# "microbatch_size": max(
|
||||
# 1, args.train_microbatch_size // 2
|
||||
# ), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
# "zero_stage": 0,
|
||||
# "max_norm": 1.0,
|
||||
# "tp_size": 4,
|
||||
# "pp_size": 2,
|
||||
# "microbatch_size": max(
|
||||
# 1, args.train_microbatch_size // 2
|
||||
# ), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
# "zero_stage": 1,
|
||||
# "max_norm": 1.0,
|
||||
# }, # for pp, tp
|
||||
inference_backend=args.backend,
|
||||
master_addr="localhost",
|
||||
master_port=args.master_port,
|
||||
master_addr=args.torch_ddp_master_address,
|
||||
master_port=args.torch_ddp_master_port,
|
||||
core_algo=args.algo,
|
||||
project_name=args.project,
|
||||
save_interval=args.save_interval,
|
||||
@ -257,4 +279,9 @@ if __name__ == "__main__":
|
||||
},
|
||||
eval_interval=args.eval_interval,
|
||||
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
||||
response_format_tags=(
|
||||
json.loads(args.response_format_tags)
|
||||
if args.response_format_tags is not None
|
||||
else DEFAULT_RESPONSE_FORMAT_TAGS
|
||||
),
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user