[rename] convert_to_dist -> redistribute (#1243)

This commit is contained in:
Jiarui Fang 2022-07-11 13:05:44 +08:00 committed by GitHub
parent f6add9b720
commit 2699dfbbfd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 18 additions and 18 deletions

View File

@ -11,7 +11,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# mat1:S[1] x mat2:S[0] = Output:P # mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res # beta * input + alpha * All-Reduce(Output) = res
mat1 = mat1.convert_to_dist_spec(distspec.shard([-1], [mat2.get_tp_world_size()])) mat1 = mat1.redistribute(distspec.shard([-1], [mat2.get_tp_world_size()]))
# Output:P # Output:P
partial_output = torch.mm(mat1, mat2) partial_output = torch.mm(mat1, mat2)
@ -28,7 +28,7 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
alpha: Number) -> ColoTensor: alpha: Number) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1] # mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
compute_spec = mat2.compute_spec compute_spec = mat2.compute_spec
mat1 = mat1.convert_to_dist_spec(distspec.replicate()) mat1 = mat1.redistribute(distspec.replicate())
mat1 = reduce_grad(mat1, mat1.get_process_group()) mat1 = reduce_grad(mat1, mat1.get_process_group())
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha) output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)

View File

@ -14,7 +14,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse: bool = False) -> ColoTensor: sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) # embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table # Gather splitted lookup table
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate()) input_tensor = input_tensor.redistribute(distspec.replicate())
output_parallel = F.embedding(input_tensor, output_parallel = F.embedding(input_tensor,
weight, weight,
@ -46,7 +46,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# Find index in this shard and mask those not here # Find index in this shard and mask those not here
# Reduce all # Reduce all
pg = weight.get_process_group() pg = weight.get_process_group()
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate()) input_tensor = input_tensor.redistribute(distspec.replicate())
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) # tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = weight.get_process_group().tp_local_rank() tensor_parallel_rank = weight.get_process_group().tp_local_rank()

View File

@ -20,7 +20,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) # embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table # Gather splitted lookup table
pg = weight.get_process_group() pg = weight.get_process_group()
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate()) input_tensor = input_tensor.redistribute(distspec.replicate())
output_parallel = F.embedding_bag(input_tensor, output_parallel = F.embedding_bag(input_tensor,
weight, weight,

View File

@ -16,7 +16,7 @@ def colo_layernorm(
assert isinstance(weight, ColoTensor) assert isinstance(weight, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
bias = convert_to_colo_tensor(bias, weight.get_process_group()) bias = convert_to_colo_tensor(bias, weight.get_process_group())
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate()) input_tensor = input_tensor.redistribute(distspec.replicate())
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps) output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group())) output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group()))

View File

@ -12,7 +12,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Reduce(Output) + bias = res # All-Reduce(Output) + bias = res
# Input:S[1] # Input:S[1]
pg = weight.get_process_group() pg = weight.get_process_group()
input_tensor = input_tensor.convert_to_dist_spec(distspec.shard([-1], [weight.get_tp_world_size()])) input_tensor = input_tensor.redistribute(distspec.shard([-1], [weight.get_tp_world_size()]))
# Output:P # Output:P
partial_output = F.linear(input_tensor, weight) partial_output = F.linear(input_tensor, weight)
@ -33,7 +33,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Gather(Output) # All-Gather(Output)
# Input:B # Input:B
compute_spec = weight.compute_spec compute_spec = weight.compute_spec
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate()) input_tensor = input_tensor.redistribute(distspec.replicate())
input_parallel = reduce_grad(input_tensor, weight.get_process_group()) input_parallel = reduce_grad(input_tensor, weight.get_process_group())
output_parallel = F.linear(input_parallel, weight, bias) output_parallel = F.linear(input_parallel, weight, bias)

View File

@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]:
Tensor._base.__get__, Tensor._base.__get__,
Tensor.grad.__get__, Tensor.grad.__get__,
Tensor._grad.__get__, Tensor._grad.__get__,
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
} }
@ -140,7 +140,7 @@ class ColoTensor(torch.Tensor):
""" """
assert isinstance(dist_spec, _DistSpec) assert isinstance(dist_spec, _DistSpec)
assert self.process_group is not None assert self.process_group is not None
self._convert_to_dist_spec(dist_spec) self._redistribute(dist_spec)
def set_tensor_spec(self, dist_spec, compute_spec): def set_tensor_spec(self, dist_spec, compute_spec):
if dist_spec: if dist_spec:
@ -174,8 +174,8 @@ class ColoTensor(torch.Tensor):
def __repr__(self): def __repr__(self):
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}' return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}'
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None: def _redistribute(self, dist_spec: _DistSpec) -> None:
"""_convert_to_dist_spec """_redistribute
Note the function will not handle the logic of backward propagation! Note the function will not handle the logic of backward propagation!
It is used during model tensor initializations as an internal function. It is used during model tensor initializations as an internal function.
Args: Args:
@ -186,7 +186,7 @@ class ColoTensor(torch.Tensor):
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group) self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
self.dist_spec = dist_spec self.dist_spec = dist_spec
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor': def redistribute(self, dist_spec: _DistSpec) -> 'ColoTensor':
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group) ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec)) return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
@ -194,13 +194,13 @@ class ColoTensor(torch.Tensor):
"""to_replicate_ """to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE an inline member function, converting dist spec of the tensor to REPLICATE
""" """
self._convert_to_dist_spec(dist_spec=distspec.replicate()) self._redistribute(dist_spec=distspec.replicate())
def to_replicate(self) -> 'ColoTensor': def to_replicate(self) -> 'ColoTensor':
"""to_replicate """to_replicate
converting dist spec of the tensor to REPLICATE converting dist spec of the tensor to REPLICATE
""" """
return self.convert_to_dist_spec(distspec.replicate()) return self.redistribute(distspec.replicate())
@staticmethod @staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor': def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
@ -234,7 +234,7 @@ class ColoTensor(torch.Tensor):
""" """
if self.is_replicate(): if self.is_replicate():
return super().view(*args) return super().view(*args)
replicated_t = self.convert_to_dist_spec(dist_spec=distspec.replicate()) replicated_t = self.redistribute(dist_spec=distspec.replicate())
return replicated_t.view(*args) return replicated_t.view(*args)
def size_global(self, args: Optional[int] = None): def size_global(self, args: Optional[int] = None):
@ -280,4 +280,4 @@ class ColoTensor(torch.Tensor):
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0 and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
def is_sharded(self): def is_sharded(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD return self.dist_spec.placement == DistPlacementPattern.SHARD

View File

@ -22,7 +22,7 @@ def check_cross_entropy():
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg)) input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
input_shard = input_t_colo.convert_to_dist_spec(distspec.shard([-1], [pg.tp_world_size()])) input_shard = input_t_colo.redistribute(distspec.shard([-1], [pg.tp_world_size()]))
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D)) input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
output = F.cross_entropy(input_t, target) output = F.cross_entropy(input_t, target)