mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[Distributed RLHF] Integration of PP (#6257)
* update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
@@ -39,6 +39,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
use_wandb=True,
|
||||
generate_config=None,
|
||||
training_config={},
|
||||
project_name=None,
|
||||
):
|
||||
super().__init__(
|
||||
num_producers,
|
||||
@@ -69,6 +70,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.accum_count = 0
|
||||
self.generate_config = generate_config
|
||||
self.training_config = training_config
|
||||
self.project_name = project_name
|
||||
|
||||
# Reference model is initialized from policy model.
|
||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
@@ -94,9 +96,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
self.policy_loss_fn = PolicyLoss()
|
||||
self.global_step = 0
|
||||
if use_wandb and self.rank == 0:
|
||||
name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
|
||||
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)
|
||||
self.use_wandb = use_wandb
|
||||
|
||||
self.lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=self.optimizer,
|
||||
@@ -107,10 +107,19 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
if self.use_wandb and (
|
||||
(not self.plugin.pp_size > 1 and self.rank == 0)
|
||||
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage())
|
||||
):
|
||||
# Initialize wandb.
|
||||
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
|
||||
self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name)
|
||||
|
||||
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
||||
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
||||
)
|
||||
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
||||
self.plugin.logger.set_level("ERROR")
|
||||
|
||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||
"""
|
||||
@@ -168,6 +177,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
).repeat_interleave(self.num_generations, dim=0)
|
||||
)
|
||||
mean_kl, mean_loss = [], []
|
||||
|
||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
|
||||
input_ids_forward_micro_batch = data["input_ids"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
@@ -186,112 +196,210 @@ class GRPOConsumer(BaseConsumer):
|
||||
advantages_forward_micro_batch = advantages[
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
]
|
||||
policy_model_logits = self.policy_model(
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
attention_mask=attention_mask_forward_micro_batch,
|
||||
).logits
|
||||
action_log_probs = calc_action_log_probs(
|
||||
policy_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
reference_model_logits = self.reference_model(
|
||||
if self.plugin.pp_size > 1:
|
||||
# Support training with PP.
|
||||
|
||||
with torch.no_grad():
|
||||
reference_model_outputs = self.booster.execute_pipeline(
|
||||
iter(
|
||||
[
|
||||
{
|
||||
"input_ids": input_ids_forward_micro_batch,
|
||||
"attention_mask": attention_mask_forward_micro_batch,
|
||||
}
|
||||
]
|
||||
),
|
||||
self.reference_model,
|
||||
criterion=lambda outputs, inputs: torch.tensor(
|
||||
[0.0], device=action_mask.device
|
||||
), # dummy criterion
|
||||
optimizer=None,
|
||||
return_loss=False,
|
||||
return_outputs=True,
|
||||
)
|
||||
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
reference_model_logits = reference_model_outputs["outputs"]["logits"]
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
else:
|
||||
# Dummy reference logprobs for data iterator.
|
||||
reference_action_log_probs = None
|
||||
|
||||
data_policy_forward = {
|
||||
"input_ids": input_ids_forward_micro_batch,
|
||||
"attention_mask": attention_mask_forward_micro_batch,
|
||||
"action_mask": action_mask_forward_micro_batch,
|
||||
"reference_action_log_probs": reference_action_log_probs,
|
||||
"advantages": advantages_forward_micro_batch,
|
||||
"loss_mask": loss_mask_forward_micro_batch,
|
||||
"source": self.rank,
|
||||
}
|
||||
|
||||
kl = []
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
action_logits = outputs.logits
|
||||
action_log_probs = calc_action_log_probs(
|
||||
action_logits / self.generate_config["temperature"],
|
||||
inputs["input_ids"],
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
per_token_kl = (
|
||||
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
||||
- (inputs["reference_action_log_probs"] - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
|
||||
inputs["action_mask"], dim=-1
|
||||
)
|
||||
kl.append(appox_kl.mean())
|
||||
loss, skip_update, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
action_log_probs,
|
||||
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
inputs["action_mask"],
|
||||
loss_mask=inputs["loss_mask"],
|
||||
)
|
||||
return loss
|
||||
|
||||
policy_model_outputs = self.booster.execute_pipeline(
|
||||
iter([data_policy_forward]),
|
||||
self.policy_model,
|
||||
criterion=_criterion,
|
||||
optimizer=self.optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True,
|
||||
)
|
||||
loss = policy_model_outputs["loss"]
|
||||
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
if len(kl) > 0:
|
||||
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin)
|
||||
mean_kl.append(kl)
|
||||
loss = all_reduce_mean(loss, self.plugin)
|
||||
mean_loss.append(loss.data)
|
||||
else:
|
||||
|
||||
policy_model_logits = self.policy_model(
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
attention_mask=attention_mask_forward_micro_batch,
|
||||
).logits
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
action_log_probs = calc_action_log_probs(
|
||||
policy_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
- (reference_action_log_probs - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
|
||||
action_mask_forward_micro_batch, dim=-1
|
||||
)
|
||||
with torch.no_grad():
|
||||
reference_model_logits = self.reference_model(
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
attention_mask=attention_mask_forward_micro_batch,
|
||||
).logits
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
- (reference_action_log_probs - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
|
||||
action_mask_forward_micro_batch, dim=-1
|
||||
)
|
||||
|
||||
loss, skip_update, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
old_action_log_probs,
|
||||
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
action_mask_forward_micro_batch,
|
||||
loss_mask=loss_mask_forward_micro_batch,
|
||||
)
|
||||
loss, skip_update, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
old_action_log_probs,
|
||||
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
action_mask_forward_micro_batch,
|
||||
loss_mask=loss_mask_forward_micro_batch,
|
||||
)
|
||||
|
||||
if not skip_update:
|
||||
self.booster.backward(loss, self.optimizer)
|
||||
loss = all_reduce_mean(loss, self.plugin)
|
||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||
# Calculate accumulate value.
|
||||
mean_kl.append(kl.data)
|
||||
mean_loss.append(loss.data)
|
||||
|
||||
reward = all_reduce_mean(reward.mean(), self.plugin)
|
||||
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
||||
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
|
||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||
self.accum_reward.add_(reward.data)
|
||||
self.accum_format_reward.add_(format_reward.data)
|
||||
self.accum_acc_reward.add_(acc_reward.data)
|
||||
self.accum_advantages.add_(advantages.data)
|
||||
self.accum_response_length.add_(response_length.data)
|
||||
self.accum_count += 1
|
||||
if not skip_update:
|
||||
self.booster.backward(loss, self.optimizer)
|
||||
loss = all_reduce_mean(loss, self.plugin)
|
||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||
# Calculate accumulate value.
|
||||
mean_kl.append(kl.data)
|
||||
mean_loss.append(loss.data)
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
|
||||
):
|
||||
reward = all_reduce_mean(reward.mean(), self.plugin)
|
||||
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
||||
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
|
||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||
self.accum_reward.add_(reward.data)
|
||||
self.accum_format_reward.add_(format_reward.data)
|
||||
self.accum_acc_reward.add_(acc_reward.data)
|
||||
self.accum_advantages.add_(advantages.data)
|
||||
self.accum_response_length.add_(response_length.data)
|
||||
self.accum_count += 1
|
||||
if need_update:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if self.rank == 0:
|
||||
print(
|
||||
"Loss:",
|
||||
self.accum_loss.item() / self.accum_count,
|
||||
"\nReward:",
|
||||
self.accum_reward.item() / self.accum_count,
|
||||
"\nFormat Reward:",
|
||||
self.accum_format_reward.item() / self.accum_count,
|
||||
"\nAcc Reward:",
|
||||
self.accum_acc_reward.item() / self.accum_count,
|
||||
"\nKL:",
|
||||
self.accum_kl.item() / self.accum_count,
|
||||
"\nAdvantages:",
|
||||
self.accum_advantages.item() / self.accum_count,
|
||||
"\nResponse Length:",
|
||||
self.accum_response_length.item() / self.accum_count,
|
||||
)
|
||||
self.wandb_run.log(
|
||||
{
|
||||
"metrics/reward": self.accum_reward.item() / self.accum_count,
|
||||
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
|
||||
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
|
||||
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
|
||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||
"train/kl": self.accum_kl.item() / self.accum_count,
|
||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||
}
|
||||
)
|
||||
self.accum_loss.zero_()
|
||||
self.accum_reward.zero_()
|
||||
self.accum_acc_reward.zero_()
|
||||
self.accum_format_reward.zero_()
|
||||
self.accum_kl.zero_()
|
||||
self.accum_advantages.zero_()
|
||||
self.accum_response_length.zero_()
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
|
||||
):
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
|
||||
):
|
||||
print(
|
||||
"Loss:",
|
||||
self.accum_loss.item() / self.accum_count,
|
||||
"\nReward:",
|
||||
self.accum_reward.item() / self.accum_count,
|
||||
"\nFormat Reward:",
|
||||
self.accum_format_reward.item() / self.accum_count,
|
||||
"\nAcc Reward:",
|
||||
self.accum_acc_reward.item() / self.accum_count,
|
||||
"\nKL:",
|
||||
self.accum_kl.item() / self.accum_count,
|
||||
"\nAdvantages:",
|
||||
self.accum_advantages.item() / self.accum_count,
|
||||
"\nResponse Length:",
|
||||
self.accum_response_length.item() / self.accum_count,
|
||||
)
|
||||
self.wandb_run.log(
|
||||
{
|
||||
"metrics/reward": self.accum_reward.item() / self.accum_count,
|
||||
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
|
||||
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
|
||||
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
|
||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||
"train/kl": self.accum_kl.item() / self.accum_count,
|
||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||
}
|
||||
)
|
||||
self.accum_loss.zero_()
|
||||
self.accum_reward.zero_()
|
||||
self.accum_acc_reward.zero_()
|
||||
self.accum_format_reward.zero_()
|
||||
self.accum_kl.zero_()
|
||||
self.accum_advantages.zero_()
|
||||
self.accum_response_length.zero_()
|
||||
|
||||
self.accum_count = 0
|
||||
return loss_scalar
|
||||
self.accum_count = 0
|
||||
return loss_scalar
|
||||
|
||||
def state_dict(self):
|
||||
self.policy_model._force_wait_all_gather()
|
||||
|
Reference in New Issue
Block a user