mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 04:32:47 +00:00
[feat] Support prompt level dynamic (#6300)
* adjust to dynamic prompt bs * remove debug * update pad seq (#6303) Co-authored-by: Tong Li <tong.li35271158@gmail.com> * adjust to dynamic prompt bs * remove debug * fix dp issue * fix * fix default settings --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
parent
b920af427b
commit
aca547623f
@ -107,9 +107,14 @@ class BaseConsumer:
|
||||
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
|
||||
)
|
||||
for episode in range(self.num_episodes):
|
||||
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
|
||||
with tqdm(
|
||||
range(self.num_update_per_episode),
|
||||
desc=f"Episode {episode} with rollout step(s)",
|
||||
disable=self.rank != 0,
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
i = 0
|
||||
allow_sync_model = False
|
||||
for _ in range(self.num_recv_per_update):
|
||||
# receive data from producers
|
||||
for r in range(self.num_producers):
|
||||
@ -127,15 +132,15 @@ class BaseConsumer:
|
||||
]
|
||||
batch = bind_batch(batches)
|
||||
batch = post_recv(batch)
|
||||
loss, num_excessive_prompts = self.step(i, pbar, **batch)
|
||||
self.buffer = (
|
||||
self.buffer[
|
||||
(self.dp_rank + 1) * self.minibatch_size
|
||||
- num_excessive_prompts : (self.dp_rank + 1) * self.minibatch_size
|
||||
]
|
||||
+ self.buffer[self.dp_size * self.minibatch_size :]
|
||||
)
|
||||
loss, excessive_prompts_idx = self.step(i, pbar, **batch)
|
||||
|
||||
if excessive_prompts_idx is not None:
|
||||
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
|
||||
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
|
||||
else:
|
||||
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
|
||||
if loss is not None:
|
||||
allow_sync_model = True
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
if self.lr_scheduler is not None:
|
||||
@ -149,29 +154,31 @@ class BaseConsumer:
|
||||
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
||||
|
||||
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
|
||||
if self.pp_size > 1:
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
|
||||
)
|
||||
else:
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||
torch.cuda.empty_cache()
|
||||
state_dict = self.state_dict()
|
||||
if self.pp_size > 1:
|
||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict,
|
||||
src=self.num_producers,
|
||||
device=self.device,
|
||||
group_name=f"sync_model_{self.pp_rank}",
|
||||
if allow_sync_model:
|
||||
if self.pp_size > 1:
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
|
||||
)
|
||||
else:
|
||||
if self.rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||
torch.cuda.empty_cache()
|
||||
state_dict = self.state_dict()
|
||||
if self.pp_size > 1:
|
||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict,
|
||||
src=self.num_producers,
|
||||
device=self.device,
|
||||
group_name=f"sync_model_{self.pp_rank}",
|
||||
)
|
||||
else:
|
||||
if self.rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
allow_sync_model = False
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
@ -1,4 +1,3 @@
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -10,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from coati.distributed.utils import calc_action_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||
from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
@ -42,13 +41,6 @@ class GRPOConsumer(BaseConsumer):
|
||||
save_dir="./model",
|
||||
):
|
||||
print(f"Using GRPO config: {grpo_config}")
|
||||
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
||||
if batch_size != minibatch_size:
|
||||
warnings.warn(
|
||||
f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {minibatch_size}->{batch_size}",
|
||||
UserWarning,
|
||||
)
|
||||
minibatch_size = batch_size
|
||||
if (
|
||||
plugin_config.get("pp_size", 1) > 1
|
||||
and "num_microbatches" not in plugin_config
|
||||
@ -90,6 +82,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.grpo_config = grpo_config
|
||||
self.project_name = project_name
|
||||
self.effective_sample_count = 0
|
||||
self.effective_prompt_count = 0
|
||||
self.total_sample_count = 0
|
||||
|
||||
self.policy_loss_fn = PolicyLoss(
|
||||
@ -213,70 +206,66 @@ class GRPOConsumer(BaseConsumer):
|
||||
group_ans_acc = (
|
||||
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||
)
|
||||
# [minibatch_size x num_of_generation]
|
||||
loss_mask = (
|
||||
torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
||||
if self.filter_range is None
|
||||
else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1])
|
||||
)
|
||||
|
||||
# filter out overlength samples
|
||||
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
||||
loss_mask = torch.logical_and(
|
||||
loss_mask,
|
||||
action_mask[:, -1] == False,
|
||||
)
|
||||
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
|
||||
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
|
||||
|
||||
# [minibatch_size] -> calculate the number of effective prompts
|
||||
effective_prompts_mask = prompt_level_mask.any(dim=1)
|
||||
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
|
||||
self.effective_prompt_count += effective_prompts.item()
|
||||
excessive_prompts_idx = None
|
||||
|
||||
mean_kl, mean_loss = [], []
|
||||
|
||||
if self.grpo_config.get("dynamic_batching", True):
|
||||
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
|
||||
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
|
||||
|
||||
if excessive_prompts > 0:
|
||||
excessive_prompts_per_rank = excessive_prompts // self.dp_size
|
||||
# Only count excessive prompts if they are greater than 1 per rank.
|
||||
# TODO: customize excessive prompts calculation.
|
||||
if excessive_prompts_per_rank != 0:
|
||||
# Mask excessive prompts to False
|
||||
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
|
||||
if excessive_prompts_per_rank <= len(true_indices):
|
||||
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
|
||||
else:
|
||||
excessive_prompts_idx = true_indices
|
||||
effective_prompts_mask[excessive_prompts_idx] = False
|
||||
|
||||
for mask_idx in range(len(effective_prompts_mask)):
|
||||
if effective_prompts_mask[mask_idx] == False:
|
||||
# Update loss mask.
|
||||
loss_mask[mask_idx] = False
|
||||
else:
|
||||
# If dynamic batching is disabled, we need to use all samples for training.
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
|
||||
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
|
||||
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
|
||||
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
|
||||
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
||||
self.effective_sample_count += effective_samples.item()
|
||||
self.total_sample_count += total_samples.item()
|
||||
|
||||
mean_kl, mean_loss = [], []
|
||||
|
||||
if self.grpo_config.get("dynamic_batching", True):
|
||||
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
|
||||
# to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
|
||||
num_excessive_samples = (
|
||||
int(
|
||||
(self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations)
|
||||
/ self.num_generations
|
||||
/ self.dp_size
|
||||
)
|
||||
* self.num_generations
|
||||
)
|
||||
if num_excessive_samples > 0:
|
||||
data = {
|
||||
k: (
|
||||
v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)]
|
||||
if k
|
||||
in [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"action_log_probs",
|
||||
"action_mask",
|
||||
"response_idx",
|
||||
"gt_answer",
|
||||
]
|
||||
else v
|
||||
)
|
||||
for k, v in data.items()
|
||||
}
|
||||
action_mask = action_mask[
|
||||
: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)
|
||||
]
|
||||
loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)]
|
||||
advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)]
|
||||
else:
|
||||
num_excessive_samples = 0
|
||||
else:
|
||||
# If dynamic batching is disabled, we need to use all samples for training.
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
num_excessive_samples = 0
|
||||
|
||||
pbar.set_postfix(
|
||||
{
|
||||
"Step": self.global_step + 1,
|
||||
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
|
||||
"Global Step": self.global_step,
|
||||
"Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}",
|
||||
"Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
|
||||
}
|
||||
)
|
||||
|
||||
@ -375,7 +364,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
kl.append(appox_kl.mean())
|
||||
else:
|
||||
per_token_kl = 0.0
|
||||
kl.append(0.0)
|
||||
kl.append(torch.tensor(0.0))
|
||||
|
||||
loss, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
@ -479,6 +468,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.optimizer.zero_grad()
|
||||
self.global_step += 1
|
||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
||||
self.effective_prompt_count = 0
|
||||
self.effective_sample_count = 0
|
||||
self.total_sample_count = 0
|
||||
loss_scalar = self.accum_loss.item()
|
||||
@ -495,6 +485,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
|
||||
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
||||
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
|
||||
f"Sample_utilization: {sample_utilization:.4f}",
|
||||
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
||||
print("\n".join(to_log_msg))
|
||||
metrics = {
|
||||
@ -520,9 +511,15 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.accum_advantages.zero_()
|
||||
self.accum_response_length.zero_()
|
||||
self.accum_count = 0
|
||||
return loss_scalar, num_excessive_samples // self.num_generations
|
||||
|
||||
if excessive_prompts_idx is not None:
|
||||
# All gather excessive prompts index across DP ranks.
|
||||
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
|
||||
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
|
||||
|
||||
return loss_scalar, excessive_prompts_idx
|
||||
else:
|
||||
return None, num_excessive_samples // self.num_generations
|
||||
return None, excessive_prompts_idx
|
||||
|
||||
def state_dict(self):
|
||||
self.policy_model._force_wait_all_gather()
|
||||
|
@ -144,3 +144,29 @@ def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
|
||||
else:
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||
return tensor
|
||||
|
||||
|
||||
def all_gather_tensors(local_tensor_list: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
|
||||
"""
|
||||
Gathers tensors from all processes and concatenates them along the first dimension.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The input tensor to be gathered.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The gathered tensor.
|
||||
"""
|
||||
# Gather tensors across DP group
|
||||
if plugin is not None:
|
||||
all_tensor_lists = [None] * plugin.dp_size
|
||||
dist.all_gather_object(all_tensor_lists, local_tensor_list, group=plugin.dp_group)
|
||||
gathered_tensor_list = []
|
||||
for tensors in all_tensor_lists:
|
||||
gathered_tensor_list.extend(tensors)
|
||||
else:
|
||||
all_tensor_lists = [None] * dist.get_world_size()
|
||||
dist.all_gather_object(all_tensor_lists, local_tensor_list)
|
||||
gathered_tensor_list = []
|
||||
for tensors in all_tensor_lists:
|
||||
gathered_tensor_list.extend(tensors)
|
||||
return gathered_tensor_list
|
||||
|
@ -9,7 +9,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
|
||||
parser.add_argument("-p", "--project", type=str, default="GRPO-V3", help="Project name.")
|
||||
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
|
||||
|
||||
# Distributed training parameters
|
||||
@ -20,7 +20,7 @@ if __name__ == "__main__":
|
||||
"-ibs",
|
||||
"--inference-batch-size",
|
||||
type=int,
|
||||
default=None,
|
||||
default=64,
|
||||
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -41,7 +41,7 @@ if __name__ == "__main__":
|
||||
"-tMbs",
|
||||
"--train-minibatch-size",
|
||||
type=int,
|
||||
default=None,
|
||||
default=8,
|
||||
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -58,7 +58,7 @@ if __name__ == "__main__":
|
||||
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional"
|
||||
"--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional"
|
||||
)
|
||||
|
||||
# Sampling parameters
|
||||
@ -223,7 +223,7 @@ if __name__ == "__main__":
|
||||
"zero_stage": 2,
|
||||
}, # for zero
|
||||
# plugin_config={
|
||||
# "tp_size": 1,
|
||||
# "tp_size": 2,
|
||||
# "pp_size": 2,
|
||||
# "microbatch_size": max(
|
||||
# 1, args.train_microbatch_size // 2
|
||||
|
Loading…
Reference in New Issue
Block a user