mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 23:11:55 +00:00
merge grpo-latest
This commit is contained in:
commit
7b921acc8a
@ -124,17 +124,17 @@ class BaseConsumer:
|
|||||||
raw_batch_with_reward = self.calculate_reward({k:v.view(-1, v.size(-1)) if k!='temperature' else v for k, v in raw_batch.items()})
|
raw_batch_with_reward = self.calculate_reward({k:v.view(-1, v.size(-1)) if k!='temperature' else v for k, v in raw_batch.items()})
|
||||||
raw_batch_with_reward = {k: v.view(-1, self.num_generations, v.size(-1)) if k!='temperature' else v for k, v in raw_batch_with_reward.items()}
|
raw_batch_with_reward = {k: v.view(-1, self.num_generations, v.size(-1)) if k!='temperature' else v for k, v in raw_batch_with_reward.items()}
|
||||||
# [batch_size, num_generations] -> [batch_size]
|
# [batch_size, num_generations] -> [batch_size]
|
||||||
group_reward_mean = raw_batch_with_reward["reward"][:,:,0].mean(dim=-1)
|
reward = raw_batch_with_reward["reward"][:,:,0]
|
||||||
group_format_acc_mean = raw_batch_with_reward["format_acc"][:,:,0].mean(dim=-1)
|
format_acc = raw_batch_with_reward["format_acc"][:,:,0]
|
||||||
group_ans_acc_mean = raw_batch_with_reward["ans_acc"][:,:,0].mean(dim=-1)
|
ans_acc = raw_batch_with_reward["ans_acc"][:,:,0]
|
||||||
group_response_len = (
|
response_len = (
|
||||||
(raw_batch_with_reward["response_idx"][:, :, 1] - raw_batch_with_reward["response_idx"][:, :, 0] + 1)
|
(raw_batch_with_reward["response_idx"][:, :, 1] - raw_batch_with_reward["response_idx"][:, :, 0] + 1)
|
||||||
.type(torch.float32)
|
.type(torch.float32)
|
||||||
.mean(dim=-1)
|
|
||||||
)
|
)
|
||||||
effective_group_mask = None
|
effective_group_mask = None
|
||||||
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
|
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
|
||||||
# filter the group based on the reward and accuracy
|
# filter the group based on the reward and accuracy
|
||||||
|
group_ans_acc_mean = ans_acc.mean(dim=1)
|
||||||
effective_group_mask = torch.logical_and(
|
effective_group_mask = torch.logical_and(
|
||||||
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
|
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
|
||||||
)
|
)
|
||||||
@ -143,15 +143,15 @@ class BaseConsumer:
|
|||||||
self.buffer.append(
|
self.buffer.append(
|
||||||
[
|
[
|
||||||
group_with_reward if effective_group_mask is None or effective_group_mask[group_idx] else None,
|
group_with_reward if effective_group_mask is None or effective_group_mask[group_idx] else None,
|
||||||
group_reward_mean[group_idx],
|
reward[group_idx],
|
||||||
group_format_acc_mean[group_idx],
|
format_acc[group_idx],
|
||||||
group_ans_acc_mean[group_idx],
|
ans_acc[group_idx],
|
||||||
group_response_len[group_idx],
|
response_len[group_idx],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
if effective_group_mask is not None:
|
if effective_group_mask is not None:
|
||||||
print(
|
print(
|
||||||
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
|
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch_with_reward)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
|
||||||
)
|
)
|
||||||
# mapping the effective group to the raw group for indexing
|
# mapping the effective group to the raw group for indexing
|
||||||
effective_group_to_raw_group_mapping = {}
|
effective_group_to_raw_group_mapping = {}
|
||||||
@ -160,7 +160,7 @@ class BaseConsumer:
|
|||||||
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
|
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
|
||||||
buffer_idx
|
buffer_idx
|
||||||
)
|
)
|
||||||
pbar.set_postfix({"Collect Effective Prompt": f"{len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"})
|
print(f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}")
|
||||||
|
|
||||||
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
|
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
|
||||||
# on each dp_rank, we use minibatch_size effective samples to form a batch
|
# on each dp_rank, we use minibatch_size effective samples to form a batch
|
||||||
|
@ -211,6 +211,17 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
loss_mask,
|
loss_mask,
|
||||||
action_mask[:, -1] == False,
|
action_mask[:, -1] == False,
|
||||||
)
|
)
|
||||||
|
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False)==False:
|
||||||
|
# filter out samples with reward outside the range
|
||||||
|
# if dynamic batching is enabled, we filter out out of range groups before training
|
||||||
|
group_ans_acc_mean = ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1)
|
||||||
|
loss_mask = torch.logical_and(
|
||||||
|
loss_mask,
|
||||||
|
torch.logical_and(
|
||||||
|
group_ans_acc_mean > self.filter_range[0],
|
||||||
|
group_ans_acc_mean < self.filter_range[1],
|
||||||
|
),
|
||||||
|
)
|
||||||
self.effective_prompt_count += group_reward.size(0) * self.dp_size
|
self.effective_prompt_count += group_reward.size(0) * self.dp_size
|
||||||
|
|
||||||
mean_kl, mean_loss = [], []
|
mean_kl, mean_loss = [], []
|
||||||
@ -229,8 +240,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
pbar.set_postfix(
|
pbar.set_postfix(
|
||||||
{
|
{
|
||||||
"Global Step": self.global_step,
|
"Global Step": self.global_step,
|
||||||
"Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}",
|
"Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples",
|
||||||
"Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -428,6 +438,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
|
# no need to run all reduce as raw_train_batch_* are not splited across dp rank
|
||||||
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
|
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
|
||||||
self.effective_prompt_count = 0
|
self.effective_prompt_count = 0
|
||||||
self.effective_sample_count = 0
|
self.effective_sample_count = 0
|
||||||
@ -438,14 +449,12 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||||
):
|
):
|
||||||
raw_batch_reward_mean = sum(self.raw_train_batch_reward) / len(self.raw_train_batch_reward)
|
raw_batch_reward_mean = torch.cat(self.raw_train_batch_reward, dim=0).mean().cpu().item()
|
||||||
raw_batch_format_acc_mean = sum(self.raw_train_batch_format_acc) / len(
|
raw_batch_format_acc_mean = torch.cat(self.raw_train_batch_format_acc, dim=0).mean().cpu().item()
|
||||||
self.raw_train_batch_format_acc
|
raw_batch_ans_acc_mean = torch.cat(self.raw_train_batch_ans_acc, dim=0).mean().cpu().item()
|
||||||
)
|
raw_batch_response_len = torch.cat(self.raw_train_batch_response_len, dim=0)
|
||||||
raw_batch_ans_acc_mean = sum(self.raw_train_batch_ans_acc) / len(self.raw_train_batch_ans_acc)
|
raw_batch_response_len_mean = raw_batch_response_len.mean().cpu().item()
|
||||||
raw_batch_response_len_mean = sum(self.raw_train_batch_response_len) / len(
|
overlength_samples_ratio = (raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item() # not an exact figure, but a close estimate
|
||||||
self.raw_train_batch_response_len
|
|
||||||
)
|
|
||||||
self.raw_train_batch_reward = []
|
self.raw_train_batch_reward = []
|
||||||
self.raw_train_batch_format_acc = []
|
self.raw_train_batch_format_acc = []
|
||||||
self.raw_train_batch_ans_acc = []
|
self.raw_train_batch_ans_acc = []
|
||||||
@ -458,6 +467,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
||||||
f"Response Length: {raw_batch_response_len_mean:.4f}",
|
f"Response Length: {raw_batch_response_len_mean:.4f}",
|
||||||
f"Sample_utilization: {sample_utilization:.4f}",
|
f"Sample_utilization: {sample_utilization:.4f}",
|
||||||
|
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
|
||||||
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
||||||
print("\n".join(to_log_msg))
|
print("\n".join(to_log_msg))
|
||||||
metrics = {
|
metrics = {
|
||||||
@ -469,6 +479,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||||
"train/sample_utilization": sample_utilization,
|
"train/sample_utilization": sample_utilization,
|
||||||
|
"train/overlength_samples_ratio": overlength_samples_ratio,
|
||||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||||
}
|
}
|
||||||
if self.policy_loss_fn.beta > 0:
|
if self.policy_loss_fn.beta > 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user