mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-17 15:36:53 +00:00
fix dist log prob test
This commit is contained in:
parent
99ba48fc40
commit
4152c0b30f
@ -1,6 +1,5 @@
|
||||
import pytest
|
||||
import torch
|
||||
from coati.distributed.utils import log_probs_from_logits
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
@ -12,6 +11,22 @@ CONFIG = dict(
|
||||
)
|
||||
|
||||
|
||||
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute the log probabilities from logits for the given labels.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): The input logits.
|
||||
labels (torch.Tensor): The target labels.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The log probabilities corresponding to the labels.
|
||||
"""
|
||||
log_probs = torch.log_softmax(logits, dim=-1)
|
||||
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||
return per_label_logps.squeeze(-1)
|
||||
|
||||
|
||||
def check_dist_log_prob(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
|
||||
|
Loading…
Reference in New Issue
Block a user