mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[tensor] add zero_like colo op, important for Optimizer (#1236)
This commit is contained in:
@@ -42,7 +42,7 @@ def _run_wrapped_tensor_func():
|
||||
assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor), f"{type(t_split1)} {type(t_split2)}"
|
||||
|
||||
|
||||
def _run_operand():
|
||||
def _run_operand(world_size):
|
||||
pg = ProcessGroup()
|
||||
t_ref = torch.randn(4, 5)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
|
||||
@@ -53,6 +53,13 @@ def _run_operand():
|
||||
assert isinstance(t_res, ColoTensor)
|
||||
assert torch.allclose(t_ref_res, t_res)
|
||||
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
|
||||
t.set_dist_spec(distspec.shard([0], [world_size]))
|
||||
t_new = torch.zeros_like(t)
|
||||
assert isinstance(t_new, ColoTensor)
|
||||
assert t_new.is_sharded()
|
||||
|
||||
|
||||
#### Test Distributed init a Colotensor
|
||||
|
||||
@@ -105,9 +112,8 @@ def run_dist_tests(rank, world_size, port):
|
||||
_run_view(world_size)
|
||||
_run_process_group(world_size)
|
||||
_run_tensor_indexing()
|
||||
_run_operand()
|
||||
# TODO not passed
|
||||
# _run_wrapped_tensor_func()
|
||||
_run_operand(world_size)
|
||||
_run_wrapped_tensor_func()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@@ -119,4 +125,4 @@ def test_dist_cases(world_size):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dist_cases(2)
|
||||
test_dist_cases(1)
|
||||
|
Reference in New Issue
Block a user