mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-11 12:51:55 +00:00
move prompt-level-filtering to buffer side
This commit is contained in:
parent
55eee129d2
commit
a528921944
@ -254,7 +254,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
|
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
|
||||||
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
||||||
self.effective_sample_count += effective_samples.item()
|
self.effective_sample_count += effective_samples.item()
|
||||||
self.total_sample_count += total_samples.item()
|
|
||||||
pbar.set_postfix(
|
pbar.set_postfix(
|
||||||
{
|
{
|
||||||
"Global Step": self.global_step,
|
"Global Step": self.global_step,
|
||||||
@ -461,6 +460,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
|
self.total_sample_count = all_reduce_sum(
|
||||||
|
torch.tensor(self.total_sample_count).to(self.accum_loss.device), self.plugin
|
||||||
|
).item()
|
||||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
sample_utilization = self.effective_sample_count / self.total_sample_count
|
||||||
self.effective_prompt_count = 0
|
self.effective_prompt_count = 0
|
||||||
self.effective_sample_count = 0
|
self.effective_sample_count = 0
|
||||||
@ -564,6 +566,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
"format_acc": torch.Tensor, [num_of_generation]
|
"format_acc": torch.Tensor, [num_of_generation]
|
||||||
"ans_acc": torch.Tensor, [num_of_generation]
|
"ans_acc": torch.Tensor, [num_of_generation]
|
||||||
"""
|
"""
|
||||||
|
self.total_sample_count += rollout_group["input_ids"].size(0)
|
||||||
if self.filter_range is not None:
|
if self.filter_range is not None:
|
||||||
# filter prompt whoes accuracy is too high or too low (out of range)
|
# filter prompt whoes accuracy is too high or too low (out of range)
|
||||||
group_ans_acc = torch.mean(rollout_group["ans_acc"])
|
group_ans_acc = torch.mean(rollout_group["ans_acc"])
|
||||||
|
Loading…
Reference in New Issue
Block a user