[Doc] add more doc for ColoTensor. (#1458)

This commit is contained in:
Jiarui Fang
2022-08-16 10:38:41 +08:00
committed by GitHub
parent a1476ea882
commit 36824a304c
4 changed files with 46 additions and 18 deletions

View File

@@ -19,7 +19,7 @@ def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8):
def shard_param(p: ColoParameter) -> None:
pg = p.get_process_group()
p._redistribute(distspec.shard([0], [pg.tp_world_size()]))
p._redistribute(distspec.ShardSpec([0], [pg.tp_world_size()]))
p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach()