mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-31 16:40:41 +00:00
[tensor] distributed checkpointing for parameters (#1240)
This commit is contained in:
@@ -143,10 +143,10 @@ class ColoTensor(torch.Tensor):
|
||||
self._redistribute(dist_spec)
|
||||
|
||||
def set_tensor_spec(self, dist_spec, compute_spec):
|
||||
if dist_spec:
|
||||
if dist_spec is not None:
|
||||
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}"
|
||||
self.set_dist_spec(dist_spec)
|
||||
if compute_spec:
|
||||
if compute_spec is not None:
|
||||
self.compute_spec = compute_spec
|
||||
|
||||
def has_compute_pattern(self, compute_pattern):
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
__all__ = ['replicate', 'shard']
|
||||
|
||||
|
Reference in New Issue
Block a user