[test] add mixtral modelling test

This commit is contained in:
botbw
2024-07-15 06:43:27 +00:00
committed by Hongxin Liu
parent 102b784a10
commit 0b5bbe9ce4
2 changed files with 144 additions and 1 deletions

View File

@@ -2,7 +2,6 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.testing import assert_close
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
@@ -146,6 +145,10 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
elif dtype is torch.bfloat16:
rtol = 4e-3
atol = 4e-3
else:
assert dtype is torch.float32
rtol = 1e-5
atol = 1e-5
a = a.detach().to(dtype)
b = b.detach().to(dtype).to(a.device)