mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 05:33:23 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -20,27 +20,28 @@ class ChunkManager:
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
|
||||
|
||||
self.device = init_device or get_current_device()
|
||||
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
|
||||
self.kwargs_config = chunk_configuration
|
||||
for k, v in self.kwargs_config.items():
|
||||
self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
|
||||
v['init_device'] = self.device
|
||||
self.dp_degree_chunk_size_dict[k] = v.pop("chunk_size")
|
||||
v["init_device"] = self.device
|
||||
|
||||
self.chunk_groups: Dict[str, Deque[Chunk]] = dict()
|
||||
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
|
||||
self.accessed_chunks: Set[Chunk] = set()
|
||||
self.accessed_mem: int = 0
|
||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||
self.total_mem: Dict[str, int] = {"cpu": 0, "cuda": 0}
|
||||
|
||||
def register_tensor(self,
|
||||
tensor: torch.Tensor,
|
||||
group_type: str,
|
||||
config_key: int,
|
||||
process_group: ProcessGroup,
|
||||
cpu_offload: bool = False,
|
||||
pin_memory: bool = False) -> None:
|
||||
def register_tensor(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
group_type: str,
|
||||
config_key: int,
|
||||
process_group: ProcessGroup,
|
||||
cpu_offload: bool = False,
|
||||
pin_memory: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Register a tensor to the chunk manager.
|
||||
Then, the tensor should be accessed by `get_chunks`.
|
||||
@@ -94,25 +95,22 @@ class ChunkManager:
|
||||
self.tensor_chunk_map[tensor] = chunk_group[-1]
|
||||
|
||||
def close_all_groups(self):
|
||||
"""Close all the chunks of all groups.
|
||||
"""
|
||||
"""Close all the chunks of all groups."""
|
||||
for group_name in self.chunk_groups:
|
||||
self.__close_one_chunk(self.chunk_groups[group_name][-1])
|
||||
|
||||
def access_chunk(self, chunk: Chunk) -> None:
|
||||
"""Make the chunk can be used for calculation.
|
||||
"""
|
||||
"""Make the chunk can be used for calculation."""
|
||||
if chunk in self.accessed_chunks:
|
||||
return
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
if chunk.device_type == 'cpu':
|
||||
if chunk.device_type == "cpu":
|
||||
chunk.shard_move(get_current_device())
|
||||
self.__add_accessed_chunk(chunk)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def release_chunk(self, chunk: Chunk) -> None:
|
||||
"""Scatter the chunk in CUDA.
|
||||
"""
|
||||
"""Scatter the chunk in CUDA."""
|
||||
if chunk not in self.accessed_chunks:
|
||||
return
|
||||
if chunk.can_release:
|
||||
@@ -121,8 +119,7 @@ class ChunkManager:
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
|
||||
"""Move the shard of the chunk to the target device.
|
||||
"""
|
||||
"""Move the shard of the chunk to the target device."""
|
||||
if not chunk.can_move or chunk.device_type == device.type:
|
||||
return
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
@@ -130,14 +127,12 @@ class ChunkManager:
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||
"""Transit tensor state according to pre-defined state machine.
|
||||
"""
|
||||
"""Transit tensor state according to pre-defined state machine."""
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.tensor_trans_state(tensor, state)
|
||||
|
||||
def reduce_chunk(self, chunk: Chunk) -> bool:
|
||||
"""Reduce or all reduce the chunk.
|
||||
"""
|
||||
"""Reduce or all reduce the chunk."""
|
||||
if not chunk.can_reduce:
|
||||
return False
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
@@ -213,18 +208,17 @@ class ChunkManager:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg = [
|
||||
'Chunk Manager Information:\n',
|
||||
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
|
||||
"Chunk Manager Information:\n",
|
||||
"Total memory: " + ", ".join([f"{k}={v}B" for k, v in self.total_mem.items()]) + "\n",
|
||||
]
|
||||
for group_name, group in self.chunk_groups.items():
|
||||
msg.append(f'Group {group_name}:\n')
|
||||
msg.append(f"Group {group_name}:\n")
|
||||
for i, chunk in enumerate(group):
|
||||
msg.append(f'[{i}] {chunk}\n')
|
||||
return ''.join(msg)
|
||||
msg.append(f"[{i}] {chunk}\n")
|
||||
return "".join(msg)
|
||||
|
||||
def __get_chunk_group(self, group_name: str) -> Deque[Chunk]:
|
||||
"""Register a chunk group.
|
||||
"""
|
||||
"""Register a chunk group."""
|
||||
if group_name not in self.chunk_groups:
|
||||
self.chunk_groups[group_name] = deque()
|
||||
return self.chunk_groups[group_name]
|
||||
|
Reference in New Issue
Block a user