[Tensor] add embedding tp1d row (#904)

This commit is contained in:
Ziyue Jiang
2022-04-29 14:10:05 +08:00
committed by GitHub
parent 16122d5fac
commit f593a5637e
5 changed files with 108 additions and 8 deletions

View File

@@ -166,6 +166,7 @@ class ColoTensor(object):
dim = -1
self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim)
self._shard_pattern = ShardPattern.NA
self._size = self._torch_tensor.size()
def is_gathered(self) -> bool:
return self._shard_pattern == ShardPattern.NA