mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[feat] add microbatch forwarding (#6251)
* add microbatch forwarding * fix forward microbatch * fix producer OOM * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change project name * fix temperature annealing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address conversation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -68,6 +68,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.accum_response_length = torch.zeros(1, device=self.device)
|
||||
self.accum_count = 0
|
||||
self.generate_config = generate_config
|
||||
self.training_config = training_config
|
||||
|
||||
# Reference model is initialized from policy model.
|
||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
@@ -131,40 +132,16 @@ class GRPOConsumer(BaseConsumer):
|
||||
num_action = action_mask.shape[1]
|
||||
old_action_log_probs = data["action_log_probs"]
|
||||
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
||||
forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0))
|
||||
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
|
||||
ctx = (
|
||||
nullcontext()
|
||||
if need_update or self.booster.plugin.zero_stage == 2
|
||||
else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||
)
|
||||
with ctx:
|
||||
policy_model_logits = self.policy_model(
|
||||
input_ids=data["input_ids"],
|
||||
attention_mask=data["attention_mask"],
|
||||
)["logits"]
|
||||
action_log_probs = calc_action_log_probs(
|
||||
policy_model_logits / self.generate_config["temperature"],
|
||||
data["input_ids"],
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
reference_model_logits = self.reference_model(
|
||||
input_ids=data["input_ids"],
|
||||
attention_mask=data["attention_mask"],
|
||||
)["logits"]
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
data["input_ids"],
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
- (reference_action_log_probs - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
|
||||
|
||||
reward_group = self.reward_model(
|
||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
||||
)
|
||||
@@ -177,6 +154,11 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
group_reward = reward.view(-1, self.num_generations)
|
||||
reward_mean = group_reward.mean(dim=1)
|
||||
# [batch_size x num_generations]
|
||||
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
|
||||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||
# [batch_size x num_generations]
|
||||
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
||||
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
|
||||
loss_mask = (
|
||||
None
|
||||
@@ -185,35 +167,82 @@ class GRPOConsumer(BaseConsumer):
|
||||
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
|
||||
).repeat_interleave(self.num_generations, dim=0)
|
||||
)
|
||||
mean_kl, mean_loss = [], []
|
||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
|
||||
input_ids_forward_micro_batch = data["input_ids"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
]
|
||||
attention_mask_forward_micro_batch = data["attention_mask"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
]
|
||||
action_mask_forward_micro_batch = action_mask[
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
]
|
||||
loss_mask_forward_micro_batch = (
|
||||
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
|
||||
if loss_mask is not None
|
||||
else None
|
||||
)
|
||||
advantages_forward_micro_batch = advantages[
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
]
|
||||
policy_model_logits = self.policy_model(
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
attention_mask=attention_mask_forward_micro_batch,
|
||||
).logits
|
||||
action_log_probs = calc_action_log_probs(
|
||||
policy_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
|
||||
# [batch_size x num_generations]
|
||||
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
|
||||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||
# [batch_size x num_generations]
|
||||
advantages = (reward - reward_mean) / (reward_std + 1e-4)
|
||||
with torch.no_grad():
|
||||
reference_model_logits = self.reference_model(
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
attention_mask=attention_mask_forward_micro_batch,
|
||||
).logits
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
|
||||
loss, skip_update, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
old_action_log_probs,
|
||||
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
action_mask,
|
||||
loss_mask=loss_mask,
|
||||
)
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
- (reference_action_log_probs - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
|
||||
action_mask_forward_micro_batch, dim=-1
|
||||
)
|
||||
|
||||
loss, skip_update, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
old_action_log_probs,
|
||||
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
action_mask_forward_micro_batch,
|
||||
loss_mask=loss_mask_forward_micro_batch,
|
||||
)
|
||||
|
||||
if not skip_update:
|
||||
self.booster.backward(loss, self.optimizer)
|
||||
loss = all_reduce_mean(loss, self.plugin)
|
||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||
# Calculate accumulate value.
|
||||
mean_kl.append(kl.data)
|
||||
mean_loss.append(loss.data)
|
||||
|
||||
if not skip_update:
|
||||
self.booster.backward(loss, self.optimizer)
|
||||
loss = all_reduce_mean(loss, self.plugin)
|
||||
reward = all_reduce_mean(reward.mean(), self.plugin)
|
||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
||||
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
|
||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||
# Calculate accumulate value.
|
||||
self.accum_loss.add_(loss.data)
|
||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||
self.accum_reward.add_(reward.data)
|
||||
self.accum_kl.add_(kl.data)
|
||||
self.accum_format_reward.add_(format_reward.data)
|
||||
self.accum_acc_reward.add_(acc_reward.data)
|
||||
self.accum_advantages.add_(advantages.data)
|
||||
|
Reference in New Issue
Block a user