mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-17 15:36:53 +00:00
fix missing or wrong file during rebase
This commit is contained in:
parent
118a66fd46
commit
3746f73854
@ -217,6 +217,7 @@ class BaseConsumer:
|
||||
effective_group_mask = None
|
||||
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
|
||||
# filter the group based on the reward and accuracy
|
||||
group_ans_acc_mean = ans_acc.mean(dim=1)
|
||||
effective_group_mask = torch.logical_and(
|
||||
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
|
||||
)
|
||||
|
@ -91,9 +91,6 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.project_name = project_name
|
||||
self.effective_sample_count = 0
|
||||
self.effective_prompt_count = 0
|
||||
self.total_sample_count = 0
|
||||
self.overlength_samples = 0
|
||||
self.total_overlength_samples = 0
|
||||
self.project_name = project_name
|
||||
self.run_name = run_name
|
||||
self.wandb_group_name = wandb_group_name
|
||||
@ -207,7 +204,6 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
# filter out overlength samples
|
||||
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
||||
old_loss_mask = loss_mask.clone()
|
||||
loss_mask = torch.logical_and(
|
||||
loss_mask,
|
||||
action_mask[:, -1] == False,
|
||||
@ -225,15 +221,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
group_ans_acc_mean < self.filter_range[1],
|
||||
),
|
||||
)
|
||||
self.total_overlength_samples += self.overlength_samples.item()
|
||||
|
||||
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
|
||||
|
||||
# [minibatch_size] -> calculate the number of effective prompts
|
||||
effective_prompts_mask = prompt_level_mask.any(dim=1)
|
||||
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
|
||||
self.effective_prompt_count += effective_prompts.item()
|
||||
excessive_prompts_idx = None
|
||||
self.effective_prompt_count += group_reward.size(0) * self.dp_size
|
||||
|
||||
mean_kl, mean_loss = [], []
|
||||
|
||||
@ -250,8 +238,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
pbar.set_postfix(
|
||||
{
|
||||
"Global Step": self.global_step,
|
||||
"Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}",
|
||||
"Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
|
||||
"Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples",
|
||||
}
|
||||
)
|
||||
|
||||
@ -477,12 +464,10 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.global_step += 1
|
||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
||||
overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count
|
||||
# no need to run all reduce as raw_train_batch_* are not splited across dp rank
|
||||
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
|
||||
self.effective_prompt_count = 0
|
||||
self.effective_sample_count = 0
|
||||
self.total_sample_count = 0
|
||||
self.total_overlength_samples = 0
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
@ -545,4 +530,4 @@ class GRPOConsumer(BaseConsumer):
|
||||
model = self.policy_model.unwrap()
|
||||
state_dict = model.state_dict()
|
||||
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
|
||||
return state_dict
|
||||
return state_dict
|
Loading…
Reference in New Issue
Block a user