[zero] fix error for BEiT models (#2169)

* [zero] fix error for BEiT models

* [ColoParameter] add unpack operation for tuple arguments

* fix bugs

* fix chunkv2 unit testing

* add assertion for gradient state
This commit is contained in:
HELSON
2022-12-26 15:03:54 +08:00
committed by GitHub
parent 4363ff3e41
commit 2458659919
7 changed files with 82 additions and 32 deletions

View File

@@ -57,7 +57,7 @@ class ColoTensor(torch.Tensor):
The Colotensor can be initialized with a PyTorch tensor in the following ways.
>>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec()))
>>> # The tensor passed in is a tensor after sharding but not a global tensor.
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],