[ColoTensor] rename APIs and add output_replicate to ComputeSpec (#1168)

This commit is contained in:
Jiarui Fang 2022-06-24 13:08:54 +08:00 committed by GitHub
parent f4ef224358
commit 4b9bba8116
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 116 additions and 105 deletions

View File

@ -13,34 +13,37 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# beta * input + alpha * All-Reduce(Output) = res # beta * input + alpha * All-Reduce(Output) = res
mat1 = mat1.convert_to_dist_spec( mat1 = mat1.convert_to_dist_spec(
distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()])) distspec.shard(mat2.tensor_spec.get_process_group(), [-1], [mat2.tensor_spec.get_process_group_size()]))
# Output:P # Output:P
partial_output = torch.mm(mat1, mat2) partial_output = torch.mm(mat1, mat2)
# Reduce(Output) # Reduce(Output)
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
# input # input
assert not input_tensor.has_spec(), 'Invalid input spec for 1Drow addmm op' assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
output = beta * input_tensor + alpha * output output = beta * input_tensor + alpha * output
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(mat2.spec.get_process_group()))) output = ColoTensor.from_torch_tensor(output,
spec=TensorSpec(distspec.replicate(mat2.tensor_spec.get_process_group())))
return output return output
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
alpha: Number) -> ColoTensor: alpha: Number) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1] # mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
parallel_action = mat2.spec.compute_spec compute_spec = mat2.tensor_spec.compute_spec
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group())) mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.tensor_spec.get_process_group()))
mat1 = reduce_grad(mat1, ParallelMode.PARALLEL_1D) mat1 = reduce_grad(mat1, ParallelMode.PARALLEL_1D)
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha) output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]), output_spec = TensorSpec(
distspec.shard(mat2.tensor_spec.get_process_group(), [-1], [mat2.tensor_spec.get_process_group_size()]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
# TODO(jiaruifang) addam is special case if compute_spec.output_replicate:
# since gpt call view after the Op.
return output.to_replicate() return output.to_replicate()
else:
return output
def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
@ -64,14 +67,15 @@ def colo_addmm(input_tensor: GeneralTensor,
# Add communication logic before and after linear call. # Add communication logic before and after linear call.
ret_tensor = None ret_tensor = None
if not mat2.has_spec(): # No Model Parallel Applied if not mat2.has_compute_spec(): # No Model Parallel Applied
assert mat2.spec.is_gathered(), 'Invalid mat2 spec for native addmm op' assert mat2.tensor_spec.is_gathered(), 'Invalid mat2 spec for native addmm op'
assert input_tensor.spec.is_gathered(), 'Invalid input spec for native addmm op' assert input_tensor.tensor_spec.is_gathered(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)) ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha))
elif mat2.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied elif mat2.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if mat2.spec.is_1D_row() and input_tensor.spec.is_gathered(): if mat2.tensor_spec.is_1D_row() and input_tensor.tensor_spec.is_gathered():
mode = 'row' mode = 'row'
elif mat2.spec.is_1D_col() and (input_tensor.spec.is_1D_col() or input_tensor.spec.is_1D_row()): elif mat2.tensor_spec.is_1D_col() and (input_tensor.tensor_spec.is_1D_col()
or input_tensor.tensor_spec.is_1D_row()):
mode = 'col' mode = 'col'
else: else:
raise NotImplementedError raise NotImplementedError

View File

@ -18,7 +18,7 @@ def register_elementwise_op(op):
""" """
output = op(input_tensor, *args, **kwargs) output = op(input_tensor, *args, **kwargs)
if isinstance(input_tensor, ColoTensor): if isinstance(input_tensor, ColoTensor):
spec = copy(input_tensor.spec) spec = copy(input_tensor.tensor_spec)
return ColoTensor.from_torch_tensor(output, spec=spec) return ColoTensor.from_torch_tensor(output, spec=spec)
return ColoTensor.from_torch_tensor(output) return ColoTensor.from_torch_tensor(output)

View File

