mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-06 03:58:00 +00:00
* [elixir] add elixir * [elixir] add unit tests * remove useless code * fix python 3.8 issue * fix typo * add test skip * add docstrings * add docstrings * add readme * fix typo
42 lines
982 B
Python
42 lines
982 B
Python
import torch
|
|
from torch.testing import assert_close
|
|
from torch.utils._pytree import tree_map
|
|
|
|
from . import gpt, mlp, opt, resnet, small
|
|
from .registry import TEST_MODELS
|
|
|
|
|
|
def to_cuda(input_dict):
|
|
|
|
def local_fn(t):
|
|
if isinstance(t, torch.Tensor):
|
|
t = t.cuda()
|
|
return t
|
|
|
|
ret = tree_map(local_fn, input_dict)
|
|
return ret
|
|
|
|
|
|
def allclose(ta, tb, **kwargs):
|
|
assert_close(ta, tb, **kwargs)
|
|
return True
|
|
|
|
|
|
def assert_dict_keys(test_dict, keys):
|
|
assert len(test_dict) == len(keys)
|
|
for k in keys:
|
|
assert k in test_dict
|
|
|
|
|
|
def assert_dict_values(da, db, fn):
|
|
assert len(da) == len(db)
|
|
for k, v in da.items():
|
|
assert k in db
|
|
if not torch.is_tensor(v):
|
|
continue
|
|
u = db.get(k)
|
|
if u.device != v.device:
|
|
v = v.to(u.device)
|
|
# print(f"checking key {k}: {u.shape} vs {v.shape}")
|
|
assert fn(u.data, v.data), f'max diff {torch.max(torch.abs(u.data - v.data))}'
|