[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -20,27 +20,28 @@ class ChunkManager:
"""
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
self.device = init_device or get_current_device()
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
self.kwargs_config = chunk_configuration
for k, v in self.kwargs_config.items():
self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
v['init_device'] = self.device
self.dp_degree_chunk_size_dict[k] = v.pop("chunk_size")
v["init_device"] = self.device
self.chunk_groups: Dict[str, Deque[Chunk]] = dict()
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
self.accessed_chunks: Set[Chunk] = set()
self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
self.total_mem: Dict[str, int] = {"cpu": 0, "cuda": 0}
def register_tensor(self,
tensor: torch.Tensor,
group_type: str,
config_key: int,
process_group: ProcessGroup,
cpu_offload: bool = False,
pin_memory: bool = False) -> None:
def register_tensor(
self,
tensor: torch.Tensor,
group_type: str,
config_key: int,
process_group: ProcessGroup,
cpu_offload: bool = False,
pin_memory: bool = False,
) -> None:
"""
Register a tensor to the chunk manager.
Then, the tensor should be accessed by `get_chunks`.
@@ -94,25 +95,22 @@ class ChunkManager:
self.tensor_chunk_map[tensor] = chunk_group[-1]
def close_all_groups(self):
"""Close all the chunks of all groups.
"""
"""Close all the chunks of all groups."""
for group_name in self.chunk_groups:
self.__close_one_chunk(self.chunk_groups[group_name][-1])
def access_chunk(self, chunk: Chunk) -> None:
"""Make the chunk can be used for calculation.
"""
"""Make the chunk can be used for calculation."""
if chunk in self.accessed_chunks:
return
self.__sub_memory_usage(chunk.memory_usage)
if chunk.device_type == 'cpu':
if chunk.device_type == "cpu":
chunk.shard_move(get_current_device())
self.__add_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)
def release_chunk(self, chunk: Chunk) -> None:
"""Scatter the chunk in CUDA.
"""
"""Scatter the chunk in CUDA."""
if chunk not in self.accessed_chunks:
return
if chunk.can_release:
@@ -121,8 +119,7 @@ class ChunkManager:
self.__add_memory_usage(chunk.memory_usage)
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
"""Move the shard of the chunk to the target device.
"""
"""Move the shard of the chunk to the target device."""
if not chunk.can_move or chunk.device_type == device.type:
return
self.__sub_memory_usage(chunk.memory_usage)
@@ -130,14 +127,12 @@ class ChunkManager:
self.__add_memory_usage(chunk.memory_usage)
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
"""Transit tensor state according to pre-defined state machine.
"""
"""Transit tensor state according to pre-defined state machine."""
chunk = self.tensor_chunk_map[tensor]
chunk.tensor_trans_state(tensor, state)
def reduce_chunk(self, chunk: Chunk) -> bool:
"""Reduce or all reduce the chunk.
"""
"""Reduce or all reduce the chunk."""
if not chunk.can_reduce:
return False
self.__sub_memory_usage(chunk.memory_usage)
@@ -213,18 +208,17 @@ class ChunkManager:
def __repr__(self) -> str:
msg = [
'Chunk Manager Information:\n',
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
"Chunk Manager Information:\n",
"Total memory: " + ", ".join([f"{k}={v}B" for k, v in self.total_mem.items()]) + "\n",
]
for group_name, group in self.chunk_groups.items():
msg.append(f'Group {group_name}:\n')
msg.append(f"Group {group_name}:\n")
for i, chunk in enumerate(group):
msg.append(f'[{i}] {chunk}\n')
return ''.join(msg)
msg.append(f"[{i}] {chunk}\n")
return "".join(msg)
def __get_chunk_group(self, group_name: str) -> Deque[Chunk]:
"""Register a chunk group.
"""
"""Register a chunk group."""
if group_name not in self.chunk_groups:
self.chunk_groups[group_name] = deque()
return self.chunk_groups[group_name]