@ -17,7 +17,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse: bool = False) -> ColoTensor: sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) # embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table # Gather splitted lookup table
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
output_parallel = F.embedding(input_tensor, output_parallel = F.embedding(input_tensor,
weight, weight,
@ -27,10 +27,15 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
scale_grad_by_freq=scale_grad_by_freq, scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse) sparse=sparse)
output_spec = TensorSpec( output_spec = TensorSpec(
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
compute_spec = weight.tensor_spec.compute_spec
if compute_spec.output_replicate:
return output.to_replicate() return output.to_replicate()
else:
return output
def colo_embedding_1Drow(input_tensor: ColoTensor, def colo_embedding_1Drow(input_tensor: ColoTensor,
@ -43,7 +48,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim) # embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here # Find index in this shard and mask those not here
# Reduce all # Reduce all
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
num_embeddings_per_partition = weight.size(0) num_embeddings_per_partition = weight.size(0)
@ -70,7 +75,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
partial_output[input_mask, :] = 0. partial_output[input_mask, :] = 0.
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs.
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group()))) output = ColoTensor.from_torch_tensor(output,
spec=TensorSpec(distspec.replicate(weight.tensor_spec.get_process_group())))
return output return output
@ -108,8 +114,8 @@ def colo_embedding(input_tensor: GeneralTensor,
# Handle differen parallel actions. # Handle differen parallel actions.
if not weight.has_spec(): # No Model Parallel Applied if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op' assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native embedding op'
return ColoTensor.from_torch_tensor( return ColoTensor.from_torch_tensor(
F.embedding(input_tensor, F.embedding(input_tensor,
weight, weight,
@ -118,10 +124,10 @@ def colo_embedding(input_tensor: GeneralTensor,
norm_type=norm_type, norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq, scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)) sparse=sparse))
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.spec.is_1D_row(): if weight.tensor_spec.is_1D_row():
mode = 'row' mode = 'row'
elif weight.spec.is_1D_col(): elif weight.tensor_spec.is_1D_col():
mode = 'col' mode = 'col'
else: else:
raise NotImplementedError raise NotImplementedError

View File

@ -19,7 +19,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
padding_idx: Optional[int] = None) -> ColoTensor: padding_idx: Optional[int] = None) -> ColoTensor:
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) # embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table # Gather splitted lookup table
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
output_parallel = F.embedding_bag(input_tensor, output_parallel = F.embedding_bag(input_tensor,
weight, weight,
@ -33,11 +33,14 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
include_last_offset=include_last_offset, include_last_offset=include_last_offset,
padding_idx=padding_idx) padding_idx=padding_idx)
output_spec = TensorSpec( output_spec = TensorSpec(
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if weight.tensor_spec.compute_spec.output_replicate:
return output.to_replicate() return output.to_replicate()
else:
return output
def colo_embedding_bag_1d(tp_mode: str, def colo_embedding_bag_1d(tp_mode: str,
@ -86,8 +89,8 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
# Handle differen parallel actions. # Handle differen parallel actions.
if not weight.has_spec(): # No Model Parallel Applied if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op' assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native embedding op'
return ColoTensor.from_torch_tensor( return ColoTensor.from_torch_tensor(
F.embedding_bag(input_tensor, F.embedding_bag(input_tensor,
weight, weight,
@ -100,8 +103,8 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
per_sample_weights=per_sample_weights, per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset, include_last_offset=include_last_offset,
padding_idx=padding_idx)) padding_idx=padding_idx))
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.spec.is_1D_col(): if weight.tensor_spec.is_1D_col():
tp_mode = 'col' tp_mode = 'col'
else: else:
raise NotImplementedError raise NotImplementedError

View File

@ -17,8 +17,8 @@ def colo_layernorm(
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias))) input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
# TODO (ver217): check dist spec # TODO (ver217): check dist spec
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.spec.get_process_group())) input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.tensor_spec.get_process_group()))
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps) output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
output = ColoTensor.from_torch_tensor(output, input_tensor.spec) output = ColoTensor.from_torch_tensor(output, input_tensor.tensor_spec)
return output return output

View File

