From 56b8863b87b2a8dabe2a9a5acbc0c1c6ae288493 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 2 Aug 2022 10:40:27 +0800 Subject: [PATCH] [zero] chunk manager allows filtering ex-large params (#1393) --- colossalai/gemini/chunk_mgr.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/colossalai/gemini/chunk_mgr.py b/colossalai/gemini/chunk_mgr.py index 2fb8772a0..4e236e5cd 100644 --- a/colossalai/gemini/chunk_mgr.py +++ b/colossalai/gemini/chunk_mgr.py @@ -1,4 +1,5 @@ import torch +import numpy as np from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable from collections import deque @@ -61,9 +62,6 @@ class ChunkManager: if isinstance(tensor, ColoTensor): assert tensor.get_process_group().dp_process_group() == self.process_group.dp_process_group( ), f"Chunk Manager can only manage ColoTensor with the same DP process group" - if self.chunk_size is not None and tensor.numel() > self.chunk_size: - raise ValueError( - f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})') try: # append the tensor to the last chunk self.chunk_groups[group_name][-1].append(tensor) @@ -71,7 +69,10 @@ class ChunkManager: # the except statement will be triggered when there is no chunk or # the last chunk in the chunk group is full # this will create a new chunk and allocate this chunk to its corresponding process - chunk_size = self.chunk_size or tensor.numel() + if self.chunk_size is not None and tensor.numel() > self.chunk_size: + chunk_size = tensor.numel() + else: + chunk_size = self.chunk_size or tensor.numel() src_rank = self._get_next_src_rank(group_name) chunk = Chunk(chunk_size, src_rank, @@ -263,7 +264,8 @@ class ChunkManager: def search_chunk_size(module: torch.nn.Module, search_range: int, n_grids: int, - min_chunk_size: Optional[int] = None) -> int: + min_chunk_size: Optional[int] = None, + filter_exlarge_params: bool = True) -> int: """ Search for the chunk size for optimal chunk utilization. @@ -278,6 +280,8 @@ class ChunkManager: assert search_range % n_grids == 0 # TODO(ver217): sort params and filter unused ones params_numel = [p.numel() for p in module.parameters()] + if filter_exlarge_params: + params_numel = _filter_exlarge_params(params_numel) max_param_numel = max(params_numel) if min_chunk_size is not None: assert min_chunk_size >= max_param_numel @@ -330,3 +334,11 @@ class ChunkManager: """ assert tensor not in self.tensor_chunk_map self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() + + +def _filter_exlarge_params(params_numel: List[int]) -> List[int]: + params_numel_arr = np.array(params_numel) + std = np.std(params_numel_arr) + mean = np.mean(params_numel_arr) + upper_limit = mean + 3 * std + return list(filter(lambda x: x <= upper_limit, params_numel))