fix dist log prob test

This commit is contained in:
YeAnbang 2025-08-15 10:11:54 +08:00
parent 99ba48fc40
commit 4152c0b30f

View File

@ -1,6 +1,5 @@
import pytest
import torch
from coati.distributed.utils import log_probs_from_logits
import colossalai
from colossalai.logging import disable_existing_loggers
@ -12,6 +11,22 @@ CONFIG = dict(
)
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
Compute the log probabilities from logits for the given labels.
Args:
logits (torch.Tensor): The input logits.
labels (torch.Tensor): The target labels.
Returns:
torch.Tensor: The log probabilities corresponding to the labels.
"""
log_probs = torch.log_softmax(logits, dim=-1)
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return per_label_logps.squeeze(-1)
def check_dist_log_prob(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")