[tensor] remove gpc in tensor tests (#1186)

This commit is contained in:
Jiarui Fang
2022-06-29 14:08:40 +08:00
committed by GitHub
parent 372f791444
commit c463f8adf9
4 changed files with 26 additions and 20 deletions

View File

@@ -30,7 +30,7 @@ class ColoTensor(torch.Tensor):
1. directly init.
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA),
>>> shard_spec = distspec.shard(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = TensorSpec(shard_spec)