[NFC] polish colossalai/gemini/update/chunkv2.py code style (#1565)

This commit is contained in:
Zangwei Zheng 2022-09-08 16:23:41 +08:00 committed by Frank Lee
parent f586887a90
commit 9823cbf24b

View File

@ -9,6 +9,7 @@ from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkF
class ChunkV2:
def __init__(self,
chunk_size: int,
process_group: ColoProcessGroup,
@ -177,9 +178,7 @@ class ChunkV2:
shard_dev = torch.device('cpu')
if self.pin_memory or shard_dev.type == 'cpu':
self.cpu_shard = torch.empty(self.shard_size,
dtype=self.dtype,
pin_memory=self.pin_memory)
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
self.cpu_shard.copy_(self.cuda_shard)
self.cpu_vis_flag = True # cpu_shard has been visited
@ -260,8 +259,7 @@ class ChunkV2:
# we use all-reduce here
dist.all_reduce(self.chunk_total, group=self.torch_pg)
else:
self.cuda_shard = torch.empty(
self.shard_size, dtype=self.dtype, device=get_current_device())
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
@ -346,8 +344,7 @@ class ChunkV2:
self.chunk_total = self.cuda_shard
else:
alloc_storage(self.chunk_total)
gather_list = list(torch.chunk(
input=self.chunk_total, chunks=self.pg_size, dim=0))
gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0))
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
self.cuda_shard = None
@ -361,9 +358,7 @@ class ChunkV2:
# sanity check
assert self.cuda_shard is None
self.cuda_shard = torch.empty(self.shard_size,
dtype=self.dtype,
device=self.chunk_total.device)
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.chunk_total.device)
self.cuda_shard.copy_(self.chunk_total[self.shard_begin:self.shard_end])
@ -412,15 +407,15 @@ class ChunkV2:
def __repr__(self, detailed: bool = False):
output = [
"AgChunk Information:\n",
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(
self.chunk_size, self.dtype, self.pg_size),
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype,
self.pg_size),
"\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format(
self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size)
]
def print_tensor(tensor, prefix=''):
output.append("{}shape: {}, dtype: {}, device: {}\n".format(
prefix, tensor.shape, tensor.dtype, tensor.device))
output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype,
tensor.device))
if self.chunk_temp is not None:
output.append("\tchunk temp:\n")