[zero] global model data memory tracer (#360)

This commit is contained in:
Jiarui Fang
2022-03-10 11:20:04 +08:00
committed by Frank Lee
parent cb34cd384d
commit ea2872073f
5 changed files with 94 additions and 4 deletions

View File

@@ -7,7 +7,7 @@ class ShardedTensor(object):
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None:
r"""
A tensor sharded in multiple processes.
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
"""
self._payload = tensor
self.process_group = process_group