mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -26,12 +26,13 @@ class GeminiManager:
|
||||
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
placement_policy: str,
|
||||
chunk_manager: ChunkManager,
|
||||
memstats: Optional[MemStats] = None,
|
||||
**placement_kwargs) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
placement_policy: str,
|
||||
chunk_manager: ChunkManager,
|
||||
memstats: Optional[MemStats] = None,
|
||||
**placement_kwargs,
|
||||
) -> None:
|
||||
assert placement_policy in PlacementPolicyFactory.get_policy_names()
|
||||
self.policy_name = placement_policy
|
||||
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
||||
@@ -39,8 +40,9 @@ class GeminiManager:
|
||||
|
||||
self._premade_memstats_ = memstats is not None
|
||||
self._memstats = memstats
|
||||
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager,
|
||||
self._memstats) if policy_cls.need_mem_stats else None
|
||||
self._mem_stats_collector = (
|
||||
ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None
|
||||
)
|
||||
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs)
|
||||
self._compute_list: List[Tuple[Chunk, ...]] = []
|
||||
self._compute_idx: int = -1
|
||||
@@ -62,7 +64,7 @@ class GeminiManager:
|
||||
|
||||
@property
|
||||
def need_warmup(self) -> bool:
|
||||
return self.policy_name in ('auto', 'const')
|
||||
return self.policy_name in ("auto", "const")
|
||||
|
||||
def is_warmup(self):
|
||||
return self._warmup
|
||||
@@ -85,15 +87,14 @@ class GeminiManager:
|
||||
self._mem_stats_collector.start_collection()
|
||||
|
||||
def post_iter(self):
|
||||
"""This function must be called when each iteration finishes
|
||||
"""
|
||||
"""This function must be called when each iteration finishes"""
|
||||
if self._mem_stats_collector and self._warmup:
|
||||
self._mem_stats_collector.finish_collection()
|
||||
self._warmup = False
|
||||
self.reset_attributes()
|
||||
|
||||
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
|
||||
""" Adjust the layout of stateful tensors according to the information provided
|
||||
"""Adjust the layout of stateful tensors according to the information provided
|
||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||
"""
|
||||
# find stateful tensor in state COMPUTE
|
||||
@@ -102,11 +103,13 @@ class GeminiManager:
|
||||
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks)
|
||||
self._layout_time += time() - start
|
||||
|
||||
vol, evict_time = self._placement_policy.evict_tensors(can_evict_chunks=hold_cuda_tensor_list,
|
||||
cuda_demand=cuda_demand,
|
||||
warmup=self._warmup,
|
||||
compute_list=self._compute_list,
|
||||
compute_idx=self._compute_idx)
|
||||
vol, evict_time = self._placement_policy.evict_tensors(
|
||||
can_evict_chunks=hold_cuda_tensor_list,
|
||||
cuda_demand=cuda_demand,
|
||||
warmup=self._warmup,
|
||||
compute_list=self._compute_list,
|
||||
compute_idx=self._compute_idx,
|
||||
)
|
||||
|
||||
self._d2h_volume += vol
|
||||
self._evict_time += evict_time
|
||||
@@ -118,12 +121,12 @@ class GeminiManager:
|
||||
start = time()
|
||||
cuda_demand = 0
|
||||
for chunk in chunks:
|
||||
if chunk.device_type == 'cuda':
|
||||
if chunk.device_type == "cuda":
|
||||
if chunk.is_gathered:
|
||||
pass
|
||||
else:
|
||||
cuda_demand += chunk.chunk_mem - chunk.shard_mem
|
||||
elif chunk.device_type == 'cpu':
|
||||
elif chunk.device_type == "cpu":
|
||||
cuda_demand += chunk.chunk_mem
|
||||
else:
|
||||
raise RuntimeError
|
||||
@@ -159,6 +162,7 @@ class GeminiManager:
|
||||
def is_cuda_margin_mem_avail(self) -> bool:
|
||||
return self._placement_policy.need_mem_stats
|
||||
|
||||
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
|
||||
torch.device]) -> None:
|
||||
def setup_grads_device(
|
||||
self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device]
|
||||
) -> None:
|
||||
self._placement_policy.setup_grads_device(params, grads_device_map)
|
||||
|
Reference in New Issue
Block a user