diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 040ca61a4..5e8f329eb 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -6,7 +6,7 @@ import torch import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss -from coati.distributed.utils import calc_action_log_probs +from coati.distributed.utils import memory_efficient_logprob from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer @@ -280,21 +280,12 @@ class GRPOConsumer(BaseConsumer): ) if self.booster.plugin.stage_manager.is_last_stage(): - reference_action_log_probs = torch.zeros( - (input_ids_forward_micro_batch.size(0), num_action), - device=input_ids_forward_micro_batch.device, + reference_action_log_probs = memory_efficient_logprob( + reference_model_outputs["outputs"]["logits"], + input_ids_forward_micro_batch, + num_action, + shard_config=self.plugin.shard_config, ) - for i in range(reference_action_log_probs.size(0)): - # activation for log_softmax is too large if vocab size and sequence length are large - # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example), - # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB - reference_action_log_probs[i, :] += calc_action_log_probs( - reference_model_outputs["outputs"]["logits"][i : i + 1] - / self.generate_config["temperature"], - input_ids_forward_micro_batch[i : i + 1], - num_action, - self.plugin.shard_config, - )[0] else: # Dummy reference logprobs for data iterator. reference_action_log_probs = None @@ -316,19 +307,12 @@ class GRPOConsumer(BaseConsumer): def _criterion(outputs, inputs): action_logits = outputs.logits - action_log_probs = torch.zeros( - (inputs["input_ids"].size(0), num_action), device=action_logits.device + action_log_probs = memory_efficient_logprob( + action_logits, + inputs["input_ids"], + num_action, + shard_config=self.plugin.shard_config, ) - for i in range(action_log_probs.size(0)): - # activation for log_softmax is too large if vocab size and sequence length are large - # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example), - # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB - action_log_probs[i, :] += calc_action_log_probs( - action_logits[i : i + 1] / self.generate_config["temperature"], - inputs["input_ids"][i : i + 1], - num_action, - self.plugin.shard_config, - )[0] if "reference_action_log_probs" in inputs: per_token_kl = ( torch.exp(inputs["reference_action_log_probs"] - action_log_probs) @@ -370,16 +354,15 @@ class GRPOConsumer(BaseConsumer): mean_kl.append(kl) mean_loss.append(all_reduce_mean(loss, self.plugin).data) else: - policy_model_logits = self.policy_model( input_ids=input_ids_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch, ).logits - action_log_probs = calc_action_log_probs( + action_log_probs = memory_efficient_logprob( policy_model_logits / self.generate_config["temperature"], input_ids_forward_micro_batch, num_action, - self.plugin.shard_config, + shard_config=self.plugin.shard_config, ) if self.policy_loss_fn.beta > 0: @@ -388,7 +371,7 @@ class GRPOConsumer(BaseConsumer): input_ids=input_ids_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch, ).logits - reference_action_log_probs = calc_action_log_probs( + reference_action_log_probs = memory_efficient_logprob( reference_model_logits / self.generate_config["temperature"], input_ids_forward_micro_batch, num_action, diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index a40ebbcfb..d46243114 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -71,31 +71,43 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T return per_label_logps.squeeze(-1) -def calc_action_log_probs( +def memory_efficient_logprob( logits: torch.Tensor, - sequences: torch.LongTensor, - num_actions: int, - shard_config, + inputs: torch.Tensor, + num_action: int, + chunk_size: int = 2048, + shard_config: Any = None, vocab_size: int = None, ) -> torch.Tensor: - """Calculate action log probs. - + """ + Calculate action log probs in a memory-efficient way by processing in chunks. Args: logits (torch.Tensor): Output tensor of Actor.forward.logits. - sequences (torch.LongTensor): Input sequences. - num_actions (int): Number of actions. - shard_config - vocab_size - - + inputs (torch.LongTensor): Input sequences. + num_action (int): Number of actions. + chunk_size (int, optional): Size of each chunk to process. Default is 2048. + shard_config: Shard configuration for distributed computation. + vocab_size (int, optional): Vocabulary size. Default is None. Returns: torch.Tensor: Action log probs. """ - # labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] - # logits: torch.Tensor, # [B, S, Vocab_size] - log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype) - log_probs = log_probs.squeeze(-1) - return log_probs[:, -num_actions:] + action_log_probs = torch.zeros((logits.size(0), num_action), device=logits.device, dtype=logits.dtype) + context_length = logits.size(1) - num_action + for i in range(action_log_probs.size(0)): + # loop over each sample in the micro-batch + for start in range(context_length, logits.size(1), chunk_size): + end = min(start + chunk_size, logits.size(1)) + # calculate log probs in chunks to save memory + log_probs = dist_log_prob( + inputs[i : i + 1, start - 1 : end], + logits[i : i + 1, start - 1 : end], + shard_config, + vocab_size, + logits.dtype, + ) # [1, chunk_size, 1] + log_probs = log_probs.squeeze(-1) + action_log_probs[i, start - context_length : end - context_length] += log_probs[0] + return action_log_probs def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: