add pp support

This commit is contained in:
YeAnbang 2025-04-04 10:05:16 +08:00
parent d961a5f725
commit 09a3173a49
4 changed files with 198 additions and 120 deletions

1
.gitignore vendored
View File

@ -164,3 +164,4 @@ coverage.xml
applications/ColossalChat/logs
applications/ColossalChat/tests/logs
applications/ColossalChat/wandb
applications/ColossalChat/model

View File

@ -94,7 +94,6 @@ class BaseConsumer:
i = 0
for _ in range(self.num_recv_per_update):
# receive data from producers
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.buffer.extend(

View File

@ -94,9 +94,7 @@ class GRPOConsumer(BaseConsumer):
self.policy_loss_fn = PolicyLoss()
self.global_step = 0
if use_wandb and self.rank == 0:
name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
self.wandb_run = wandb.init(project="GRPO-V1-PP", sync_tensorboard=True, dir="./wandb", name=name)
self.use_wandb = use_wandb
self.lr_scheduler = CosineAnnealingWarmupLR(
optimizer=self.optimizer,
@ -107,10 +105,19 @@ class GRPOConsumer(BaseConsumer):
def setup(self):
super().setup()
if self.use_wandb and (
(not self.plugin.pp_size > 1 and self.rank == 0)
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage())
):
# Initialize wandb.
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
self.wandb_run = wandb.init(project="GRPO-V1-PP", sync_tensorboard=True, dir="./wandb", name=name)
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
)
self.reference_model, *_ = self.booster.boost(self.reference_model)
self.plugin.logger.set_level("ERROR")
def step(self, step_idx: int, **kwargs) -> Optional[float]:
"""
@ -168,72 +175,130 @@ class GRPOConsumer(BaseConsumer):
).repeat_interleave(self.num_generations, dim=0)
)
mean_kl, mean_loss = [], []
if self.plugin.pp_size > 1:
# Support training with PP.
data_iter = iter([data])
with torch.no_grad():
reference_model_outputs = self.booster.execute_pipeline(
data_iter,
self.reference_model,
criterion=lambda outputs, inputs: outputs.logits.mean(), # dummy criterion
optimizer=None,
return_loss=False,
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
input_ids_forward_micro_batch = data["input_ids"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
attention_mask_forward_micro_batch = data["attention_mask"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
action_mask_forward_micro_batch = action_mask[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
loss_mask_forward_micro_batch = (
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
if loss_mask is not None
else None
)
advantages_forward_micro_batch = advantages[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
if self.plugin.pp_size > 1:
# Support training with PP.
with torch.no_grad():
reference_model_outputs = self.booster.execute_pipeline(
iter(
[
{
"input_ids": input_ids_forward_micro_batch,
"attention_mask": attention_mask_forward_micro_batch,
}
]
),
self.reference_model,
criterion=lambda outputs, inputs: torch.tensor(
[0.0], device=action_mask.device
), # dummy criterion
optimizer=None,
return_loss=False,
return_outputs=True,
)
if self.booster.plugin.stage_manager.is_last_stage():
reference_model_logits = reference_model_outputs["outputs"]["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
else:
# Dummy reference logprobs for data iterator.
reference_action_log_probs = None
data_policy_forward = {
"input_ids": input_ids_forward_micro_batch,
"attention_mask": attention_mask_forward_micro_batch,
"action_mask": action_mask_forward_micro_batch,
"reference_action_log_probs": reference_action_log_probs,
"advantages": advantages_forward_micro_batch,
"loss_mask": loss_mask_forward_micro_batch,
"source": self.rank,
}
def _criterion(outputs, inputs):
action_logits = outputs.logits
action_log_probs = calc_action_log_probs(
action_logits / self.generate_config["temperature"],
inputs["input_ids"],
num_action,
self.plugin.shard_config,
)
per_token_kl = (
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
- (inputs["reference_action_log_probs"] - action_log_probs)
- 1
)
decode_tokens_100 = self.tokenizer.batch_decode(
input_ids_forward_micro_batch[:, -num_action:],
skip_special_tokens=False,
)
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
action_log_probs,
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
inputs["action_mask"],
loss_mask=inputs["loss_mask"],
)
return loss
policy_model_outputs = self.booster.execute_pipeline(
iter([data_policy_forward]),
self.policy_model,
criterion=_criterion,
optimizer=self.optimizer,
return_loss=True,
return_outputs=True,
)
loss = policy_model_outputs["loss"]
if self.booster.plugin.stage_manager.is_last_stage():
reference_model_logits = reference_model_outputs["outputs"]["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
data["input_ids"],
num_action,
self.plugin.shard_config,
)
if self.booster.plugin.stage_manager.is_last_stage():
# calculate kl
action_logits = policy_model_outputs["outputs"]["logits"]
action_log_probs = calc_action_log_probs(
action_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
- (reference_action_log_probs - action_log_probs)
- 1
)
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
action_mask_forward_micro_batch, dim=-1
)
kl = all_reduce_mean(kl.mean(), self.plugin)
loss = all_reduce_mean(loss, self.plugin)
mean_loss.append(loss.data)
mean_kl.append(kl)
else:
# Dummy reference logprobs for data iterator.
reference_action_log_probs = torch.zeros(
(old_action_log_probs.size(0), old_action_log_probs.size(1))
)
data["reference_action_log_probs"] = reference_action_log_probs
data_iter = iter([data])
def _criterion(outputs, inputs):
pass
outputs = self.booster.execute_pipeline(
data_iter,
self.policy_model,
criterion=_criterion,
optimizer=self.optimizer,
return_loss=True,
)
loss = outputs["loss"]
if self.booster.plugin.stage_manager.is_last_stage():
loss = all_reduce_mean(loss, self.plugin)
mean_loss.append(loss.data)
else:
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
input_ids_forward_micro_batch = data["input_ids"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
attention_mask_forward_micro_batch = data["attention_mask"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
action_mask_forward_micro_batch = action_mask[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
loss_mask_forward_micro_batch = (
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
if loss_mask is not None
else None
)
advantages_forward_micro_batch = advantages[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
policy_model_logits = self.policy_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
@ -256,7 +321,6 @@ class GRPOConsumer(BaseConsumer):
num_action,
self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
- (reference_action_log_probs - action_log_probs)
@ -282,64 +346,71 @@ class GRPOConsumer(BaseConsumer):
# Calculate accumulate value.
mean_kl.append(kl.data)
mean_loss.append(loss.data)
reward = all_reduce_mean(reward.mean(), self.plugin)
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
advantages = all_reduce_mean(advantages.mean(), self.plugin)
response_length = all_reduce_mean(response_length.mean(), self.plugin)
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
self.accum_reward.add_(reward.data)
self.accum_format_reward.add_(format_reward.data)
self.accum_acc_reward.add_(acc_reward.data)
self.accum_advantages.add_(advantages.data)
self.accum_response_length.add_(response_length.data)
self.accum_count += 1
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
):
reward = all_reduce_mean(reward.mean(), self.plugin)
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
advantages = all_reduce_mean(advantages.mean(), self.plugin)
response_length = all_reduce_mean(response_length.mean(), self.plugin)
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
self.accum_reward.add_(reward.data)
self.accum_format_reward.add_(format_reward.data)
self.accum_acc_reward.add_(acc_reward.data)
self.accum_advantages.add_(advantages.data)
self.accum_response_length.add_(response_length.data)
self.accum_count += 1
if need_update:
self.optimizer.step()
self.optimizer.zero_grad()
loss_scalar = self.accum_loss.item()
if self.rank == 0:
print(
"Loss:",
self.accum_loss.item() / self.accum_count,
"\nReward:",
self.accum_reward.item() / self.accum_count,
"\nFormat Reward:",
self.accum_format_reward.item() / self.accum_count,
"\nAcc Reward:",
self.accum_acc_reward.item() / self.accum_count,
"\nKL:",
self.accum_kl.item() / self.accum_count,
"\nAdvantages:",
self.accum_advantages.item() / self.accum_count,
"\nResponse Length:",
self.accum_response_length.item() / self.accum_count,
)
self.wandb_run.log(
{
"metrics/reward": self.accum_reward.item() / self.accum_count,
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
"train/loss": self.accum_loss.item() / self.accum_count,
"train/kl": self.accum_kl.item() / self.accum_count,
"train/advantages": self.accum_advantages.item() / self.accum_count,
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
}
)
self.accum_loss.zero_()
self.accum_reward.zero_()
self.accum_acc_reward.zero_()
self.accum_format_reward.zero_()
self.accum_kl.zero_()
self.accum_advantages.zero_()
self.accum_response_length.zero_()
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
):
loss_scalar = self.accum_loss.item()
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
):
print(
"Loss:",
self.accum_loss.item() / self.accum_count,
"\nReward:",
self.accum_reward.item() / self.accum_count,
"\nFormat Reward:",
self.accum_format_reward.item() / self.accum_count,
"\nAcc Reward:",
self.accum_acc_reward.item() / self.accum_count,
"\nKL:",
self.accum_kl.item() / self.accum_count,
"\nAdvantages:",
self.accum_advantages.item() / self.accum_count,
"\nResponse Length:",
self.accum_response_length.item() / self.accum_count,
)
self.wandb_run.log(
{
"metrics/reward": self.accum_reward.item() / self.accum_count,
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
"train/loss": self.accum_loss.item() / self.accum_count,
"train/kl": self.accum_kl.item() / self.accum_count,
"train/advantages": self.accum_advantages.item() / self.accum_count,
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
}
)
self.accum_loss.zero_()
self.accum_reward.zero_()
self.accum_acc_reward.zero_()
self.accum_format_reward.zero_()
self.accum_kl.zero_()
self.accum_advantages.zero_()
self.accum_response_length.zero_()
self.accum_count = 0
return loss_scalar
self.accum_count = 0
return loss_scalar
def state_dict(self):
self.policy_model._force_wait_all_gather()

View File

@ -109,7 +109,14 @@ if __name__ == "__main__":
generate_config=generate_config,
num_generations=args.num_generations,
train_model_config=train_model_config,
plugin_config={"pp_size": 2, "tp_size": 1, "microbatch_size": 2, "zero_stage": 0},
# plugin_config={}, # for zero
plugin_config={
"pp_size": 2,
"tp_size": 1,
"microbatch_size": args.train_microbatch_size // 2,
"zero_stage": 0,
"max_norm": 1.0,
}, # for pp
inference_backend=args.backend,
master_addr="localhost",
master_port=29505,