@ -13,7 +13,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Reduce(Output) + bias = res # All-Reduce(Output) + bias = res
# Input:S[1] # Input:S[1]
input_tensor = input_tensor.convert_to_dist_spec( input_tensor = input_tensor.convert_to_dist_spec(
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()])) distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]))
# Output:P # Output:P
partial_output = F.linear(input_tensor, weight) partial_output = F.linear(input_tensor, weight)
@ -21,10 +21,11 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
# Bias # Bias
if bias is not None: if bias is not None:
assert not bias.has_spec(), 'Invalid bias spec for 1Drow Linear op' assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
output = output + bias output = output + bias
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group()))) output = ColoTensor.from_torch_tensor(output,
spec=TensorSpec(distspec.replicate(weight.tensor_spec.get_process_group())))
return output return output
@ -32,17 +33,20 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output) # All-Gather(Output)
# Input:B # Input:B
parallel_action = weight.spec.compute_spec compute_spec = weight.tensor_spec.compute_spec
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
input_parallel = reduce_grad(input_tensor, ParallelMode.PARALLEL_1D) input_parallel = reduce_grad(input_tensor, ParallelMode.PARALLEL_1D)
output_parallel = F.linear(input_parallel, weight, bias) output_parallel = F.linear(input_parallel, weight, bias)
output = ColoTensor.from_torch_tensor(output_parallel, output = ColoTensor.from_torch_tensor(output_parallel,
spec=TensorSpec( spec=TensorSpec(
distspec.shard(weight.spec.get_process_group(), [-1], distspec.shard(weight.tensor_spec.get_process_group(), [-1],
[weight.spec.get_process_group_size()]), [weight.tensor_spec.get_process_group_size()]),
ComputeSpec(ComputePattern.TP1D))) ComputeSpec(ComputePattern.TP1D)))
if compute_spec.output_replicate:
return output.to_replicate() return output.to_replicate()
else:
return output
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
@ -62,14 +66,15 @@ def colo_linear_imp(input_tensor: GeneralTensor,
# Add communication logic before and after linear call. # Add communication logic before and after linear call.
ret_tensor = None ret_tensor = None
if not weight.has_spec(): # No Model Parallel Applied if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.spec.is_gathered(), 'Invalid weight spec for native Linear op' assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native Linear op'
assert bias is None or bias.spec.is_gathered(), 'Invalid bias spec for native Linear op' assert bias is None or bias.tensor_spec.is_gathered(), 'Invalid bias spec for native Linear op'
ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias)) ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias))
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.spec.is_1D_col() and (bias is None or bias.spec.is_gathered()): if weight.tensor_spec.is_1D_col() and (bias is None or bias.tensor_spec.is_gathered()):
mode = 'row' mode = 'row'
elif weight.spec.is_1D_row() and (bias is None or bias.spec.is_1D_row() or bias.spec.is_1D_col()): elif weight.tensor_spec.is_1D_row() and (bias is None or bias.tensor_spec.is_1D_row()
or bias.tensor_spec.is_1D_col()):
mode = 'col' mode = 'col'
else: else:
raise NotImplementedError raise NotImplementedError

View File

@ -18,7 +18,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
label_smoothing: float = 0.0): label_smoothing: float = 0.0):
input_tensor, target, weight = tuple(map(convert_to_colo_tensor, (input_tensor, target, weight))) input_tensor, target, weight = tuple(map(convert_to_colo_tensor, (input_tensor, target, weight)))
if input_tensor.spec.is_gathered(): # Input is gathered if input_tensor.tensor_spec.is_gathered(): # Input is gathered
output = F.cross_entropy(input_tensor, output = F.cross_entropy(input_tensor,
target, target,
weight=weight, weight=weight,
@ -28,8 +28,8 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
reduction=reduction, reduction=reduction,
label_smoothing=label_smoothing) label_smoothing=label_smoothing)
return ColoTensor.from_torch_tensor(output) return ColoTensor.from_torch_tensor(output)
elif input_tensor.has_spec(): # Single Model Parallel Applied elif input_tensor.has_compute_spec(): # Single Model Parallel Applied
if input_tensor.spec.is_1D_col(): if input_tensor.tensor_spec.is_1D_col():
output = VocabParallelCrossEntropyLoss1D()(input_tensor, target) output = VocabParallelCrossEntropyLoss1D()(input_tensor, target)
return ColoTensor.from_torch_tensor(output) return ColoTensor.from_torch_tensor(output)
else: else:

View File

