mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[tensor] torch function return colotensor (#1229)
This commit is contained in:
@@ -113,6 +113,7 @@ def run_1d_hybrid_tp(model_name):
|
||||
|
||||
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
|
||||
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
|
||||
|
||||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
output = model(data)
|
||||
|
@@ -39,7 +39,7 @@ def check_spec_eq(tensor, other):
|
||||
assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor)
|
||||
for k in dir(tensor.dist_spec):
|
||||
if not k.startswith('__'):
|
||||
assert hasattr(other.dist_spec, k)
|
||||
assert hasattr(other.dist_spec, k), f"{k}"
|
||||
assert getattr(tensor.dist_spec, k) == getattr(other.dist_spec, k)
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ def check_element_wise_ops():
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
t = torch.rand(2, 2)
|
||||
x = ColoTensor(t, spec=ColoTensorSpec(pg, distspec.shard([0], [pg.tp_world_size()])))
|
||||
|
||||
check_spec_eq(x, x.cuda())
|
||||
assert torch.equal(x.cuda(), t.cuda())
|
||||
check_spec_eq(x, torch.abs(x))
|
||||
|
@@ -49,6 +49,8 @@ def _run_operand():
|
||||
|
||||
t_ref_res = t_ref + t_ref
|
||||
t_res = t + t
|
||||
|
||||
assert isinstance(t_res, ColoTensor)
|
||||
assert torch.allclose(t_ref_res, t_res)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user