remove gather out in parallel action (#1163)

This commit is contained in:
Jiarui Fang
2022-06-23 16:35:05 +08:00
committed by GitHub
parent 51f1ec96b0
commit 177c374401
8 changed files with 43 additions and 32 deletions

View File

@@ -37,10 +37,10 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]),
ParallelAction(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if parallel_action.gather_out:
# All-Gather(Output)
output = output.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group()))
return output
# TODO(jiaruifang) addam is special case
# since gpt call view after the Op.
return output.to_replicate()
def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
@@ -62,11 +62,6 @@ def colo_addmm(input_tensor: GeneralTensor,
"""
input_tensor, mat1, mat2 = tuple(map(convert_to_colo_tensor, (input_tensor, mat1, mat2)))
# building the computing graph, inputs -> op
# if GraphGlobalEnv().graph_building:
# cur_op_node = GraphOpNode('linear', [weight, bias])
# cur_op_node.add_prev_tensor(input_tensor)
# Add communication logic before and after linear call.
ret_tensor = None
if not mat2.has_spec(): # No Model Parallel Applied
@@ -84,8 +79,4 @@ def colo_addmm(input_tensor: GeneralTensor,
else:
raise NotImplementedError
# building the computing graph, op -> output
# if GraphGlobalEnv().graph_building:
# cur_op_node.add_post_tensor(ret_tensor)
return ret_tensor

View File

@@ -30,9 +30,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
ParallelAction(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if weight.spec.parallel_action.gather_out:
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
return output
return output.to_replicate()
def colo_embedding_1Drow(input_tensor: ColoTensor,

View File

@@ -36,9 +36,8 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
ParallelAction(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if weight.spec.parallel_action.gather_out:
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
return output
return output.to_replicate()
def colo_embedding_bag_1d(tp_mode: str,

View File

@@ -42,10 +42,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
distspec.shard(weight.spec.get_process_group(), [-1],
[weight.spec.get_process_group_size()]),
ParallelAction(ComputePattern.TP1D)))
if parallel_action.gather_out:
# All-Gather(Output)
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
return output
return output.to_replicate()
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':

View File

@@ -92,10 +92,13 @@ class ColoTensor(torch.Tensor):
def __repr__(self):
return f'ColoTensor: {super().__repr__()}'
def is_model_data(self) -> bool:
return self._type == TensorType.MODEL
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
"""_convert_to_dist_spec
Note the function will not handle the logic of backward propagation!
It is used during model tensor initializations as an internal function.
Args:
dist_spec (_DistSpec): the target dist. spec.
"""
with DistSpecManager.no_grad():
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
self._tensor_spec.dist_spec = dist_spec
@@ -106,6 +109,19 @@ class ColoTensor(torch.Tensor):
ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
return ColoTensor.from_torch_tensor(ret, tensor_spec)
def to_replicate_(self):
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, distspec.replicate())
self._tensor_spec.dist_spec = distspec.replicate()
def to_replicate(self) -> 'ColoTensor':
"""to_replicate
converting dist spec of the tensor to REPLICATE
"""
return self.convert_to_dist_spec(distspec.replicate(self.spec.get_process_group()))
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
tensor = tensor.as_subclass(ColoTensor)
@@ -121,3 +137,13 @@ class ColoTensor(torch.Tensor):
tensor = ColoTensor(data, spec=copy(self.spec))
memo[id(self)] = tensor
return tensor
# TODO(jiaruifang) a patch for gpt test.
# We need to override the member function must operate on a replicated tensor
# def view(self, *args, **kwargs):
# self.data = DistSpecManager.handle_trans_spec(self,
# self.spec.dist_spec,
# distspec.replicate(self.spec.get_process_group()))
# # self._tensor_spec.dist_spec = distspec.replicate(self.spec.get_process_group())
# self.data.view(*args, **kwargs)
# return ColoTensor.from_torch_tensor(self.data)

View File

@@ -13,13 +13,12 @@ class ComputePattern(Enum):
class ParallelAction(object):
def __init__(self, compute_pattern: ComputePattern, gather_out: bool = True) -> None:
def __init__(self, compute_pattern: ComputePattern) -> None:
assert isinstance(compute_pattern, ComputePattern)
self.compute_pattern = compute_pattern
self.gather_out = gather_out
def __repr__(self):
return f'compute pattern: {self.compute_pattern}, gather out: {self.gather_out}'
return f'compute pattern: {self.compute_pattern}'
class TensorSpec(object):