mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-13 13:45:51 +00:00
refactored
This commit is contained in:
parent
f4bdbf6b5d
commit
1f31f4f4fb
@ -33,7 +33,7 @@ class BaseConsumer:
|
|||||||
batch_size: int,
|
batch_size: int,
|
||||||
model_config: Dict[str, Any],
|
model_config: Dict[str, Any],
|
||||||
plugin_config: Dict[str, Any],
|
plugin_config: Dict[str, Any],
|
||||||
microbatch_size: int = 1,
|
minibatch_size: int = 1,
|
||||||
save_interval: int = 100,
|
save_interval: int = 100,
|
||||||
save_dir: str = "./model",
|
save_dir: str = "./model",
|
||||||
):
|
):
|
||||||
@ -46,11 +46,11 @@ class BaseConsumer:
|
|||||||
self.num_update_per_episode = num_update_per_episode
|
self.num_update_per_episode = num_update_per_episode
|
||||||
self.num_recv_per_update = num_recv_per_update
|
self.num_recv_per_update = num_recv_per_update
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.microbatch_size = microbatch_size
|
self.minibatch_size = minibatch_size
|
||||||
self.save_interval = save_interval
|
self.save_interval = save_interval
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
|
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
|
||||||
self.num_microbatches = batch_size // microbatch_size
|
self.num_microbatches = batch_size // minibatch_size
|
||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.plugin_config = plugin_config
|
self.plugin_config = plugin_config
|
||||||
@ -67,7 +67,7 @@ class BaseConsumer:
|
|||||||
|
|
||||||
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
|
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
|
||||||
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
|
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
|
||||||
plugin_config["microbatch_size"] = self.microbatch_size
|
plugin_config["microbatch_size"] = self.minibatch_size
|
||||||
plugin_config.update(self.plugin_config)
|
plugin_config.update(self.plugin_config)
|
||||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||||
self.booster = Booster(plugin=self.plugin)
|
self.booster = Booster(plugin=self.plugin)
|
||||||
@ -105,22 +105,22 @@ class BaseConsumer:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
while len(self.buffer) >= self.dp_size * self.microbatch_size:
|
while len(self.buffer) >= self.dp_size * self.minibatch_size:
|
||||||
batches = self.buffer[
|
batches = self.buffer[
|
||||||
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size
|
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
|
||||||
]
|
]
|
||||||
batch = pad_batch(
|
batch = pad_batch(
|
||||||
batches
|
batches
|
||||||
) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking
|
) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking
|
||||||
batch = bind_batch(batches)
|
batch = bind_batch(batches)
|
||||||
batch = post_recv(batch)
|
batch = post_recv(batch)
|
||||||
loss, num_excessive_rollouts = self.step(i, pbar, **batch)
|
loss, num_excessive_prompts = self.step(i, pbar, **batch)
|
||||||
self.buffer = (
|
self.buffer = (
|
||||||
self.buffer[
|
self.buffer[
|
||||||
(self.dp_rank + 1) * self.microbatch_size
|
(self.dp_rank + 1) * self.minibatch_size
|
||||||
- num_excessive_rollouts : (self.dp_rank + 1) * self.microbatch_size
|
- num_excessive_prompts : (self.dp_rank + 1) * self.minibatch_size
|
||||||
]
|
]
|
||||||
+ self.buffer[self.dp_size * self.microbatch_size :]
|
+ self.buffer[self.dp_size * self.minibatch_size :]
|
||||||
)
|
)
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
pbar.set_postfix({"loss": loss})
|
pbar.set_postfix({"loss": loss})
|
||||||
@ -162,7 +162,8 @@ class SimpleConsumer(BaseConsumer):
|
|||||||
batch_size,
|
batch_size,
|
||||||
model_config,
|
model_config,
|
||||||
plugin_config,
|
plugin_config,
|
||||||
microbatch_size=1,
|
minibatch_size=1,
|
||||||
|
save_interval: int = 100,
|
||||||
save_dir="./model",
|
save_dir="./model",
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -177,7 +178,7 @@ class SimpleConsumer(BaseConsumer):
|
|||||||
batch_size,
|
batch_size,
|
||||||
model_config,
|
model_config,
|
||||||
plugin_config,
|
plugin_config,
|
||||||
microbatch_size,
|
minibatch_size,
|
||||||
)
|
)
|
||||||
path = model_config.pop("path")
|
path = model_config.pop("path")
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
import wandb
|
import wandb
|
||||||
from coati.distributed.consumer import BaseConsumer
|
from coati.distributed.consumer import BaseConsumer
|
||||||
from coati.distributed.loss import PolicyLoss
|
from coati.distributed.loss import PolicyLoss
|
||||||
@ -35,22 +32,23 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
batch_size,
|
batch_size,
|
||||||
model_config,
|
model_config,
|
||||||
plugin_config,
|
plugin_config,
|
||||||
microbatch_size=1,
|
minibatch_size=1,
|
||||||
num_generations=8,
|
num_generations=8,
|
||||||
use_wandb=True,
|
use_wandb=True,
|
||||||
generate_config=None,
|
generate_config=None,
|
||||||
grpo_config={},
|
grpo_config={},
|
||||||
project_name=None,
|
project_name=None,
|
||||||
|
save_interval: int = 100,
|
||||||
save_dir="./model",
|
save_dir="./model",
|
||||||
):
|
):
|
||||||
print(f"Using GRPO config: {grpo_config}")
|
print(f"Using GRPO config: {grpo_config}")
|
||||||
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
||||||
if batch_size != microbatch_size:
|
if batch_size != minibatch_size:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {microbatch_size}->{batch_size}",
|
f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {minibatch_size}->{batch_size}",
|
||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
microbatch_size = batch_size
|
minibatch_size = batch_size
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_producers,
|
num_producers,
|
||||||
num_episodes,
|
num_episodes,
|
||||||
@ -63,7 +61,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
batch_size,
|
batch_size,
|
||||||
model_config,
|
model_config,
|
||||||
plugin_config,
|
plugin_config,
|
||||||
microbatch_size,
|
minibatch_size,
|
||||||
save_dir=save_dir,
|
save_dir=save_dir,
|
||||||
)
|
)
|
||||||
path = model_config.pop("path")
|
path = model_config.pop("path")
|
||||||
@ -166,10 +164,10 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
},
|
},
|
||||||
...]
|
...]
|
||||||
Format:
|
Format:
|
||||||
[batch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Reshape to [batch_size x num_of_generation, prompt_length + response_length]
|
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
|
||||||
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
|
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
|
||||||
action_mask = data["action_mask"]
|
action_mask = data["action_mask"]
|
||||||
num_action = action_mask.shape[1]
|
num_action = action_mask.shape[1]
|
||||||
@ -187,15 +185,15 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
|
format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
|
||||||
ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
|
ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
|
||||||
|
|
||||||
# [batch_size, num_generations]
|
# [minibatch_size, num_generations]
|
||||||
|
|
||||||
group_reward = reward.view(-1, self.num_generations)
|
group_reward = reward.view(-1, self.num_generations)
|
||||||
reward_mean = group_reward.mean(dim=1)
|
reward_mean = group_reward.mean(dim=1)
|
||||||
# [batch_size x num_generations]
|
# [minibatch_size x num_generations]
|
||||||
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
|
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
|
||||||
|
|
||||||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||||
# [batch_size x num_generations]
|
# [minibatch_size x num_generations]
|
||||||
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
||||||
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
|
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
|
||||||
group_ans_acc = (
|
group_ans_acc = (
|
||||||
@ -522,125 +520,3 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
model = self.policy_model.unwrap()
|
model = self.policy_model.unwrap()
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
|
||||||
class GRPOEvalConsumer(BaseConsumer):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_producers,
|
|
||||||
num_episodes,
|
|
||||||
rank,
|
|
||||||
world_size,
|
|
||||||
master_addr,
|
|
||||||
master_port,
|
|
||||||
num_update_per_episode,
|
|
||||||
num_recv_per_update,
|
|
||||||
batch_size,
|
|
||||||
model_config,
|
|
||||||
plugin_config,
|
|
||||||
microbatch_size=1,
|
|
||||||
num_generations=4,
|
|
||||||
use_wandb=True,
|
|
||||||
log_dir="./results",
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
num_producers,
|
|
||||||
num_episodes,
|
|
||||||
rank,
|
|
||||||
world_size,
|
|
||||||
master_addr,
|
|
||||||
master_port,
|
|
||||||
num_update_per_episode,
|
|
||||||
num_recv_per_update,
|
|
||||||
batch_size,
|
|
||||||
model_config,
|
|
||||||
plugin_config,
|
|
||||||
microbatch_size,
|
|
||||||
)
|
|
||||||
path = model_config.pop("path")
|
|
||||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
|
||||||
self.policy_model.train()
|
|
||||||
self.accum_reward = torch.zeros(1, device=self.device)
|
|
||||||
self.accum_format_acc = torch.zeros(1, device=self.device)
|
|
||||||
self.accum_ans_acc = torch.zeros(1, device=self.device)
|
|
||||||
self.accum_response_length = torch.zeros(1, device=self.device)
|
|
||||||
self.accum_count = torch.zeros(1, device=self.device)
|
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
|
||||||
self.pad_token_id = self.tokenizer.pad_token_id
|
|
||||||
self.num_generations = num_generations
|
|
||||||
|
|
||||||
# Initialize verifiable reward.
|
|
||||||
response_format_tags = {
|
|
||||||
"think_start": {"text": "<think>", "num_occur": 1},
|
|
||||||
"think_end": {"text": "</think>", "num_occur": 1},
|
|
||||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
|
||||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
|
||||||
}
|
|
||||||
self.reward_model = VerifiableReward(
|
|
||||||
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
|
|
||||||
)
|
|
||||||
|
|
||||||
self.log_dir = log_dir
|
|
||||||
if not os.path.exists(self.log_dir):
|
|
||||||
os.makedirs(self.log_dir)
|
|
||||||
else:
|
|
||||||
os.system(f"rm -rf {self.log_dir}/*")
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
super().setup()
|
|
||||||
self.policy_model, _, *_ = self.booster.boost(self.policy_model)
|
|
||||||
|
|
||||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
|
||||||
rank = dist.get_rank()
|
|
||||||
data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()}
|
|
||||||
kwargs["input_ids"].size(0)
|
|
||||||
reward_group = self.reward_model(
|
|
||||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
|
||||||
)
|
|
||||||
reward = [value[0].item() for value in reward_group]
|
|
||||||
format_acc = [value[1].item() for value in reward_group]
|
|
||||||
ans_acc = [value[2].item() for value in reward_group]
|
|
||||||
response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))]
|
|
||||||
|
|
||||||
response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True)
|
|
||||||
with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f:
|
|
||||||
for i in range(len(response)):
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"response": response[i],
|
|
||||||
"reward": reward[i],
|
|
||||||
"format_acc": format_acc[i],
|
|
||||||
"ans_acc": ans_acc[i],
|
|
||||||
"response_length": response_length[i],
|
|
||||||
},
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.accum_reward += sum(reward)
|
|
||||||
self.accum_format_acc += sum(format_acc)
|
|
||||||
self.accum_ans_acc += sum(ans_acc)
|
|
||||||
self.accum_response_length += sum(response_length)
|
|
||||||
self.accum_count += len(reward)
|
|
||||||
|
|
||||||
# print results
|
|
||||||
total_count = all_reduce_mean(self.accum_count, self.plugin)
|
|
||||||
mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count
|
|
||||||
mean_format_acc = all_reduce_mean(self.accum_format_acc, self.plugin) / total_count
|
|
||||||
mean_ans_acc = all_reduce_mean(self.accum_ans_acc, self.plugin) / total_count
|
|
||||||
mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count
|
|
||||||
if rank == 0:
|
|
||||||
print(
|
|
||||||
f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_acc}, Mean Acc Reward: {mean_ans_acc}, Mean Response Length: {mean_response_length}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
self.policy_model._force_wait_all_gather()
|
|
||||||
model = self.policy_model.unwrap()
|
|
||||||
state_dict = model.state_dict()
|
|
||||||
return state_dict
|
|
||||||
|
@ -4,10 +4,10 @@ from typing import Any, Dict, Optional
|
|||||||
import ray
|
import ray
|
||||||
|
|
||||||
from .consumer import SimpleConsumer
|
from .consumer import SimpleConsumer
|
||||||
from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer
|
from .grpo_consumer import GRPOConsumer
|
||||||
from .producer import SimpleProducer
|
from .producer import SimpleProducer
|
||||||
|
|
||||||
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer}
|
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
|
||||||
|
|
||||||
|
|
||||||
def get_jsonl_size_fast(path: str) -> int:
|
def get_jsonl_size_fast(path: str) -> int:
|
||||||
@ -49,6 +49,8 @@ def launch_distributed(
|
|||||||
master_port: int = 29500,
|
master_port: int = 29500,
|
||||||
core_algo: str = "GRPO",
|
core_algo: str = "GRPO",
|
||||||
project_name: Optional[str] = None,
|
project_name: Optional[str] = None,
|
||||||
|
save_interval: int = 100,
|
||||||
|
save_dir: str = "./model",
|
||||||
):
|
):
|
||||||
|
|
||||||
if core_algo not in ALGO_MAP:
|
if core_algo not in ALGO_MAP:
|
||||||
@ -102,12 +104,13 @@ def launch_distributed(
|
|||||||
batch_size=train_batch_size,
|
batch_size=train_batch_size,
|
||||||
model_config=train_model_config,
|
model_config=train_model_config,
|
||||||
plugin_config=plugin_config,
|
plugin_config=plugin_config,
|
||||||
microbatch_size=train_minibatch_size,
|
minibatch_size=train_minibatch_size,
|
||||||
generate_config=generate_config_consumer,
|
generate_config=generate_config_consumer,
|
||||||
grpo_config=grpo_config,
|
grpo_config=grpo_config,
|
||||||
num_generations=num_generations,
|
num_generations=num_generations,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
save_dir=grpo_config.get("save_dir", f"./model/{project_name}"),
|
save_interval=save_interval,
|
||||||
|
save_dir=save_dir,
|
||||||
)
|
)
|
||||||
procs.append(consumer)
|
procs.append(consumer)
|
||||||
ray.get([p.setup.remote() for p in procs])
|
ray.get([p.setup.remote() for p in procs])
|
||||||
|
@ -103,7 +103,14 @@ class BaseProducer:
|
|||||||
|
|
||||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||||
outputs["temperature"] = torch.tensor(
|
outputs["temperature"] = torch.tensor(
|
||||||
[self.model.generate_config.temperature] * outputs["input_ids"].size(0)
|
[
|
||||||
|
(
|
||||||
|
self.model.generate_config["temperature"]
|
||||||
|
if isinstance(self.model.generate_config.temperature, dict)
|
||||||
|
else self.model.generate_config.temperature
|
||||||
|
)
|
||||||
|
]
|
||||||
|
* outputs["input_ids"].size(0)
|
||||||
).to(outputs["input_ids"].device)
|
).to(outputs["input_ids"].device)
|
||||||
outputs = pre_send(outputs)
|
outputs = pre_send(outputs)
|
||||||
ray_broadcast_tensor_dict(
|
ray_broadcast_tensor_dict(
|
||||||
@ -136,9 +143,14 @@ class BaseProducer:
|
|||||||
# linear annealing for 1 episode, temperature from initial to 0.9
|
# linear annealing for 1 episode, temperature from initial to 0.9
|
||||||
if episode <= 0:
|
if episode <= 0:
|
||||||
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
|
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
|
||||||
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
|
if isinstance(self.model.generate_config.temperature, dict):
|
||||||
"temperature"
|
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
||||||
] + ratio * 0.9
|
"temperature"
|
||||||
|
] + ratio * 0.9
|
||||||
|
else:
|
||||||
|
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
|
||||||
|
"temperature"
|
||||||
|
] + ratio * 0.9
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
|
@ -5,8 +5,7 @@ from .reward_utils import extract_solution, validate_response_structure
|
|||||||
|
|
||||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||||
tokenizer = kwargs["tokenizer"]
|
tokenizer = kwargs["tokenizer"]
|
||||||
soft_over_length_punishment = kwargs["soft_over_length_punishment"]
|
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
||||||
format_score = 0.0
|
|
||||||
acc_score = 10.0
|
acc_score = 10.0
|
||||||
reward = torch.tensor(0.0)
|
reward = torch.tensor(0.0)
|
||||||
format_acc = torch.tensor(0.0)
|
format_acc = torch.tensor(0.0)
|
||||||
@ -33,7 +32,6 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
# Check format accuracy
|
# Check format accuracy
|
||||||
if format_valid:
|
if format_valid:
|
||||||
format_acc += 1
|
format_acc += 1
|
||||||
reward += format_score
|
|
||||||
|
|
||||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||||
if (
|
if (
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
@ -8,15 +9,17 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||||
|
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
|
||||||
|
|
||||||
|
# Distributed training parameters
|
||||||
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
||||||
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
||||||
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
|
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
|
||||||
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-ibs",
|
"-ibs",
|
||||||
"--inference-batch-size",
|
"--inference-batch-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=64,
|
default=None,
|
||||||
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
|
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -37,7 +40,7 @@ if __name__ == "__main__":
|
|||||||
"-tMbs",
|
"-tMbs",
|
||||||
"--train-minibatch-size",
|
"--train-minibatch-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=None,
|
||||||
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
|
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -47,23 +50,61 @@ if __name__ == "__main__":
|
|||||||
default=2,
|
default=2,
|
||||||
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
|
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
|
||||||
)
|
)
|
||||||
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
|
|
||||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
|
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
|
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sampling parameters
|
||||||
|
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
|
||||||
|
parser.add_argument("-temp", "--temperature", type=float, default=1.0, help="Temperature for sampling.")
|
||||||
|
parser.add_argument(
|
||||||
|
"-topk",
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Top k for sampling. Please check the generation arguments documentation for your backend.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-topp",
|
||||||
|
"--top-p",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Top p for sampling. Please check the generation arguments documentation for your backend.",
|
||||||
|
)
|
||||||
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
|
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
|
||||||
|
parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.")
|
||||||
|
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
|
||||||
|
|
||||||
|
# GRPO parameters
|
||||||
|
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
|
||||||
|
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
|
||||||
|
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
|
||||||
|
|
||||||
|
# Logging/Checkpointing parameters
|
||||||
|
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
||||||
|
parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0"
|
if args.train_minibatch_size is None:
|
||||||
|
# Default settings: Using train batch size as mini batch size
|
||||||
|
args.train_minibatch_size = args.train_batch_size
|
||||||
|
if args.inference_batch_size is None:
|
||||||
|
# Default settings: Using train batch size as inference batch size, sync every inference model every train step
|
||||||
|
args.inference_batch_size = args.train_batch_size
|
||||||
assert (
|
assert (
|
||||||
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
|
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
|
||||||
and args.train_microbatch_size > 0
|
and args.train_microbatch_size > 0
|
||||||
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
||||||
assert args.train_minibatch_size < args.train_batch_size, "Train mini batch size must be less than train batch size"
|
assert (
|
||||||
|
args.train_minibatch_size <= args.train_batch_size
|
||||||
|
), "Train mini batch size must be less than or equals to train batch size"
|
||||||
|
|
||||||
if args.master_address is None:
|
if args.master_address is None:
|
||||||
# Default settings: Using single machine
|
# Default settings: Using single machine
|
||||||
@ -72,9 +113,15 @@ if __name__ == "__main__":
|
|||||||
# For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node
|
# For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node
|
||||||
ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir)
|
ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir)
|
||||||
|
|
||||||
|
if args.top_k is None:
|
||||||
|
if args.backend == "transformers":
|
||||||
|
args.top_k = 50
|
||||||
|
elif args.backend == "vllm":
|
||||||
|
args.top_k = -1
|
||||||
|
|
||||||
inference_model_config = dict(path=args.model)
|
inference_model_config = dict(path=args.model)
|
||||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
||||||
generate_config = dict(top_k=-1, top_p=1.0, temperature=1.0)
|
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
|
||||||
|
|
||||||
if args.backend == "transformers":
|
if args.backend == "transformers":
|
||||||
inference_model_config.update(
|
inference_model_config.update(
|
||||||
@ -85,7 +132,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
dict(
|
dict(
|
||||||
max_length=1024 + 512,
|
max_length=args.max_new_tokens + args.max_prompt_tokens,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
max_new_tokens=None,
|
max_new_tokens=None,
|
||||||
early_stopping=False,
|
early_stopping=False,
|
||||||
@ -98,54 +145,48 @@ if __name__ == "__main__":
|
|||||||
gpu_memory_utilization=0.7,
|
gpu_memory_utilization=0.7,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
enable_chunked_prefill=True,
|
enable_chunked_prefill=True,
|
||||||
max_model_len=1024 * 4 + 510,
|
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
dict(
|
dict(
|
||||||
max_tokens=1024 * 4,
|
max_tokens=args.max_new_tokens, # max new tokens
|
||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
include_stop_str_in_output=True,
|
include_stop_str_in_output=True,
|
||||||
stop=["</answer>"],
|
stop=["</answer>"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inference_model_config.update(
|
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||||
dict(
|
|
||||||
mem_fraction_static=0.6,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
generate_config.update(
|
|
||||||
dict(
|
|
||||||
max_new_tokens=256,
|
|
||||||
ignore_eos=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Default Settings
|
if args.algo == "GRPO":
|
||||||
# grpo_config = {
|
# Default Settings
|
||||||
# "filter_range": [0.05, 9.0],
|
grpo_config = {
|
||||||
# "lr": 1e-6,
|
"lr": args.learning_rate,
|
||||||
# "train_microbatch_size": train_microbatch_size,
|
"train_microbatch_size": args.train_microbatch_size,
|
||||||
# }
|
"beta": args.kl_coeff, # KL penalty coefficient
|
||||||
|
"loss_variation": "sample_level",
|
||||||
# DAPO variant settings
|
}
|
||||||
grpo_config = {
|
elif args.algo == "DAPO":
|
||||||
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
|
# DAPO variant settings
|
||||||
"lr": 1e-6,
|
grpo_config = {
|
||||||
"train_microbatch_size": args.train_microbatch_size,
|
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
|
||||||
"dynamic_batching": True,
|
"lr": args.learning_rate,
|
||||||
"clip_eps_low": 0.2,
|
"train_microbatch_size": args.train_microbatch_size,
|
||||||
"clip_eps_high": 0.28,
|
"dynamic_batching": True,
|
||||||
"skip_threshold": 20.0,
|
"clip_eps_low": 0.2,
|
||||||
"beta": 0.0, # no KL penalty
|
"clip_eps_high": 0.28,
|
||||||
"loss_variation": "token_level",
|
"skip_threshold": 20.0,
|
||||||
"soft_over_length_punishment": True,
|
"beta": 0, # no KL penalty for DAPO
|
||||||
"max_length": 1024 * 4,
|
"loss_variation": "token_level",
|
||||||
"cache_length": 512,
|
"soft_over_length_punishment": True,
|
||||||
"filter_truncated_response": True,
|
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||||
}
|
"cache_length": min(1024, int(args.max_new_tokens / 4)),
|
||||||
|
"filter_truncated_response": True,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
||||||
|
|
||||||
launch_distributed(
|
launch_distributed(
|
||||||
num_producers=args.num_inferencer,
|
num_producers=args.num_inferencer,
|
||||||
@ -157,7 +198,11 @@ if __name__ == "__main__":
|
|||||||
train_batch_size=args.train_batch_size,
|
train_batch_size=args.train_batch_size,
|
||||||
train_minibatch_size=args.train_minibatch_size,
|
train_minibatch_size=args.train_minibatch_size,
|
||||||
train_microbatch_size=args.train_microbatch_size,
|
train_microbatch_size=args.train_microbatch_size,
|
||||||
dataset_config={"path": args.dataset, "max_length": 300, "system_prompt": args.system_prompt},
|
dataset_config={
|
||||||
|
"path": args.dataset,
|
||||||
|
"max_length": args.max_prompt_tokens,
|
||||||
|
"system_prompt": args.system_prompt,
|
||||||
|
},
|
||||||
dataloaders_config={},
|
dataloaders_config={},
|
||||||
inference_model_config=inference_model_config,
|
inference_model_config=inference_model_config,
|
||||||
generate_config=generate_config,
|
generate_config=generate_config,
|
||||||
@ -167,6 +212,7 @@ if __name__ == "__main__":
|
|||||||
plugin_config={
|
plugin_config={
|
||||||
"zero_stage": 2,
|
"zero_stage": 2,
|
||||||
}, # for zero
|
}, # for zero
|
||||||
|
# currently not support tp/pp
|
||||||
# plugin_config={
|
# plugin_config={
|
||||||
# "tp_size": 2,
|
# "tp_size": 2,
|
||||||
# "microbatch_size": args.train_microbatch_size // 2,
|
# "microbatch_size": args.train_microbatch_size // 2,
|
||||||
@ -175,7 +221,9 @@ if __name__ == "__main__":
|
|||||||
# }, # for pp
|
# }, # for pp
|
||||||
inference_backend=args.backend,
|
inference_backend=args.backend,
|
||||||
master_addr="localhost",
|
master_addr="localhost",
|
||||||
master_port=29506,
|
master_port=args.master_port,
|
||||||
core_algo=args.algo,
|
core_algo=args.algo,
|
||||||
project_name=args.project,
|
project_name=args.project,
|
||||||
|
save_interval=args.save_interval,
|
||||||
|
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user