mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[Tensor] rename some APIs in TensorSpec and Polish view unittest (#1176)
This commit is contained in:
@@ -116,6 +116,7 @@ class Chunk:
|
||||
if self.is_src_rank:
|
||||
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten())
|
||||
tensor_state = TensorState.HOLD
|
||||
assert type(self._payload) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
|
||||
tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape)
|
||||
else:
|
||||
tensor.storage().resize_(0)
|
||||
@@ -131,6 +132,7 @@ class Chunk:
|
||||
self._update_tensors_state(TensorState.FREE)
|
||||
|
||||
def _update_tensors_ptr(self) -> None:
|
||||
assert type(self._payload) == torch.Tensor
|
||||
for tensor, tensor_info in self.tensors_info.items():
|
||||
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
|
||||
@@ -228,7 +230,7 @@ class Chunk:
|
||||
data_slice (torch.Tensor): the tensor to be copied to the chunk
|
||||
"""
|
||||
tensor_info = self.tensors_info[tensor]
|
||||
self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1))
|
||||
self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten())
|
||||
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
|
||||
@property
|
||||
|
@@ -54,5 +54,5 @@ def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int
|
||||
assert process_group is not None
|
||||
assert isinstance(dims, list) and isinstance(num_partitions, list)
|
||||
assert len(dims) == len(num_partitions)
|
||||
assert prod(num_partitions) == process_group.size()
|
||||
assert prod(num_partitions) == process_group.size(), f"{num_partitions} {process_group.size()}"
|
||||
return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions))
|
||||
|
@@ -32,11 +32,11 @@ class TensorSpec(object):
|
||||
and self.dist_spec.num_partitions[0] == 1) \
|
||||
or (self.dist_spec.process_group.size() == 1)
|
||||
|
||||
def is_1D_col(self):
|
||||
def is_shard_1dcol(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
||||
|
||||
def is_1D_row(self):
|
||||
def is_shard_1drow(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||
|
||||
|
Reference in New Issue
Block a user