diff --git a/colossalai/tensor/chunk.py b/colossalai/tensor/chunk.py index 8167bf507..5d228e937 100644 --- a/colossalai/tensor/chunk.py +++ b/colossalai/tensor/chunk.py @@ -268,3 +268,41 @@ class ChunkManager: for i, chunk in enumerate(group): msg += f'[{i}] {chunk}\n' return msg + + @staticmethod + def get_chunk_util(chunk_size: int, params_numel: List[int]) -> float: + assert len(params_numel) > 0 + total_size = 0 + total_utilized_size = 0 + cur_chunk_utilized_size = 0 + for size in params_numel: + assert chunk_size >= size + total_utilized_size += size + if total_size == 0 or cur_chunk_utilized_size + size > chunk_size: + total_size += chunk_size + cur_chunk_utilized_size = 0 + cur_chunk_utilized_size += size + return total_utilized_size / total_size + + @staticmethod + def search_chunk_size(module: torch.nn.Module, + search_range: int, + n_grids: int, + min_chunk_size: Optional[int] = None) -> int: + assert search_range % n_grids == 0 + # TODO(ver217): sort params and filter unused ones + params_numel = [p.numel() for p in module.parameters()] + max_param_numel = max(params_numel) + if min_chunk_size is not None: + assert min_chunk_size >= max_param_numel + else: + min_chunk_size = max_param_numel + step_size = search_range // n_grids + max_chunk_util = -1 + best_chunk_size = -1 + for chunk_size in range(min_chunk_size, min_chunk_size + search_range + 1, step_size): + chunk_util = ChunkManager.get_chunk_util(chunk_size, params_numel) + if chunk_util > max_chunk_util: + max_chunk_util = chunk_util + best_chunk_size = chunk_size + return best_chunk_size