diff --git a/colossalai/gemini/__init__.py b/colossalai/gemini/__init__.py index a82640d67..746b3e02a 100644 --- a/colossalai/gemini/__init__.py +++ b/colossalai/gemini/__init__.py @@ -1,6 +1,10 @@ -from .chunk import TensorInfo, TensorState +from .chunk import TensorInfo, Chunk, TensorState +from .chunk_mgr import ChunkManager from .stateful_tensor_mgr import StatefulTensorMgr from .tensor_placement_policy import TensorPlacementPolicyFactory from .gemini_mgr import GeminiManager -__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState'] +__all__ = [ + 'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'ChunkManager', 'TensorInfo', 'Chunk', + 'TensorState' +] diff --git a/colossalai/gemini/chunk.py b/colossalai/gemini/chunk.py new file mode 100644 index 000000000..b454fc988 --- /dev/null +++ b/colossalai/gemini/chunk.py @@ -0,0 +1,316 @@ +import torch +import torch.distributed as dist +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Dict, List + +from colossalai.utils import get_current_device +from colossalai.tensor import ProcessGroup as ColoProcessGroup + + +class TensorState(Enum): + FREE = 0 + COMPUTE = 1 + HOLD = 2 + HOLD_AFTER_BWD = 3 + READY_FOR_REDUCE = 4 + + +STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), + (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), + (TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), + (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), + (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, + TensorState.HOLD)) + + +@dataclass +class TensorInfo: + state: TensorState + offset: int + end: int + + +class ChunkFullError(Exception): + pass + + +def is_storage_empty(tensor: torch.Tensor) -> bool: + return tensor.storage().size() == 0 + + +def free_storage(tensor: torch.Tensor) -> None: + if not is_storage_empty(tensor): + tensor.storage().resize_(0) + + +def alloc_storage(tensor: torch.Tensor) -> None: + if is_storage_empty(tensor): + tensor.storage().resize_(tensor.numel()) + + +class Chunk: + """ + A chunk is a contiguous memory space which contains multiple tensors. + + Args: + chunk_size (int): the number of elements in a chunk + src_rank (int): the process which owns the chunk + dtype (torch.dtype): the data type of the chunk + init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU. + force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False. + """ + + def __init__(self, + chunk_size: int, + src_rank: int, + process_group: ColoProcessGroup, + dtype: torch.dtype, + init_device: Optional[torch.device] = None, + force_data_on_cuda: bool = False) -> None: + self.size = chunk_size + self.utilized_size = 0 + self.src_rank = src_rank + self.process_group = process_group + self.is_src_rank = process_group.dp_local_rank() == src_rank + self.global_src_rank = process_group.get_ranks_in_dp()[src_rank] + self.dtype = dtype + device = init_device or get_current_device() + if force_data_on_cuda: + self.data = torch.empty(chunk_size, dtype=dtype, device=get_current_device()) + self._cpu_data = torch.empty(chunk_size, dtype=dtype) + if device.type == 'cuda': + free_storage(self._cpu_data) + else: + free_storage(self.data) + else: + self.data = torch.empty(chunk_size, dtype=dtype, device=device) + self._cpu_data = None + + # we only keep the chunk in full in the process by which the tensor is owned + if not self.is_src_rank: + free_storage(self._payload) + + # each tensor is associated with a TensorInfo to track meta info + self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} + self.mem = self.size * self.data.element_size() + + def append(self, tensor: torch.Tensor) -> None: + """ + Add a tensor to the chunk. + + Args: + tensor (torch.Tensor): a tensor to be added to the chunk + """ + assert tensor.dtype == self.dtype + new_utilized_size = self.utilized_size + tensor.numel() + + # raise exception when the chunk size is exceeded + if new_utilized_size > self.size: + raise ChunkFullError + + # set tensor state + tensor_state = TensorState.FREE + + # if the process owns the rank, then copy the tensor to its chunk buffer + # otherwise set its storage size to 0 to reduce memory consumption + if self.is_src_rank: + self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten()) + tensor_state = TensorState.HOLD + assert type(self._payload) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" + tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape) + else: + tensor.storage().resize_(0) + self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size) + self.utilized_size = new_utilized_size + + def release(self) -> None: + """ + Release the memory space on processes which do not own the chunk. + """ + if not self.is_src_rank: + free_storage(self._payload) + self._update_tensors_state(TensorState.FREE) + + def _update_tensors_ptr(self) -> None: + assert type(self._payload) == torch.Tensor + for tensor, tensor_info in self.tensors_info.items(): + tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape) + + def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None): + for tensor_info in self.tensors_info.values(): + if prev_state is None or tensor_info.state == prev_state: + tensor_info.state = next_state + + def access(self) -> None: + """ + Broadcast the chunk to synchronize the tensors across data parallel processes. + """ + # recover the chunk on non-owner processes + # and broadcast the chunk from the source to all processes + if not self.is_src_rank: + alloc_storage(self._payload) + self.move_device(get_current_device(), update_ptr=False) + dist.broadcast(self.data, self.global_src_rank, group=self.process_group.dp_process_group()) + + # update tensor meta info + self._update_tensors_ptr() + if not self.is_src_rank: + self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE) + + def move_device(self, device: torch.device, update_ptr: bool = True) -> None: + """ + Move the chunk to a target device. + + Args: + device (torch.device): the target device for data movement. + """ + if self._payload.device == device: + return + if self._cpu_data is None: + self.data.data = self.data.to(device) + else: + if device.type == 'cuda': + # cpu -> cuda + src = self._cpu_data + dest = self.data + else: + # cuda -> cpu + src = self.data + dest = self._cpu_data + alloc_storage(dest) + dest.copy_(src) + free_storage(src) + + if update_ptr: + self._update_tensors_ptr() + + def reduce(self, is_all_reduce: bool = False) -> None: + """ + Reduce or all-reduce the chunk. + + Args: + is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false. + """ + self.move_device(get_current_device(), update_ptr=False) + if is_all_reduce: + dist.all_reduce(self.data, group=self.process_group.dp_process_group()) + else: + dist.reduce(self.data, self.global_src_rank, group=self.process_group.dp_process_group()) + self._update_tensors_ptr() + self._update_tensors_state(TensorState.HOLD) + + def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: + """ + Make a transition of the tensor into the next state. + + Args: + tensor (torch.Tensor): a torch Tensor object. + tensor_state (TensorState): the target state for transition. + """ + + # As the gradient hook can be triggered either before or after post-backward + # tensor's state can be compute -> hold_after_bwd -> ready_for_reduce + # or compute -> ready_for_reduce -> hold_after_bwd + # the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd + # this function only apply valid state transformation + # invalid calls will be ignored and nothing changes + if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: + # print( + # f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}' + # ) + return + self.tensors_info[tensor].state = tensor_state + + def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: + """ + Copy data slice to the memory space indexed by the input tensor in the chunk. + + Args: + tensor (torch.Tensor): the tensor used to retrive meta information + data_slice (torch.Tensor): the tensor to be copied to the chunk + """ + tensor_info = self.tensors_info[tensor] + self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten()) + tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape) + + @property + def can_release(self) -> bool: + """ + Check whether the chunk can be released. + """ + for tensor_info in self.tensors_info.values(): + if tensor_info.state != TensorState.HOLD: + return False + return True + + @property + def can_move_device(self) -> bool: + """ + Check whether the chunk can be moved across devices. + """ + for tensor_info in self.tensors_info.values(): + if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE): + return False + return True + + @property + def can_reduce(self) -> bool: + """ + Check whether the chunk can be reduced. + """ + for tensor_info in self.tensors_info.values(): + if tensor_info.state != TensorState.READY_FOR_REDUCE: + return False + return True + + @property + def is_empty(self) -> bool: + """ + Check whether the chunk is empty. + """ + return is_storage_empty(self._payload) + + def __repr__(self) -> str: + return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}' + + @property + def has_inf_or_nan(self) -> bool: + """ + Check if the chunk has inf or nan values. + """ + return torch.isinf(self._payload[:self.utilized_size]).any().item() or \ + torch.isnan(self._payload[:self.utilized_size]).any().item() + + def copy_(self, dest_chunk: 'Chunk'): + """ + Copy the data of this chunk to a destination chunk. + """ + assert not self.is_empty + assert not dest_chunk.is_empty + assert self.size == dest_chunk.size + assert self.utilized_size == dest_chunk.utilized_size + self._payload.copy_(dest_chunk._payload) + self._update_tensors_ptr() + + @property + def device_type(self) -> str: + """ + Get the device type of the chunk. + """ + return self._payload.device.type + + def __hash__(self) -> int: + return hash(id(self)) + + def __eq__(self, __o: object) -> bool: + return self is __o + + def get_tensors(self) -> List[torch.Tensor]: + return list(self.tensors_info.keys()) + + @property + def _payload(self) -> torch.Tensor: + if self._cpu_data is None or is_storage_empty(self._cpu_data): + return self.data + return self._cpu_data diff --git a/colossalai/gemini/chunk/__init__.py b/colossalai/gemini/chunk/__init__.py deleted file mode 100644 index 8468a6815..000000000 --- a/colossalai/gemini/chunk/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .chunk import TensorState, TensorInfo, ChunkFullError, Chunk -from .manager import ChunkManager -from .search_utils import clasify_params, search_chunk_configuration diff --git a/colossalai/gemini/chunk_mgr.py b/colossalai/gemini/chunk_mgr.py new file mode 100644 index 000000000..4e236e5cd --- /dev/null +++ b/colossalai/gemini/chunk_mgr.py @@ -0,0 +1,344 @@ +import torch +import numpy as np +from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable +from collections import deque + +from colossalai.utils import get_current_device +from colossalai.tensor import ProcessGroup as ColoProcessGroup, ColoTensor +from .chunk import Chunk, ChunkFullError, TensorState + + +class ChunkManager: + """ + A manager class to manipulate the tensors in chunks. + + Args: + chunk_size (int): the size of a chunk. + process_group (ColoProcessGroup): process group of the chunk. + enable_distributed_storage (bool): optional, allow for distributed storage of a chunk. The default is false. + init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. + """ + + def __init__(self, + chunk_size: Optional[int], + process_group: ColoProcessGroup, + enable_distributed_storage: bool = False, + init_device: Optional[torch.device] = None) -> None: + assert chunk_size is None or chunk_size > 0 + assert isinstance(process_group, ColoProcessGroup) + self.chunk_size = chunk_size + self.process_group = process_group + self.enable_distributed_storage = enable_distributed_storage + self.device = init_device or get_current_device() + self.chunk_groups: Dict[str, Deque[Chunk]] = {} + self.groups_force_data_on_cuda: Dict[str, bool] = {} + self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = {} + self.accessed_chunks: Set[Chunk] = set() + self.lazy_release_tensors: List[torch.Tensor] = [] + if enable_distributed_storage and chunk_size is None: + self.rank_load: Dict[str, torch.Tensor] = {} + self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} + + def create_group(self, group_name: str, force_data_on_cuda: bool = False) -> None: + """Create a chunk group. + + Args: + group_name (str): group name + force_data_on_cuda (bool, optional): If True, the data of chunks in this group is always on cuda.. Defaults to False. + """ + assert group_name not in self.chunk_groups + self.chunk_groups[group_name] = deque() + self.groups_force_data_on_cuda[group_name] = force_data_on_cuda + + def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None: + """ + Append a tensor to a chunk. + + Args: + tensor (torch.Tensor): a tensor to append to the chunk. + group_name (str): the name of the chunk group. + """ + assert tensor not in self.tensor_chunk_map + if isinstance(tensor, ColoTensor): + assert tensor.get_process_group().dp_process_group() == self.process_group.dp_process_group( + ), f"Chunk Manager can only manage ColoTensor with the same DP process group" + try: + # append the tensor to the last chunk + self.chunk_groups[group_name][-1].append(tensor) + except (IndexError, ChunkFullError): + # the except statement will be triggered when there is no chunk or + # the last chunk in the chunk group is full + # this will create a new chunk and allocate this chunk to its corresponding process + if self.chunk_size is not None and tensor.numel() > self.chunk_size: + chunk_size = tensor.numel() + else: + chunk_size = self.chunk_size or tensor.numel() + src_rank = self._get_next_src_rank(group_name) + chunk = Chunk(chunk_size, + src_rank, + self.process_group, + tensor.dtype, + self.device, + force_data_on_cuda=self.groups_force_data_on_cuda[group_name]) + + if self.enable_distributed_storage and self.chunk_size is None: + self.rank_load[group_name][src_rank] += chunk_size + + self.chunk_groups[group_name].append(chunk) + chunk.append(tensor) + if not chunk.is_empty: + self.total_mem[chunk.device_type] += chunk.mem + self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1] + if not self.enable_distributed_storage: + # as distributed storage is not enabled, there is no need to broadcast + # chunks, thus we set these chunks as accessed + self.accessed_chunks.add(self.chunk_groups[group_name][-1]) + + def _get_next_src_rank(self, group_name: str) -> int: + if not self.enable_distributed_storage: + # the chunk is owned by the current rank if no distributed storage is enabled + return self.process_group.dp_local_rank() + if self.chunk_size is None: + if group_name not in self.rank_load: + self.rank_load[group_name] = torch.zeros(self.process_group.dp_world_size(), dtype=torch.int64) + + # the process owning the tensor will be the process with the smallest number of elements + src_rank = torch.argmin(self.rank_load[group_name]).item() + else: + # chunk is owned by processes in a round-robin fashion + chunk_idx = len(self.chunk_groups[group_name]) + src_rank = chunk_idx % self.process_group.dp_world_size() + return src_rank + + def access_chunk(self, chunk: Chunk) -> None: + """ + Synchronize the chunks via broadcast. + + Args: + chunk (Chunk): the chunk to synchronize. + """ + if chunk in self.accessed_chunks: + if chunk.device_type != 'cuda': + self.total_mem[chunk.device_type] -= chunk.mem + chunk.move_device(get_current_device()) + self.total_mem[chunk.device_type] += chunk.mem + return + if not chunk.is_empty: + # as tensor is moved to the target device + # the memory consumption of the original device is reduced + self.total_mem[chunk.device_type] -= chunk.mem + chunk.access() + self.accessed_chunks.add(chunk) + self.total_mem[chunk.device_type] += chunk.mem + + def release_chunk(self, chunk: Chunk) -> None: + """ + Release the memory space of a chunk. + + Args: + chunk (Chunk): the chunk to release memory space + """ + + if not self.enable_distributed_storage: + return + if chunk not in self.accessed_chunks: + return + if chunk.can_release: + chunk.release() + self.accessed_chunks.remove(chunk) + if chunk.is_empty: + # update the memory consumption after releasing + self.total_mem[chunk.device_type] -= chunk.mem + + def move_chunk(self, chunk: Chunk, device: torch.device, update_ptr: bool = True) -> None: + """ + Move the chunk to the target device. + + Args: + chunk (Chunk): the chunk to move to target device + device (torch.device): target device + """ + if chunk.device_type == device.type: + return + if chunk.can_move_device and not chunk.is_empty: + self.total_mem[chunk.device_type] -= chunk.mem + chunk.move_device(device, update_ptr=update_ptr) + self.total_mem[chunk.device_type] += chunk.mem + + def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: + """ + Transit tensor state according to pre-defined state machine. + + Args: + tensor (torch.Tensor): the tensor for state transititon + state (TensorState): next tensor state for transtition + """ + chunk = self.tensor_chunk_map[tensor] + chunk.tensor_trans_state(tensor, state) + + def reduce_chunk(self, chunk: Chunk) -> bool: + """ + Reduce or all reduce the chunk. If enable_distributed_storage is true, all-reduce is used. + Otherwise, this method uses reduce. + + Args: + chunk (Chunk): the chunk for reduction. + """ + if not chunk.can_reduce: + return False + self.total_mem[chunk.device_type] -= chunk.mem + chunk.reduce(is_all_reduce=not self.enable_distributed_storage) + self.total_mem[chunk.device_type] += chunk.mem + return True + + def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None: + """ + Copy data to the chunk. + + Args: + tensor (torch.Tensor): the tensor used to retrive meta information + data (torch.Tensor): the tensor to be copied to the chunk + """ + chunk = self.tensor_chunk_map[tensor] + chunk.copy_tensor_to_chunk_slice(tensor, data) + + def get_chunk(self, tensor: torch.Tensor) -> Chunk: + """ + Return the chunk owning the tensor. + + Args: + tensor (torch.Tensor): a torch tensor object + """ + return self.tensor_chunk_map[tensor] + + def add_lazy_release_tensors(self, tensors: List[torch.Tensor]) -> None: + """ + Add tensors to the buffer for lazy release. + + Args: + tensors (List[torch.Tensor]): the tensors to be released lazily + """ + self.lazy_release_tensors.extend(tensors) + + def exec_lazy_release(self) -> None: + """ + Execute release for tensors added to the lazy release buffer. + """ + + for chunk in self.get_chunks(self.lazy_release_tensors): + self.release_chunk(chunk) + self.lazy_release_tensors.clear() + + def __repr__(self) -> str: + msg = f'Rank {self.process_group.dp_local_rank()}:\n' + msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' + for group_name, group in self.chunk_groups.items(): + msg += f'Group {group_name}:\n' + for i, chunk in enumerate(group): + msg += f'[{i}] {chunk}\n' + return msg + + @staticmethod + def get_chunk_util(chunk_size: int, params_numel: List[int]) -> float: + """ + Calculate the utilization rate of a chunk. + + Args: + chunk_size (int): the size of a chunk + params_numel (List[int]): the list of integers representing the number of elements of parameters + """ + assert len(params_numel) > 0 + total_size = 0 + total_utilized_size = 0 + cur_chunk_utilized_size = 0 + for size in params_numel: + assert chunk_size >= size + total_utilized_size += size + if total_size == 0 or cur_chunk_utilized_size + size > chunk_size: + total_size += chunk_size + cur_chunk_utilized_size = 0 + cur_chunk_utilized_size += size + return total_utilized_size / total_size + + @staticmethod + def search_chunk_size(module: torch.nn.Module, + search_range: int, + n_grids: int, + min_chunk_size: Optional[int] = None, + filter_exlarge_params: bool = True) -> int: + """ + Search for the chunk size for optimal chunk utilization. + + Args: + module (torch.nn.Module): a torch module object + search_range (int): the range of chunk size to search. The actual search range will be from + max(min_chunk_size, max_param_size) to max(min_chunk_size, max_param_size) + search_range. + n_grids (int): the number of intervals in the search range + min_chunk_size (int): optional, the minimum size for a chunk. The default is None. + + """ + assert search_range % n_grids == 0 + # TODO(ver217): sort params and filter unused ones + params_numel = [p.numel() for p in module.parameters()] + if filter_exlarge_params: + params_numel = _filter_exlarge_params(params_numel) + max_param_numel = max(params_numel) + if min_chunk_size is not None: + assert min_chunk_size >= max_param_numel + else: + min_chunk_size = max_param_numel + step_size = search_range // n_grids + max_chunk_util = -1 + best_chunk_size = -1 + for chunk_size in range(min_chunk_size, min_chunk_size + search_range + 1, step_size): + chunk_util = ChunkManager.get_chunk_util(chunk_size, params_numel) + if chunk_util > max_chunk_util: + max_chunk_util = chunk_util + best_chunk_size = chunk_size + return best_chunk_size + + def copy_chunk_group(self, dest_group_name: str, src_group_name: str): + """ + Copy chunk data from one group to another group. + + Args: + dest_group_name (str): the destination group which receives the copied data + src_group_name (str): the source group which provides the data to copy + """ + for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]): + if not dest_chunk.is_empty: + dest_chunk.copy_(src_chunk) + + def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]: + """ + Get all chunks owning the input tensors. + + Args: + tensors (Iterable[torch.Tensor]): the tensors used to look for chunks + """ + chunks = [] + for tensor in tensors: + chunk = self.get_chunk(tensor) + if chunk not in chunks: + chunks.append(chunk) + return tuple(chunks) + + def add_extern_static_tensor(self, tensor: torch.Tensor) -> None: + """Add extern static tensor to chunk manager. + Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them. + They are "static", which means their shape, dtype, device never change. + Thus, their memory usage never changes. + + Args: + tensor (torch.Tensor): An extern static tensor. E.g. optimizer state. + """ + assert tensor not in self.tensor_chunk_map + self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() + + +def _filter_exlarge_params(params_numel: List[int]) -> List[int]: + params_numel_arr = np.array(params_numel) + std = np.std(params_numel_arr) + mean = np.mean(params_numel_arr) + upper_limit = mean + 3 * std + return list(filter(lambda x: x <= upper_limit, params_numel)) diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py index 0bdddd9a7..4717e6f24 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/gemini/gemini_mgr.py @@ -3,7 +3,7 @@ import functools from .memory_tracer.memstats_collector import MemStatsCollectorV2 from typing import List, Optional, Tuple from time import time -from colossalai.gemini.chunk import Chunk, ChunkManager +from colossalai.gemini import Chunk, ChunkManager from .placement_policy import PlacementPolicyFactory @@ -56,44 +56,37 @@ class GeminiManager: self._evict_time = 0 self._comp_cuda_demand_time = 0 - def adjust_layout(self, chunks: Tuple[Chunk, ...], group_type: str) -> None: + def adjust_layout(self, chunks: Tuple[Chunk, ...], group_name: str) -> None: """ Adjust the layout of statefuil tensor according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. """ # find stateful tensor in state COMPUTE start = time() self._record_chunks_order(chunks) - cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks, group_type) + cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks, group_name) self._layout_time += time() - start - - vol, evict_time = self._placement_policy.evict_tensors(can_evict_chunks=hold_cuda_tensor_list, + vol, evict_time = self._placement_policy.evict_tensors(hold_cuda_tensor_list, cuda_demand=cuda_demand, warmup=self._warmup, compute_list=self._compute_list, compute_idx=self._compute_idx) - self._d2h_volume += vol self._evict_time += evict_time # move COMPUTE tensors to CUDA self._h2d_volume += cuda_demand @functools.lru_cache(maxsize=None) - def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_type: str): + def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_name: str): start = time() cuda_demand = 0 for chunk in chunks: - if chunk.device_type == 'cuda': - if chunk.is_gathered: - pass - else: - cuda_demand += chunk.chunk_mem - chunk.shard_mem - elif chunk.device_type == 'cpu': - cuda_demand += chunk.chunk_mem - else: - raise RuntimeError + if chunk.device_type == 'cpu' or chunk.is_empty: + cuda_demand += chunk.mem self._comp_cuda_demand_time += time() - start - - can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks(group_type) + can_evict_chunks = [] + for chunk in self._chunk_manager.chunk_groups[group_name]: + if not chunk.is_empty and chunk.device_type == 'cuda' and chunk.can_move_device: + can_evict_chunks.append(chunk) return cuda_demand, can_evict_chunks def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None: diff --git a/colossalai/gemini/memory_tracer/memstats_collector.py b/colossalai/gemini/memory_tracer/memstats_collector.py index 4366956fe..1ff88bd3f 100644 --- a/colossalai/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/gemini/memory_tracer/memstats_collector.py @@ -2,7 +2,7 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor from colossalai.utils.memory import colo_device_memory_used, colo_device_memory_capacity from colossalai.utils import get_current_device from colossalai.gemini.stateful_tensor import StatefulTensor -from colossalai.gemini.chunk import ChunkManager +from colossalai.gemini import ChunkManager import torch import time diff --git a/colossalai/gemini/placement_policy.py b/colossalai/gemini/placement_policy.py index 1a7e172ed..ec6afbc07 100644 --- a/colossalai/gemini/placement_policy.py +++ b/colossalai/gemini/placement_policy.py @@ -8,7 +8,7 @@ from colossalai.utils.memory import colo_device_memory_capacity from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollectorV2 from typing import Type import functools -from colossalai.gemini.chunk import Chunk, ChunkManager +from colossalai.gemini import Chunk, ChunkManager class PlacementPolicy(ABC): @@ -19,7 +19,7 @@ class PlacementPolicy(ABC): self.mem_stats_collector: Optional[MemStatsCollectorV2] = mem_stats_collector @abstractmethod - def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: + def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> None: raise NotImplementedError @staticmethod @@ -32,12 +32,12 @@ class CPUPlacementPolicy(PlacementPolicy): def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) - def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: + def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int: volume = 0 start = time() for chunk in can_evict_chunks: - self.chunk_manager.move_chunk(chunk, torch.device('cpu')) - volume += chunk.shard_mem + self.chunk_manager.move_chunk(chunk, torch.device('cpu'), update_ptr=False) + volume += chunk.mem return volume, time() - start @@ -47,7 +47,7 @@ class CUDAPlacementPolicy(PlacementPolicy): assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available' super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) - def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: + def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int: return 0, 0 @staticmethod @@ -59,8 +59,7 @@ class AutoPlacementPolicy(PlacementPolicy): need_mem_stats: bool = True # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase - # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() - # and AutoPlacementPolicy.set_steady_cuda_cap_ratio() + # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() and AutoPlacementPolicy.set_steady_cuda_cap_ratio() _warmup_non_model_data_ratio: float = 0.8 _steady_cuda_cap_ratio: float = 0.9 @@ -71,14 +70,14 @@ class AutoPlacementPolicy(PlacementPolicy): can_evict_chunks: List[Chunk], cuda_demand: int = 0, warmup: bool = True, - compute_list: Optional[List[Tuple[Chunk, ...]]] = None, + compute_list: List[Tuple[Chunk, ...]] = [], compute_idx: int = 0, - **kwargs) -> Tuple[int, float]: + **kwargs) -> int: """ Evict tensors from CUDA device. Args: - can_evict_chunks (List[StatefulTensor]): the list of tensors that can be evicted. + hold_cuda_tensor_list (List[StatefulTensor]): the list of tensor in state of HOLD-like cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0. warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True. compute_list (List[StatefulTensor], optional): TODO. Defaults to []. @@ -115,12 +114,12 @@ class AutoPlacementPolicy(PlacementPolicy): for chunk in to_free_chunks: if freed_cuda_model_data >= to_free_cuda_model_data: break - - self.chunk_manager.move_chunk(chunk, torch.device('cpu')) - freed_cuda_model_data += chunk.shard_mem + freed_cuda_model_data += chunk.mem + self.chunk_manager.move_chunk(chunk, torch.device('cpu'), update_ptr=False) if freed_cuda_model_data < to_free_cuda_model_data: - raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! " - f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}") + raise RuntimeError( + f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" + ) return freed_cuda_model_data, time() - start @staticmethod @@ -148,7 +147,7 @@ class AutoPlacementPolicy(PlacementPolicy): class PlacementPolicyFactory: - policies: Dict[str, Type[PlacementPolicy]] = { + policies: Dict[str, PlacementPolicy] = { 'cpu': CPUPlacementPolicy, 'cuda': CUDAPlacementPolicy, 'auto': AutoPlacementPolicy diff --git a/colossalai/gemini/stateful_tensor_container.py b/colossalai/gemini/stateful_tensor_container.py new file mode 100644 index 000000000..c82113028 --- /dev/null +++ b/colossalai/gemini/stateful_tensor_container.py @@ -0,0 +1,131 @@ +import queue +import heapq +from abc import ABC, abstractmethod +from typing import Optional, List, Dict +from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState + + +def evict_check(st: StatefulTensor) -> bool: + if st.state is not TensorState.COMPUTE and st.device.type == 'cuda': + return True + return False + + +# Here ST means Stateful Tensor +class BaseSTContainer(ABC): + """A type of container that store all potential stateful tensors which can be evicted from + CUDA. This kind of stateful tensor should satisfy two conditions. One is that it hasn't been + evicted, meaning the type of its device is CUDA, the other is that it isn't pinned in CUDA + memory, meaning its state isn't COMPUTE. + + This container should get a stateful tensor when it become HOLD_LIKE from COMPUTE. + And it pops stateful tensors in function, `evict_tensors`. + + In order to acquire an optimal eviction policy, users may need to offer computation step + index of each stateful tensor. So we can use a heap to maintain all potential evictable + statefule tensors. When poping, we can get the stateful tensor that used furthest in + current computation step. + """ + + def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int): + self.compute_step_dict = compute_step_dict + self.total_step = total_step + + @abstractmethod + def empty(self) -> bool: + pass + + @abstractmethod + def create(self, stateful_tensor_list: List[StatefulTensor]) -> None: + pass + + @abstractmethod + def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None: + pass + + @abstractmethod + def pop(self) -> Optional[StatefulTensor]: + pass + + +class QueueSTContainer(BaseSTContainer): + """Queue type stateful tensor container. This is used in 'cpu' tensor placement policy. + It pops potential evictable stateful tensors in FIFO. + """ + + def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int): + super().__init__(compute_step_dict, total_step) + self.container = None + + def empty(self) -> bool: + assert self.container is not None + return self.container.empty() + + def create(self, stateful_tensor_list: List[StatefulTensor]) -> None: + self.container = queue.SimpleQueue() + for stateful_tensor in stateful_tensor_list: + self.container.put(stateful_tensor) + + def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None: + self.container.put(stateful_tensor) + + def pop(self) -> Optional[StatefulTensor]: + ret = None + while not self.empty(): + out_tensor = self.container.get() + if evict_check(out_tensor): + ret = out_tensor + break + + return ret + + +class HeapSTContainer(BaseSTContainer): + """Heap type stateful tensor container. This is used in 'auto' tensor placement policy. + It pops potential evictable stateful tensors in the order of the distance between current + step and next used step. + """ + + def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int): + super().__init__(compute_step_dict, total_step) + self.container = None + + def empty(self) -> bool: + assert self.container is not None + return self.container == [] + + def create(self, stateful_tensor_list: List[StatefulTensor]) -> None: + self.container = [] + for stateful_tensor in stateful_tensor_list: + # we want to pop the tensor which has the greatest next_step + # so the weight is next_step multiplied by -1 + weight = -self.__get_next_compute_step(stateful_tensor, -1) + self.container.append((weight, stateful_tensor)) + heapq.heapify(self.container) + + def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None: + # we want to pop the tensor which has the greatest next_step + # so the weight is next_step multiplied by -1 + weight = -self.__get_next_compute_step(stateful_tensor, cur_step) + heapq.heappush(self.container, (weight, stateful_tensor)) + + def pop(self) -> Optional[StatefulTensor]: + ret = None + while not self.empty(): + _, out_tensor = heapq.heappop(self.container) + if evict_check(out_tensor): + ret = out_tensor + break + return ret + + def __get_next_compute_step(self, stateful_tensor: StatefulTensor, cur_step: int): + # compute the id of next step + # if the tensor is not used in the furture + # next_step is set to the maximum + next_step = self.total_step + step_list = self.compute_step_dict[stateful_tensor] + for step in step_list: + if step > cur_step: + next_step = step + break + return next_step diff --git a/colossalai/gemini/update/__init__.py b/colossalai/gemini/update/__init__.py new file mode 100644 index 000000000..20e3abccb --- /dev/null +++ b/colossalai/gemini/update/__init__.py @@ -0,0 +1,3 @@ +from .chunkv2 import ChunkV2 +from .chunk_mgrv2 import ChunkManagerV2 +from .search_utils import clasify_params, search_chunk_configuration diff --git a/colossalai/gemini/chunk/manager.py b/colossalai/gemini/update/chunk_mgrv2.py similarity index 80% rename from colossalai/gemini/chunk/manager.py rename to colossalai/gemini/update/chunk_mgrv2.py index 2d75dcce5..d6cd0745c 100644 --- a/colossalai/gemini/chunk/manager.py +++ b/colossalai/gemini/update/chunk_mgrv2.py @@ -4,19 +4,23 @@ from collections import deque from colossalai.utils import get_current_device from colossalai.tensor import ColoTensor -from colossalai.gemini.chunk import ChunkFullError, TensorState, Chunk +from colossalai.gemini.chunk import ChunkFullError, TensorState +from colossalai.gemini.update import ChunkV2 as Chunk -class ChunkManager: +class ChunkManagerV2: """ A manager class to manipulate the tensors in chunks. Args: chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager. init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. + pin_memory (bool): if ture, all chunks have a piece of pinned memory in CPU. """ - def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None: + def __init__(self, chunk_configuration: Dict[int, Dict], + init_device: Optional[torch.device] = None, + pin_memory: bool = False) -> None: self.device = init_device or get_current_device() self.size_config: Dict[int, int] = dict() @@ -24,6 +28,7 @@ class ChunkManager: for k, v in self.kwargs_config.items(): self.size_config[k] = v.pop('chunk_size') v['init_device'] = self.device + v['pin_memory'] = pin_memory self.chunk_groups: Dict[str, Deque] = dict() self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict() @@ -31,14 +36,8 @@ class ChunkManager: self.lazy_release_tensors: List[torch.Tensor] = list() self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} - def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int, pin_memory: bool = False) -> None: + def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int) -> None: """Append a tensor to a chunk. - - Args: - tensor: the tensor appended to the chunk - group_type: the data type of the group - config_key: the key of the group's name, usually the size of the dp world - pin_memory: whether the chunk is pinned in the cpu memory """ assert tensor not in self.tensor_chunk_map assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager" @@ -67,8 +66,7 @@ class ChunkManager: chunk_size=chunk_size, process_group=tensor.process_group, dtype=tensor.dtype, - pin_memory=pin_memory, - **chunk_kwargs, + **chunk_kwargs ) chunk_group.append(chunk) @@ -89,8 +87,6 @@ class ChunkManager: if chunk in self.accessed_chunks: return self.__sub_memroy_usage(chunk.memory_usage) - if chunk.device_type == 'cpu': - chunk.shard_move(get_current_device()) chunk.access_chunk() self.__add_memory_usage(chunk.memory_usage) self.accessed_chunks.add(chunk) @@ -106,13 +102,13 @@ class ChunkManager: self.__add_memory_usage(chunk.memory_usage) self.accessed_chunks.remove(chunk) - def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None: + def move_chunk(self, chunk: Chunk, device: torch.device) -> None: """Move the shard of the chunk to the target device. """ if not chunk.can_move or chunk.device_type == device.type: return self.__sub_memroy_usage(chunk.memory_usage) - chunk.shard_move(device, force_copy) + chunk.shard_move(device) self.__add_memory_usage(chunk.memory_usage) def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: @@ -127,7 +123,7 @@ class ChunkManager: if not chunk.can_reduce: return False self.__sub_memroy_usage(chunk.memory_usage) - chunk.reduce() + chunk.release_chunk() self.__add_memory_usage(chunk.memory_usage) return True @@ -169,14 +165,14 @@ class ChunkManager: self.release_chunk(chunk) self.lazy_release_tensors.clear() - def get_cuda_movable_chunks(self, group_type: str) -> List[Chunk]: - chunk_list = [] - for group_name in self.chunk_groups: - if group_type in group_name: - for chunk in self.chunk_groups[group_name]: - if chunk.device_type == 'cuda' and chunk.can_move: - chunk_list.append(chunk) - return chunk_list + def __repr__(self) -> str: + msg = ['Chunk Manager Information:\n', + 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'] + for group_name, group in self.chunk_groups.items(): + msg.append(f'Group {group_name}:\n') + for i, chunk in enumerate(group): + msg.append(f'[{i}] {chunk}\n') + return ''.join(msg) def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]: """ @@ -204,17 +200,6 @@ class ChunkManager: assert tensor not in self.tensor_chunk_map self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() - def __repr__(self) -> str: - msg = [ - 'Chunk Manager Information:\n', - 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' - ] - for group_name, group in self.chunk_groups.items(): - msg.append(f'Group {group_name}:\n') - for i, chunk in enumerate(group): - msg.append(f'[{i}] {chunk}\n') - return ''.join(msg) - def __get_chunk_group(self, group_name: str) -> Deque: """Register a chunk group. """ @@ -223,9 +208,8 @@ class ChunkManager: return self.chunk_groups[group_name] def __close_one_chunk(self, chunk: Chunk): - device = get_current_device() if chunk.keep_gathered else self.device # keep gathered chunk in cuda self.__sub_memroy_usage(chunk.memory_usage) - chunk.close_chunk(device) + chunk.close_chunk(self.device) self.__add_memory_usage(chunk.memory_usage) def __sub_memroy_usage(self, usage: Dict[str, int]): diff --git a/colossalai/gemini/chunk/chunk.py b/colossalai/gemini/update/chunkv2.py similarity index 75% rename from colossalai/gemini/chunk/chunk.py rename to colossalai/gemini/update/chunkv2.py index 610e83ce5..25f7858ea 100644 --- a/colossalai/gemini/chunk/chunk.py +++ b/colossalai/gemini/update/chunkv2.py @@ -1,55 +1,14 @@ import torch import torch.distributed as dist -from dataclasses import dataclass -from enum import Enum from typing import Optional, Dict, List from colossalai.utils import get_current_device from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkFullError, \ + free_storage, alloc_storage -class TensorState(Enum): - FREE = 0 - COMPUTE = 1 - HOLD = 2 - HOLD_AFTER_BWD = 3 - READY_FOR_REDUCE = 4 - - -STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), - (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), - (TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), - (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), - (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, - TensorState.HOLD)) - - -@dataclass -class TensorInfo: - state: TensorState - offset: int - end: int - - -class ChunkFullError(Exception): - pass - - -def is_storage_empty(tensor: torch.Tensor) -> bool: - return tensor.storage().size() == 0 - - -def free_storage(tensor: torch.Tensor) -> None: - if not is_storage_empty(tensor): - tensor.storage().resize_(0) - - -def alloc_storage(tensor: torch.Tensor) -> None: - if is_storage_empty(tensor): - tensor.storage().resize_(tensor.numel()) - - -class Chunk: +class ChunkV2: def __init__(self, chunk_size: int, @@ -60,18 +19,18 @@ class Chunk: pin_memory: bool = False) -> None: """ Chunk: A container owning a piece of contiguous memory space for tensors - Here we use all-gather operation to gather the whole chunk. - Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters. + AgChunk is a kind of chunk, which uses all-gather operation to gather the whole chunk. + This kind of chunk is exclusively used for DDP and ZeRO DDP. It is designed to make the full use of communication and PCIE bandwidth. Args: - chunk_size (int): the number of elements in the chunk + chunk_size (int): the number of elements in a chunk process_group (ColoProcessGroup): the process group of this chunk dtype (torch.dtype): the data type of the chunk init_device (torch.device): optional, the device where the tensor is initialized The default value is None, which is the current GPU keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory - pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory + pin_memory (bool): optional, if True, this chunk always has a shard copy in pinned CPU memory """ self.chunk_size = chunk_size @@ -83,8 +42,7 @@ class Chunk: self.pg_rank = dist.get_rank(self.torch_pg) # the chunk size should be able to be divied by the size of GPU - if not keep_gathered: - assert chunk_size % self.pg_size == 0 + assert chunk_size % self.pg_size == 0 self.shard_size = chunk_size // self.pg_size self.shard_begin = self.shard_size * self.pg_rank self.shard_end = self.shard_begin + self.shard_size @@ -122,15 +80,18 @@ class Chunk: # we introduce the paired chunk here # it refers to another chunk having the same parameters - # but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk + # but with different dtype(such as fp16_chunk.mapping_chunk -> fp32_chunk self.paired_chunk = None + # if the the gradient of this chunk is reduced, the flag is True + # so the flag is False for unused parameters + self.grad_reduced_flag = False # if this chunk is synchronized with the optimizer, the flag is True self.optim_sync_flag = True # if the cpu_shard has been visited during the training step, the flag is True self.cpu_vis_flag = False @property - def memory_usage(self) -> Dict[str, int]: + def memory_usage(self): cuda_memory = 0 cpu_memory = 0 @@ -151,7 +112,7 @@ class Chunk: return dict(cuda=cuda_memory, cpu=cpu_memory) @property - def device_type(self) -> str: + def device_type(self): if self.chunk_temp is not None: return self.chunk_temp.device.type else: @@ -162,56 +123,6 @@ class Chunk: else: return 'cpu' - @property - def payload(self) -> torch.Tensor: - # sanity check - assert self.chunk_temp is None - - if self.is_gathered: - return self.chunk_total - elif self.cuda_shard is not None: - return self.cuda_shard - else: - return self.cpu_shard - - @property - def payload_mem(self) -> int: - # sanity check - assert self.chunk_temp is None - - if self.is_gathered: - return self.chunk_mem - else: - return self.shard_mem - - @property - def can_move(self) -> bool: - return not self.is_gathered - - @property - def can_release(self) -> bool: - if self.keep_gathered: - return False - else: - return self.tensors_state_monitor[TensorState.HOLD] + \ - self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors - - @property - def can_reduce(self): - return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors - - @property - def has_inf_or_nan(self) -> bool: - """Check if the chunk has inf or nan values in CUDA. - """ - if self.is_gathered: - valid_tensor = self.chunk_total[:self.utilized_size] - else: - assert self.cuda_shard is not None # only check in CUDA - valid_tensor = self.cuda_shard[:self.valid_end] - - return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() - def append_tensor(self, tensor: torch.Tensor): """Add a tensor to the chunk. @@ -239,10 +150,7 @@ class Chunk: self.utilized_size = new_utilized_size def close_chunk(self, shard_dev: Optional[torch.device] = None): - """Close the chunk. Any tensor can't be appended to a closed chunk later. - - Args: - shard_dev: the device where the shard locates + """Close the chunk. Any tensor can't be appended to a closed chunk. """ # sanity check assert self.chunk_temp is not None @@ -255,7 +163,6 @@ class Chunk: if self.chunk_temp.device.type == 'cpu': self.chunk_total = self.chunk_temp.to(get_current_device()) - self.__update_tensors_ptr() else: self.chunk_total = self.chunk_temp self.chunk_temp = None @@ -279,12 +186,6 @@ class Chunk: self.cuda_shard = None def shard_move(self, device: torch.device, force_copy: bool = False): - """Move the shard tensor in the chunk. - - Args: - device: the device to which the shard will move - force_copy: if True, copy function is called mandatorily - """ # sanity check assert not self.is_gathered # when the current chunk is not synchronized with the optimizer @@ -322,7 +223,8 @@ class Chunk: raise NotImplementedError def access_chunk(self): - """Make the chunk usable for the parameters inside it. It's an operation done in CUDA. + """Make the chunk usable for the parameters inside it. + It is an operation done in CUDA. """ # sanity check assert self.chunk_temp is None @@ -332,7 +234,8 @@ class Chunk: self.__update_tensors_ptr() def release_chunk(self): - """Release the usable chunk. It's an operation done in CUDA. + """Release the usable chunk. + It is an operation done in CUDA. """ # sanity check assert self.chunk_temp is None @@ -341,7 +244,8 @@ class Chunk: self.__scatter() def reduce(self): - """Reduce scatter all the gradients. It's an operation done in CUDA. + """Reduce scatter all the gradients. + It is an operation done in CUDA. """ # sanity check assert self.is_gathered @@ -363,6 +267,7 @@ class Chunk: free_storage(self.chunk_total) self.is_gathered = False self.__update_tensors_state(TensorState.HOLD) + self.grad_reduced_flag = True def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: """ @@ -380,6 +285,9 @@ class Chunk: # this function only apply valid state transformation # invalid calls will be ignored and nothing changes if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: + # print( + # f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}' + # ) return self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state) @@ -398,56 +306,46 @@ class Chunk: self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten()) tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape) - def get_valid_length(self) -> int: - """Get the valid length of the chunk's payload. - """ + @property + def can_move(self) -> bool: + return not self.is_gathered + + @property + def can_release(self) -> bool: if self.keep_gathered: - return self.utilized_size + return False else: - return self.valid_end + return self.tensors_state_monitor[TensorState.HOLD] + \ + self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors - def init_pair(self, friend_chunk: 'Chunk') -> None: - """Initialize the paired chunk. + @property + def can_reduce(self): + return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors + + @property + def has_inf_or_nan(self) -> bool: """ - if self.paired_chunk is None and friend_chunk.paired_chunk is None: - self.paired_chunk = friend_chunk - friend_chunk.paired_chunk = self - else: - assert self.paired_chunk is friend_chunk - assert friend_chunk.paired_chunk is self - - def optim_update(self) -> None: - """Update the fp16 chunks via their fp32 chunks. It's used by the optimizer. + Check if the chunk has inf or nan values in CUDA. """ - # sanity check - assert self.paired_chunk is not None - - friend_chunk = self.paired_chunk - if self.is_gathered is True: - assert friend_chunk.is_gathered is True - self.chunk_total.copy_(friend_chunk.chunk_total) - self.optim_sync_flag = True - elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda': - self.cuda_shard.copy_(friend_chunk.cuda_shard) - self.optim_sync_flag = True - self.cpu_vis_flag = False + if self.is_gathered: + valid_tensor = self.chunk_total[:self.utilized_size] else: - assert friend_chunk.device_type == 'cpu' - assert self.device_type == 'cpu' - self.optim_sync_flag = False - self.cpu_vis_flag = False + assert self.cuda_shard is not None # only check in CUDA + valid_tensor = self.cuda_shard[:self.valid_end] - def get_tensors(self) -> List[torch.Tensor]: - return list(self.tensors_info.keys()) + return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() def __gather(self): if not self.is_gathered: # sanity check assert self.cuda_shard is not None - alloc_storage(self.chunk_total) - gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0)) - dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) + if self.pg_size == 1: + self.chunk_total = self.cuda_shard + else: + alloc_storage(self.chunk_total) + gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0)) + dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) self.cuda_shard = None self.is_gathered = True @@ -506,9 +404,9 @@ class Chunk: def __eq__(self, __o: object) -> bool: return self is __o - def __repr__(self, detailed: bool = True): + def __repr__(self, detailed: bool = False): output = [ - "Chunk Information:\n", + "AgChunk Information:\n", "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype, self.pg_size), "\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format( @@ -544,3 +442,6 @@ class Chunk: output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st])) return ''.join(output) + + def get_tensors(self) -> List[torch.Tensor]: + return list(self.tensors_info.keys()) diff --git a/colossalai/gemini/chunk/search_utils.py b/colossalai/gemini/update/search_utils.py similarity index 77% rename from colossalai/gemini/chunk/search_utils.py rename to colossalai/gemini/update/search_utils.py index f309872a4..fdbbf0817 100644 --- a/colossalai/gemini/chunk/search_utils.py +++ b/colossalai/gemini/update/search_utils.py @@ -1,4 +1,3 @@ -import math from typing import Dict, List import numpy as np import torch.nn as nn @@ -8,7 +7,7 @@ from colossalai.tensor import ColoParameter def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: """Filter those parameters whose size is too large from others. """ - params_size = [p.numel() for p in model.parameters() if not getattr(p, '_ddp_to_ignore', False)] + params_size = [p.numel() for p in model.parameters()] params_size_arr = np.array(params_size) std = np.std(params_size_arr) @@ -37,9 +36,6 @@ def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]: params_dict: Dict[int, List[ColoParameter]] = dict() for param in model.parameters(): assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" - if getattr(param, '_ddp_to_ignore', False): - continue - param_key = param.process_group.dp_world_size() if param_key not in params_dict: @@ -51,13 +47,13 @@ def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]: def search_chunk_configuration( model: nn.Module, - search_range_mb: float, + search_range_mb: int, search_interval_byte: int, # hidden size is the best value for the interval - min_chunk_size_mb: float = 32, - filter_exlarge_params: bool = True) -> Dict: - search_range_byte = round(search_range_mb * 1024**2) - min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) - assert search_range_byte >= 0 + min_chunk_size_mb: int = 32, + filter_exlarge_params: bool = True): + search_range_byte = search_range_mb * 1024**2 + min_chunk_size_byte = min_chunk_size_mb * 1024**2 + assert search_range_byte % search_interval_byte == 0 params_dict = clasify_params(model) config_dict: Dict[int, Dict] = dict() @@ -79,12 +75,11 @@ def search_chunk_configuration( max_size = min_chunk_size_byte for key in size_dict: max_size = max(max_size, max(size_dict[key])) - start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte) min_chunk_waste = float('+inf') - best_chunk_size = start_size + best_chunk_size = max_size - for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): + for chunk_size in range(max_size, max_size + search_range_byte + 1, search_interval_byte): temp_waste = 0 for key in size_dict: temp_waste += _get_unused_byte(size_dict[key], chunk_size) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 18a7250af..378f186a8 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -3,18 +3,16 @@ import itertools import torch.distributed as dist from functools import partial from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2 +from colossalai.gemini.chunk import TensorState, Chunk from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.gemini.gemini_mgr import GeminiManager from typing import Dict, Iterable, List, Optional, Set from colossalai.logging import get_dist_logger from collections import OrderedDict -from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec +from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor import ProcessGroup as ColoProcessGroup from .reducer import Reducer -from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager -from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda - try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys except ImportError: @@ -210,34 +208,28 @@ class ZeroDDP(ColoDDP): def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager, - pin_memory: bool = False, force_outputs_fp32: bool = False) -> None: - super().__init__(module, process_group=ColoProcessGroup()) + super().__init__(module, process_group=gemini_manager.chunk_manager.process_group) self.gemini_manager = gemini_manager - self.chunk_manager: ChunkManager = gemini_manager.chunk_manager + self.chunk_manager = gemini_manager.chunk_manager self.force_outputs_fp32 = force_outputs_fp32 self.param_op_hook = ZeROHookV2(gemini_manager) - self.fp32_params: List[ColoTensor] = [] + self.fp32_params: List[ColoParameter] = [] self.overflow_counter = 0 self.grads_device: Dict[torch.Tensor, torch.device] = {} - + self.chunk_manager.create_group('fp16_param', force_data_on_cuda=True) + self.chunk_manager.create_group('fp32_param') # TODO: get param order and filter unused params for p in module.parameters(): - assert isinstance(p, ColoParameter) if getattr(p, '_ddp_to_ignore', False): p.data = p.half() continue - - dp_world_size = p.process_group.dp_world_size() - fp32_data = p.float().data + fp32_p = p.float().detach() p.data = p.half() - fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) - self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory) - self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory) + self.chunk_manager.append_tensor(p, 'fp16_param') + self.chunk_manager.append_tensor(fp32_p, 'fp32_param') self.fp32_params.append(fp32_p) self.grads_device[p] = self.gemini_manager.default_device - self.chunk_manager.close_all_groups() - self._cast_buffers() self._logger = get_dist_logger() @@ -256,7 +248,10 @@ class ZeroDDP(ColoDDP): for p in self.module.parameters(): if getattr(p, '_ddp_to_ignore', False): continue - p.grad = None + if self.chunk_manager.get_chunk(p).is_empty or not p.requires_grad: + p.grad = None + else: + p.grad = p.data def _post_backward(self): self.chunk_manager.exec_lazy_release() @@ -281,22 +276,21 @@ class ZeroDDP(ColoDDP): free_storage(empty_grad) with torch._C.DisableTorchFunction(): self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) + if self.dp_world_size > 1: + grad = grad / self.dp_world_size + self.chunk_manager.copy_tensor_to_chunk_slice(p, grad) chunk = self.chunk_manager.get_chunk(p) - chunk.copy_tensor_to_chunk_slice(p, grad) reduced = self.chunk_manager.reduce_chunk(chunk) - if reduced: - if chunk.is_gathered: - chunk.chunk_total.div_(chunk.pg_size) - else: - chunk.cuda_shard.div_(chunk.pg_size) + self.chunk_manager.release_chunk(chunk) + if reduced and not chunk.is_empty: self.overflow_counter += chunk.has_inf_or_nan - self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) + self.chunk_manager.move_chunk(chunk, self.grads_device[p]) return empty_grad def zero_grad(self, set_to_none: bool = False) -> None: self.module.zero_grad(set_to_none=True) - def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None: + def _set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None: for tensor in chunk.get_tensors(): self.grads_device[tensor] = device @@ -317,11 +311,14 @@ class ZeroDDP(ColoDDP): ['bias', 'weight'] """ + is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0 + record_flag = (not only_rank_0) or is_rank_0 + if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) - self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0) + self._save_to_state_dict(destination, prefix, keep_vars, record_flag) for hook in self._state_dict_hooks.values(): hook_result = hook(self, destination, prefix, local_metadata) @@ -329,7 +326,7 @@ class ZeroDDP(ColoDDP): destination = hook_result return destination - def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): + def _save_to_state_dict(self, destination, prefix, keep_vars, record_flag: bool = True): r"""Saves module state to `destination` dictionary, containing a state of the module, but not its descendants. This is called on every submodule in :meth:`~torch.nn.Module.state_dict`. @@ -342,30 +339,30 @@ class ZeroDDP(ColoDDP): prefix (str): the prefix for parameters and buffers used in this module """ - assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." - # save parameters param_to_save_data = dict() chunk_list = self.chunk_manager.get_chunks(self.fp32_params) for chunk in chunk_list: - temp_chunk = get_temp_total_chunk_on_cuda(chunk) + # record the original device of the chunk + org_chunk_dev_typ = chunk.device_type + self.chunk_manager.access_chunk(chunk) - for tensor, tensor_info in chunk.tensors_info.items(): - record_tensor = torch.empty([0]) - record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) + for tensor in chunk.get_tensors(): + rec_p = torch.empty([0]) if record_flag: - record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() - + rec_p = tensor.cpu() # move the whole tensor to CPU mem assert tensor not in param_to_save_data - param_to_save_data[tensor] = record_tensor - - del temp_chunk + param_to_save_data[tensor] = rec_p + # release the actual memory of the chunk + self.chunk_manager.release_chunk(chunk) + if not chunk.is_empty and org_chunk_dev_typ == 'cpu': + self.chunk_manager.move_chunk(chunk, torch.device('cpu')) for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): if p is not None: assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) - record_parameter = param_to_save_data[fp32_p] - destination[prefix + name] = record_parameter + rec_p = param_to_save_data[fp32_p] + destination[prefix + name] = rec_p if keep_vars else rec_p.detach() # save all buffers for name, buf in self.named_buffers(): @@ -469,61 +466,40 @@ class ZeroDDP(ColoDDP): local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items()) local_state = {k: v for k, v in local_name_params if v is not None} - def load(param_name, dest_tensor, copy_func): - state_key = prefix + param_name - if state_key in state_dict: - input_param = state_dict[state_key] + def load(name, dest_tensor, copy_func): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1: input_param = input_param[0] if input_param.shape != dest_tensor.shape: # local shape should match the one in checkpoint error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' - 'the shape in current model is {}.'.format(state_key, input_param.shape, + 'the shape in current model is {}.'.format(key, input_param.shape, dest_tensor.shape)) return try: with torch.no_grad(): + # self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, input_param) copy_func(input_param) except Exception as ex: error_msgs.append('While copying the parameter named "{}", ' 'whose dimensions in the model are {} and ' 'whose dimensions in the checkpoint are {}, ' - 'an exception occurred : {}.'.format(state_key, dest_tensor.size(), - input_param.size(), ex.args)) + 'an exception occurred : {}.'.format(key, dest_tensor.size(), input_param.size(), + ex.args)) elif strict: - missing_keys.append(state_key) + missing_keys.append(key) - def load_fp32_parameter(chunk_slice, data): - chunk_slice.copy_(data.flatten()) + def load_fp32_p(fp32_p, data): + if fp32_p.storage().size() > 0: + self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, data) - fp32_to_name = dict() for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): if p is not None: - fp32_to_name[fp32_p] = name - - chunk_list = self.chunk_manager.get_chunks(self.fp32_params) - for chunk in chunk_list: - temp_chunk = get_temp_total_chunk_on_cuda(chunk) - - for tensor, tensor_info in chunk.tensors_info.items(): - parameter_name = fp32_to_name[tensor] - parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end] - load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice)) - - if chunk.is_gathered: - chunk.chunk_total.copy_(temp_chunk) - elif chunk.cuda_shard is not None: - chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) - else: - chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) - - del temp_chunk - - for chunk_32 in chunk_list: - chunk_16 = chunk_32.paired_chunk - assert chunk_16 is not None - chunk_16.optim_update() + load(name, fp32_p, partial(load_fp32_p, fp32_p)) + self.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param') for name, buf in persistent_buffers.items(): if buf is not None: diff --git a/colossalai/nn/parallel/utils.py b/colossalai/nn/parallel/utils.py deleted file mode 100644 index 587339549..000000000 --- a/colossalai/nn/parallel/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch -import torch.distributed as dist -from colossalai.gemini.chunk import Chunk -from colossalai.utils import get_current_device - - -def get_temp_total_chunk_on_cuda(chunk: Chunk): - if chunk.is_gathered: - return chunk.chunk_total - - if chunk.cuda_shard is not None: - shard_temp = chunk.cuda_shard - else: - shard_temp = chunk.cpu_shard.to(get_current_device()) - - total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device()) - gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0)) - dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg) - - return total_temp diff --git a/colossalai/zero/utils/zero_hook_v2.py b/colossalai/zero/utils/zero_hook_v2.py index 3f3472f0e..af0187b4f 100644 --- a/colossalai/zero/utils/zero_hook_v2.py +++ b/colossalai/zero/utils/zero_hook_v2.py @@ -54,8 +54,8 @@ class ZeROHookV2(ParamOpHook): @contextmanager def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD): - old_training_phase = self._training_phase try: + old_training_phase = self._training_phase self._training_phase = training_phase yield finally: diff --git a/colossalai/zero/zero_optimizer.py b/colossalai/zero/zero_optimizer.py index 3dc7b322c..55b4d7ee9 100644 --- a/colossalai/zero/zero_optimizer.py +++ b/colossalai/zero/zero_optimizer.py @@ -2,14 +2,17 @@ import torch import torch.distributed as dist from enum import Enum from torch.optim import Optimizer -from torch.nn import Parameter from colossalai.nn.parallel.data_parallel import ZeroDDP -from typing import Dict, Tuple, Set +from typing import Dict from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.utils import get_current_device, disposable -from colossalai.gemini.chunk import Chunk, ChunkManager +from colossalai.utils.common import _compute_grad_lp, compute_grad_norm, _clip_grad_norm +from collections import defaultdict, abc as container_abcs +from copy import deepcopy +from itertools import chain +from torch._six import inf class OptimState(Enum): @@ -30,8 +33,8 @@ class ZeroOptimizer(ColossalaiOptimizer): Args: optim (Optimizer): An Optimizer instance. module (ZeroDDP): A ``ZeroDDP`` instance. - gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) - which will be used when using hybrid CPU optimizer. + gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) + which will be used when using hybrid CPU optimizer. This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto". Defaults to 0.0. initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. @@ -58,20 +61,11 @@ class ZeroOptimizer(ColossalaiOptimizer): assert isinstance(module, ZeroDDP) self.module = module self.gemini_manager = module.gemini_manager - self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager + self.chunk_manager = self.gemini_manager.chunk_manager self.optim_state = OptimState.UNSCALED - self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict() - self.param_to_chunk32: Dict[Parameter, Chunk] = dict() - self.chunk16_set: Set[Chunk] = set() - + self.fp16_param_to_fp32_param: Dict[torch.Tensor, torch.Tensor] = {} for p, fp32_p in zip(module.parameters(), module.fp32_params): - chunk_16 = self.chunk_manager.get_chunk(p) - chunk_32 = self.chunk_manager.get_chunk(fp32_p) - chunk_32.init_pair(chunk_16) - if chunk_16 not in self.chunk16_set: - self.chunk16_set.add(chunk_16) - - self.__init__optimizer() + self.fp16_param_to_fp32_param[p] = fp32_p # Grad scaler self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, @@ -81,7 +75,7 @@ class ZeroOptimizer(ColossalaiOptimizer): growth_interval=growth_interval, hysteresis=hysteresis, max_scale=max_scale) - self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device()) self._logger = get_dist_logger() self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) @@ -96,26 +90,16 @@ class ZeroOptimizer(ColossalaiOptimizer): self._register_states = disposable(self._register_states_) - def _set_grad_ptr(self): - for group in self.param_groups: - for fake_param in group['params']: - chunk32 = self.param_to_chunk32[fake_param] - begin, end = self.param_to_range[fake_param] - chunk16 = chunk32.paired_chunk - - fake_param.data = chunk16.payload[begin:end] - fake_param.grad = fake_param.data - fake_param.data = chunk32.payload[begin:end] + def _update_params_ptr(self): + for group in self.optim.param_groups: + for p in group['params']: + if not self.module.chunk_manager.get_chunk(p).is_empty: + p.data = self.fp16_param_to_fp32_param[p] + else: + assert p.grad is None def _update_fp16_params(self): - none_tensor = torch.empty([0]) - for group in self.param_groups: - for fake_param in group['params']: - assert fake_param.grad is None - fake_param.data = none_tensor - - for chunk16 in self.chunk16_set: - chunk16.optim_update() + self.module.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param') def _check_overflow(self): # clear previous overflow record @@ -144,7 +128,6 @@ class ZeroOptimizer(ColossalaiOptimizer): def step(self, *args, **kwargs): self._maybe_move_fp32_params() - self._set_grad_ptr() # unscale grads if scaled if self.optim_state == OptimState.SCALED: self._unscale_grads() @@ -155,14 +138,45 @@ class ZeroOptimizer(ColossalaiOptimizer): self.zero_grad() self._update_fp16_params() return + self._update_params_ptr() ret = self.optim.step(*args, **kwargs) self._register_states() self.zero_grad() self._update_fp16_params() return ret + def compute_grad_norm(self, norm_type: float = 2.0) -> float: + norm_type = float(norm_type) + if not self.chunk_manager.enable_distributed_storage: + return compute_grad_norm(self.module.parameters(), norm_type) + + non_distributed_params = [] + distributed_params = [] + for p in self.module.parameters(): + if getattr(p, '_ddp_to_ignore', False): + non_distributed_params.append(p) + else: + distributed_params.append(p) + non_distributed_norm = _compute_grad_lp(non_distributed_params, norm_type) + distributed_norm_tensor = torch.tensor([_compute_grad_lp(distributed_params, norm_type)], + device=get_current_device()) + if norm_type == inf: + dist.all_reduce(distributed_norm_tensor, + op=dist.ReduceOp.MAX, + group=self.chunk_manager.process_group.dp_process_group()) + total_norm = max(non_distributed_norm, distributed_norm_tensor.item()) + else: + dist.all_reduce(distributed_norm_tensor, group=self.chunk_manager.process_group.dp_process_group()) + total_norm = non_distributed_norm + distributed_norm_tensor.item() + total_norm = total_norm**(1 / norm_type) + return total_norm + def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0): - raise NotImplementedError + if self.optim_state == OptimState.SCALED: + self._unscale_grads() + total_norm = self.compute_grad_norm(norm_type) + _clip_grad_norm(self.module.parameters(), max_norm, total_norm) + return total_norm def backward(self, loss: torch.Tensor): loss = self.loss_scale * loss @@ -183,31 +197,24 @@ class ZeroOptimizer(ColossalaiOptimizer): available_cuda_margin_mem = self.gemini_manager.cuda_margin_mem * self.gpu_margin_mem_ratio fp32_params_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param fp32_params_used_cuda_margin_mem = 0 - - for group in self.param_groups: - for fake_param in group['params']: - chunk32 = self.param_to_chunk32[fake_param] - chunk16 = chunk32.paired_chunk - - if chunk32.device_type == 'cuda': - continue - - if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem: - self.chunk_manager.move_chunk(chunk32, get_current_device()) - # stores grad now - self.chunk_manager.move_chunk(chunk16, get_current_device()) - self.module.set_chunk_grad_device(chunk16, get_current_device()) - fp32_params_used_cuda_margin_mem += chunk32.payload_mem - - for group in self.param_groups: - for fake_param in group['params']: - chunk32 = self.param_to_chunk32[fake_param] - if chunk32.device_type == 'cuda': - state = self.optim.state[fake_param] + for fp16_param_chunk, fp32_param_chunk in zip(self.chunk_manager.chunk_groups['fp16_param'], + self.chunk_manager.chunk_groups['fp32_param']): + if fp32_param_chunk.is_empty: + continue + if fp32_params_used_cuda_margin_mem + fp32_param_chunk.mem < fp32_params_available_cuda_margin_mem: + self.chunk_manager.move_chunk(fp32_param_chunk, get_current_device()) + # stores grad now + self.chunk_manager.move_chunk(fp16_param_chunk, get_current_device()) + self.module._set_chunk_grad_device(fp16_param_chunk, get_current_device()) + fp32_params_used_cuda_margin_mem += fp32_param_chunk.mem + for p in fp16_param_chunk.get_tensors(): + state = self.optim.state[p] for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(get_current_device()) + self.module._setup_grads_ptr() + def _register_states_(self): for group in self.optim.param_groups: for p in group['params']: @@ -216,27 +223,110 @@ class ZeroOptimizer(ColossalaiOptimizer): if isinstance(val, torch.Tensor): self.chunk_manager.add_extern_static_tensor(val) - def __init__optimizer(self): + def state_dict(self, only_rank_0: bool = True): + r"""Returns the state of the optimizer as a :class:`dict`. If only_rank_0 is True, for DP rank != 0, this function returns None. + This saves memory usage. - def get_range_pair(local_chunk: Chunk, local_param: Parameter): - param_info = local_chunk.tensors_info[local_param] - begin = max(0, param_info.offset - local_chunk.shard_begin) - end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin) - return begin, end + It contains two entries: - for group in self.optim.param_groups: - fake_params_list = list() + * state - a dict holding current optimization state. Its content + differs between optimizer classes. + * param_groups - a list containing all parameter groups where each + parameter group is a dict + """ + is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0 + if not self.chunk_manager.enable_distributed_storage and only_rank_0 and not is_rank_0: + return + optim_state_dict = super().state_dict() + scaler_state_dict = self.grad_scaler.state_dict() + optim_state_dict['scaler'] = scaler_state_dict + if not self.chunk_manager.enable_distributed_storage: + return optim_state_dict + local_state = {k: convert_state_dict_to_cpu(v) for k, v in optim_state_dict['state'].items() if len(v) > 0} + if not self.chunk_manager.process_group.has_cpu_groups: + self.chunk_manager.process_group.set_cpu_groups() + output = [None for _ in range(self.chunk_manager.process_group.dp_world_size())] + if only_rank_0: + dst_rank = self.chunk_manager.process_group.dp_rank_list()[0] + dist.gather_object(local_state, + output if self.chunk_manager.process_group.dp_local_rank() == 0 else None, + dst=dst_rank, + group=self.chunk_manager.process_group.cpu_dp_process_group()) + if not is_rank_0: + return + else: + dist.all_gather_object(output, local_state, group=self.chunk_manager.process_group.cpu_dp_process_group()) + for state in output: + optim_state_dict['state'].update(state) + return optim_state_dict - for param in group['params']: - chunk16 = self.chunk_manager.get_chunk(param) - range_pair = get_range_pair(chunk16, param) - if range_pair[0] >= range_pair[1]: - continue + def load_state_dict(self, state_dict): + r"""Loads the optimizer state. - fake_param = torch.nn.Parameter(torch.empty([0])) - self.param_to_chunk32[fake_param] = chunk16.paired_chunk - self.param_to_range[fake_param] = range_pair + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + if 'scaler' not in state_dict: + self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0]) + else: + self.grad_scaler.load_state_dict(deepcopy(state_dict['scaler'])) - fake_params_list.append(fake_param) + # Validate the state_dict + groups = self.param_groups + saved_groups = deepcopy(state_dict['param_groups']) - group['params'] = fake_params_list + if len(groups) != len(saved_groups): + raise ValueError("loaded state dict has a different number of " + "parameter groups") + param_lens = (len(g['params']) for g in groups) + saved_lens = (len(g['params']) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError("loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group") + + # Update the state + id_map = { + old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups + )), chain.from_iterable((g['params'] for g in groups))) + } + + def cast(param, value): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, torch.Tensor): + # Floating-point types are a bit special here. They are the only ones + # that are assumed to always match the type of params. + if param.is_floating_point(): + value = value.to(param.dtype) + value = value.to(param.device) + return value + elif isinstance(value, dict): + return {k: cast(param, v) for k, v in value.items()} + elif isinstance(value, container_abcs.Iterable): + return type(value)(cast(param, v) for v in value) + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state = defaultdict(dict) + for k, v in state_dict['state'].items(): + if k in id_map: + param = self.fp16_param_to_fp32_param[id_map[k]] + if param.storage().size() > 0: + state[param] = cast(param, deepcopy(v)) + else: + state[k] = deepcopy(v) + + # Update parameter groups, setting their 'params' value + def update_group(group, new_group): + new_group['params'] = group['params'] + return new_group + + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({'state': state, 'param_groups': param_groups}) + + +def convert_state_dict_to_cpu(state: Dict[str, torch.Tensor]): + return {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in state.items()} diff --git a/tests/test_ddp/test_ddp_ignore_params.py b/tests/test_ddp/test_ddp_ignore_params.py index d98018adf..8789c18a6 100644 --- a/tests/test_ddp/test_ddp_ignore_params.py +++ b/tests/test_ddp/test_ddp_ignore_params.py @@ -6,11 +6,11 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini import ChunkManager from functools import partial from colossalai.nn.parallel import ColoDDP, ZeroDDP from colossalai.gemini.gemini_mgr import GeminiManager -from typing import Callable, Type +from typing import Callable import torch.distributed as dist import os import random @@ -32,9 +32,10 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP: return ColoDDP(module, process_group=pg) -def init_ddpv2(module: torch.nn.Module) -> ZeroDDP: - chunk_config = search_chunk_configuration(module, 4, 1024) - chunk_manager = ChunkManager(chunk_config) +def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP: + pg = ProcessGroup() + chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None + chunk_manager = ChunkManager(chunk_size, pg) gemini_manager = GeminiManager('cuda', chunk_manager) return ZeroDDP(module, gemini_manager) @@ -50,7 +51,7 @@ class Net(torch.nn.Module): return self.fc2(self.fc1(x)) -def run_fwd_bwd(ddp_cls: Type[ColoDDP], init_ddp_func: Callable[[torch.nn.Module], ColoDDP]): +def run_fwd_bwd(ddp_cls: ColoDDP, init_ddp_func: Callable[[torch.nn.Module], ColoDDP]): with ColoInitContext(device=get_current_device()): model = Net().cuda() w1 = model.fc1.weight @@ -61,14 +62,8 @@ def run_fwd_bwd(ddp_cls: Type[ColoDDP], init_ddp_func: Callable[[torch.nn.Module logits = model(x) loss = torch.sum(logits) model.backward(loss) - - if ddp_cls is ZeroDDP: - w1s_grad = w1 - else: - w1s_grad = w1.grad - w1_grads = [torch.empty_like(w1) for _ in range(dist.get_world_size())] - dist.all_gather(w1_grads, w1s_grad) + dist.all_gather(w1_grads, w1.grad) assert torch.equal(w1_grads[0], w1_grads[1]) w2_grads = [torch.empty_like(w2) for _ in range(dist.get_world_size())] dist.all_gather(w2_grads, w2.grad) @@ -79,7 +74,8 @@ def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') set_seed(dist.get_rank()) run_fwd_bwd(ColoDDP, init_ddp) - run_fwd_bwd(ZeroDDP, init_ddpv2) + run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=False)) + run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=True)) @pytest.mark.dist diff --git a/tests/test_ddp/test_ddp_state_dict.py b/tests/test_ddp/test_ddp_state_dict.py index f229364c6..c13f7a72c 100644 --- a/tests/test_ddp/test_ddp_state_dict.py +++ b/tests/test_ddp/test_ddp_state_dict.py @@ -8,11 +8,14 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.gemini import ChunkManager from functools import partial from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.nn.parallel import ColoDDP +from colossalai.nn.parallel import ZeroDDP, ColoDDP +from colossalai.gemini.gemini_mgr import GeminiManager from collections import OrderedDict from colossalai.tensor import ProcessGroup, ColoParameter +from colossalai.testing import parameterize def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict): @@ -27,11 +30,42 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic assert torch.equal(t1, temp_t2), "\t{}\n\t{}".format(t1, temp_t2) +def check_model_equal(model_a, model_b, allow_empty: bool = False, same_dtype: bool = True): + for (na, pa), (nb, pb) in zip(model_a.named_parameters(), model_b.named_parameters()): + assert na == nb + + if not allow_empty: + assert pa.storage().size() > 0 + assert pb.storage().size() > 0 + else: + if pa.storage().size() == 0 or pb.storage().size() == 0: + continue + + if same_dtype: + assert pa.dtype == pb.dtype + temp_pb = pb + else: + temp_pb = pb.to(pa.dtype) + + assert torch.equal(pa, temp_pb), "Parameter '{}' is not equal.\n {} {}".format(na, pa, pb) + + def init_ddp(module: torch.nn.Module) -> ColoDDP: pg = ProcessGroup() return ColoDDP(module, process_group=pg) +def init_ddpv2(module: torch.nn.Module, + use_chunk: bool = False, + use_zero: bool = False, + placement_policy: str = 'cuda') -> ZeroDDP: + pg = ProcessGroup() + chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None + chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + return ZeroDDP(module, gemini_manager) + + def run_ddp_state_dict(): get_components_func = non_distributed_component_funcs.get_callable('gpt2') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -54,9 +88,44 @@ def run_ddp_state_dict(): check_state_dict_equal(torch_state_dict, state_dict) +@parameterize('use_chunk', [False, True]) +@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('use_zero', [False, True]) +@parameterize('only_rank_0', [False, True]) +def run_zero_state_dict(use_chunk, placement_policy, use_zero, only_rank_0): + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + torch_model = model_builder().cuda() + org_torch_model = copy.deepcopy(torch_model) + torch_state_dict = torch_model.state_dict() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + model = init_ddpv2(model, use_chunk, use_zero, placement_policy) + + for param in model.parameters(): + if isinstance(param, ColoParameter): + assert param.get_process_group() is not None + + model.load_state_dict(torch_state_dict, strict=False) + check_model_equal(model, torch_model, allow_empty=True, same_dtype=False) + + for param in model.parameters(): + if isinstance(param, ColoParameter): + assert param.get_process_group() is not None + + pg = ProcessGroup() + state_dict = model.state_dict(only_rank_0=only_rank_0) + if not only_rank_0 or pg.dp_local_rank() == 0: + torch_model.load_state_dict(state_dict, strict=False) + check_model_equal(torch_model, org_torch_model, allow_empty=False, same_dtype=True) + + def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_ddp_state_dict() + run_zero_state_dict() @pytest.mark.dist diff --git a/tests/test_gemini/test_stateful_tensor_container.py b/tests/test_gemini/test_stateful_tensor_container.py new file mode 100644 index 000000000..60ac2a69b --- /dev/null +++ b/tests/test_gemini/test_stateful_tensor_container.py @@ -0,0 +1,74 @@ +import pytest +import torch + +from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor +from colossalai.gemini.stateful_tensor_container import QueueSTContainer, HeapSTContainer + + +@pytest.mark.dist +def test_stateful_tensor_container(): + st1 = StatefulTensor(torch.randn(1, device='cuda')) + st2 = StatefulTensor(torch.randn(2, device='cuda')) + st3 = StatefulTensor(torch.randn(3, device='cuda')) + stateful_tensor_list = [st1, st2, st3] + step_list = [st1, st2, st3, st3, st2, st1] + + compute_step_dict = dict() + compute_step_dict[st1] = [0, 5] + compute_step_dict[st2] = [1, 4] + compute_step_dict[st3] = [2, 3] + + def run_queue_test(): + # test queue container + queue_container = QueueSTContainer(compute_step_dict, 6) + queue_container.create(stateful_tensor_list) + + res_list = [] + + for i in range(6): + stateful_tensor = step_list[i] + stateful_tensor.trans_state(TensorState.COMPUTE) + st_out = queue_container.pop() + st_out.move_to(torch.device('cpu')) + + res_list.append(st_out.payload.size(0)) + + stateful_tensor.move_to(torch.device('cuda')) + queue_container.push(stateful_tensor, i) + stateful_tensor.trans_state(TensorState.HOLD) + + assert res_list == [2, 3, 1, 2, 3, 2] + + run_queue_test() + + def run_heap_test(): + # test heap container + st1.move_to(torch.device('cuda')) + st2.move_to(torch.device('cuda')) + st3.move_to(torch.device('cuda')) + + heap_container = HeapSTContainer(compute_step_dict, 6) + heap_container.create(stateful_tensor_list) + + res_list = [] + + for i in range(6): + stateful_tensor = step_list[i] + stateful_tensor.trans_state(TensorState.COMPUTE) + st_out = heap_container.pop() + + if st_out is not None: + res_list.append(st_out.payload.size(0)) + st_out.move_to(torch.device('cpu')) + + stateful_tensor.move_to(torch.device('cuda')) + heap_container.push(stateful_tensor, i) + stateful_tensor.trans_state(TensorState.HOLD) + + assert res_list == [3, 1, 2, 3, 2] + + run_heap_test() + + +if __name__ == '__main__': + test_stateful_tensor_container() diff --git a/tests/test_gemini/update/test_chunk_mgrv2.py b/tests/test_gemini/update/test_chunk_mgrv2.py index fa7a9b1b5..c4df217e1 100644 --- a/tests/test_gemini/update/test_chunk_mgrv2.py +++ b/tests/test_gemini/update/test_chunk_mgrv2.py @@ -3,7 +3,7 @@ import colossalai import pytest import torch.multiprocessing as mp from functools import partial -from colossalai.gemini.chunk import ChunkManager +from colossalai.gemini.update import ChunkManagerV2 from colossalai.testing import rerun_if_address_is_in_use, parameterize from colossalai.utils import free_port from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec @@ -19,17 +19,23 @@ CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}} def exam_chunk_memory(keep_gathered, pin_memory): pg = ProcessGroup() - debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory)) + debug_print([0], "keep_gathered: {}, pin_memory: {}".format( + keep_gathered, pin_memory)) params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)] - config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)} + config = { + 2: dict( + chunk_size=128, + keep_gathered=keep_gathered + ) + } - chunk_manager = ChunkManager(config) + chunk_manager = ChunkManagerV2(config, pin_memory=pin_memory) assert chunk_manager.total_mem['cpu'] == 0 assert chunk_manager.total_mem['cuda'] == 0 for p in params: - chunk_manager.append_tensor(p, 'param', 2, pin_memory=pin_memory) + chunk_manager.append_tensor(p, 'param', 2) chunk_manager.close_all_groups() assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] diff --git a/tests/test_gemini/update/test_chunkv2.py b/tests/test_gemini/update/test_chunkv2.py index 57a49314f..deea46acb 100644 --- a/tests/test_gemini/update/test_chunkv2.py +++ b/tests/test_gemini/update/test_chunkv2.py @@ -9,7 +9,7 @@ from colossalai.utils import free_port, get_current_device from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ColoParameter from colossalai.gemini import TensorState -from colossalai.gemini.chunk import Chunk +from colossalai.gemini.update import ChunkV2 def dist_sum(x): @@ -38,12 +38,14 @@ def check_euqal(param, param_cp): def exam_chunk_basic(init_device, keep_gathered, pin_memory): world_size = torch.distributed.get_world_size() pg = ColoProcessGroup() - my_chunk = Chunk(chunk_size=1024, - process_group=pg, - dtype=torch.float32, - init_device=init_device, - keep_gathered=keep_gathered, - pin_memory=pin_memory) + my_chunk = ChunkV2( + chunk_size=1024, + process_group=pg, + dtype=torch.float32, + init_device=init_device, + keep_gathered=keep_gathered, + pin_memory=pin_memory + ) param_list = [] param_cp_list = [] diff --git a/tests/test_gemini/update/test_fwd_bwd.py b/tests/test_gemini/update/test_fwd_bwd.py deleted file mode 100644 index 6bd25c0be..000000000 --- a/tests/test_gemini/update/test_fwd_bwd.py +++ /dev/null @@ -1,109 +0,0 @@ -import pytest -import colossalai -import torch -import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext - -from functools import partial -from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal -from tests.components_to_test.registry import non_distributed_component_funcs -from torch.nn.parallel import DistributedDataParallel as DDP -from colossalai.nn.parallel import ZeroDDP -from colossalai.nn.optimizer import HybridAdam -from colossalai.zero import ZeroOptimizer -from colossalai.testing import parameterize -from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor -from tests.test_tensor.common_utils import debug_print - -from time import time -from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager - - -def check_grad(model: ZeroDDP, torch_model: torch.nn.Module): - chunk_manager = model.chunk_manager - param_list = [p for p in model.parameters()] - chunk_list = chunk_manager.get_chunks(param_list) - for chunk in chunk_list: - chunk_manager.access_chunk(chunk) - - for (p0, p1) in zip(model.parameters(), torch_model.parameters()): - assert torch.allclose(p0, p1.grad, atol=1e-3, rtol=1e-5), "{}".format(torch.max(torch.abs(p0 - p1.grad)).item()) - - -def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): - optimizer.zero_grad() - logits = model(input_ids, attn_mask) - logits = logits.float() - loss = criterion(logits, input_ids) - optimizer.backward(loss) - return logits - - -@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) -def exam_gpt_fwd_bwd(placement_policy): - set_seed(42) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - with ColoInitContext(device=get_current_device()): - model = model_builder() - - torch_model = model_builder().cuda() - for torch_p, p in zip(torch_model.parameters(), model.parameters()): - torch_p.data.copy_(p.data) - - world_size = torch.distributed.get_world_size() - config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = False - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) - - pg = ProcessGroup() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) - torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) - torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) - torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) - - model.eval() - torch_model.eval() - - set_seed(pg.dp_local_rank()) - for i, (input_ids, attn_mask) in enumerate(train_dataloader): - if i > 0: - break - - logits = model(input_ids, attn_mask) - logits = logits.float() - loss = criterion(logits, input_ids) - model.backward(loss) - - torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) - assert torch.allclose(logits, torch_logits, rtol=0), "{} {} {}".format( - torch.max(torch.abs(logits - torch_logits)).item(), logits, torch_logits) - - check_grad(model, torch_model) - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_gpt_fwd_bwd() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_gpt(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_gpt(1) diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py deleted file mode 100644 index cefda045d..000000000 --- a/tests/test_gemini/update/test_optim.py +++ /dev/null @@ -1,118 +0,0 @@ -import pytest -import colossalai -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext - -from functools import partial -from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal -from tests.components_to_test.registry import non_distributed_component_funcs -from torch.nn.parallel import DistributedDataParallel as DDP -from colossalai.nn.parallel import ZeroDDP -from colossalai.nn.optimizer import HybridAdam -from colossalai.zero import ZeroOptimizer -from colossalai.testing import parameterize -from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.gemini_mgr import GeminiManager -from tests.test_tensor.common_utils import debug_print - -from time import time -from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager - - -def check_param(model: ZeroDDP, torch_model: torch.nn.Module): - zero_dict = model.state_dict(only_rank_0=False) - torch_dict = torch_model.state_dict() - - for key, value in torch_dict.items(): - # key is 'module.model.PARAMETER', so we truncate it - key = key[7:] - if key == 'model.lm_head.weight': - continue - assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) - temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) - # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) - assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key) - - -def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): - optimizer.zero_grad() - logits = model(input_ids, attn_mask) - logits = logits.float() - loss = criterion(logits, input_ids) - optimizer.backward(loss) - return logits - - -@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) -def exam_gpt_fwd_bwd(placement_policy): - set_seed(42) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - with ColoInitContext(device=get_current_device()): - model = model_builder() - - torch_model = model_builder().cuda() - for torch_p, p in zip(torch_model.parameters(), model.parameters()): - torch_p.data.copy_(p.data) - - world_size = torch.distributed.get_world_size() - config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = False - if placement_policy != 'cuda': - init_device = torch.device('cpu') - else: - init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) - - optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) - - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) - torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) - torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) - torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) - - model.eval() - torch_model.eval() - - set_seed(dist.get_rank() * 3 + 128) - for i, (input_ids, attn_mask) in enumerate(train_dataloader): - if i > 2: - break - - zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask) - torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) - assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) - # debug_print([0], zero_logits, torch_logits) - - zero_optim.step() - torch_optim.step() - - check_param(model, torch_model) - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_gpt_fwd_bwd() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_gpt(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_gpt(1) diff --git a/tests/test_gemini/update/test_search.py b/tests/test_gemini/update/test_search.py index 6655c3e39..fcc6bcf0e 100644 --- a/tests/test_gemini/update/test_search.py +++ b/tests/test_gemini/update/test_search.py @@ -8,7 +8,7 @@ import torch.distributed as dist import colossalai from colossalai.testing import rerun_if_address_is_in_use -from colossalai.gemini.chunk import search_chunk_configuration +from colossalai.gemini.update import search_chunk_configuration from colossalai.utils import free_port, get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup @@ -35,11 +35,12 @@ def exam_search_chunk_size(): with ColoInitContext(device=get_current_device()): model = model_builder() init_1d_row_spec(model, pg_tp) - config_dict = search_chunk_configuration(model, - search_range_mb=1, - search_interval_byte=16, - min_chunk_size_mb=0, - filter_exlarge_params=True) + config_dict = search_chunk_configuration( + model, + search_range_mb=1, + search_interval_byte=16, + min_chunk_size_mb=0, + filter_exlarge_params=True) for key in config_dict: chunk_size = config_dict[key]['chunk_size'] diff --git a/tests/test_gemini/update/test_zeroddp_state_dict.py b/tests/test_gemini/update/test_zeroddp_state_dict.py deleted file mode 100644 index 827280dc1..000000000 --- a/tests/test_gemini/update/test_zeroddp_state_dict.py +++ /dev/null @@ -1,114 +0,0 @@ -import pytest -import colossalai -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext - -from functools import partial -from tests.test_tensor.common_utils import set_seed -from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.nn.parallel import ZeroDDP -from colossalai.zero import ZeroOptimizer -from colossalai.testing import parameterize -from colossalai.gemini.gemini_mgr import GeminiManager -from tests.test_tensor.common_utils import debug_print - -from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager - - -@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) -@parameterize('keep_gathered', [True, False]) -def exam_state_dict(placement_policy, keep_gathered): - set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - with ColoInitContext(device=get_current_device()): - model = model_builder() - - torch_model = model_builder() - for torch_p, p in zip(torch_model.parameters(), model.parameters()): - torch_p.data.copy_(p.data) - - world_size = torch.distributed.get_world_size() - config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gathered - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) - model.train() - - zero_dict = model.state_dict(only_rank_0=False) - torch_dict = torch_model.state_dict() - - for key, value in torch_dict.items(): - if key == 'model.lm_head.weight': - continue - assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) - temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) - assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) - - -@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) -@parameterize('keep_gathered', [True, False]) -def exam_load_state_dict(placement_policy, keep_gathered): - set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - with ColoInitContext(device=get_current_device()): - model = model_builder() - - set_seed(451) - torch_model = model_builder() # get a different model - - world_size = torch.distributed.get_world_size() - config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gathered - - if placement_policy != 'cuda': - init_device = torch.device('cpu') - else: - init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) - - optimizer = torch.optim.Adam(model.parameters()) - optim = ZeroOptimizer(optimizer, model) # initialize the link between chunk16 and chunk32 - - torch_dict = torch_model.state_dict() - model.load_state_dict(torch_dict, strict=False) - zero_dict = model.state_dict(only_rank_0=False) - - for key, value in torch_dict.items(): - if key == 'model.lm_head.weight': - continue - assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) - temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) - assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_state_dict() - exam_load_state_dict() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_zero_ddp(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_zero_ddp(1) diff --git a/tests/test_gemini/update/test_zerooptim_state_dict.py b/tests/test_gemini/update/test_zerooptim_state_dict.py deleted file mode 100644 index 169c0effc..000000000 --- a/tests/test_gemini/update/test_zerooptim_state_dict.py +++ /dev/null @@ -1,81 +0,0 @@ -import pytest -import colossalai -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext - -from functools import partial -from tests.test_tensor.common_utils import set_seed -from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.nn.parallel import ZeroDDP -from colossalai.zero import ZeroOptimizer -from colossalai.testing import parameterize -from colossalai.gemini.gemini_mgr import GeminiManager -from tests.test_tensor.common_utils import debug_print - -from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager - - -@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) -@parameterize('keep_gathered', [True, False]) -def exam_zero_optim_state_dict(placement_policy, keep_gathered): - set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - with ColoInitContext(device=get_current_device()): - model = model_builder() - - set_seed(451) - torch_model = model_builder() # get a different model - - world_size = torch.distributed.get_world_size() - config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gathered - - if placement_policy != 'cuda': - init_device = torch.device('cpu') - else: - init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) - - optimizer = torch.optim.Adam(model.parameters()) - optim = ZeroOptimizer(optimizer, model) # initialize the link between chunk16 and chunk32 - - set_seed(dist.get_rank() * 3 + 128) - for i, (input_ids, attn_mask) in enumerate(train_dataloader): - if i > 0: - break - optim.zero_grad() - logits = model(input_ids, attn_mask) - logits = logits.float() - loss = criterion(logits, input_ids) - optim.backward(loss) - - optim_state_dict = optim.state_dict() - optim.load_state_dict(optim_state_dict) - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_zero_optim_state_dict() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_zero_optim(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_zero_optim(1) diff --git a/tests/test_tensor/test_chunk.py b/tests/test_tensor/test_chunk.py new file mode 100644 index 000000000..1f1b6e44b --- /dev/null +++ b/tests/test_tensor/test_chunk.py @@ -0,0 +1,86 @@ +import torch +import colossalai +import pytest +import torch.multiprocessing as mp +from typing import List +from functools import partial +from colossalai.gemini import ChunkManager +from colossalai.testing import rerun_if_address_is_in_use, parameterize +from colossalai.utils import free_port +from colossalai.tensor import ProcessGroup as ColoProcessGroup + + +def check_has_params(params: List[torch.Tensor], has_tensors: List[bool]): + for p, has_tensor in zip(params, has_tensors): + if has_tensor: + assert p.storage().size() > 0 + assert p.device.type == 'cuda' + else: + assert p.storage().size() == 0 + + +# HAS_TENSORS[use_chunk][use_zero] +HAS_TENSORS = { + True: { + True: [[True, True, False], [False, False, True]], + False: [[True, True, True], [True, True, True]] + }, + False: { + True: [[True, False, True], [False, True, False]], + False: [[True, True, True], [True, True, True]] + } +} + +TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512, 256], False: [768, 768]}} + + +@parameterize('use_chunk', [False, True]) +@parameterize('use_zero', [False, True]) +def run_chunk_zero(use_chunk, use_zero): + pg = ColoProcessGroup() + rank = pg.rank() + if rank == 0: + print(f'use_chunk={use_chunk}, use_zero={use_zero}') + params = [torch.rand(8, 8) for _ in range(3)] + chunk_size = 128 if use_chunk else None + chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero) + chunk_manager.create_group('param') + assert chunk_manager.total_mem['cpu'] == 0 + assert chunk_manager.total_mem['cuda'] == 0 + for p in params: + chunk_manager.append_tensor(p, 'param') + check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank]) + assert chunk_manager.total_mem['cpu'] == 0 + assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank] + chunks = chunk_manager.get_chunks(params) + for chunk in chunks: + chunk_manager.access_chunk(chunk) + check_has_params(params, [True, True, True]) + assert chunk_manager.total_mem['cpu'] == 0 + assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank] + for chunk in chunks: + chunk_manager.release_chunk(chunk) + check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank]) + assert chunk_manager.total_mem['cpu'] == 0 + assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda'] + for chunk in chunks: + chunk_manager.move_chunk(chunk, torch.device('cpu')) + assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda'] + assert chunk_manager.total_mem['cuda'] == 0 + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_chunk_zero() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_chunk_mapping(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_chunk_mapping(2) diff --git a/tests/test_tensor/test_zero_optim.py b/tests/test_tensor/test_zero_optim.py index 70cb837d8..b08ceed32 100644 --- a/tests/test_tensor/test_zero_optim.py +++ b/tests/test_tensor/test_zero_optim.py @@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini import ChunkManager from functools import partial from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal from tests.components_to_test.registry import non_distributed_component_funcs @@ -21,20 +21,20 @@ from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, Compute from tests.test_tensor.model.test_gpt2 import init_megatron_spec -def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup): - zero_dict = model.state_dict(only_rank_0=False) - torch_dict = torch_model.state_dict() +def check_param_equal(model, torch_model, pg: ProcessGroup): + for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()): + if p.storage().size() > 0: + assert p.dtype == torch.float16 + assert tensor_shard_equal(tp.to(dtype=p.dtype, device=p.device), p, pg.tp_local_rank(), + pg.tp_world_size()), f'{tp} vs {p}\n{n}:\n\t{tp.shape} vs {p.shape}' - for key, value in torch_dict.items(): - # key is 'module.model.PARAMETER', so we truncate it - key = key[7:] - if key == 'model.lm_head.weight': - continue - assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) - temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) - # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) - assert tensor_shard_equal(value, temp_zero_value, pg.tp_local_rank(), pg.tp_world_size()), \ - "parameter '{}' has problem.".format(key) + +def check_grad_equal(model, torch_model, pg: ProcessGroup): + for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()): + if p.grad is not None: + assert tensor_shard_equal(tp.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad, + pg.tp_local_rank(), pg.tp_world_size()), \ + f'{tp.grad} vs {p.grad}\n{n}:\n\t{tp.grad.shape} vs {p.grad.shape} in {pg.rank()}' def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): @@ -62,8 +62,10 @@ def init_1d_col_spec(model, pg: ProcessGroup): p.set_tensor_spec(*spec) +@parameterize('use_chunk', [False, True]) +@parameterize('use_zero', [False, True]) @parameterize('placement_policy', ['cuda', 'cpu']) -def run_gpt(placement_policy, tp_init_spec_func=None): +def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable('gpt2') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -87,20 +89,15 @@ def run_gpt(placement_policy, tp_init_spec_func=None): if tp_init_spec_func: tp_init_spec_func(model, pg) - dp_world_size = pg.dp_world_size() - config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[dp_world_size]['chunk_size'] = 5000 - config_dict[dp_world_size]['keep_gathered'] = False - if placement_policy != 'cuda': - init_device = torch.device('cpu') - else: - init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) + chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None + chunk_manager = ChunkManager(chunk_size, + pg, + enable_distributed_storage=use_zero, + init_device=GeminiManager.get_default_device(placement_policy)) gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) - - optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1) + model = ZeroDDP(model, gemini_manager) + optim = HybridAdam(model.parameters(), lr=1e-3) + optim = ZeroOptimizer(optim, model, initial_scale=1) amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) @@ -108,7 +105,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None): torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) print(chunk_manager) - check_param(model, torch_model, pg) + check_param_equal(model, torch_model, pg) model.eval() torch_model.eval() @@ -118,13 +115,13 @@ def run_gpt(placement_policy, tp_init_spec_func=None): if i > 2: break input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg)) - zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo, attn_mask) + logits = run_fwd_bwd(model, criterion, optim, input_ids_colo, attn_mask) torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) - assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) - - zero_optim.step() + assert tensor_equal(logits, torch_logits) + check_grad_equal(model, torch_model, pg) + optim.step() torch_optim.step() - check_param(model, torch_model, pg) + check_param_equal(model, torch_model, pg) def run_dist(rank, world_size, port): diff --git a/tests/test_zero/test_zero_optim_state_dict.py b/tests/test_zero/test_zero_optim_state_dict.py new file mode 100644 index 000000000..cc67242c9 --- /dev/null +++ b/tests/test_zero/test_zero_optim_state_dict.py @@ -0,0 +1,100 @@ +import pytest +import colossalai +import torch +import torch.multiprocessing as mp +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.gemini import ChunkManager +from functools import partial +from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.nn.parallel import ZeroDDP +from colossalai.nn.optimizer import HybridAdam +from colossalai.zero import ZeroOptimizer +from colossalai.testing import parameterize +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.tensor import ProcessGroup + + +def check_state(s1, s2): + for v1, v2 in zip(s1.values(), s2.values()): + if isinstance(v1, torch.Tensor): + v1 = v1.to(v2.device) + assert torch.equal(v1, v2), f'{torch.sum((v1-v2).abs())}' + else: + assert v1 == v2 + + +def check_load_state_dict(optim, torch_optim): + for group, torch_group in zip(optim.optim.param_groups, torch_optim.param_groups): + for p, torch_p in zip(group['params'], torch_group['params']): + state = optim.optim.state[p] + torch_state = torch_optim.state[torch_p] + if p.storage().size() == 0: + assert len(state) == 0 + check_state(state, torch_state) + + +def check_state_dict(state_dict, torch_state_dict): + for (k1, s1), (k2, s2) in zip(state_dict['state'].items(), torch_state_dict['state'].items()): + assert k1 == k2 + check_state(s1, s2) + + +@parameterize('use_chunk', [False, True]) +@parameterize('use_zero', [False, True]) +@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +@parameterize('only_rank_0', [False, True]) +def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy, only_rank_0): + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + model = model.cuda() + torch_model = model_builder().cuda() + + pg = ProcessGroup() + + chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None + chunk_manager = ChunkManager(chunk_size, + pg, + enable_distributed_storage=use_zero, + init_device=GeminiManager.get_default_device(placement_policy)) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager) + optim = HybridAdam(model.parameters(), lr=1e-3) + optim = ZeroOptimizer(optim, model, initial_scale=1) + + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + + for p in torch_model.parameters(): + p.grad = torch.rand_like(p) + + torch_optim.step() + torch_state_dict = torch_optim.state_dict() + optim.load_state_dict(torch_state_dict) + check_load_state_dict(optim, torch_optim) + + state_dict = optim.state_dict(only_rank_0) + if not only_rank_0 or pg.rank() == 0: + check_state_dict(state_dict, torch_state_dict) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_zero_optim_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_zero_optim_state_dict(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_optim_state_dict(2)