ColossalAI/colossalai/nn/metric/__init__.py
アマデウス 9ee197d0e9 moved env variables to global variables; (#215)
added branch context;
added vocab parallel layers;
moved split_batch from load_batch to tensor parallel embedding layers;
updated gpt model;
updated unit test cases;
fixed few collective communicator bugs
2022-02-15 11:31:13 +08:00

27 lines
704 B
Python

from torch import nn
from ._utils import calc_acc
from .accuracy_2d import Accuracy2D
from .accuracy_2p5d import Accuracy2p5D
from .accuracy_3d import Accuracy3D
from colossalai.nn.layer.utils import get_tensor_parallel_mode
_parallel_accuracy = {
'2d': Accuracy2D,
'2.5d': Accuracy2p5D,
'3d': Accuracy3D,
}
class Accuracy(nn.Module):
def __init__(self):
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel not in _parallel_accuracy:
self.acc = calc_acc
else:
self.acc = _parallel_accuracy[tensor_parallel]()
def forward(self, *args):
return self.acc(*args)