@ -38,8 +38,8 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
param = module.get_parameter(param_name) param = module.get_parameter(param_name)
if not isinstance(param, ColoParameter): if not isinstance(param, ColoParameter):
raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.') raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.')
if param.has_spec(): if param.has_compute_spec():
cur_compute_pattern = param.spec.compute_spec.compute_pattern cur_compute_pattern = param.tensor_spec.compute_spec.compute_pattern
if compute_pattern is None: if compute_pattern is None:
compute_pattern = cur_compute_pattern compute_pattern = cur_compute_pattern
else: else:
@ -61,8 +61,8 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
cur_match = True cur_match = True
for param_name, dist_spec in param_specs.items(): for param_name, dist_spec in param_specs.items():
param = module.get_parameter(param_name) param = module.get_parameter(param_name)
if param.has_spec(): if param.has_compute_spec():
if dist_spec != param.spec.dist_spec: if dist_spec != param.tensor_spec.dist_spec:
cur_match = False cur_match = False
break break
else: else:
@ -97,7 +97,7 @@ def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recu
param = module.get_parameter(param_name) param = module.get_parameter(param_name)
if isinstance(param, ColoParameter): if isinstance(param, ColoParameter):
spec = TensorSpec(dist_spec, parallel_action) spec = TensorSpec(dist_spec, parallel_action)
param.set_spec(spec) param.set_tensor_spec(spec)
for mod in param.shared_param_modules: for mod in param.shared_param_modules:
modules_update_param.add(mod) modules_update_param.add(mod)
for mod in modules_update_param: for mod in modules_update_param:

View File

@ -82,7 +82,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
else: else:
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
data = self.data.clone() data = self.data.clone()
tensor = ColoParameter(data, self.requires_grad, spec=copy(self.spec)) tensor = ColoParameter(data, self.requires_grad, spec=copy(self.tensor_spec))
memo[id(self)] = tensor memo[id(self)] = tensor
return tensor return tensor

View File

@ -57,15 +57,15 @@ class ColoTensor(torch.Tensor):
self._graph_node = None self._graph_node = None
@property @property
def spec(self) -> TensorSpec: def tensor_spec(self) -> TensorSpec:
return self._tensor_spec return self._tensor_spec
def set_spec(self, spec: TensorSpec) -> None: def set_tensor_spec(self, spec: TensorSpec) -> None:
spec = copy(spec) spec = copy(spec)
self._convert_to_dist_spec(spec.dist_spec) self._convert_to_dist_spec(spec.dist_spec)
self._tensor_spec = spec self._tensor_spec = spec
def has_spec(self) -> bool: def has_compute_spec(self) -> bool:
return self._tensor_spec.compute_spec is not None return self._tensor_spec.compute_spec is not None
def is_model_data(self) -> bool: def is_model_data(self) -> bool:
@ -100,27 +100,27 @@ class ColoTensor(torch.Tensor):
dist_spec (_DistSpec): the target dist. spec. dist_spec (_DistSpec): the target dist. spec.
""" """
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec) self.data = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, dist_spec)
self._tensor_spec.dist_spec = dist_spec self._tensor_spec.dist_spec = dist_spec
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor': def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
tensor_spec = copy(self._tensor_spec) tensor_spec = copy(self._tensor_spec)
tensor_spec.dist_spec = dist_spec tensor_spec.dist_spec = dist_spec
ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec) ret = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, dist_spec)
return ColoTensor.from_torch_tensor(ret, tensor_spec) return ColoTensor.from_torch_tensor(ret, tensor_spec)
def to_replicate_(self): def to_replicate_(self):
"""to_replicate_ """to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE an inline member function, converting dist spec of the tensor to REPLICATE
""" """
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, distspec.replicate()) self.data = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, distspec.replicate())
self._tensor_spec.dist_spec = distspec.replicate() self._tensor_spec.dist_spec = distspec.replicate()
def to_replicate(self) -> 'ColoTensor': def to_replicate(self) -> 'ColoTensor':
"""to_replicate """to_replicate
converting dist spec of the tensor to REPLICATE converting dist spec of the tensor to REPLICATE
""" """
return self.convert_to_dist_spec(distspec.replicate(self.spec.get_process_group())) return self.convert_to_dist_spec(distspec.replicate(self.tensor_spec.get_process_group()))
@staticmethod @staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor': def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
@ -134,16 +134,6 @@ class ColoTensor(torch.Tensor):
else: else:
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
data = self.data.clone() data = self.data.clone()
tensor = ColoTensor(data, spec=copy(self.spec)) tensor = ColoTensor(data, spec=copy(self.tensor_spec))
memo[id(self)] = tensor memo[id(self)] = tensor
return tensor return tensor
# TODO(jiaruifang) a patch for gpt test.
# We need to override the member function must operate on a replicated tensor
# def view(self, *args, **kwargs):
# self.data = DistSpecManager.handle_trans_spec(self,
# self.spec.dist_spec,
# distspec.replicate(self.spec.get_process_group()))
# # self._tensor_spec.dist_spec = distspec.replicate(self.spec.get_process_group())
# self.data.view(*args, **kwargs)
# return ColoTensor.from_torch_tensor(self.data)

