mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-22 11:13:13 +00:00
optimize pp log_softmax OOM
This commit is contained in:
parent
0e69b98c28
commit
30a6859f77
2
.gitignore
vendored
2
.gitignore
vendored
@ -171,3 +171,5 @@ applications/ColossalChat/*.txt
|
|||||||
applications/ColossalChat/*.db
|
applications/ColossalChat/*.db
|
||||||
applications/ColossalChat/stdin
|
applications/ColossalChat/stdin
|
||||||
applications/ColossalChat/*.zip
|
applications/ColossalChat/*.zip
|
||||||
|
applications/ColossalChat/*.prof
|
||||||
|
applications/ColossalChat/*.png
|
||||||
|
@ -280,13 +280,21 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.booster.plugin.stage_manager.is_last_stage():
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
reference_model_logits = reference_model_outputs["outputs"]["logits"]
|
reference_action_log_probs = torch.zeros(
|
||||||
reference_action_log_probs = calc_action_log_probs(
|
(input_ids_forward_micro_batch.size(0), num_action),
|
||||||
reference_model_logits / self.generate_config["temperature"],
|
device=input_ids_forward_micro_batch.device,
|
||||||
input_ids_forward_micro_batch,
|
|
||||||
num_action,
|
|
||||||
self.plugin.shard_config,
|
|
||||||
)
|
)
|
||||||
|
for i in range(reference_action_log_probs.size(0)):
|
||||||
|
# activation for log_softmax is too large if vocab size and sequence length are large
|
||||||
|
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
|
||||||
|
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
|
||||||
|
reference_action_log_probs[i, :] += calc_action_log_probs(
|
||||||
|
reference_model_outputs["outputs"]["logits"][i : i + 1]
|
||||||
|
/ self.generate_config["temperature"],
|
||||||
|
input_ids_forward_micro_batch[i : i + 1],
|
||||||
|
num_action,
|
||||||
|
self.plugin.shard_config,
|
||||||
|
)[0]
|
||||||
else:
|
else:
|
||||||
# Dummy reference logprobs for data iterator.
|
# Dummy reference logprobs for data iterator.
|
||||||
reference_action_log_probs = None
|
reference_action_log_probs = None
|
||||||
@ -308,12 +316,19 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
def _criterion(outputs, inputs):
|
def _criterion(outputs, inputs):
|
||||||
action_logits = outputs.logits
|
action_logits = outputs.logits
|
||||||
action_log_probs = calc_action_log_probs(
|
action_log_probs = torch.zeros(
|
||||||
action_logits / self.generate_config["temperature"],
|
(inputs["input_ids"].size(0), num_action), device=action_logits.device
|
||||||
inputs["input_ids"],
|
|
||||||
num_action,
|
|
||||||
self.plugin.shard_config,
|
|
||||||
)
|
)
|
||||||
|
for i in range(action_log_probs.size(0)):
|
||||||
|
# activation for log_softmax is too large if vocab size and sequence length are large
|
||||||
|
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
|
||||||
|
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
|
||||||
|
action_log_probs[i, :] += calc_action_log_probs(
|
||||||
|
action_logits[i : i + 1] / self.generate_config["temperature"],
|
||||||
|
inputs["input_ids"][i : i + 1],
|
||||||
|
num_action,
|
||||||
|
self.plugin.shard_config,
|
||||||
|
)[0]
|
||||||
if "reference_action_log_probs" in inputs:
|
if "reference_action_log_probs" in inputs:
|
||||||
per_token_kl = (
|
per_token_kl = (
|
||||||
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
||||||
|
Loading…
Reference in New Issue
Block a user