From 54aabb8da461524a526302a1ec3b5d46b649feeb Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 22 Jun 2022 11:54:36 +0800 Subject: [PATCH] [gemini] refactor gemini mgr (#1151) * refactor gemini mgr * udpate __init__ --- colossalai/gemini/__init__.py | 3 ++- colossalai/gemini/gemini_mgr.py | 22 ++++++++++++---------- colossalai/gemini/placement_policy.py | 5 ++--- colossalai/nn/parallel/data_parallel.py | 2 +- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/colossalai/gemini/__init__.py b/colossalai/gemini/__init__.py index 8fe68cbb3..b3ea38935 100644 --- a/colossalai/gemini/__init__.py +++ b/colossalai/gemini/__init__.py @@ -1,4 +1,5 @@ from .stateful_tensor_mgr import StatefulTensorMgr from .tensor_placement_policy import TensorPlacementPolicyFactory +from .gemini_mgr import GeminiManager -__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory'] \ No newline at end of file +__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager'] diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py index 481761c37..42ae598db 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/gemini/gemini_mgr.py @@ -1,4 +1,5 @@ import torch +import functools from .memory_tracer.memstats_collector import MemStatsCollectorV2 from typing import List, Optional, Tuple from time import time @@ -23,10 +24,12 @@ class GeminiManager: self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 - self._cpu_gpu_move_volume = 0 + self._h2d_volume = 0 + self._d2h_volume = 0 self._layout_time = 0 self._evict_time = 0 self._warmup = True + self._comp_cuda_demand_time = 0 def pre_iter(self): if self._mem_stats_collector and self._warmup: @@ -39,9 +42,11 @@ class GeminiManager: self._mem_stats_collector.finish_collection() self._warmup = False self._compute_idx = -1 - self._cpu_gpu_move_volume = 0 + self._h2d_volume = 0 + self._d2h_volume = 0 self._layout_time = 0 self._evict_time = 0 + self._comp_cuda_demand_time = 0 def adjust_layout(self, chunks: Tuple[Chunk, ...], group_name: str) -> None: """ Adjust the layout of statefuil tensor according to the information provided @@ -57,22 +62,19 @@ class GeminiManager: warmup=self._warmup, compute_list=self._compute_list, compute_idx=self._compute_idx) - self._cpu_gpu_move_volume += vol + self._d2h_volume += vol self._evict_time += evict_time # move COMPUTE tensors to CUDA - self._cpu_gpu_move_volume += cuda_demand + self._h2d_volume += cuda_demand - @property - def cpu_gpu_move_volume(self): - return self._cpu_gpu_move_volume - - # @functools.lru_cache(maxsize=None) - # TODO: test lru + @functools.lru_cache(maxsize=None) def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_name: str): + start = time() cuda_demand = 0 for chunk in chunks: if chunk.device_type == 'cpu' or chunk.is_empty: cuda_demand += chunk.mem + self._comp_cuda_demand_time += time() - start can_evict_chunks = [] for chunk in self._chunk_manager.chunk_groups[group_name]: if not chunk.is_empty and chunk.device_type == 'cuda' and chunk.can_move_device: diff --git a/colossalai/gemini/placement_policy.py b/colossalai/gemini/placement_policy.py index 28b841c8d..7e8a0fc61 100644 --- a/colossalai/gemini/placement_policy.py +++ b/colossalai/gemini/placement_policy.py @@ -102,7 +102,7 @@ class AutoPlacementPolicy(PlacementPolicy): total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data freed_cuda_model_data = 0 - end = time() + if avail_cuda_model_data < cuda_demand: # Move cuda_demand - avail_cuda_model_data volume of tensors # to_free_cuda_model_data = cuda_demand - avail_cuda_model_data @@ -111,7 +111,6 @@ class AutoPlacementPolicy(PlacementPolicy): if not warmup: to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list)) # print(self._sort_can_evict_chunks.cache_info()) - end = time() for chunk in to_free_chunks: if freed_cuda_model_data >= to_free_cuda_model_data: break @@ -121,7 +120,7 @@ class AutoPlacementPolicy(PlacementPolicy): raise RuntimeError( f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" ) - return freed_cuda_model_data, end - start + return freed_cuda_model_data, time() - start @staticmethod @functools.lru_cache(maxsize=None) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 0c0a3e33a..f983a782c 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -225,7 +225,7 @@ class ZeroDDP(ColoDDP): self.chunk_manager.exec_lazy_release() self._setup_grads_ptr() self._logger.debug( - f'layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, PCIE move vol: {self.gemini_manager._cpu_gpu_move_volume}B' + f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' ) self.gemini_manager.post_iter()