From 4152c0b30f04a5be0a2ff6b7445f82de05f9599d Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 15 Aug 2025 10:11:54 +0800 Subject: [PATCH] fix dist log prob test --- .../test_layer/test_dist_log_prob.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/test_shardformer/test_layer/test_dist_log_prob.py b/tests/test_shardformer/test_layer/test_dist_log_prob.py index 05a6a5d47..f863ee555 100644 --- a/tests/test_shardformer/test_layer/test_dist_log_prob.py +++ b/tests/test_shardformer/test_layer/test_dist_log_prob.py @@ -1,6 +1,5 @@ import pytest import torch -from coati.distributed.utils import log_probs_from_logits import colossalai from colossalai.logging import disable_existing_loggers @@ -12,6 +11,22 @@ CONFIG = dict( ) +def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """ + Compute the log probabilities from logits for the given labels. + + Args: + logits (torch.Tensor): The input logits. + labels (torch.Tensor): The target labels. + + Returns: + torch.Tensor: The log probabilities corresponding to the labels. + """ + log_probs = torch.log_softmax(logits, dim=-1) + per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) + return per_label_logps.squeeze(-1) + + def check_dist_log_prob(rank, world_size, port): disable_existing_loggers() colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")