refactored

This commit is contained in:
YeAnbang 2025-04-25 16:54:20 +08:00
parent f4bdbf6b5d
commit 1f31f4f4fb
6 changed files with 144 additions and 206 deletions

View File

@ -33,7 +33,7 @@ class BaseConsumer:
batch_size: int, batch_size: int,
model_config: Dict[str, Any], model_config: Dict[str, Any],
plugin_config: Dict[str, Any], plugin_config: Dict[str, Any],
microbatch_size: int = 1, minibatch_size: int = 1,
save_interval: int = 100, save_interval: int = 100,
save_dir: str = "./model", save_dir: str = "./model",
): ):
@ -46,11 +46,11 @@ class BaseConsumer:
self.num_update_per_episode = num_update_per_episode self.num_update_per_episode = num_update_per_episode
self.num_recv_per_update = num_recv_per_update self.num_recv_per_update = num_recv_per_update
self.batch_size = batch_size self.batch_size = batch_size
self.microbatch_size = microbatch_size self.minibatch_size = minibatch_size
self.save_interval = save_interval self.save_interval = save_interval
self.save_dir = save_dir self.save_dir = save_dir
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size" assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // microbatch_size self.num_microbatches = batch_size // minibatch_size
self.model_config = model_config self.model_config = model_config
self.plugin_config = plugin_config self.plugin_config = plugin_config
@ -67,7 +67,7 @@ class BaseConsumer:
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size plugin_config["microbatch_size"] = self.minibatch_size
plugin_config.update(self.plugin_config) plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config) self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin) self.booster = Booster(plugin=self.plugin)
@ -105,22 +105,22 @@ class BaseConsumer:
) )
) )
) )
while len(self.buffer) >= self.dp_size * self.microbatch_size: while len(self.buffer) >= self.dp_size * self.minibatch_size:
batches = self.buffer[ batches = self.buffer[
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
] ]
batch = pad_batch( batch = pad_batch(
batches batches
) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking ) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking
batch = bind_batch(batches) batch = bind_batch(batches)
batch = post_recv(batch) batch = post_recv(batch)
loss, num_excessive_rollouts = self.step(i, pbar, **batch) loss, num_excessive_prompts = self.step(i, pbar, **batch)
self.buffer = ( self.buffer = (
self.buffer[ self.buffer[
(self.dp_rank + 1) * self.microbatch_size (self.dp_rank + 1) * self.minibatch_size
- num_excessive_rollouts : (self.dp_rank + 1) * self.microbatch_size - num_excessive_prompts : (self.dp_rank + 1) * self.minibatch_size
] ]
+ self.buffer[self.dp_size * self.microbatch_size :] + self.buffer[self.dp_size * self.minibatch_size :]
) )
if loss is not None: if loss is not None:
pbar.set_postfix({"loss": loss}) pbar.set_postfix({"loss": loss})
@ -162,7 +162,8 @@ class SimpleConsumer(BaseConsumer):
batch_size, batch_size,
model_config, model_config,
plugin_config, plugin_config,
microbatch_size=1, minibatch_size=1,
save_interval: int = 100,
save_dir="./model", save_dir="./model",
): ):
super().__init__( super().__init__(
@ -177,7 +178,7 @@ class SimpleConsumer(BaseConsumer):
batch_size, batch_size,
model_config, model_config,
plugin_config, plugin_config,
microbatch_size, minibatch_size,
) )
path = model_config.pop("path") path = model_config.pop("path")
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)

View File

