From 30a6859f7792a9026d388224f8fa50e4d6e711ac Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 13 Jun 2025 18:21:54 +0800 Subject: [PATCH 1/3] optimize pp log_softmax OOM --- .gitignore | 2 + .../coati/distributed/grpo_consumer.py | 37 +++++++++++++------ 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index e603f5015..94d952295 100644 --- a/.gitignore +++ b/.gitignore @@ -171,3 +171,5 @@ applications/ColossalChat/*.txt applications/ColossalChat/*.db applications/ColossalChat/stdin applications/ColossalChat/*.zip +applications/ColossalChat/*.prof +applications/ColossalChat/*.png diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 51cdcf322..040ca61a4 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -280,13 +280,21 @@ class GRPOConsumer(BaseConsumer): ) if self.booster.plugin.stage_manager.is_last_stage(): - reference_model_logits = reference_model_outputs["outputs"]["logits"] - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, + reference_action_log_probs = torch.zeros( + (input_ids_forward_micro_batch.size(0), num_action), + device=input_ids_forward_micro_batch.device, ) + for i in range(reference_action_log_probs.size(0)): + # activation for log_softmax is too large if vocab size and sequence length are large + # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example), + # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB + reference_action_log_probs[i, :] += calc_action_log_probs( + reference_model_outputs["outputs"]["logits"][i : i + 1] + / self.generate_config["temperature"], + input_ids_forward_micro_batch[i : i + 1], + num_action, + self.plugin.shard_config, + )[0] else: # Dummy reference logprobs for data iterator. reference_action_log_probs = None @@ -308,12 +316,19 @@ class GRPOConsumer(BaseConsumer): def _criterion(outputs, inputs): action_logits = outputs.logits - action_log_probs = calc_action_log_probs( - action_logits / self.generate_config["temperature"], - inputs["input_ids"], - num_action, - self.plugin.shard_config, + action_log_probs = torch.zeros( + (inputs["input_ids"].size(0), num_action), device=action_logits.device ) + for i in range(action_log_probs.size(0)): + # activation for log_softmax is too large if vocab size and sequence length are large + # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example), + # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB + action_log_probs[i, :] += calc_action_log_probs( + action_logits[i : i + 1] / self.generate_config["temperature"], + inputs["input_ids"][i : i + 1], + num_action, + self.plugin.shard_config, + )[0] if "reference_action_log_probs" in inputs: per_token_kl = ( torch.exp(inputs["reference_action_log_probs"] - action_log_probs) From e3d56cbd860da04d1b81a7917fac10d8aaedaa9a Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 18 Jun 2025 10:24:48 +0000 Subject: [PATCH 2/3] implement memory efficient logprob --- .../coati/distributed/grpo_consumer.py | 45 ++++++------------ .../ColossalChat/coati/distributed/utils.py | 46 ++++++++++++------- 2 files changed, 43 insertions(+), 48 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 040ca61a4..5e8f329eb 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -6,7 +6,7 @@ import torch import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss -from coati.distributed.utils import calc_action_log_probs +from coati.distributed.utils import memory_efficient_logprob from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer @@ -280,21 +280,12 @@ class GRPOConsumer(BaseConsumer): ) if self.booster.plugin.stage_manager.is_last_stage(): - reference_action_log_probs = torch.zeros( - (input_ids_forward_micro_batch.size(0), num_action), - device=input_ids_forward_micro_batch.device, + reference_action_log_probs = memory_efficient_logprob( + reference_model_outputs["outputs"]["logits"], + input_ids_forward_micro_batch, + num_action, + shard_config=self.plugin.shard_config, ) - for i in range(reference_action_log_probs.size(0)): - # activation for log_softmax is too large if vocab size and sequence length are large - # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example), - # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB - reference_action_log_probs[i, :] += calc_action_log_probs( - reference_model_outputs["outputs"]["logits"][i : i + 1] - / self.generate_config["temperature"], - input_ids_forward_micro_batch[i : i + 1], - num_action, - self.plugin.shard_config, - )[0] else: # Dummy reference logprobs for data iterator. reference_action_log_probs = None @@ -316,19 +307,12 @@ class GRPOConsumer(BaseConsumer): def _criterion(outputs, inputs): action_logits = outputs.logits - action_log_probs = torch.zeros( - (inputs["input_ids"].size(0), num_action), device=action_logits.device + action_log_probs = memory_efficient_logprob( + action_logits, + inputs["input_ids"], + num_action, + shard_config=self.plugin.shard_config, ) - for i in range(action_log_probs.size(0)): - # activation for log_softmax is too large if vocab size and sequence length are large - # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example), - # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB - action_log_probs[i, :] += calc_action_log_probs( - action_logits[i : i + 1] / self.generate_config["temperature"], - inputs["input_ids"][i : i + 1], - num_action, - self.plugin.shard_config, - )[0] if "reference_action_log_probs" in inputs: per_token_kl = ( torch.exp(inputs["reference_action_log_probs"] - action_log_probs) @@ -370,16 +354,15 @@ class GRPOConsumer(BaseConsumer): mean_kl.append(kl) mean_loss.append(all_reduce_mean(loss, self.plugin).data) else: - policy_model_logits = self.policy_model( input_ids=input_ids_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch, ).logits - action_log_probs = calc_action_log_probs( + action_log_probs = memory_efficient_logprob( policy_model_logits / self.generate_config["temperature"], input_ids_forward_micro_batch, num_action, - self.plugin.shard_config, + shard_config=self.plugin.shard_config, ) if self.policy_loss_fn.beta > 0: @@ -388,7 +371,7 @@ class GRPOConsumer(BaseConsumer): input_ids=input_ids_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch, ).logits - reference_action_log_probs = calc_action_log_probs( + reference_action_log_probs = memory_efficient_logprob( reference_model_logits / self.generate_config["temperature"], input_ids_forward_micro_batch, num_action, diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index a40ebbcfb..d46243114 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -71,31 +71,43 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T return per_label_logps.squeeze(-1) -def calc_action_log_probs( +def memory_efficient_logprob( logits: torch.Tensor, - sequences: torch.LongTensor, - num_actions: int, - shard_config, + inputs: torch.Tensor, + num_action: int, + chunk_size: int = 2048, + shard_config: Any = None, vocab_size: int = None, ) -> torch.Tensor: - """Calculate action log probs. - + """ + Calculate action log probs in a memory-efficient way by processing in chunks. Args: logits (torch.Tensor): Output tensor of Actor.forward.logits. - sequences (torch.LongTensor): Input sequences. - num_actions (int): Number of actions. - shard_config - vocab_size - - + inputs (torch.LongTensor): Input sequences. + num_action (int): Number of actions. + chunk_size (int, optional): Size of each chunk to process. Default is 2048. + shard_config: Shard configuration for distributed computation. + vocab_size (int, optional): Vocabulary size. Default is None. Returns: torch.Tensor: Action log probs. """ - # labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] - # logits: torch.Tensor, # [B, S, Vocab_size] - log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype) - log_probs = log_probs.squeeze(-1) - return log_probs[:, -num_actions:] + action_log_probs = torch.zeros((logits.size(0), num_action), device=logits.device, dtype=logits.dtype) + context_length = logits.size(1) - num_action + for i in range(action_log_probs.size(0)): + # loop over each sample in the micro-batch + for start in range(context_length, logits.size(1), chunk_size): + end = min(start + chunk_size, logits.size(1)) + # calculate log probs in chunks to save memory + log_probs = dist_log_prob( + inputs[i : i + 1, start - 1 : end], + logits[i : i + 1, start - 1 : end], + shard_config, + vocab_size, + logits.dtype, + ) # [1, chunk_size, 1] + log_probs = log_probs.squeeze(-1) + action_log_probs[i, start - context_length : end - context_length] += log_probs[0] + return action_log_probs def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: From 6b06430ca418715240c10a28baf14ff3e4b35ca0 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 19 Jun 2025 01:37:52 +0000 Subject: [PATCH 3/3] fix small bug --- .../ColossalChat/coati/distributed/grpo_consumer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 5e8f329eb..d7d6221a1 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -281,7 +281,7 @@ class GRPOConsumer(BaseConsumer): if self.booster.plugin.stage_manager.is_last_stage(): reference_action_log_probs = memory_efficient_logprob( - reference_model_outputs["outputs"]["logits"], + reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"], input_ids_forward_micro_batch, num_action, shard_config=self.plugin.shard_config, @@ -308,7 +308,7 @@ class GRPOConsumer(BaseConsumer): def _criterion(outputs, inputs): action_logits = outputs.logits action_log_probs = memory_efficient_logprob( - action_logits, + action_logits / self.generate_config["temperature"], inputs["input_ids"], num_action, shard_config=self.plugin.shard_config, @@ -375,7 +375,7 @@ class GRPOConsumer(BaseConsumer): reference_model_logits / self.generate_config["temperature"], input_ids_forward_micro_batch, num_action, - self.plugin.shard_config, + shard_config=self.plugin.shard_config, ) per_token_kl = ( torch.exp(reference_action_log_probs - action_log_probs)