[feat] add microbatch forwarding (#6251)

* add microbatch forwarding

* fix forward microbatch

* fix producer OOM

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change project name

* fix temperature annealing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* address conversation

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
YeAnbang
2025-03-28 10:24:58 +08:00
committed by GitHub
parent 489f215ad9
commit 50153005b4
5 changed files with 112 additions and 72 deletions

View File

@@ -68,6 +68,7 @@ class GRPOConsumer(BaseConsumer):
self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = 0
self.generate_config = generate_config
self.training_config = training_config
# Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@@ -131,40 +132,16 @@ class GRPOConsumer(BaseConsumer):
num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"]
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0))
need_update = (step_idx + 1) % self.num_microbatches == 0
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
ctx = (
nullcontext()
if need_update or self.booster.plugin.zero_stage == 2
else self.booster.no_sync(self.policy_model, self.optimizer)
)
with ctx:
policy_model_logits = self.policy_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
action_log_probs = calc_action_log_probs(
policy_model_logits / self.generate_config["temperature"],
data["input_ids"],
num_action,
self.plugin.shard_config,
)
with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
data["input_ids"],
num_action,
self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
- (reference_action_log_probs - action_log_probs)
- 1
)
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
reward_group = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
@@ -177,6 +154,11 @@ class GRPOConsumer(BaseConsumer):
group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# [batch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [batch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
loss_mask = (
None
@@ -185,35 +167,82 @@ class GRPOConsumer(BaseConsumer):
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
).repeat_interleave(self.num_generations, dim=0)
)
mean_kl, mean_loss = [], []
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
input_ids_forward_micro_batch = data["input_ids"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
attention_mask_forward_micro_batch = data["attention_mask"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
action_mask_forward_micro_batch = action_mask[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
loss_mask_forward_micro_batch = (
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
if loss_mask is not None
else None
)
advantages_forward_micro_batch = advantages[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
policy_model_logits = self.policy_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
action_log_probs = calc_action_log_probs(
policy_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
# [batch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [batch_size x num_generations]
advantages = (reward - reward_mean) / (reward_std + 1e-4)
with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
action_mask,
loss_mask=loss_mask,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
- (reference_action_log_probs - action_log_probs)
- 1
)
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
action_mask_forward_micro_batch, dim=-1
)
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
action_mask_forward_micro_batch,
loss_mask=loss_mask_forward_micro_batch,
)
if not skip_update:
self.booster.backward(loss, self.optimizer)
loss = all_reduce_mean(loss, self.plugin)
kl = all_reduce_mean(kl.mean(), self.plugin)
# Calculate accumulate value.
mean_kl.append(kl.data)
mean_loss.append(loss.data)
if not skip_update:
self.booster.backward(loss, self.optimizer)
loss = all_reduce_mean(loss, self.plugin)
reward = all_reduce_mean(reward.mean(), self.plugin)
kl = all_reduce_mean(kl.mean(), self.plugin)
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
advantages = all_reduce_mean(advantages.mean(), self.plugin)
response_length = all_reduce_mean(response_length.mean(), self.plugin)
# Calculate accumulate value.
self.accum_loss.add_(loss.data)
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
self.accum_reward.add_(reward.data)
self.accum_kl.add_(kl.data)
self.accum_format_reward.add_(format_reward.data)
self.accum_acc_reward.add_(acc_reward.data)
self.accum_advantages.add_(advantages.data)