diff --git a/colossalai/nn/_ops/cache_embedding/__init__.py b/colossalai/nn/_ops/cache_embedding/__init__.py new file mode 100644 index 000000000..4693e9055 --- /dev/null +++ b/colossalai/nn/_ops/cache_embedding/__init__.py @@ -0,0 +1,4 @@ +from .cache_mgr import CachedParamMgr +from .copyer import LimitBuffIndexCopyer + +__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer'] \ No newline at end of file diff --git a/colossalai/nn/_ops/cache_embedding/base_embedding.py b/colossalai/nn/_ops/cache_embedding/base_embedding.py new file mode 100644 index 000000000..705835a0e --- /dev/null +++ b/colossalai/nn/_ops/cache_embedding/base_embedding.py @@ -0,0 +1,36 @@ +import abc +import torch.nn as nn + + +class BaseEmbeddingBag(abc.ABC, nn.Module): + + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2., + scale_grad_by_freq=False, + sparse=False, + mode='mean', + include_last_offset=False, + ): + super(BaseEmbeddingBag, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + + # Specific to embedding bag + self.mode = mode + self.include_last_offset = include_last_offset diff --git a/colossalai/nn/_ops/cache_embedding/cache_mgr.py b/colossalai/nn/_ops/cache_embedding/cache_mgr.py new file mode 100644 index 000000000..79e188b07 --- /dev/null +++ b/colossalai/nn/_ops/cache_embedding/cache_mgr.py @@ -0,0 +1,348 @@ +import numpy as np +import torch +from torch.profiler import record_function +from typing import List, Optional +from contexttimer import Timer +from .copyer import LimitBuffIndexCopyer + + +class CachedParamMgr(torch.nn.Module): + """ + Manage Embedding Weights in Cache on CPU and CUDA memory. + CPU maintains entire original weight. + CUDA maintains a fraction of weights used in the upcomming computation. + During training, GPU needs to transmit rows between CPU and GPU. + """ + + def __init__(self, weight: torch.Tensor, cuda_row_num: int = 0, buffer_size: int = 50_000) -> None: + super(CachedParamMgr, self).__init__() + self.buffer_size = buffer_size + self.num_embeddings, self.embedding_dim = weight.shape + self.cuda_row_num = cuda_row_num + self._cuda_available_row_num = self.cuda_row_num + + self.elem_size_in_byte = weight.element_size() + + self.cuda_cached_weight = torch.nn.Parameter( + torch.zeros(self.cuda_row_num, self.embedding_dim, device=torch.cuda.current_device(), dtype=weight.dtype)) + + if weight.device.type == 'cuda': + weight = weight.cpu() + + # pin memory cpu for higher CPU-GPU copy bandwidth + self.cpu_weight = weight.contiguous().pin_memory() + + # map original id to new id with respect to frequency + # id -> cpu_row_idx + self.register_buffer( + "idx_map", + torch.arange(self.num_embeddings, dtype=torch.long, device=torch.cuda.current_device()), + persistent=False, + ) + + # cached_idx_map: gpu_row_idx -> cpu_row_idx + self.register_buffer("cached_idx_map", + torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), + dtype=torch.long).fill_(-1), + persistent=False) + + # cpu_row_id -> gpu_row_idx. + # gpu_row_idx as -1 means cpu_row_id not in CUDA. + self.register_buffer("inverted_cached_idx", + torch.zeros(self.num_embeddings, device=torch.cuda.current_device(), + dtype=torch.long).fill_(-1), + persistent=False) + + self.evict_backlist = torch.tensor([], device=torch.cuda.current_device()) + + # index copy buffer size should less than 10% of cuda weight. + if self.buffer_size > 0: + self.limit_buff_index_copyer = LimitBuffIndexCopyer(self.buffer_size) + + self.num_hits_history = [] + self.num_miss_history = [] + self.num_write_back_history = [] + self.input_id_percent_in_load_chunk = [] + self._reset_comm_stats() + + def cpu_weight_data(self, chunk_id: int) -> torch.Tensor: + """ + access a chunk of CPU weight. + + Args: + chunk_id (int): chunk id + + Returns: + torch.Tensor: a piece of memory in CPU weight corresponding to chunk id's payload. The tensor is 1-D. + """ + + return self.cpu_weight.data.view(-1).narrow(0, + int(chunk_id) * self.embedding_dim, + self.embedding_dim).view(1, self.embedding_dim) + + @property + def cuda_available_chunk_num(self): + return self._cuda_available_row_num + + @torch.no_grad() + def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7): + """reorder the cpu_weight according to ids' frequency in dataset before training. + Also Build the IndexMappingTable, aka index_mapping_table. + Execute only once before training. + Args: + ids_freq_mapping (List[int]): a list, idx is id number, value is freq. if None no reorder + warmup_ratio (float): the amount of chunks preloaded in cuda cache + """ + if ids_freq_mapping is not None: + tmp_idx = torch.argsort(torch.from_numpy(ids_freq_mapping).cuda(), descending=True) + sorted_idx = torch.argsort(tmp_idx) + self.idx_map.data.copy_(sorted_idx) + + # TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks. + # As cuda_cached_weight is very big. You may not have that much available memory! + # Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda + preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings) + if preload_row_num > 0: + with Timer() as timer: + # extract chunks from cpu weight + preload_row_ids = torch.arange(preload_row_num) + preload_slot_ids = preload_row_ids.cuda() + + if self.buffer_size > 0: + self.limit_buff_index_copyer.index_copy(0, + src_index=preload_row_ids, + tgt_index=preload_slot_ids, + src=self.cpu_weight.view(self.num_embeddings, -1), + tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) + else: + preload_chunks = self.cpu_weight.view(self.num_embeddings, -1).index_select(0, + preload_row_ids).cuda() + self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_chunks) + + # update auxiliary info + slot_offsets = preload_slot_ids + self.cached_idx_map[preload_slot_ids] = preload_slot_ids + self.inverted_cached_idx[preload_slot_ids] = slot_offsets + self._cuda_available_row_num -= preload_row_num + print(f'Cache warmup finished cost {timer.elapsed} sec.') + + def flush(self): + """flush all CUDA chunks to CPU. + The function is usually called after training finished. + """ + slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1) + chunk_ids = self.cached_idx_map[slots] + chunks = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu() + self.cpu_weight.view(self.num_embeddings, -1).index_copy_(0, chunk_ids.cpu(), chunks) + self.cached_idx_map.index_fill_(0, slots, -1) + self.inverted_cached_idx.index_fill_(0, chunk_ids, -1) + self._cuda_available_row_num += slots.numel() + + assert self._cuda_available_row_num == self.cuda_row_num + assert torch.all(self.inverted_cached_idx == -1).item() + assert torch.all(self.cached_idx_map == -1).item() + + def print_comm_stats(self): + if self._cuda_to_cpu_numel > 0: + print( + f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / self._cuda_to_cpu_elapse} MB/s {self._cuda_to_cpu_numel / 1e6} M elem" + ) + if self._cpu_to_cuda_numel > 0: + print( + f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / self._cpu_to_cuda_elpase} MB/s {self._cpu_to_cuda_numel / 1e6} M elem" + ) + + @torch.no_grad() + def _id_to_cached_cuda_id(self, ids: torch.Tensor) -> torch.Tensor: + """ + convert ids to indices in self.cuda_cached_weight. + Implemented with parallel operations on GPU. + + Args: + ids (torch.Tensor): ids from the dataset + + Returns: + torch.Tensor: contains indices in self.cuda_cached_weight + """ + ids = self.idx_map.index_select(0, ids.view(-1)) + ret = self.inverted_cached_idx.index_select(0, ids) + return ret + + @torch.no_grad() + def prepare_ids(self, ids: torch.Tensor) -> torch.Tensor: + """ + move the cpu embedding rows w.r.t. ids into CUDA memory + + Args: + ids (torch.Tensor): the ids to be computed + Returns: + torch.Tensor: indices on the cuda_cached_weight. + """ + with record_function("(zhg) get unique indices"): + cpu_row_idxs = torch.unique(self.idx_map.index_select(0, ids)) + + assert len(cpu_row_idxs) <= self.cuda_row_num, \ + f"the input indices pull {len(cpu_row_idxs)} chunks, " \ + f"which is larger than the presented {self.cuda_row_num}, " \ + f"please increase cuda_row_num shrink batch size" + self.evict_backlist = cpu_row_idxs + + with record_function("(zhg) get cpu chunk indices"): + comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)] + + self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs)) + self.num_miss_history.append(len(comm_cpu_row_idxs)) + self.num_write_back_history.append(0) + + # move sure the cuda chunk will not be evicted! + with record_function("(zhg) cache update"): + self._prepare_rows_on_cuda(comm_cpu_row_idxs) + + self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype) + # new ids chunk_offset + offset_in_chunk + with record_function("(zhg) embed idx -> cache chunk id"): + gpu_row_idxs = self._id_to_cached_cuda_id(ids) + return gpu_row_idxs + + def _reset_comm_stats(self): + self._cpu_to_cuda_numel = 0 + self._cpu_to_cuda_elpase = 0 + self._cuda_to_cpu_elapse = 0 + self._cuda_to_cpu_numel = 0 + + def _chunk_in_cuda(self, chunk_id: int) -> bool: + return self.inverted_cached_idx[chunk_id] != -1 + + @torch.no_grad() + def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: + """prepare rows in cpu_row_idxs on CUDA memory + Args: + cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA + """ + evict_num = cpu_row_idxs.numel() - self.cuda_available_chunk_num + if evict_num > 0: + with Timer() as timer: + mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) + backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() + invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) + + self.cached_idx_map.index_fill_(0, invalid_idxs, -2) + evict_gpu_row_idxs = torch.argsort(self.cached_idx_map, descending=True)[:evict_num] + self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) + + evict_info = self.cached_idx_map[evict_gpu_row_idxs] + + if self.buffer_size > 0: + self.limit_buff_index_copyer.index_copy(0, + src_index=evict_gpu_row_idxs, + tgt_index=evict_info.cpu(), + src=self.cuda_cached_weight.view(self.cuda_row_num, -1), + tgt=self.cpu_weight.view(self.num_embeddings, -1)) + else: + # allocate tmp memory on CPU and copy rows on CUDA to CPU. + rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs).cpu() + self.cpu_weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows) + + self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1) + self.inverted_cached_idx.index_fill_(0, evict_info, -1) + self._cuda_available_row_num += evict_num + + weight_size = evict_gpu_row_idxs.numel() * self.embedding_dim + self._cuda_to_cpu_elapse += timer.elapsed + self._cuda_to_cpu_numel += weight_size + # print(f"evict embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB") + + with Timer() as timer: + slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[:cpu_row_idxs.numel()] + # Here also allocate extra memory on CUDA. #cpu_row_idxs + if self.buffer_size > 0: + self.limit_buff_index_copyer.index_copy(0, + src_index=cpu_row_idxs.cpu(), + tgt_index=slots, + src=self.cpu_weight.view(self.num_embeddings, -1), + tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) + else: + rows = self.cpu_weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda() + self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, rows) + slot_offsets = slots + self.cached_idx_map[slots] = cpu_row_idxs + self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets) + self._cuda_available_row_num -= cpu_row_idxs.numel() + self._cpu_to_cuda_elpase += timer.elapsed + weight_size = cpu_row_idxs.numel() * self.embedding_dim + self._cpu_to_cuda_numel += weight_size + # print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB") + + def _evict(self) -> int: + """ + evict one chunk from cuda to cpu. + Returns: + (int) : the slot id be evicted. + """ + mask = torch.logical_or(torch.isin(self.cached_idx_map, self.evict_backlist), self.cached_idx_map == -1) + buf = self.cached_idx_map[mask].clone() + idx = torch.nonzero(mask).squeeze(1) + self.cached_idx_map.index_fill_(0, idx, -1) + max_row, max_cpu_row_idx = torch.max(self.cached_idx_map, dim=0) + max_gpu_row_idx = self.cached_idx_map[max_cpu_row_idx] + + if max_gpu_row_idx == -1: + raise RuntimeError("Can not evict a row") + + max_gpu_row_idx = max_gpu_row_idx.item() + max_offset = self.inverted_cached_idx[max_gpu_row_idx] + # recover + self.cached_idx_map.index_copy_(0, idx, buf) + + with Timer() as timer: + cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, max_offset * self.embedding_dim, + self.embedding_dim).view(1, self.embedding_dim) + self.cpu_weight_data(max_gpu_row_idx).data.copy_(cuda_tensor) + + # update inverted_cached_idx, min_slot_id is evicted from cuda + self.cached_idx_map[max_cpu_row_idx] = -1 + + self.inverted_cached_idx[max_gpu_row_idx] = -1 + + self._cuda_available_row_num += 1 + + self._cuda_to_cpu_numel += self.embedding_dim + self._cuda_to_cpu_elapse += timer.elapsed + # self.num_write_back_history[-1] += 1 + return max_cpu_row_idx + + def _find_free_cuda_row(self) -> int: + if self._cuda_available_row_num == 0: + return -1 + candidates = torch.nonzero(self.cached_idx_map == -1).squeeze(1) + return candidates[0].item() + + @torch.no_grad() + def _admit(self, row_id: int): + """ + move in row_id to CUDA + + Args: + row_id (int): the id of row to be moved in + """ + # find a free slot in partial cuda weight + slot_id = self._find_free_cuda_row() + + if slot_id == -1: + # evict one row + slot_id = self._evict() + slot_offset = slot_id + # copy payload from cpu to cuda + with Timer() as timer: + cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, slot_offset * self.embedding_dim, + self.embedding_dim).view(1, self.embedding_dim) + cuda_tensor.data.copy_(self.cpu_weight_data(row_id)) + + # update the inverted_cached_idx + self.cached_idx_map[slot_id] = row_id + self.inverted_cached_idx[row_id] = slot_offset + + self._cuda_available_row_num -= 1 + + self._cpu_to_cuda_numel += self.embedding_dim + self._cpu_to_cuda_elpase += timer.elapsed diff --git a/colossalai/nn/_ops/cache_embedding/copyer.py b/colossalai/nn/_ops/cache_embedding/copyer.py new file mode 100644 index 000000000..a8e8a819d --- /dev/null +++ b/colossalai/nn/_ops/cache_embedding/copyer.py @@ -0,0 +1,48 @@ +import torch +from torch import LongTensor + + +class LimitBuffIndexCopyer(object): + """LimitBuffIndexCopyer + Index Copy using limited temp buffer on CUDA. + + Args: + size (int): buffer size + """ + + def __init__(self, size: int) -> None: + self._buff_size = size + + @torch.no_grad() + def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor): + """copy + src tensor[src_index] -(index_select)-> tmp -()-> tgt tensor [tgt_index] + The valid part in src is continous, while in tgt is scatter. + Args: + dim (int): dimension along which to index + src_index (int): indices of src tensor to select from + tgt_index (int): indices of tgt tensor to select from + src (torch.Tensor): the tensor containing values to copy + tgt (torch.Tensor): the tensor to be copied + """ + # tgt.index_copy_(dim, index, src) + assert dim == 0, "only support index_copy on dim 0" + assert tgt.dim() == 2 + assert src.dim() == 2 + tgt_device = tgt.device + src_device = src.device + + assert src_index.numel() == tgt_index.numel() + dim_size = src_index.numel() + src_index = src_index.to(src_device) + for begin_pos in range(0, dim_size, self._buff_size): + cur_len = min(self._buff_size, dim_size - begin_pos) + src_idx_piece = src_index.narrow(0, begin_pos, cur_len) + if src_device.type == 'cpu' and tgt_device.type == 'cuda': + cpu_tmp_buffer = src.index_select(dim, src_idx_piece).pin_memory() + tmp_buffer = torch.empty_like(cpu_tmp_buffer, device=tgt_device) + tmp_buffer.copy_(cpu_tmp_buffer) + else: + tmp_buffer = src.index_select(dim, src_idx_piece).to(tgt_device) + tgt_idx_piece = tgt_index.narrow(0, begin_pos, cur_len) + tgt.index_copy_(dim, tgt_idx_piece, tmp_buffer) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index e69bcc244..ae3ff2fe9 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -5,3 +5,4 @@ timm titans torchaudio torchrec +contexttimer diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 572ee77dd..528bc6f25 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -7,3 +7,4 @@ pre-commit rich click fabric +contexttimer \ No newline at end of file diff --git a/tests/test_tensor/ops/test_cache_embedding.py b/tests/test_tensor/ops/test_cache_embedding.py new file mode 100644 index 000000000..6546af361 --- /dev/null +++ b/tests/test_tensor/ops/test_cache_embedding.py @@ -0,0 +1,76 @@ +import pytest +from functools import partial +import torch +import torch.multiprocessing as mp +import numpy as np + +from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use + +from colossalai.nn._ops.cache_embedding import CachedParamMgr + +NUM_EMBED, EMBED_DIM = 100, 8 +BATCH_SIZE = 8 + + +def test_cachemgr(): + model = torch.nn.EmbeddingBag(10000, 128) + # 10 chunks, 5 in cuda + mgr = CachedParamMgr(model.weight, 5) + assert mgr.cuda_row_num == 5 + + mgr._admit(1) + assert not mgr._chunk_in_cuda(2) + assert mgr._chunk_in_cuda(1) + + # print(mgr.cached_chunk_table) + mgr._admit(8) + + # now 3 chunk is available + assert mgr.cuda_available_chunk_num == 3 + + mgr._evict() + assert mgr.cuda_available_chunk_num == 4 + + mgr._prepare_rows_on_cuda(torch.tensor([9, 6, 5], dtype=torch.long, device=0)) + mgr._prepare_rows_on_cuda(torch.tensor([3, 4, 5], dtype=torch.long, device=0)) + # print(mgr.cached_chunk_table) + # mgr.print_comm_stats() + + mgr.flush() + assert mgr.cuda_available_chunk_num == 5 + + +def test_reorder_with_freq(): + num_embed = 100 + chunk_size = 1 + num_chunk = 5 + + idx_map = np.random.randint(10000, size=(num_embed,)) + sorted_idx = np.flipud(np.argsort(idx_map)).tolist() + chunkid, offset_in_chunk = [], [] + for i in range(num_embed): + idx = sorted_idx.index(i) + chunkid.append(idx // chunk_size) + offset_in_chunk.append(idx % chunk_size) + + chunkid = torch.tensor(chunkid, dtype=torch.long, device=torch.cuda.current_device()) + offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=torch.cuda.current_device()) + + weight = torch.rand(num_embed, 2) + mgr = CachedParamMgr(weight, num_chunk) + + mgr.reorder(idx_map) + + indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=torch.cuda.current_device())) + mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode='floor') + mgr_offsets = torch.remainder(indices, chunk_size) + assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}" + assert torch.allclose(offset_in_chunk, mgr_offsets), \ + f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}" + + +if __name__ == '__main__': + # test_freq_aware_embed() + # test_chunkmgr_admit() + pass