@ -1,12 +1,9 @@
import json
import os
import warnings import warnings
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Optional from typing import Any, Optional
import ray import ray
import torch import torch
import torch.distributed as dist
import wandb import wandb
from coati.distributed.consumer import BaseConsumer from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss from coati.distributed.loss import PolicyLoss
@ -35,22 +32,23 @@ class GRPOConsumer(BaseConsumer):
batch_size, batch_size,
model_config, model_config,
plugin_config, plugin_config,
microbatch_size=1, minibatch_size=1,
num_generations=8, num_generations=8,
use_wandb=True, use_wandb=True,
generate_config=None, generate_config=None,
grpo_config={}, grpo_config={},
project_name=None, project_name=None,
save_interval: int = 100,
save_dir="./model", save_dir="./model",
): ):
print(f"Using GRPO config: {grpo_config}") print(f"Using GRPO config: {grpo_config}")
if grpo_config.get("loss_variation", "sample_level") == "token_level": if grpo_config.get("loss_variation", "sample_level") == "token_level":
if batch_size != microbatch_size: if batch_size != minibatch_size:
warnings.warn( warnings.warn(
f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {microbatch_size}->{batch_size}", f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {minibatch_size}->{batch_size}",
UserWarning, UserWarning,
) )
microbatch_size = batch_size minibatch_size = batch_size
super().__init__( super().__init__(
num_producers, num_producers,
num_episodes, num_episodes,
@ -63,7 +61,7 @@ class GRPOConsumer(BaseConsumer):
batch_size, batch_size,
model_config, model_config,
plugin_config, plugin_config,
microbatch_size, minibatch_size,
save_dir=save_dir, save_dir=save_dir,
) )
path = model_config.pop("path") path = model_config.pop("path")
@ -166,10 +164,10 @@ class GRPOConsumer(BaseConsumer):
}, },
...] ...]
Format: Format:
[batch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>. [minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
""" """
# Reshape to [batch_size x num_of_generation, prompt_length + response_length] # Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()} data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
action_mask = data["action_mask"] action_mask = data["action_mask"]
num_action = action_mask.shape[1] num_action = action_mask.shape[1]
@ -187,15 +185,15 @@ class GRPOConsumer(BaseConsumer):
format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
# [batch_size, num_generations] # [minibatch_size, num_generations]
group_reward = reward.view(-1, self.num_generations) group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1) reward_mean = group_reward.mean(dim=1)
# [batch_size x num_generations] # [minibatch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [batch_size x num_generations] # [minibatch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
group_ans_acc = ( group_ans_acc = (
@ -522,125 +520,3 @@ class GRPOConsumer(BaseConsumer):
model = self.policy_model.unwrap() model = self.policy_model.unwrap()
state_dict = model.state_dict() state_dict = model.state_dict()
return state_dict return state_dict
@ray.remote
class GRPOEvalConsumer(BaseConsumer):
def __init__(
self,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size=1,
num_generations=4,
use_wandb=True,
log_dir="./results",
):
super().__init__(
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_format_acc = torch.zeros(1, device=self.device)
self.accum_ans_acc = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = torch.zeros(1, device=self.device)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations
# Initialize verifiable reward.
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
self.reward_model = VerifiableReward(
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
)
self.log_dir = log_dir
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
else:
os.system(f"rm -rf {self.log_dir}/*")
def setup(self):
super().setup()
self.policy_model, _, *_ = self.booster.boost(self.policy_model)
def step(self, step_idx: int, **kwargs) -> Optional[float]:
rank = dist.get_rank()
data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()}
kwargs["input_ids"].size(0)
reward_group = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward = [value[0].item() for value in reward_group]
format_acc = [value[1].item() for value in reward_group]
ans_acc = [value[2].item() for value in reward_group]
response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))]
response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True)
with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f:
for i in range(len(response)):
f.write(
json.dumps(
{
"response": response[i],
"reward": reward[i],
"format_acc": format_acc[i],
"ans_acc": ans_acc[i],
"response_length": response_length[i],
},
ensure_ascii=False,
)
+ "\n"
)
self.accum_reward += sum(reward)
self.accum_format_acc += sum(format_acc)
self.accum_ans_acc += sum(ans_acc)
self.accum_response_length += sum(response_length)
self.accum_count += len(reward)
# print results
total_count = all_reduce_mean(self.accum_count, self.plugin)
mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count
mean_format_acc = all_reduce_mean(self.accum_format_acc, self.plugin) / total_count
mean_ans_acc = all_reduce_mean(self.accum_ans_acc, self.plugin) / total_count
mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count
if rank == 0:
print(
f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_acc}, Mean Acc Reward: {mean_ans_acc}, Mean Response Length: {mean_response_length}"
)
return None
def state_dict(self):
self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap()
state_dict = model.state_dict()
return state_dict

View File

@ -4,10 +4,10 @@ from typing import Any, Dict, Optional
import ray import ray
from .consumer import SimpleConsumer from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer from .grpo_consumer import GRPOConsumer
from .producer import SimpleProducer from .producer import SimpleProducer
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer} ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
def get_jsonl_size_fast(path: str) -> int: def get_jsonl_size_fast(path: str) -> int:
@ -49,6 +49,8 @@ def launch_distributed(
master_port: int = 29500, master_port: int = 29500,
core_algo: str = "GRPO", core_algo: str = "GRPO",
project_name: Optional[str] = None, project_name: Optional[str] = None,
save_interval: int = 100,
save_dir: str = "./model",
): ):
if core_algo not in ALGO_MAP: if core_algo not in ALGO_MAP:
@ -102,12 +104,13 @@ def launch_distributed(
batch_size=train_batch_size, batch_size=train_batch_size,
model_config=train_model_config, model_config=train_model_config,
plugin_config=plugin_config, plugin_config=plugin_config,
microbatch_size=train_minibatch_size, minibatch_size=train_minibatch_size,
generate_config=generate_config_consumer, generate_config=generate_config_consumer,
grpo_config=grpo_config, grpo_config=grpo_config,
num_generations=num_generations, num_generations=num_generations,
project_name=project_name, project_name=project_name,
save_dir=grpo_config.get("save_dir", f"./model/{project_name}"), save_interval=save_interval,
save_dir=save_dir,
) )
procs.append(consumer) procs.append(consumer)
ray.get([p.setup.remote() for p in procs]) ray.get([p.setup.remote() for p in procs])

View File

@ -103,7 +103,14 @@ class BaseProducer:
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs["temperature"] = torch.tensor( outputs["temperature"] = torch.tensor(
[self.model.generate_config.temperature] * outputs["input_ids"].size(0) [
(
self.model.generate_config["temperature"]
if isinstance(self.model.generate_config.temperature, dict)
else self.model.generate_config.temperature
)
]
* outputs["input_ids"].size(0)
).to(outputs["input_ids"].device) ).to(outputs["input_ids"].device)
outputs = pre_send(outputs) outputs = pre_send(outputs)
ray_broadcast_tensor_dict( ray_broadcast_tensor_dict(
@ -136,9 +143,14 @@ class BaseProducer:
# linear annealing for 1 episode, temperature from initial to 0.9 # linear annealing for 1 episode, temperature from initial to 0.9
if episode <= 0: if episode <= 0:
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ if isinstance(self.model.generate_config.temperature, dict):
"temperature" self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
] + ratio * 0.9 "temperature"
] + ratio * 0.9
else:
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
"temperature"
] + ratio * 0.9
@ray.remote @ray.remote

View File

@ -5,8 +5,7 @@ from .reward_utils import extract_solution, validate_response_structure
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"] tokenizer = kwargs["tokenizer"]
soft_over_length_punishment = kwargs["soft_over_length_punishment"] soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
format_score = 0.0
acc_score = 10.0 acc_score = 10.0
reward = torch.tensor(0.0) reward = torch.tensor(0.0)
format_acc = torch.tensor(0.0) format_acc = torch.tensor(0.0)
@ -33,7 +32,6 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
# Check format accuracy # Check format accuracy
if format_valid: if format_valid:
format_acc += 1 format_acc += 1
reward += format_score
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
if ( if (

View File

@ -1,4 +1,5 @@
import argparse import argparse
import os
import ray import ray
import torch import torch
@ -8,15 +9,17 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
# Distributed training parameters
parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-t", "--num-trainers", type=int, default=2)
parser.add_argument("-i", "--num-inferencer", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2)
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.") parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
parser.add_argument( parser.add_argument(
"-ibs", "-ibs",
"--inference-batch-size", "--inference-batch-size",
type=int, type=int,
default=64, default=None,
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.", help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
) )
parser.add_argument( parser.add_argument(
@ -37,7 +40,7 @@ if __name__ == "__main__":
"-tMbs", "-tMbs",
"--train-minibatch-size", "--train-minibatch-size",
type=int, type=int,
default=1, default=None,
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs", help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
) )
parser.add_argument( parser.add_argument(
@ -47,23 +50,61 @@ if __name__ == "__main__":
default=2, default=2,
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.", help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
) )
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
parser.add_argument( parser.add_argument(
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional" "--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
) )
parser.add_argument( parser.add_argument(
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional" "--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
) )
parser.add_argument(
"--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional"
)
# Sampling parameters
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
parser.add_argument("-temp", "--temperature", type=float, default=1.0, help="Temperature for sampling.")
parser.add_argument(
"-topk",
"--top-k",
type=int,
default=None,
help="Top k for sampling. Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-topp",
"--top-p",
type=float,
default=1.0,
help="Top p for sampling. Please check the generation arguments documentation for your backend.",
)
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.") parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.")
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
# GRPO parameters
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
# Logging/Checkpointing parameters
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.")
args = parser.parse_args() args = parser.parse_args()
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0" if args.train_minibatch_size is None:
# Default settings: Using train batch size as mini batch size
args.train_minibatch_size = args.train_batch_size
if args.inference_batch_size is None:
# Default settings: Using train batch size as inference batch size, sync every inference model every train step
args.inference_batch_size = args.train_batch_size
assert ( assert (
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
and args.train_microbatch_size > 0 and args.train_microbatch_size > 0
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations" ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
assert args.train_minibatch_size < args.train_batch_size, "Train mini batch size must be less than train batch size" assert (
args.train_minibatch_size <= args.train_batch_size
), "Train mini batch size must be less than or equals to train batch size"
if args.master_address is None: if args.master_address is None:
# Default settings: Using single machine # Default settings: Using single machine
@ -72,9 +113,15 @@ if __name__ == "__main__":
# For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node # For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node
ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir) ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir)
if args.top_k is None:
if args.backend == "transformers":
args.top_k = 50
elif args.backend == "vllm":
args.top_k = -1
inference_model_config = dict(path=args.model) inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
generate_config = dict(top_k=-1, top_p=1.0, temperature=1.0) generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
if args.backend == "transformers": if args.backend == "transformers":
inference_model_config.update( inference_model_config.update(
@ -85,7 +132,7 @@ if __name__ == "__main__":
) )
generate_config.update( generate_config.update(
dict( dict(
max_length=1024 + 512, max_length=args.max_new_tokens + args.max_prompt_tokens,
do_sample=True, do_sample=True,
max_new_tokens=None, max_new_tokens=None,
early_stopping=False, early_stopping=False,
@ -98,54 +145,48 @@ if __name__ == "__main__":
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
enforce_eager=True, enforce_eager=True,
enable_chunked_prefill=True, enable_chunked_prefill=True,
max_model_len=1024 * 4 + 510, max_model_len=args.max_new_tokens + args.max_prompt_tokens,
tensor_parallel_size=1, tensor_parallel_size=1,
) )
) )
generate_config.update( generate_config.update(
dict( dict(
max_tokens=1024 * 4, max_tokens=args.max_new_tokens, # max new tokens
ignore_eos=True, ignore_eos=True,
include_stop_str_in_output=True, include_stop_str_in_output=True,
stop=["</answer>"], stop=["</answer>"],
) )
) )
else: else:
inference_model_config.update( raise ValueError(f"Unsupported backend: {args.backend}")
dict(
mem_fraction_static=0.6,
)
)
generate_config.update(
dict(
max_new_tokens=256,
ignore_eos=True,
)
)
# Default Settings if args.algo == "GRPO":
# grpo_config = { # Default Settings
# "filter_range": [0.05, 9.0], grpo_config = {
# "lr": 1e-6, "lr": args.learning_rate,
# "train_microbatch_size": train_microbatch_size, "train_microbatch_size": args.train_microbatch_size,
# } "beta": args.kl_coeff, # KL penalty coefficient
"loss_variation": "sample_level",
# DAPO variant settings }
grpo_config = { elif args.algo == "DAPO":
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch # DAPO variant settings
"lr": 1e-6, grpo_config = {
"train_microbatch_size": args.train_microbatch_size, "filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
"dynamic_batching": True, "lr": args.learning_rate,
"clip_eps_low": 0.2, "train_microbatch_size": args.train_microbatch_size,
"clip_eps_high": 0.28, "dynamic_batching": True,
"skip_threshold": 20.0, "clip_eps_low": 0.2,
"beta": 0.0, # no KL penalty "clip_eps_high": 0.28,
"loss_variation": "token_level", "skip_threshold": 20.0,
"soft_over_length_punishment": True, "beta": 0, # no KL penalty for DAPO
"max_length": 1024 * 4, "loss_variation": "token_level",
"cache_length": 512, "soft_over_length_punishment": True,
"filter_truncated_response": True, "max_length": args.max_new_tokens + args.max_prompt_tokens,
} "cache_length": min(1024, int(args.max_new_tokens / 4)),
"filter_truncated_response": True,
}
else:
raise ValueError(f"Unsupported algorithm: {args.algo}")
launch_distributed( launch_distributed(
num_producers=args.num_inferencer, num_producers=args.num_inferencer,
@ -157,7 +198,11 @@ if __name__ == "__main__":
train_batch_size=args.train_batch_size, train_batch_size=args.train_batch_size,
train_minibatch_size=args.train_minibatch_size, train_minibatch_size=args.train_minibatch_size,
train_microbatch_size=args.train_microbatch_size, train_microbatch_size=args.train_microbatch_size,
dataset_config={"path": args.dataset, "max_length": 300, "system_prompt": args.system_prompt}, dataset_config={
"path": args.dataset,
"max_length": args.max_prompt_tokens,
"system_prompt": args.system_prompt,
},
dataloaders_config={}, dataloaders_config={},
inference_model_config=inference_model_config, inference_model_config=inference_model_config,
generate_config=generate_config, generate_config=generate_config,
@ -167,6 +212,7 @@ if __name__ == "__main__":
plugin_config={ plugin_config={
"zero_stage": 2, "zero_stage": 2,
}, # for zero }, # for zero
# currently not support tp/pp
# plugin_config={ # plugin_config={
# "tp_size": 2, # "tp_size": 2,
# "microbatch_size": args.train_microbatch_size // 2, # "microbatch_size": args.train_microbatch_size // 2,
@ -175,7 +221,9 @@ if __name__ == "__main__":
# }, # for pp # }, # for pp
inference_backend=args.backend, inference_backend=args.backend,
master_addr="localhost", master_addr="localhost",
master_port=29506, master_port=args.master_port,
core_algo=args.algo, core_algo=args.algo,
project_name=args.project, project_name=args.project,
save_interval=args.save_interval,
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
) )