mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-18 11:48:53 +00:00
Merge branch 'feat/hybrid_model_sync' into support_hybrid_model_sync
This commit is contained in:
commit
b2362ed33e
@ -67,7 +67,7 @@ class BaseConsumer:
|
|||||||
and "num_microbatches" not in self.plugin_config
|
and "num_microbatches" not in self.plugin_config
|
||||||
and "microbatch_size" not in self.plugin_config
|
and "microbatch_size" not in self.plugin_config
|
||||||
):
|
):
|
||||||
plugin_config["microbatch_size"] = self.minibatch_size
|
plugin_config["microbatch_size"] = max(1, self.minibatch_size // plugin_config.get("pp_size", 1))
|
||||||
plugin_config.update(self.plugin_config)
|
plugin_config.update(self.plugin_config)
|
||||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||||
self.booster = Booster(plugin=self.plugin)
|
self.booster = Booster(plugin=self.plugin)
|
||||||
|
@ -7,7 +7,7 @@ import torch
|
|||||||
import wandb
|
import wandb
|
||||||
from coati.distributed.consumer import BaseConsumer
|
from coati.distributed.consumer import BaseConsumer
|
||||||
from coati.distributed.loss import PolicyLoss
|
from coati.distributed.loss import PolicyLoss
|
||||||
from coati.distributed.reward.reward_fn import math_reward_fn
|
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||||
from coati.distributed.utils import calc_action_log_probs
|
from coati.distributed.utils import calc_action_log_probs
|
||||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||||
@ -54,7 +54,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
and "num_microbatches" not in plugin_config
|
and "num_microbatches" not in plugin_config
|
||||||
and "microbatch_size" not in plugin_config
|
and "microbatch_size" not in plugin_config
|
||||||
):
|
):
|
||||||
plugin_config["microbatch_size"] = max(1, grpo_config.get("train_microbatch_size") // 2)
|
plugin_config["microbatch_size"] = max(
|
||||||
|
1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1)
|
||||||
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_producers,
|
num_producers,
|
||||||
num_episodes,
|
num_episodes,
|
||||||
@ -131,7 +133,12 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
|
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
|
||||||
}
|
}
|
||||||
self.reward_model = VerifiableReward(
|
self.reward_model = VerifiableReward(
|
||||||
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs
|
reward_fns=[
|
||||||
|
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
|
||||||
|
],
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
tags=response_format_tags,
|
||||||
|
**reward_model_kwargs,
|
||||||
)
|
)
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
self.use_wandb = use_wandb
|
self.use_wandb = use_wandb
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .reward_utils import extract_solution, validate_response_structure
|
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
||||||
|
|
||||||
|
|
||||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||||
@ -70,3 +70,43 @@ def gsm8k_reward_fn(input_ids, **kwargs):
|
|||||||
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
|
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
|
||||||
reward = reward + 9.0
|
reward = reward + 9.0
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
|
|
||||||
|
def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||||
|
tokenizer = kwargs["tokenizer"]
|
||||||
|
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
||||||
|
format_score = 0.0
|
||||||
|
acc_score = 10.0
|
||||||
|
reward = torch.tensor(0.0)
|
||||||
|
format_acc = torch.tensor(0.0)
|
||||||
|
ans_acc = torch.tensor(0.0)
|
||||||
|
s, e = response_idx[0], response_idx[1]
|
||||||
|
|
||||||
|
length_reward = 0.0
|
||||||
|
if soft_over_length_punishment:
|
||||||
|
max_length = kwargs.get("max_length", 1024 * 4)
|
||||||
|
cache_length = kwargs.get("cache_length", 512)
|
||||||
|
res_length = e.item() - s.item() + 1
|
||||||
|
if max_length - cache_length < res_length < max_length:
|
||||||
|
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
|
||||||
|
|
||||||
|
if gt_answer is None:
|
||||||
|
return reward
|
||||||
|
|
||||||
|
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||||
|
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||||
|
final_answer = extract_boxed_solution(decoded_final_answer)
|
||||||
|
format_valid = final_answer is not None
|
||||||
|
# Check format accuracy
|
||||||
|
if format_valid:
|
||||||
|
format_acc += 1
|
||||||
|
reward += format_score
|
||||||
|
|
||||||
|
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||||
|
if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower():
|
||||||
|
ans_acc += 1
|
||||||
|
reward += acc_score
|
||||||
|
|
||||||
|
reward = reward + length_reward
|
||||||
|
|
||||||
|
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||||
|
@ -74,3 +74,51 @@ def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
|
|||||||
|
|
||||||
final_answer = matches[-1].group(1).strip()
|
final_answer = matches[-1].group(1).strip()
|
||||||
return final_answer, solution_str
|
return final_answer, solution_str
|
||||||
|
|
||||||
|
|
||||||
|
def extract_boxed_solution(text: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Modified from: https://gist.github.com/lewtun/9c2ce1937b741404090a3dc4c7c022b3
|
||||||
|
Retrieves the content from the last occurrence of `\boxed{}` in a LaTeX-like string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): A string potentially containing LaTeX-style boxed expressions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: The text inside the final `\boxed{}` if successfully extracted;
|
||||||
|
returns `None` if no properly closed box is found.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> extract_boxed_solution("The answer is \\boxed{42}.")
|
||||||
|
'42'
|
||||||
|
>>> extract_boxed_solution("Here is an unmatched \\boxed{42")
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Find the last occurrence of "\boxed{"
|
||||||
|
start_idx = text.rindex("\\boxed{")
|
||||||
|
# Move past "\boxed{" to find the start of the content
|
||||||
|
content_start = start_idx + len("\\boxed{")
|
||||||
|
open_braces = 1
|
||||||
|
pos = content_start
|
||||||
|
|
||||||
|
# Traverse the string to find the matching closing brace
|
||||||
|
while open_braces > 0 and pos < len(text):
|
||||||
|
if text[pos] == "{":
|
||||||
|
open_braces += 1
|
||||||
|
elif text[pos] == "}":
|
||||||
|
open_braces -= 1
|
||||||
|
pos += 1
|
||||||
|
|
||||||
|
# If all braces are matched, extract and return the content
|
||||||
|
if open_braces == 0:
|
||||||
|
return text[content_start : pos - 1].strip()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
# "\boxed{" not found
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
# Any other unexpected error
|
||||||
|
return None
|
||||||
|
@ -86,6 +86,14 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
|
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
|
||||||
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
|
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
|
||||||
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
|
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
|
||||||
|
parser.add_argument(
|
||||||
|
"-rt",
|
||||||
|
"--reward-type",
|
||||||
|
type=str,
|
||||||
|
default="think_answer_tags",
|
||||||
|
choices=["think_answer_tags", "boxed"],
|
||||||
|
help="Reward type for GRPO.",
|
||||||
|
)
|
||||||
|
|
||||||
# Logging/Checkpointing parameters
|
# Logging/Checkpointing parameters
|
||||||
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
||||||
@ -136,8 +144,8 @@ if __name__ == "__main__":
|
|||||||
max_length=args.max_new_tokens + args.max_prompt_tokens,
|
max_length=args.max_new_tokens + args.max_prompt_tokens,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
max_new_tokens=None,
|
max_new_tokens=None,
|
||||||
early_stopping=False,
|
early_stopping=False if args.reward_type == "think_answer_tags" else True,
|
||||||
stop_strings=["</answer>"],
|
stop_strings=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif args.backend == "vllm":
|
elif args.backend == "vllm":
|
||||||
@ -153,9 +161,9 @@ if __name__ == "__main__":
|
|||||||
generate_config.update(
|
generate_config.update(
|
||||||
dict(
|
dict(
|
||||||
max_tokens=args.max_new_tokens, # max new tokens
|
max_tokens=args.max_new_tokens, # max new tokens
|
||||||
ignore_eos=True,
|
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
|
||||||
include_stop_str_in_output=True,
|
include_stop_str_in_output=True,
|
||||||
stop=["</answer>"],
|
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -168,6 +176,7 @@ if __name__ == "__main__":
|
|||||||
"train_microbatch_size": args.train_microbatch_size,
|
"train_microbatch_size": args.train_microbatch_size,
|
||||||
"beta": args.kl_coeff, # KL penalty coefficient
|
"beta": args.kl_coeff, # KL penalty coefficient
|
||||||
"loss_variation": "sample_level",
|
"loss_variation": "sample_level",
|
||||||
|
"reward_fn_type": args.reward_type,
|
||||||
}
|
}
|
||||||
elif args.algo == "DAPO":
|
elif args.algo == "DAPO":
|
||||||
# DAPO variant settings
|
# DAPO variant settings
|
||||||
@ -185,6 +194,7 @@ if __name__ == "__main__":
|
|||||||
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||||
"cache_length": min(1024, int(args.max_new_tokens / 4)),
|
"cache_length": min(1024, int(args.max_new_tokens / 4)),
|
||||||
"filter_truncated_response": True,
|
"filter_truncated_response": True,
|
||||||
|
"reward_fn_type": args.reward_type,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
||||||
@ -210,12 +220,17 @@ if __name__ == "__main__":
|
|||||||
train_model_config=train_model_config,
|
train_model_config=train_model_config,
|
||||||
grpo_config=grpo_config,
|
grpo_config=grpo_config,
|
||||||
plugin_config={
|
plugin_config={
|
||||||
"pp_size": 2,
|
"zero_stage": 2,
|
||||||
"tp_size": 2,
|
}, # for zero
|
||||||
"microbatch_size": args.train_microbatch_size // 2,
|
# plugin_config={
|
||||||
"zero_stage": 1,
|
# "tp_size": 2,
|
||||||
"max_norm": 1.0,
|
# "pp_size": 2,
|
||||||
}, # for tp + pp
|
# "microbatch_size": max(
|
||||||
|
# 1, args.train_microbatch_size // 2
|
||||||
|
# ), # microbatch size should be set to train_microbatch_size // pp_size
|
||||||
|
# "zero_stage": 0,
|
||||||
|
# "max_norm": 1.0,
|
||||||
|
# }, # for pp, tp
|
||||||
inference_backend=args.backend,
|
inference_backend=args.backend,
|
||||||
master_addr="localhost",
|
master_addr="localhost",
|
||||||
master_port=args.master_port,
|
master_port=args.master_port,
|
||||||
|
Loading…
Reference in New Issue
Block a user