mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-19 00:17:18 +00:00
fix missing or wrong file during rebase
This commit is contained in:
parent
118a66fd46
commit
3746f73854
@ -217,6 +217,7 @@ class BaseConsumer:
|
|||||||
effective_group_mask = None
|
effective_group_mask = None
|
||||||
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
|
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
|
||||||
# filter the group based on the reward and accuracy
|
# filter the group based on the reward and accuracy
|
||||||
|
group_ans_acc_mean = ans_acc.mean(dim=1)
|
||||||
effective_group_mask = torch.logical_and(
|
effective_group_mask = torch.logical_and(
|
||||||
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
|
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
|
||||||
)
|
)
|
||||||
|
@ -91,9 +91,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.project_name = project_name
|
self.project_name = project_name
|
||||||
self.effective_sample_count = 0
|
self.effective_sample_count = 0
|
||||||
self.effective_prompt_count = 0
|
self.effective_prompt_count = 0
|
||||||
self.total_sample_count = 0
|
|
||||||
self.overlength_samples = 0
|
|
||||||
self.total_overlength_samples = 0
|
|
||||||
self.project_name = project_name
|
self.project_name = project_name
|
||||||
self.run_name = run_name
|
self.run_name = run_name
|
||||||
self.wandb_group_name = wandb_group_name
|
self.wandb_group_name = wandb_group_name
|
||||||
@ -207,7 +204,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
# filter out overlength samples
|
# filter out overlength samples
|
||||||
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
||||||
old_loss_mask = loss_mask.clone()
|
|
||||||
loss_mask = torch.logical_and(
|
loss_mask = torch.logical_and(
|
||||||
loss_mask,
|
loss_mask,
|
||||||
action_mask[:, -1] == False,
|
action_mask[:, -1] == False,
|
||||||
@ -225,15 +221,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
group_ans_acc_mean < self.filter_range[1],
|
group_ans_acc_mean < self.filter_range[1],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.total_overlength_samples += self.overlength_samples.item()
|
self.effective_prompt_count += group_reward.size(0) * self.dp_size
|
||||||
|
|
||||||
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
|
|
||||||
|
|
||||||
# [minibatch_size] -> calculate the number of effective prompts
|
|
||||||
effective_prompts_mask = prompt_level_mask.any(dim=1)
|
|
||||||
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
|
|
||||||
self.effective_prompt_count += effective_prompts.item()
|
|
||||||
excessive_prompts_idx = None
|
|
||||||
|
|
||||||
mean_kl, mean_loss = [], []
|
mean_kl, mean_loss = [], []
|
||||||
|
|
||||||
@ -250,8 +238,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
pbar.set_postfix(
|
pbar.set_postfix(
|
||||||
{
|
{
|
||||||
"Global Step": self.global_step,
|
"Global Step": self.global_step,
|
||||||
"Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}",
|
"Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples",
|
||||||
"Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -477,12 +464,10 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
# no need to run all reduce as raw_train_batch_* are not splited across dp rank
|
||||||
overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count
|
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
|
||||||
self.effective_prompt_count = 0
|
self.effective_prompt_count = 0
|
||||||
self.effective_sample_count = 0
|
self.effective_sample_count = 0
|
||||||
self.total_sample_count = 0
|
|
||||||
self.total_overlength_samples = 0
|
|
||||||
loss_scalar = self.accum_loss.item()
|
loss_scalar = self.accum_loss.item()
|
||||||
if not self.plugin.pp_size > 1 or (
|
if not self.plugin.pp_size > 1 or (
|
||||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||||
|
Loading…
Reference in New Issue
Block a user