View File

@ -18,6 +18,8 @@ class ComputeSpec(object):
def __init__(self, compute_pattern: ComputePattern) -> None: def __init__(self, compute_pattern: ComputePattern) -> None:
assert isinstance(compute_pattern, ComputePattern) assert isinstance(compute_pattern, ComputePattern)
self.compute_pattern = compute_pattern self.compute_pattern = compute_pattern
# Make sure output tensors are replicate
self.output_replicate = True
def __repr__(self): def __repr__(self):
return f'compute pattern: {self.compute_pattern}' return f'compute pattern: {self.compute_pattern}'

View File

@ -129,7 +129,7 @@ def _get_colo_tensors_info(*args) -> list:
info = [] info = []
for arg in args: for arg in args:
if isinstance(arg, ColoTensor): if isinstance(arg, ColoTensor):
info.append((arg.__class__, arg.spec)) info.append((arg.__class__, arg.tensor_spec))
else: else:
info.append(None) info.append(None)
return info return info

View File

@ -42,10 +42,10 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
has_dist_parameter = False has_dist_parameter = False
with torch.no_grad(): with torch.no_grad():
for param in self.parameters(): for param in self.parameters():
if isinstance(param, ColoParameter) and param.has_spec(): if isinstance(param, ColoParameter) and param.has_compute_spec():
has_dist_parameter = True has_dist_parameter = True
mapping[id(param)] = copy(param.spec) mapping[id(param)] = copy(param.tensor_spec)
param.set_spec(TensorSpec(distspec.replicate())) param.set_tensor_spec(TensorSpec(distspec.replicate()))
# TODO: fix when keep_vars = True # TODO: fix when keep_vars = True
# when keep_vars = False, the state_dict_func will call detach to create # when keep_vars = False, the state_dict_func will call detach to create
@ -62,7 +62,7 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
param_id = id(param) param_id = id(param)
if param_id in mapping: if param_id in mapping:
spec = mapping[id(param)] spec = mapping[id(param)]
param.set_spec(spec) param.set_tensor_spec(spec)
return ret return ret

View File

@ -43,7 +43,7 @@ def init_1d_row(weight, bias):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_tensor_spec(spec)
def init_1d_col(weight, bias): def init_1d_col(weight, bias):
@ -51,8 +51,8 @@ def init_1d_col(weight, bias):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_tensor_spec(spec)
bias.set_spec(spec) bias.set_tensor_spec(spec)
def run_with_spec(spec_init_func): def run_with_spec(spec_init_func):
@ -63,6 +63,7 @@ def run_with_spec(spec_init_func):
x = torch.rand(2, 16).cuda() x = torch.rand(2, 16).cuda()
out = model(x) out = model(x)
colo_out = torch.addmm(bias, x, weight) colo_out = torch.addmm(bias, x, weight)
colo_out = colo_out.to_replicate()
assert tensor_equal(out, colo_out) assert tensor_equal(out, colo_out)
grad = torch.rand_like(out) grad = torch.rand_like(out)
out.backward(grad) out.backward(grad)

View File

