mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[test] add mixtral modelling test
This commit is contained in:
@@ -2,7 +2,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
|
||||
@@ -146,6 +145,10 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
|
||||
elif dtype is torch.bfloat16:
|
||||
rtol = 4e-3
|
||||
atol = 4e-3
|
||||
else:
|
||||
assert dtype is torch.float32
|
||||
rtol = 1e-5
|
||||
atol = 1e-5
|
||||
|
||||
a = a.detach().to(dtype)
|
||||
b = b.detach().to(dtype).to(a.device)
|
||||
|
Reference in New Issue
Block a user