[tensor] torch function return colotensor (#1229)

This commit is contained in:
Jiarui Fang
2022-07-07 18:09:18 +08:00
committed by GitHub
parent 5581170890
commit a98319f023
8 changed files with 42 additions and 21 deletions

View File

@@ -113,6 +113,7 @@ def run_1d_hybrid_tp(model_name):
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
# Bcast rank0 data to all processes
if criterion:
output = model(data)

View File

@@ -39,7 +39,7 @@ def check_spec_eq(tensor, other):
assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor)
for k in dir(tensor.dist_spec):
if not k.startswith('__'):
assert hasattr(other.dist_spec, k)
assert hasattr(other.dist_spec, k), f"{k}"
assert getattr(tensor.dist_spec, k) == getattr(other.dist_spec, k)
@@ -48,6 +48,7 @@ def check_element_wise_ops():
pg = ProcessGroup(tp_degree=world_size)
t = torch.rand(2, 2)
x = ColoTensor(t, spec=ColoTensorSpec(pg, distspec.shard([0], [pg.tp_world_size()])))
check_spec_eq(x, x.cuda())
assert torch.equal(x.cuda(), t.cuda())
check_spec_eq(x, torch.abs(x))

View File

@@ -49,6 +49,8 @@ def _run_operand():
t_ref_res = t_ref + t_ref
t_res = t + t
assert isinstance(t_res, ColoTensor)
assert torch.allclose(t_ref_res, t_res)