mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-05 12:24:38 +00:00
[refactor] move process group from _DistSpec to ColoTensor. (#1203)
This commit is contained in:
@@ -1,10 +1,16 @@
|
||||
from colossalai.tensor import ColoParameter, ColoTensor
|
||||
from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup
|
||||
import torch
|
||||
from numpy import allclose
|
||||
import pytest
|
||||
from _utils import tensor_equal
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_multiinheritance():
|
||||
colo_param = ColoParameter()
|
||||
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
colo_param = ColoParameter(None, requires_grad=True)
|
||||
assert colo_param.dist_spec.placement.value == 'r'
|
||||
assert isinstance(colo_param, ColoTensor)
|
||||
assert isinstance(colo_param, torch.nn.Parameter)
|
||||
|
||||
@@ -22,5 +28,6 @@ def test_multiinheritance():
|
||||
clone_param = torch.clone(colo_param)
|
||||
assert isinstance(clone_param, ColoTensor)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_multiinheritance()
|
||||
test_multiinheritance()
|
||||
|
||||
Reference in New Issue
Block a user