diff --git a/.gitignore b/.gitignore index e603f5015..94d952295 100644 --- a/.gitignore +++ b/.gitignore @@ -171,3 +171,5 @@ applications/ColossalChat/*.txt applications/ColossalChat/*.db applications/ColossalChat/stdin applications/ColossalChat/*.zip +applications/ColossalChat/*.prof +applications/ColossalChat/*.png diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 51cdcf322..040ca61a4 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -280,13 +280,21 @@ class GRPOConsumer(BaseConsumer): ) if self.booster.plugin.stage_manager.is_last_stage(): - reference_model_logits = reference_model_outputs["outputs"]["logits"] - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, + reference_action_log_probs = torch.zeros( + (input_ids_forward_micro_batch.size(0), num_action), + device=input_ids_forward_micro_batch.device, ) + for i in range(reference_action_log_probs.size(0)): + # activation for log_softmax is too large if vocab size and sequence length are large + # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example), + # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB + reference_action_log_probs[i, :] += calc_action_log_probs( + reference_model_outputs["outputs"]["logits"][i : i + 1] + / self.generate_config["temperature"], + input_ids_forward_micro_batch[i : i + 1], + num_action, + self.plugin.shard_config, + )[0] else: # Dummy reference logprobs for data iterator. reference_action_log_probs = None @@ -308,12 +316,19 @@ class GRPOConsumer(BaseConsumer): def _criterion(outputs, inputs): action_logits = outputs.logits - action_log_probs = calc_action_log_probs( - action_logits / self.generate_config["temperature"], - inputs["input_ids"], - num_action, - self.plugin.shard_config, + action_log_probs = torch.zeros( + (inputs["input_ids"].size(0), num_action), device=action_logits.device ) + for i in range(action_log_probs.size(0)): + # activation for log_softmax is too large if vocab size and sequence length are large + # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example), + # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB + action_log_probs[i, :] += calc_action_log_probs( + action_logits[i : i + 1] / self.generate_config["temperature"], + inputs["input_ids"][i : i + 1], + num_action, + self.plugin.shard_config, + )[0] if "reference_action_log_probs" in inputs: per_token_kl = ( torch.exp(inputs["reference_action_log_probs"] - action_log_probs)