optimize pp log_softmax OOM

This commit is contained in:
YeAnbang 2025-06-13 18:21:54 +08:00
parent 0e69b98c28
commit 30a6859f77
2 changed files with 28 additions and 11 deletions

2
.gitignore vendored
View File

@ -171,3 +171,5 @@ applications/ColossalChat/*.txt
applications/ColossalChat/*.db applications/ColossalChat/*.db
applications/ColossalChat/stdin applications/ColossalChat/stdin
applications/ColossalChat/*.zip applications/ColossalChat/*.zip
applications/ColossalChat/*.prof
applications/ColossalChat/*.png

View File

@ -280,13 +280,21 @@ class GRPOConsumer(BaseConsumer):
) )
if self.booster.plugin.stage_manager.is_last_stage(): if self.booster.plugin.stage_manager.is_last_stage():
reference_model_logits = reference_model_outputs["outputs"]["logits"] reference_action_log_probs = torch.zeros(
reference_action_log_probs = calc_action_log_probs( (input_ids_forward_micro_batch.size(0), num_action),
reference_model_logits / self.generate_config["temperature"], device=input_ids_forward_micro_batch.device,
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
) )
for i in range(reference_action_log_probs.size(0)):
# activation for log_softmax is too large if vocab size and sequence length are large
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
reference_action_log_probs[i, :] += calc_action_log_probs(
reference_model_outputs["outputs"]["logits"][i : i + 1]
/ self.generate_config["temperature"],
input_ids_forward_micro_batch[i : i + 1],
num_action,
self.plugin.shard_config,
)[0]
else: else:
# Dummy reference logprobs for data iterator. # Dummy reference logprobs for data iterator.
reference_action_log_probs = None reference_action_log_probs = None
@ -308,12 +316,19 @@ class GRPOConsumer(BaseConsumer):
def _criterion(outputs, inputs): def _criterion(outputs, inputs):
action_logits = outputs.logits action_logits = outputs.logits
action_log_probs = calc_action_log_probs( action_log_probs = torch.zeros(
action_logits / self.generate_config["temperature"], (inputs["input_ids"].size(0), num_action), device=action_logits.device
inputs["input_ids"],
num_action,
self.plugin.shard_config,
) )
for i in range(action_log_probs.size(0)):
# activation for log_softmax is too large if vocab size and sequence length are large
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
action_log_probs[i, :] += calc_action_log_probs(
action_logits[i : i + 1] / self.generate_config["temperature"],
inputs["input_ids"][i : i + 1],
num_action,
self.plugin.shard_config,
)[0]
if "reference_action_log_probs" in inputs: if "reference_action_log_probs" in inputs:
per_token_kl = ( per_token_kl = (
torch.exp(inputs["reference_action_log_probs"] - action_log_probs) torch.exp(inputs["reference_action_log_probs"] - action_log_probs)