[tensor] distributed checkpointing for parameters (#1240)

This commit is contained in:
Jiarui Fang
2022-07-12 15:51:06 +08:00
committed by GitHub
parent 49114d8df0
commit c92f84fcdb
6 changed files with 72 additions and 155 deletions

View File

@@ -143,10 +143,10 @@ class ColoTensor(torch.Tensor):
self._redistribute(dist_spec)
def set_tensor_spec(self, dist_spec, compute_spec):
if dist_spec:
if dist_spec is not None:
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}"
self.set_dist_spec(dist_spec)
if compute_spec:
if compute_spec is not None:
self.compute_spec = compute_spec
def has_compute_pattern(self, compute_pattern):

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import List
from typing import List, Optional
__all__ = ['replicate', 'shard']