mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-31 07:18:59 +00:00
Merge d4a6b6c4a7
into 17928ad84f
This commit is contained in:
commit
d369d0d1ea
@ -358,9 +358,6 @@ def apply_chat_template_and_mask(
|
||||
ignore_idx: int = -100,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
|
||||
if system_prompt is None:
|
||||
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"
|
||||
|
||||
system_element = {
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
@ -375,7 +372,9 @@ def apply_chat_template_and_mask(
|
||||
tokens = []
|
||||
assistant_mask = []
|
||||
for i, msg in enumerate(chat):
|
||||
msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True)
|
||||
msg_tokens = tokenizer.apply_chat_template(
|
||||
[system_element, msg] if system_prompt else [msg], tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
# remove unexpected bos token
|
||||
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
|
||||
msg_tokens = msg_tokens[1:]
|
||||
|
@ -115,10 +115,10 @@ class BaseConsumer:
|
||||
eval_statistics = None
|
||||
eval_global_step = None
|
||||
for r in range(self.num_producers):
|
||||
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
|
||||
local_eval_result = ray_broadcast_tensor_dict(
|
||||
None, src=0, device=self.device, group_name=f"sync_data_{r}"
|
||||
)
|
||||
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
|
||||
assert "consumer_global_step" in local_eval_result
|
||||
eval_global_step = local_eval_result.pop("consumer_global_step").item()
|
||||
if eval_statistics is None:
|
||||
@ -128,9 +128,8 @@ class BaseConsumer:
|
||||
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
|
||||
}
|
||||
eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
|
||||
if dist.get_rank() == 0:
|
||||
if hasattr(self, "wandb_run"):
|
||||
self.wandb_run.log(eval_statistics, step=eval_global_step)
|
||||
if hasattr(self, "wandb_run"):
|
||||
self.wandb_run.log(eval_statistics, step=eval_global_step)
|
||||
print(f"Eval statistics: {eval_statistics}")
|
||||
for _ in range(self.num_recv_per_update):
|
||||
# receive data from producers
|
||||
|
@ -41,6 +41,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
save_interval: int = 100,
|
||||
save_dir="./model",
|
||||
eval_interval: int = -1,
|
||||
response_format_tags: dict = None,
|
||||
):
|
||||
print(f"Using GRPO config: {grpo_config}")
|
||||
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
||||
@ -125,12 +126,6 @@ class GRPOConsumer(BaseConsumer):
|
||||
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
||||
)
|
||||
# Initialize verifiable reward.
|
||||
response_format_tags = {
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
reward_model_kwargs = {
|
||||
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
|
||||
}
|
||||
@ -154,9 +149,13 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
if self.use_wandb and (
|
||||
(not self.plugin.pp_size > 1 and self.rank == 0)
|
||||
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
|
||||
if (
|
||||
self.use_wandb
|
||||
and self.dp_rank == 0
|
||||
and (
|
||||
(not self.plugin.pp_size > 1 and self.rank == 0)
|
||||
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
|
||||
)
|
||||
):
|
||||
# Initialize wandb.
|
||||
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
|
||||
@ -377,7 +376,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
kl.append(appox_kl.mean())
|
||||
else:
|
||||
per_token_kl = 0.0
|
||||
kl.append(0.0)
|
||||
kl.append(torch.zeros(1, device=action_log_probs.device))
|
||||
|
||||
loss, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
@ -487,8 +486,13 @@ class GRPOConsumer(BaseConsumer):
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
if self.dp_rank == 0 and (
|
||||
(not self.plugin.pp_size > 1 and self.rank == 0)
|
||||
or (
|
||||
self.plugin.pp_size > 1
|
||||
and self.booster.plugin.stage_manager.is_last_stage()
|
||||
and self.tp_rank == 0
|
||||
)
|
||||
):
|
||||
to_log_msg = [
|
||||
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
|
||||
|
@ -44,7 +44,6 @@ def launch_distributed(
|
||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||
inference_backend: str = "transformers",
|
||||
num_generations: int = 8,
|
||||
master_addr: str = "localhost",
|
||||
master_port: int = 29500,
|
||||
core_algo: str = "GRPO",
|
||||
project_name: Optional[str] = None,
|
||||
@ -53,6 +52,7 @@ def launch_distributed(
|
||||
eval_dataset_config: Optional[Dict[str, Any]] = None,
|
||||
eval_interval: int = 100,
|
||||
eval_save_dir: Optional[str] = None,
|
||||
response_format_tags: Dict[str, Any] = None,
|
||||
):
|
||||
|
||||
if core_algo not in ALGO_MAP:
|
||||
@ -65,13 +65,46 @@ def launch_distributed(
|
||||
|
||||
dataset_path = train_dataset_config["path"]
|
||||
num_samples = get_jsonl_size_fast(dataset_path)
|
||||
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
|
||||
global_inference_batch_size = inference_batch_size * num_producers
|
||||
num_update_per_episode = num_samples // global_inference_batch_size
|
||||
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
||||
|
||||
procs = []
|
||||
# Attention: Ray use complex schedualing method that consider various factors including load-balancing.
|
||||
# when requesting resources, it is not guaranteed that the resource comes from a node with lower node it
|
||||
# this go against the design principle of our implementation, and we need to manually force the schedualing,
|
||||
# allocating the producer to nodes with lower node id and the consumer to the resouces from nodes with higher
|
||||
# node id. See the reference here: https://docs.ray.io/en/latest/ray-core/scheduling/index.html#nodeaffinityschedulingstrategy
|
||||
nodes = ray.nodes()
|
||||
node_info = {
|
||||
node["NodeID"]: {
|
||||
"num_gpus": node["Resources"].get("GPU", 0),
|
||||
"address": node["NodeManagerAddress"],
|
||||
} # Default to 0 if no GPUs are available
|
||||
for node in nodes
|
||||
}
|
||||
gpu_to_node_id = []
|
||||
gpu_to_ip_address = []
|
||||
for node_id in node_info:
|
||||
for idx in range(int(node_info[node_id]["num_gpus"])):
|
||||
gpu_to_node_id.append(node_id)
|
||||
gpu_to_ip_address.append(node_info[node_id]["address"])
|
||||
print(node_info)
|
||||
|
||||
producer_procs = []
|
||||
for i in range(num_producers):
|
||||
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
|
||||
node_id = gpu_to_node_id[0]
|
||||
producer_ip_address = gpu_to_ip_address[0]
|
||||
for _ in range(num_proc_per_producer):
|
||||
gpu_to_node_id.pop(0)
|
||||
gpu_to_ip_address.pop(0)
|
||||
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
|
||||
producer = SimpleProducer.options(
|
||||
num_gpus=num_proc_per_producer,
|
||||
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
|
||||
node_id=node_id,
|
||||
soft=False,
|
||||
),
|
||||
).remote(
|
||||
producer_idx=i,
|
||||
num_producers=num_producers,
|
||||
num_consumer_procs=num_consumer_procs,
|
||||
@ -91,20 +124,35 @@ def launch_distributed(
|
||||
evaluation_function_type=grpo_config["reward_fn_type"],
|
||||
eval_save_dir=eval_save_dir,
|
||||
)
|
||||
procs.append(producer)
|
||||
producer_procs.append(producer)
|
||||
ray.get([p.setup.remote() for p in producer_procs])
|
||||
generate_config_consumer = copy.deepcopy(generate_config)
|
||||
generate_config_consumer.update(
|
||||
dict(
|
||||
backend=inference_backend,
|
||||
)
|
||||
)
|
||||
consumer_master_ip_address = gpu_to_ip_address[0]
|
||||
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
|
||||
consumer_procs = []
|
||||
for i in range(num_consumer_procs):
|
||||
consumer = core_consumer.options(num_gpus=1).remote(
|
||||
node_id = gpu_to_node_id[0]
|
||||
consumer_ip_address = gpu_to_ip_address[0]
|
||||
gpu_to_node_id.pop(0)
|
||||
gpu_to_ip_address.pop(0)
|
||||
print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}")
|
||||
consumer = core_consumer.options(
|
||||
num_gpus=1,
|
||||
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
|
||||
node_id=node_id,
|
||||
soft=False,
|
||||
),
|
||||
).remote(
|
||||
num_producers=num_producers,
|
||||
num_episodes=num_episodes,
|
||||
rank=i,
|
||||
world_size=num_consumer_procs,
|
||||
master_addr=master_addr,
|
||||
master_addr=consumer_master_ip_address,
|
||||
master_port=master_port,
|
||||
num_update_per_episode=num_update_per_episode,
|
||||
num_recv_per_update=num_recv_per_update,
|
||||
@ -119,7 +167,8 @@ def launch_distributed(
|
||||
save_interval=save_interval,
|
||||
save_dir=save_dir,
|
||||
eval_interval=eval_interval,
|
||||
response_format_tags=response_format_tags,
|
||||
)
|
||||
procs.append(consumer)
|
||||
ray.get([p.setup.remote() for p in procs])
|
||||
ray.get([p.loop.remote() for p in procs])
|
||||
consumer_procs.append(consumer)
|
||||
ray.get([p.setup.remote() for p in consumer_procs])
|
||||
ray.get([p.loop.remote() for p in (producer_procs + consumer_procs)])
|
||||
|
@ -112,6 +112,8 @@ class BaseProducer:
|
||||
drop_last=False,
|
||||
seed=42,
|
||||
),
|
||||
num_workers=4,
|
||||
drop_last=False,
|
||||
)
|
||||
if evaluation_function_type == "think_answer_tags":
|
||||
self.evaluation_function = math_reward_fn
|
||||
@ -166,8 +168,8 @@ class BaseProducer:
|
||||
)
|
||||
eval_results = []
|
||||
eval_statistics[eval_task_name] = torch.zeros(2, device=self.device)
|
||||
for eval_batch in tqdm.tqdm(
|
||||
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
|
||||
for eval_batch_id, eval_batch in tqdm.tqdm(
|
||||
enumerate(self.eval_dataloaders[eval_task_name]), desc=f"Evaluating: {eval_task_name}"
|
||||
):
|
||||
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
|
||||
eval_results = eval_results + [
|
||||
@ -190,9 +192,7 @@ class BaseProducer:
|
||||
self.eval_save_dir,
|
||||
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
|
||||
)
|
||||
# delete the file if it exists
|
||||
safe_write_jsonl(result_file_name, eval_results)
|
||||
print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
|
||||
eval_statistics["consumer_global_step"] = torch.tensor(
|
||||
[self.consumer_global_step], device=self.device
|
||||
)
|
||||
@ -230,6 +230,8 @@ class BaseProducer:
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
|
||||
)
|
||||
if "consumer_global_step" in state_dict:
|
||||
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||
self.load_state_dict(state_dict)
|
||||
else:
|
||||
print(
|
||||
@ -301,14 +303,19 @@ class SimpleProducer(BaseProducer):
|
||||
)
|
||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
||||
self.eval_generation_config.update(
|
||||
{"n": 1, "temperature": 0.6, "top_p": 0.95}
|
||||
) # use 1 generation for evaluation
|
||||
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
|
||||
|
||||
@torch.no_grad()
|
||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
||||
if self.producer_idx == 1:
|
||||
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
|
||||
print(
|
||||
"Truncated rollout example:\n",
|
||||
self.tokenizer.decode(rollouts["input_ids"][0][0][-5000:], skip_special_tokens=True),
|
||||
)
|
||||
|
||||
return rollouts
|
||||
|
||||
|
@ -14,6 +14,10 @@ def verify_math_representation(completion, gt_answer):
|
||||
"""
|
||||
Verify if the completion is a valid math representation of the gt_answer.
|
||||
"""
|
||||
if not completion.startswith("\\boxed{"):
|
||||
completion = "\\boxed{" + completion + "}"
|
||||
if not gt_answer.startswith("\\boxed{"):
|
||||
gt_answer = "\\boxed{" + gt_answer + "}"
|
||||
target = (
|
||||
ExprExtractionConfig(),
|
||||
LatexExtractionConfig(
|
||||
@ -59,7 +63,7 @@ def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, rew
|
||||
if math_verify_result == CANNOT_PARSE_GT_ANSWER:
|
||||
# plain text answer cannot be parsed, but is correct
|
||||
reward += acc_score
|
||||
else:
|
||||
elif math_verify_result == CANNOT_PARSE_PREDICTION:
|
||||
reward += (
|
||||
acc_score / 2
|
||||
) # not a valid latex math representation, but the answer is correct, receive half of the score
|
||||
@ -140,9 +144,15 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||
|
||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||
final_answer = extract_boxed_solution(decoded_final_answer)
|
||||
format_valid = final_answer is not None
|
||||
if "tags" in kwargs and kwargs["tags"]:
|
||||
tags = kwargs["tags"]
|
||||
format_valid = format_valid and all(
|
||||
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
|
||||
)
|
||||
# Check format accuracy
|
||||
if format_valid:
|
||||
format_acc += 1
|
||||
|
@ -65,10 +65,16 @@ if __name__ == "__main__":
|
||||
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
|
||||
"--master_address",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Ray master address for multi-node distributed training, Optional",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional"
|
||||
"--torch_ddp_master_port",
|
||||
type=int,
|
||||
default=29505,
|
||||
help="Torch DDP master port for multi-node distributed training, Optional",
|
||||
)
|
||||
|
||||
# Sampling parameters
|
||||
@ -105,6 +111,9 @@ if __name__ == "__main__":
|
||||
help="Reward type for GRPO.",
|
||||
)
|
||||
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
|
||||
parser.add_argument(
|
||||
"-rft", "--response_format_tags", type=str, default=None, help="Optional json string of the response format tag"
|
||||
)
|
||||
|
||||
# Logging/Checkpointing parameters
|
||||
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
||||
@ -164,7 +173,7 @@ if __name__ == "__main__":
|
||||
elif args.backend == "vllm":
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
gpu_memory_utilization=0.7,
|
||||
gpu_memory_utilization=0.5,
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
|
||||
@ -232,21 +241,20 @@ if __name__ == "__main__":
|
||||
num_generations=args.num_generations,
|
||||
train_model_config=train_model_config,
|
||||
grpo_config=grpo_config,
|
||||
plugin_config={
|
||||
"zero_stage": 2,
|
||||
}, # for zero
|
||||
# plugin_config={
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 2,
|
||||
# "microbatch_size": max(
|
||||
# 1, args.train_microbatch_size // 2
|
||||
# ), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
# "zero_stage": 0,
|
||||
# "max_norm": 1.0,
|
||||
# }, # for pp, tp
|
||||
# "zero_stage": 2,
|
||||
# }, # for zero
|
||||
plugin_config={
|
||||
"tp_size": 4,
|
||||
"pp_size": 2,
|
||||
"microbatch_size": max(
|
||||
1, args.train_microbatch_size // 2
|
||||
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
"zero_stage": 1,
|
||||
"max_norm": 1.0,
|
||||
}, # for pp, tp
|
||||
inference_backend=args.backend,
|
||||
master_addr="localhost",
|
||||
master_port=args.master_port,
|
||||
master_port=args.torch_ddp_master_port,
|
||||
core_algo=args.algo,
|
||||
project_name=args.project,
|
||||
save_interval=args.save_interval,
|
||||
@ -257,4 +265,5 @@ if __name__ == "__main__":
|
||||
},
|
||||
eval_interval=args.eval_interval,
|
||||
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
||||
response_format_tags=(json.loads(args.response_format_tags) if args.response_format_tags is not None else {}),
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user