[shardformer] supported T5 and its variants (#4045)

This commit is contained in:
Frank Lee
2023-06-19 17:57:37 +08:00
parent c1d5453e9f
commit d857f3dbba
10 changed files with 316 additions and 221 deletions

View File

@@ -1,4 +1,4 @@
from typing import OrderedDict
from typing import Any, List, OrderedDict
import torch
import torch.distributed as dist
@@ -52,3 +52,52 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool
assert torch.equal(v, d2[k])
else:
assert v == d2[k]
def assert_hf_output_close(out1: Any,
out2: Any,
ignore_keys: List[str] = None,
track_name: str = "",
atol=1e-5,
rtol=1e-5):
"""
Check if two outputs from huggingface are equal.
Args:
out1 (Any): the first output
out2 (Any): the second output
ignore_keys (List[str]): the keys to ignore when comparing two dicts
track_name (str): the name of the value compared, used to track the path
"""
if isinstance(out1, dict) and isinstance(out2, dict):
# if two values are dict
# we recursively check the keys
assert set(out1.keys()) == set(out2.keys())
for k in out1.keys():
if ignore_keys is not None and k in ignore_keys:
continue
assert_hf_output_close(out1[k],
out2[k],
track_name=f"{track_name}.{k}",
ignore_keys=ignore_keys,
atol=atol,
rtol=rtol)
elif isinstance(out1, (list, tuple)) and isinstance(out2, (list, tuple)):
# if two values are list
# we recursively check the elements
assert len(out1) == len(out2)
for i in range(len(out1)):
assert_hf_output_close(out1[i],
out2[i],
track_name=f"{track_name}.{i}",
ignore_keys=ignore_keys,
atol=atol,
rtol=rtol)
elif isinstance(out1, Tensor) and isinstance(out2, Tensor):
if out1.shape != out2.shape:
raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}")
assert torch.allclose(
out1, out2, atol=atol, rtol=rtol
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, mean error: {torch.abs(out1 - out2).mean()}"
else:
assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}"