[Tensor] rename some APIs in TensorSpec and Polish view unittest (#1176)

This commit is contained in:
Jiarui Fang
2022-06-27 15:56:11 +08:00
committed by GitHub
parent dd0420909f
commit 0dd4e2bbfb
9 changed files with 26 additions and 18 deletions

View File

@@ -72,10 +72,10 @@ def colo_addmm(input_tensor: GeneralTensor,
assert input_tensor.tensor_spec.is_replicate(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha))
elif mat2.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if mat2.tensor_spec.is_1D_row() and input_tensor.tensor_spec.is_replicate():
if mat2.tensor_spec.is_shard_1drow() and input_tensor.tensor_spec.is_replicate():
mode = 'row'
elif mat2.tensor_spec.is_1D_col() and (input_tensor.tensor_spec.is_1D_col()
or input_tensor.tensor_spec.is_1D_row()):
elif mat2.tensor_spec.is_shard_1dcol() and (input_tensor.tensor_spec.is_shard_1dcol()
or input_tensor.tensor_spec.is_shard_1drow()):
mode = 'col'
else:
raise NotImplementedError

View File

@@ -32,6 +32,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
compute_spec = weight.tensor_spec.compute_spec
if compute_spec.output_replicate:
return output.to_replicate()
else:
@@ -125,9 +126,9 @@ def colo_embedding(input_tensor: GeneralTensor,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse))
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.tensor_spec.is_1D_row():
if weight.tensor_spec.is_shard_1drow():
mode = 'row'
elif weight.tensor_spec.is_1D_col():
elif weight.tensor_spec.is_shard_1dcol():
mode = 'col'
else:
raise NotImplementedError

View File

@@ -104,7 +104,7 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
include_last_offset=include_last_offset,
padding_idx=padding_idx))
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.tensor_spec.is_1D_col():
if weight.tensor_spec.is_shard_1dcol():
tp_mode = 'col'
else:
raise NotImplementedError

View File

@@ -71,10 +71,10 @@ def colo_linear_imp(input_tensor: GeneralTensor,
assert bias is None or bias.tensor_spec.is_replicate(), 'Invalid bias spec for native Linear op'
ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias))
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.tensor_spec.is_1D_col() and (bias is None or bias.tensor_spec.is_replicate()):
if weight.tensor_spec.is_shard_1dcol() and (bias is None or bias.tensor_spec.is_replicate()):
mode = 'row'
elif weight.tensor_spec.is_1D_row() and (bias is None or bias.tensor_spec.is_1D_row()
or bias.tensor_spec.is_1D_col()):
elif weight.tensor_spec.is_shard_1drow() and (bias is None or bias.tensor_spec.is_shard_1drow()
or bias.tensor_spec.is_shard_1dcol()):
mode = 'col'
else:
raise RuntimeError(f"the weight or bias tensor spec is not valid, weight {weight.tensor_spec}, bias {bias}")

View File

@@ -29,7 +29,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
label_smoothing=label_smoothing)
return ColoTensor.from_torch_tensor(output)
elif input_tensor.has_compute_spec(): # Single Model Parallel Applied
if input_tensor.tensor_spec.is_1D_col():
if input_tensor.tensor_spec.is_shard_1dcol():
output = VocabParallelCrossEntropyLoss1D()(input_tensor, target)
return ColoTensor.from_torch_tensor(output)
else:

View File

@@ -116,6 +116,7 @@ class Chunk:
if self.is_src_rank:
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten())
tensor_state = TensorState.HOLD
assert type(self._payload) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape)
else:
tensor.storage().resize_(0)
@@ -131,6 +132,7 @@ class Chunk:
self._update_tensors_state(TensorState.FREE)
def _update_tensors_ptr(self) -> None:
assert type(self._payload) == torch.Tensor
for tensor, tensor_info in self.tensors_info.items():
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
@@ -228,7 +230,7 @@ class Chunk:
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
tensor_info = self.tensors_info[tensor]
self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1))
self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten())
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
@property

View File

@@ -54,5 +54,5 @@ def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int
assert process_group is not None
assert isinstance(dims, list) and isinstance(num_partitions, list)
assert len(dims) == len(num_partitions)
assert prod(num_partitions) == process_group.size()
assert prod(num_partitions) == process_group.size(), f"{num_partitions} {process_group.size()}"
return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions))

View File

@@ -32,11 +32,11 @@ class TensorSpec(object):
and self.dist_spec.num_partitions[0] == 1) \
or (self.dist_spec.process_group.size() == 1)
def is_1D_col(self):
def is_shard_1dcol(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
def is_1D_row(self):
def is_shard_1drow(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0