@ -20,7 +20,7 @@ def init_1d_col(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_tensor_spec(spec)
def run_with_spec(spec_init_func): def run_with_spec(spec_init_func):

View File

@ -20,7 +20,7 @@ def init_1d_row(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_tensor_spec(spec)
def init_1d_col(weight): def init_1d_col(weight):
@ -28,7 +28,7 @@ def init_1d_col(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_tensor_spec(spec)
def run_with_spec(spec_init_func): def run_with_spec(spec_init_func):

View File

@ -22,7 +22,7 @@ def init_1d_row_spec(model):
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n: if 'weight' in n and 'ln' not in n:
p.set_spec(spec) p.set_tensor_spec(spec)
def init_1d_col_spec(model): def init_1d_col_spec(model):
@ -32,7 +32,7 @@ def init_1d_col_spec(model):
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n): if 'ln' not in n and ('weight' in n or 'bias' in n):
p.set_spec(spec) p.set_tensor_spec(spec)
def check_param_equal(model, torch_model): def check_param_equal(model, torch_model):

View File

@ -21,7 +21,7 @@ def init_1d_row(weight, bias):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_tensor_spec(spec)
def init_1d_col(weight, bias): def init_1d_col(weight, bias):
@ -29,8 +29,8 @@ def init_1d_col(weight, bias):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_tensor_spec(spec)
bias.set_spec(spec) bias.set_tensor_spec(spec)
def run_with_spec(spec_init_func): def run_with_spec(spec_init_func):

View File

@ -23,7 +23,7 @@ def init_1d_row_linear(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_tensor_spec(spec)
def init_1d_col_linear(weight): def init_1d_col_linear(weight):
@ -31,7 +31,7 @@ def init_1d_col_linear(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_tensor_spec(spec)
def init_1d_row_embedding(weight): def init_1d_row_embedding(weight):
@ -39,7 +39,7 @@ def init_1d_row_embedding(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_tensor_spec(spec)
def init_1d_col_embedding(weight): def init_1d_col_embedding(weight):
@ -47,7 +47,7 @@ def init_1d_col_embedding(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_tensor_spec(spec)
def run_1d_hybrid_tp(model_name): def run_1d_hybrid_tp(model_name):

View File

@ -157,7 +157,7 @@ def run_check_shared_param():
col_spec = TensorSpec( col_spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
model.cls.predictions.bias.set_spec(col_spec) model.cls.predictions.bias.set_tensor_spec(col_spec)
try: try:
check_colo_module(model.cls.predictions.decoder, recursive=False) check_colo_module(model.cls.predictions.decoder, recursive=False)
except Exception as e: except Exception as e:

View File

@ -36,10 +36,10 @@ def test_layernorm():
def check_spec_eq(tensor, other): def check_spec_eq(tensor, other):
assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor) assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor)
for k in dir(tensor.spec.dist_spec): for k in dir(tensor.tensor_spec.dist_spec):
if not k.startswith('__'): if not k.startswith('__'):
assert hasattr(other.spec.dist_spec, k) assert hasattr(other.tensor_spec.dist_spec, k)
assert getattr(tensor.spec.dist_spec, k) == getattr(other.spec.dist_spec, k) assert getattr(tensor.tensor_spec.dist_spec, k) == getattr(other.tensor_spec.dist_spec, k)
def check_element_wise_ops(): def check_element_wise_ops():

View File

@ -66,7 +66,7 @@ def _run_tensor_shard_init(world_size):
shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size]) shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size])
tensor_spec = TensorSpec(shard_spec) tensor_spec = TensorSpec(shard_spec)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_spec(TensorSpec(dist_spec=distspec.replicate())) t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))
assert t.shape == torch.Size((4 * world_size, 5)) assert t.shape == torch.Size((4 * world_size, 5))

View File

@ -51,7 +51,7 @@ def init_1d_row_spec(model):
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n: if 'weight' in n and 'ln' not in n:
p.set_spec(spec) p.set_tensor_spec(spec)
def init_1d_col_spec(model): def init_1d_col_spec(model):
@ -61,7 +61,7 @@ def init_1d_col_spec(model):
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n): if 'ln' not in n and ('weight' in n or 'bias' in n):
p.set_spec(spec) p.set_tensor_spec(spec)
@parameterize('use_chunk', [False, True]) @parameterize('use_chunk', [False, True])