Files
ColossalAI/colossalai/legacy/nn/metric/__init__.py
Hongxin Liu 554aa9592e [legacy] move communication and nn to legacy and refactor logger (#4671)
* [legacy] move communication to legacy (#4640)

* [legacy] refactor logger and clean up legacy codes (#4654)

* [legacy] make logger independent to gpc

* [legacy] make optim independent to registry

* [legacy] move test engine to legacy

* [legacy] move nn to legacy (#4656)

* [legacy] move nn to legacy

* [checkpointio] fix save hf config

* [test] remove useledd rpc pp test

* [legacy] fix nn init

* [example] skip tutorial hybriad parallel example

* [devops] test doc check

* [devops] test doc check
2023-09-11 16:24:28 +08:00

29 lines
687 B
Python

from torch import nn
from colossalai.legacy.nn.layer.utils import get_tensor_parallel_mode
from ._utils import calc_acc
from .accuracy_2d import Accuracy2D
from .accuracy_2p5d import Accuracy2p5D
from .accuracy_3d import Accuracy3D
_parallel_accuracy = {
'2d': Accuracy2D,
'2.5d': Accuracy2p5D,
'3d': Accuracy3D,
}
class Accuracy(nn.Module):
def __init__(self):
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel not in _parallel_accuracy:
self.acc = calc_acc
else:
self.acc = _parallel_accuracy[tensor_parallel]()
def forward(self, *args):
return self.acc(*args)