diff --git a/colossalai/nn/metric/_utils.py b/colossalai/nn/metric/_utils.py index d4a69f943..eac591b64 100644 --- a/colossalai/nn/metric/_utils.py +++ b/colossalai/nn/metric/_utils.py @@ -1,5 +1,6 @@ import torch + def calc_acc(logits, targets): preds = torch.argmax(logits, dim=-1) correct = torch.sum(targets == preds)