mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 14:10:29 +00:00
[rename] convert_to_dist -> redistribute (#1243)
This commit is contained in:
parent
f6add9b720
commit
2699dfbbfd
@ -11,7 +11,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
|||||||
# mat1:S[1] x mat2:S[0] = Output:P
|
# mat1:S[1] x mat2:S[0] = Output:P
|
||||||
# beta * input + alpha * All-Reduce(Output) = res
|
# beta * input + alpha * All-Reduce(Output) = res
|
||||||
|
|
||||||
mat1 = mat1.convert_to_dist_spec(distspec.shard([-1], [mat2.get_tp_world_size()]))
|
mat1 = mat1.redistribute(distspec.shard([-1], [mat2.get_tp_world_size()]))
|
||||||
|
|
||||||
# Output:P
|
# Output:P
|
||||||
partial_output = torch.mm(mat1, mat2)
|
partial_output = torch.mm(mat1, mat2)
|
||||||
@ -28,7 +28,7 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
|||||||
alpha: Number) -> ColoTensor:
|
alpha: Number) -> ColoTensor:
|
||||||
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
|
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
|
||||||
compute_spec = mat2.compute_spec
|
compute_spec = mat2.compute_spec
|
||||||
mat1 = mat1.convert_to_dist_spec(distspec.replicate())
|
mat1 = mat1.redistribute(distspec.replicate())
|
||||||
mat1 = reduce_grad(mat1, mat1.get_process_group())
|
mat1 = reduce_grad(mat1, mat1.get_process_group())
|
||||||
|
|
||||||
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
|
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
|
||||||
|
@ -14,7 +14,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
|||||||
sparse: bool = False) -> ColoTensor:
|
sparse: bool = False) -> ColoTensor:
|
||||||
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||||
# Gather splitted lookup table
|
# Gather splitted lookup table
|
||||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
|
input_tensor = input_tensor.redistribute(distspec.replicate())
|
||||||
|
|
||||||
output_parallel = F.embedding(input_tensor,
|
output_parallel = F.embedding(input_tensor,
|
||||||
weight,
|
weight,
|
||||||
@ -46,7 +46,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
|||||||
# Find index in this shard and mask those not here
|
# Find index in this shard and mask those not here
|
||||||
# Reduce all
|
# Reduce all
|
||||||
pg = weight.get_process_group()
|
pg = weight.get_process_group()
|
||||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
|
input_tensor = input_tensor.redistribute(distspec.replicate())
|
||||||
|
|
||||||
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||||
tensor_parallel_rank = weight.get_process_group().tp_local_rank()
|
tensor_parallel_rank = weight.get_process_group().tp_local_rank()
|
||||||
|
@ -20,7 +20,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
|
|||||||
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||||
# Gather splitted lookup table
|
# Gather splitted lookup table
|
||||||
pg = weight.get_process_group()
|
pg = weight.get_process_group()
|
||||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
|
input_tensor = input_tensor.redistribute(distspec.replicate())
|
||||||
|
|
||||||
output_parallel = F.embedding_bag(input_tensor,
|
output_parallel = F.embedding_bag(input_tensor,
|
||||||
weight,
|
weight,
|
||||||
|
@ -16,7 +16,7 @@ def colo_layernorm(
|
|||||||
assert isinstance(weight, ColoTensor)
|
assert isinstance(weight, ColoTensor)
|
||||||
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
|
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
|
||||||
bias = convert_to_colo_tensor(bias, weight.get_process_group())
|
bias = convert_to_colo_tensor(bias, weight.get_process_group())
|
||||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
|
input_tensor = input_tensor.redistribute(distspec.replicate())
|
||||||
|
|
||||||
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
|
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
|
||||||
output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group()))
|
output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group()))
|
||||||
|
@ -12,7 +12,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||||||
# All-Reduce(Output) + bias = res
|
# All-Reduce(Output) + bias = res
|
||||||
# Input:S[1]
|
# Input:S[1]
|
||||||
pg = weight.get_process_group()
|
pg = weight.get_process_group()
|
||||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.shard([-1], [weight.get_tp_world_size()]))
|
input_tensor = input_tensor.redistribute(distspec.shard([-1], [weight.get_tp_world_size()]))
|
||||||
|
|
||||||
# Output:P
|
# Output:P
|
||||||
partial_output = F.linear(input_tensor, weight)
|
partial_output = F.linear(input_tensor, weight)
|
||||||
@ -33,7 +33,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||||||
# All-Gather(Output)
|
# All-Gather(Output)
|
||||||
# Input:B
|
# Input:B
|
||||||
compute_spec = weight.compute_spec
|
compute_spec = weight.compute_spec
|
||||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
|
input_tensor = input_tensor.redistribute(distspec.replicate())
|
||||||
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
|
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
|
||||||
|
|
||||||
output_parallel = F.linear(input_parallel, weight, bias)
|
output_parallel = F.linear(input_parallel, weight, bias)
|
||||||
|
@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]:
|
|||||||
Tensor._base.__get__,
|
Tensor._base.__get__,
|
||||||
Tensor.grad.__get__,
|
Tensor.grad.__get__,
|
||||||
Tensor._grad.__get__,
|
Tensor._grad.__get__,
|
||||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -140,7 +140,7 @@ class ColoTensor(torch.Tensor):
|
|||||||
"""
|
"""
|
||||||
assert isinstance(dist_spec, _DistSpec)
|
assert isinstance(dist_spec, _DistSpec)
|
||||||
assert self.process_group is not None
|
assert self.process_group is not None
|
||||||
self._convert_to_dist_spec(dist_spec)
|
self._redistribute(dist_spec)
|
||||||
|
|
||||||
def set_tensor_spec(self, dist_spec, compute_spec):
|
def set_tensor_spec(self, dist_spec, compute_spec):
|
||||||
if dist_spec:
|
if dist_spec:
|
||||||
@ -174,8 +174,8 @@ class ColoTensor(torch.Tensor):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}'
|
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}'
|
||||||
|
|
||||||
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
|
def _redistribute(self, dist_spec: _DistSpec) -> None:
|
||||||
"""_convert_to_dist_spec
|
"""_redistribute
|
||||||
Note the function will not handle the logic of backward propagation!
|
Note the function will not handle the logic of backward propagation!
|
||||||
It is used during model tensor initializations as an internal function.
|
It is used during model tensor initializations as an internal function.
|
||||||
Args:
|
Args:
|
||||||
@ -186,7 +186,7 @@ class ColoTensor(torch.Tensor):
|
|||||||
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
|
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
|
||||||
self.dist_spec = dist_spec
|
self.dist_spec = dist_spec
|
||||||
|
|
||||||
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
def redistribute(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
||||||
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
|
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
|
||||||
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
|
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
|
||||||
|
|
||||||
@ -194,13 +194,13 @@ class ColoTensor(torch.Tensor):
|
|||||||
"""to_replicate_
|
"""to_replicate_
|
||||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||||
"""
|
"""
|
||||||
self._convert_to_dist_spec(dist_spec=distspec.replicate())
|
self._redistribute(dist_spec=distspec.replicate())
|
||||||
|
|
||||||
def to_replicate(self) -> 'ColoTensor':
|
def to_replicate(self) -> 'ColoTensor':
|
||||||
"""to_replicate
|
"""to_replicate
|
||||||
converting dist spec of the tensor to REPLICATE
|
converting dist spec of the tensor to REPLICATE
|
||||||
"""
|
"""
|
||||||
return self.convert_to_dist_spec(distspec.replicate())
|
return self.redistribute(distspec.replicate())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
||||||
@ -234,7 +234,7 @@ class ColoTensor(torch.Tensor):
|
|||||||
"""
|
"""
|
||||||
if self.is_replicate():
|
if self.is_replicate():
|
||||||
return super().view(*args)
|
return super().view(*args)
|
||||||
replicated_t = self.convert_to_dist_spec(dist_spec=distspec.replicate())
|
replicated_t = self.redistribute(dist_spec=distspec.replicate())
|
||||||
return replicated_t.view(*args)
|
return replicated_t.view(*args)
|
||||||
|
|
||||||
def size_global(self, args: Optional[int] = None):
|
def size_global(self, args: Optional[int] = None):
|
||||||
@ -280,4 +280,4 @@ class ColoTensor(torch.Tensor):
|
|||||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||||
|
|
||||||
def is_sharded(self):
|
def is_sharded(self):
|
||||||
return self.dist_spec.placement == DistPlacementPattern.SHARD
|
return self.dist_spec.placement == DistPlacementPattern.SHARD
|
||||||
|
@ -22,7 +22,7 @@ def check_cross_entropy():
|
|||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
pg = ProcessGroup(tp_degree=world_size)
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
|
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
|
||||||
input_shard = input_t_colo.convert_to_dist_spec(distspec.shard([-1], [pg.tp_world_size()]))
|
input_shard = input_t_colo.redistribute(distspec.shard([-1], [pg.tp_world_size()]))
|
||||||
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
|
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
|
||||||
|
|
||||||
output = F.cross_entropy(input_t, target)
|
output = F.cross_entropy(input_t, target)
|
||||||
|
Loading…
Reference in New Issue
Block a user