mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-23 11:44:15 +00:00
implement memory efficient logprob
This commit is contained in:
parent
30a6859f77
commit
e3d56cbd86
@ -6,7 +6,7 @@ import torch
|
|||||||
import wandb
|
import wandb
|
||||||
from coati.distributed.consumer import BaseConsumer
|
from coati.distributed.consumer import BaseConsumer
|
||||||
from coati.distributed.loss import PolicyLoss
|
from coati.distributed.loss import PolicyLoss
|
||||||
from coati.distributed.utils import calc_action_log_probs
|
from coati.distributed.utils import memory_efficient_logprob
|
||||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
@ -280,21 +280,12 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.booster.plugin.stage_manager.is_last_stage():
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
reference_action_log_probs = torch.zeros(
|
reference_action_log_probs = memory_efficient_logprob(
|
||||||
(input_ids_forward_micro_batch.size(0), num_action),
|
reference_model_outputs["outputs"]["logits"],
|
||||||
device=input_ids_forward_micro_batch.device,
|
input_ids_forward_micro_batch,
|
||||||
|
num_action,
|
||||||
|
shard_config=self.plugin.shard_config,
|
||||||
)
|
)
|
||||||
for i in range(reference_action_log_probs.size(0)):
|
|
||||||
# activation for log_softmax is too large if vocab size and sequence length are large
|
|
||||||
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
|
|
||||||
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
|
|
||||||
reference_action_log_probs[i, :] += calc_action_log_probs(
|
|
||||||
reference_model_outputs["outputs"]["logits"][i : i + 1]
|
|
||||||
/ self.generate_config["temperature"],
|
|
||||||
input_ids_forward_micro_batch[i : i + 1],
|
|
||||||
num_action,
|
|
||||||
self.plugin.shard_config,
|
|
||||||
)[0]
|
|
||||||
else:
|
else:
|
||||||
# Dummy reference logprobs for data iterator.
|
# Dummy reference logprobs for data iterator.
|
||||||
reference_action_log_probs = None
|
reference_action_log_probs = None
|
||||||
@ -316,19 +307,12 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
def _criterion(outputs, inputs):
|
def _criterion(outputs, inputs):
|
||||||
action_logits = outputs.logits
|
action_logits = outputs.logits
|
||||||
action_log_probs = torch.zeros(
|
action_log_probs = memory_efficient_logprob(
|
||||||
(inputs["input_ids"].size(0), num_action), device=action_logits.device
|
action_logits,
|
||||||
|
inputs["input_ids"],
|
||||||
|
num_action,
|
||||||
|
shard_config=self.plugin.shard_config,
|
||||||
)
|
)
|
||||||
for i in range(action_log_probs.size(0)):
|
|
||||||
# activation for log_softmax is too large if vocab size and sequence length are large
|
|
||||||
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
|
|
||||||
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
|
|
||||||
action_log_probs[i, :] += calc_action_log_probs(
|
|
||||||
action_logits[i : i + 1] / self.generate_config["temperature"],
|
|
||||||
inputs["input_ids"][i : i + 1],
|
|
||||||
num_action,
|
|
||||||
self.plugin.shard_config,
|
|
||||||
)[0]
|
|
||||||
if "reference_action_log_probs" in inputs:
|
if "reference_action_log_probs" in inputs:
|
||||||
per_token_kl = (
|
per_token_kl = (
|
||||||
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
||||||
@ -370,16 +354,15 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
mean_kl.append(kl)
|
mean_kl.append(kl)
|
||||||
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
|
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
policy_model_logits = self.policy_model(
|
policy_model_logits = self.policy_model(
|
||||||
input_ids=input_ids_forward_micro_batch,
|
input_ids=input_ids_forward_micro_batch,
|
||||||
attention_mask=attention_mask_forward_micro_batch,
|
attention_mask=attention_mask_forward_micro_batch,
|
||||||
).logits
|
).logits
|
||||||
action_log_probs = calc_action_log_probs(
|
action_log_probs = memory_efficient_logprob(
|
||||||
policy_model_logits / self.generate_config["temperature"],
|
policy_model_logits / self.generate_config["temperature"],
|
||||||
input_ids_forward_micro_batch,
|
input_ids_forward_micro_batch,
|
||||||
num_action,
|
num_action,
|
||||||
self.plugin.shard_config,
|
shard_config=self.plugin.shard_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.policy_loss_fn.beta > 0:
|
if self.policy_loss_fn.beta > 0:
|
||||||
@ -388,7 +371,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
input_ids=input_ids_forward_micro_batch,
|
input_ids=input_ids_forward_micro_batch,
|
||||||
attention_mask=attention_mask_forward_micro_batch,
|
attention_mask=attention_mask_forward_micro_batch,
|
||||||
).logits
|
).logits
|
||||||
reference_action_log_probs = calc_action_log_probs(
|
reference_action_log_probs = memory_efficient_logprob(
|
||||||
reference_model_logits / self.generate_config["temperature"],
|
reference_model_logits / self.generate_config["temperature"],
|
||||||
input_ids_forward_micro_batch,
|
input_ids_forward_micro_batch,
|
||||||
num_action,
|
num_action,
|
||||||
|
@ -71,31 +71,43 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
|
|||||||
return per_label_logps.squeeze(-1)
|
return per_label_logps.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
def calc_action_log_probs(
|
def memory_efficient_logprob(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sequences: torch.LongTensor,
|
inputs: torch.Tensor,
|
||||||
num_actions: int,
|
num_action: int,
|
||||||
shard_config,
|
chunk_size: int = 2048,
|
||||||
|
shard_config: Any = None,
|
||||||
vocab_size: int = None,
|
vocab_size: int = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Calculate action log probs.
|
"""
|
||||||
|
Calculate action log probs in a memory-efficient way by processing in chunks.
|
||||||
Args:
|
Args:
|
||||||
logits (torch.Tensor): Output tensor of Actor.forward.logits.
|
logits (torch.Tensor): Output tensor of Actor.forward.logits.
|
||||||
sequences (torch.LongTensor): Input sequences.
|
inputs (torch.LongTensor): Input sequences.
|
||||||
num_actions (int): Number of actions.
|
num_action (int): Number of actions.
|
||||||
shard_config
|
chunk_size (int, optional): Size of each chunk to process. Default is 2048.
|
||||||
vocab_size
|
shard_config: Shard configuration for distributed computation.
|
||||||
|
vocab_size (int, optional): Vocabulary size. Default is None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Action log probs.
|
torch.Tensor: Action log probs.
|
||||||
"""
|
"""
|
||||||
# labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
|
action_log_probs = torch.zeros((logits.size(0), num_action), device=logits.device, dtype=logits.dtype)
|
||||||
# logits: torch.Tensor, # [B, S, Vocab_size]
|
context_length = logits.size(1) - num_action
|
||||||
log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype)
|
for i in range(action_log_probs.size(0)):
|
||||||
log_probs = log_probs.squeeze(-1)
|
# loop over each sample in the micro-batch
|
||||||
return log_probs[:, -num_actions:]
|
for start in range(context_length, logits.size(1), chunk_size):
|
||||||
|
end = min(start + chunk_size, logits.size(1))
|
||||||
|
# calculate log probs in chunks to save memory
|
||||||
|
log_probs = dist_log_prob(
|
||||||
|
inputs[i : i + 1, start - 1 : end],
|
||||||
|
logits[i : i + 1, start - 1 : end],
|
||||||
|
shard_config,
|
||||||
|
vocab_size,
|
||||||
|
logits.dtype,
|
||||||
|
) # [1, chunk_size, 1]
|
||||||
|
log_probs = log_probs.squeeze(-1)
|
||||||
|
action_log_probs[i, start - context_length : end - context_length] += log_probs[0]
|
||||||
|
return action_log_probs
|
||||||
|
|
||||||
|
|
||||||
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
||||||
|
Loading…
Reference in New Issue
Block a user