mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[tensor] distributed checkpointing for parameters (#1240)
This commit is contained in:
@@ -122,6 +122,19 @@ def _run_redistributed(world_size):
|
||||
assert t1.is_replicate()
|
||||
|
||||
|
||||
def _run_set_tensor_spec(world_size):
|
||||
if world_size != 4:
|
||||
return
|
||||
pg = ProcessGroup(tp_degree=2, dp_degree=2)
|
||||
spec1 = ColoTensorSpec(pg)
|
||||
t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1)
|
||||
|
||||
dist_spec2 = (ShardSpec([-1], [pg.tp_world_size()]), None)
|
||||
assert t1.is_replicate()
|
||||
t1.set_dist_spec(*dist_spec2)
|
||||
assert t1.is_shard_1dcol()
|
||||
|
||||
|
||||
def run_dist_tests(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_tensor_shard_init(world_size)
|
||||
@@ -132,6 +145,7 @@ def run_dist_tests(rank, world_size, port):
|
||||
_run_operand(world_size)
|
||||
_run_wrapped_tensor_func()
|
||||
_run_redistributed(world_size)
|
||||
_run_set_tensor_spec(world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
Reference in New Issue
Block a user