mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-30 23:56:42 +00:00
support PP training
This commit is contained in:
@@ -54,7 +54,6 @@ class BaseConsumer:
|
||||
|
||||
self.model_config = model_config
|
||||
self.plugin_config = plugin_config
|
||||
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
|
||||
|
||||
self.device = get_current_device()
|
||||
self.lr_scheduler = None
|
||||
|
@@ -96,7 +96,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.global_step = 0
|
||||
if use_wandb and self.rank == 0:
|
||||
name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
|
||||
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)
|
||||
self.wandb_run = wandb.init(project="GRPO-V1-PP", sync_tensorboard=True, dir="./wandb", name=name)
|
||||
|
||||
self.lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=self.optimizer,
|
||||
@@ -168,72 +168,120 @@ class GRPOConsumer(BaseConsumer):
|
||||
).repeat_interleave(self.num_generations, dim=0)
|
||||
)
|
||||
mean_kl, mean_loss = [], []
|
||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
|
||||
input_ids_forward_micro_batch = data["input_ids"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
]
|
||||
attention_mask_forward_micro_batch = data["attention_mask"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
]
|
||||
action_mask_forward_micro_batch = action_mask[
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
]
|
||||
loss_mask_forward_micro_batch = (
|
||||
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
|
||||
if loss_mask is not None
|
||||
else None
|
||||
)
|
||||
advantages_forward_micro_batch = advantages[
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
]
|
||||
policy_model_logits = self.policy_model(
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
attention_mask=attention_mask_forward_micro_batch,
|
||||
).logits
|
||||
action_log_probs = calc_action_log_probs(
|
||||
policy_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
if self.plugin.pp_size > 1:
|
||||
# Support training with PP.
|
||||
data_iter = iter([data])
|
||||
|
||||
with torch.no_grad():
|
||||
reference_model_logits = self.reference_model(
|
||||
reference_model_outputs = self.booster.execute_pipeline(
|
||||
data_iter,
|
||||
self.reference_model,
|
||||
criterion=lambda outputs, inputs: outputs.logits.mean(), # dummy criterion
|
||||
optimizer=None,
|
||||
return_loss=False,
|
||||
return_outputs=True,
|
||||
)
|
||||
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
reference_model_logits = reference_model_outputs["outputs"]["logits"]
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
data["input_ids"],
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
else:
|
||||
# Dummy reference logprobs for data iterator.
|
||||
reference_action_log_probs = torch.zeros(
|
||||
(old_action_log_probs.size(0), old_action_log_probs.size(1))
|
||||
)
|
||||
|
||||
data["reference_action_log_probs"] = reference_action_log_probs
|
||||
|
||||
data_iter = iter([data])
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
pass
|
||||
|
||||
outputs = self.booster.execute_pipeline(
|
||||
data_iter,
|
||||
self.policy_model,
|
||||
criterion=_criterion,
|
||||
optimizer=self.optimizer,
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
loss = all_reduce_mean(loss, self.plugin)
|
||||
mean_loss.append(loss.data)
|
||||
else:
|
||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
|
||||
input_ids_forward_micro_batch = data["input_ids"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
]
|
||||
attention_mask_forward_micro_batch = data["attention_mask"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
]
|
||||
action_mask_forward_micro_batch = action_mask[
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
]
|
||||
loss_mask_forward_micro_batch = (
|
||||
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
|
||||
if loss_mask is not None
|
||||
else None
|
||||
)
|
||||
advantages_forward_micro_batch = advantages[
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
]
|
||||
policy_model_logits = self.policy_model(
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
attention_mask=attention_mask_forward_micro_batch,
|
||||
).logits
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
action_log_probs = calc_action_log_probs(
|
||||
policy_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
- (reference_action_log_probs - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
|
||||
action_mask_forward_micro_batch, dim=-1
|
||||
)
|
||||
with torch.no_grad():
|
||||
reference_model_logits = self.reference_model(
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
attention_mask=attention_mask_forward_micro_batch,
|
||||
).logits
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
|
||||
loss, skip_update, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
old_action_log_probs,
|
||||
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
action_mask_forward_micro_batch,
|
||||
loss_mask=loss_mask_forward_micro_batch,
|
||||
)
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
- (reference_action_log_probs - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
|
||||
action_mask_forward_micro_batch, dim=-1
|
||||
)
|
||||
|
||||
if not skip_update:
|
||||
self.booster.backward(loss, self.optimizer)
|
||||
loss = all_reduce_mean(loss, self.plugin)
|
||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||
# Calculate accumulate value.
|
||||
mean_kl.append(kl.data)
|
||||
mean_loss.append(loss.data)
|
||||
loss, skip_update, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
old_action_log_probs,
|
||||
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
action_mask_forward_micro_batch,
|
||||
loss_mask=loss_mask_forward_micro_batch,
|
||||
)
|
||||
|
||||
if not skip_update:
|
||||
self.booster.backward(loss, self.optimizer)
|
||||
loss = all_reduce_mean(loss, self.plugin)
|
||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||
# Calculate accumulate value.
|
||||
mean_kl.append(kl.data)
|
||||
mean_loss.append(loss.data)
|
||||
|
||||
reward = all_reduce_mean(reward.mean(), self.plugin)
|
||||
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
||||
|
@@ -31,7 +31,13 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.",
|
||||
)
|
||||
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2, help="Number of samples per device.")
|
||||
parser.add_argument(
|
||||
"-tmbs",
|
||||
"--train-microbatch-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of samples per device. PP micro batchsize when PP is activated.",
|
||||
)
|
||||
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
||||
args = parser.parse_args()
|
||||
@@ -45,11 +51,7 @@ if __name__ == "__main__":
|
||||
ray.init(address="local", namespace="ray-example")
|
||||
|
||||
inference_model_config = dict(path=args.model)
|
||||
train_model_config = dict(
|
||||
path=args.model,
|
||||
# use_flash_attention_2=True,
|
||||
# use_cache=False
|
||||
)
|
||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
||||
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
|
||||
|
||||
if args.backend == "transformers":
|
||||
@@ -107,7 +109,7 @@ if __name__ == "__main__":
|
||||
generate_config=generate_config,
|
||||
num_generations=args.num_generations,
|
||||
train_model_config=train_model_config,
|
||||
plugin_config={},
|
||||
plugin_config={"pp_size": 2, "tp_size": 1, "microbatch_size": 2, "zero_stage": 0},
|
||||
inference_backend=args.backend,
|
||||
master_addr="localhost",
|
||||
master_port=29505,
|
||||
|
@@ -1411,8 +1411,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
)
|
||||
|
||||
# run with gradients accumulation
|
||||
if model.require_grad_sync == False or (
|
||||
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
|
||||
if (
|
||||
not torch.is_grad_enabled()
|
||||
or model.require_grad_sync == False
|
||||
or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
|
||||
):
|
||||
return outputs
|
||||
|
||||
|
@@ -284,6 +284,7 @@ class Qwen2PipelineForwards:
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
|
Reference in New Issue
Block a user