mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-17 23:46:52 +00:00
fix dist log prob test
This commit is contained in:
parent
99ba48fc40
commit
4152c0b30f
@ -1,6 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from coati.distributed.utils import log_probs_from_logits
|
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
@ -12,6 +11,22 @@ CONFIG = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute the log probabilities from logits for the given labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (torch.Tensor): The input logits.
|
||||||
|
labels (torch.Tensor): The target labels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The log probabilities corresponding to the labels.
|
||||||
|
"""
|
||||||
|
log_probs = torch.log_softmax(logits, dim=-1)
|
||||||
|
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||||
|
return per_label_logps.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
def check_dist_log_prob(rank, world_size, port):
|
def check_dist_log_prob(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
|
||||||
|
Loading…
Reference in New Issue
Block a user