[tensor] add zero_like colo op, important for Optimizer (#1236)

This commit is contained in:
Jiarui Fang
2022-07-08 14:55:27 +08:00
committed by GitHub
parent 3b500984b1
commit 4a76084dc9
4 changed files with 16 additions and 6 deletions

View File

@@ -55,7 +55,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
return tensor
def __repr__(self):
return f'ColoParameter: {torch.Tensor.__repr__(self)}'
return f'ColoParameter: {ColoTensor.__repr__(self)}'
@classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None):

View File

@@ -271,3 +271,6 @@ class ColoTensor(torch.Tensor):
def is_shard_1drow(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
def is_sharded(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD