From 964123ae0f729262081c2a6b53bba1e1090e1016 Mon Sep 17 00:00:00 2001 From: CsRic <59389055+CsRic@users.noreply.github.com> Date: Mon, 5 Sep 2022 15:12:53 +0800 Subject: [PATCH] [embedding] freq_aware_embedding: add small functions for caller application (#1537) --- colossalai/nn/parallel/layers/__init__.py | 5 +- .../layers/cache_embedding/__init__.py | 4 +- .../parallel_freq_aware_embedding.py | 6 + ...parallel_freq_aware_embedding_tablewise.py | 201 ++++++++++++++++-- tests/test_layers/test_cache_embedding.py | 44 ++-- 5 files changed, 214 insertions(+), 46 deletions(-) diff --git a/colossalai/nn/parallel/layers/__init__.py b/colossalai/nn/parallel/layers/__init__.py index ee20fc65b..9e1777fa4 100644 --- a/colossalai/nn/parallel/layers/__init__.py +++ b/colossalai/nn/parallel/layers/__init__.py @@ -4,10 +4,11 @@ from .embedding import ColoEmbedding from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy, \ - ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig + ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelFreqAwareEmbeddingBagTablewiseSpiltCache __all__ = [ 'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', 'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr', - 'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig' + 'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', + 'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache' ] diff --git a/colossalai/nn/parallel/layers/cache_embedding/__init__.py b/colossalai/nn/parallel/layers/cache_embedding/__init__.py index 1622f848c..15f921968 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/__init__.py +++ b/colossalai/nn/parallel/layers/cache_embedding/__init__.py @@ -2,8 +2,8 @@ from .cache_mgr import CachedParamMgr, EvictionStrategy from .copyer import LimitBuffIndexCopyer from .freq_aware_embedding import FreqAwareEmbeddingBag from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag -from .parallel_freq_aware_embedding_tablewise import ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig +from .parallel_freq_aware_embedding_tablewise import ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelFreqAwareEmbeddingBagTablewiseSpiltCache __all__ = [ 'CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', - 'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig' + 'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', 'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache' ] diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py index 5c2f65b76..e53b126b7 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py @@ -121,3 +121,9 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): buffer_size=buffer_size) embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze return embedding_bag + + def print_comm_stats_(self): + self.cache_weight_mgr.print_comm_stats() + + def element_size(self): + return self.weight.element_size() \ No newline at end of file diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py index 35faa67b5..9df793828 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py @@ -1,9 +1,10 @@ import torch import torch.distributed as dist import torch.nn as nn +from torch.profiler import record_function from typing import List import abc - +import torch.nn.functional as F from .freq_aware_embedding import FreqAwareEmbeddingBag from colossalai.tensor import ProcessGroup @@ -38,7 +39,137 @@ class TablewiseEmbeddingBagConfig: self.name = name -class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module): +class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): + """ + all tables assigned to this class instance are managed by a single FreqAwareEmbeddingBag. + Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight. + """ + def __init__(self, + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], + embedding_dim: int, + padding_idx=None, + max_norm=None, + norm_type=2., + scale_grad_by_freq=False, + sparse=False, + _weight=None, + mode='mean', + include_last_offset=False, + dtype=None, + device=None, + cuda_row_num=0, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list] + self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list] + self.global_tables_num = len(embedding_bag_config_list) + self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda() + self.assigned_table_list: List[int] = [] + self.pg = ProcessGroup(tp_degree=self.world_size) + self.num_embeddings = 0 + for i, rank in enumerate(self.rank_of_tables): + if rank == self.rank: + self.assigned_table_list.append(i) + self.num_embeddings += self.global_table_num_embeddings_list[i] + self.include_last_offset = include_last_offset + + ids_freq_mapping = [] + for config in embedding_bag_config_list: + if config.assigned_rank == self.rank: + if config.ids_freq_mapping != None: + ids_freq_mapping.extend(config.ids_freq_mapping) + else: + ids_freq_mapping = None + break + + # table-associate cache + super(ParallelFreqAwareEmbeddingBagTablewise, + self).__init__(self.num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, + sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping, + warmup_ratio, buffer_size, pin_weight, evict_strategy) + + # for assigned tables reconnection: + self.idx_offset_list = [] + offset_cumsum = 0 + for table_i, table_num_embeddings in enumerate(self.global_table_num_embeddings_list): + if self.rank_of_tables[table_i] == self.rank: + self.idx_offset_list.append(offset_cumsum) + else : + offset_cumsum += table_num_embeddings + + # prepare list shape for all_to_all output + self.embedding_dim_per_rank = [0 for i in range(self.world_size)] + for rank in self.rank_of_tables: + self.embedding_dim_per_rank[rank] += embedding_dim + + def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None): + batch_size = (offsets.shape[0]) // self.global_tables_num + local_indices_list: List(torch.Tensor) = [] + local_offsets_list: List(torch.Tensor) = [] + if per_sample_weights != None: + local_per_sample_weights_list: List(torch.Tensor) = [] + + offset_pre_end = 0 # local_offsets trick + for i, handle_table in enumerate(self.assigned_table_list): + indices_start_position = offsets[batch_size * handle_table] + if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): + # till-the-end special case + indices_end_position = indices.shape[0] + else: + indices_end_position = offsets[batch_size * (handle_table + 1)] + # 1. local_indices_list: + local_indices_list.append(indices.narrow(0, indices_start_position, indices_end_position + - indices_start_position).sub(self.idx_offset_list[i])) + # 2. local_offsets_list: + if i + 1 == len(self.assigned_table_list): + # till-the-end special case + if not self.include_last_offset: + local_offsets = offsets.narrow(0, batch_size * handle_table, + batch_size).add(offset_pre_end - offsets[batch_size * (handle_table)]) + else : + local_offsets = offsets.narrow(0, batch_size * handle_table, + batch_size + 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + local_offsets_list.append(local_offsets) + else: + local_offsets = offsets.narrow(0, batch_size * handle_table, + batch_size + 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + offset_pre_end = local_offsets[-1] + local_offsets_list.append(local_offsets[:-1]) + # 3. local_per_sample_weights_list: + if per_sample_weights != None: + local_per_sample_weights_list.append(per_sample_weights[indices_start_position:indices_end_position]) + + local_indices = torch.cat(local_indices_list, 0) + local_offsets = torch.cat(local_offsets_list, 0) + local_per_sample_weights = None + if per_sample_weights != None: + local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0) + with torch.no_grad(): + reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices) + + local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, + self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, + local_per_sample_weights, self.include_last_offset, self.padding_idx) + local_output = torch.cat(local_output.split(batch_size),1) + + remains = batch_size % self.world_size + scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] + output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank) + if shape_hook is not None: + output_full = shape_hook(output_full) + return output_full + + def print_comm_stats_(self): + self.cache_weight_mgr.print_comm_stats() + + def element_size(self): + return self.weight.element_size() + +class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module): """ every table assigned to this class instance is managed by a FreqAwareEmbeddingBag. """ @@ -58,7 +189,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module): warmup_ratio=0.7, pin_weight=False, evict_strategy: EvictionStrategy = EvictionStrategy.LFU): - super(ParallelFreqAwareEmbeddingBagTablewise, self).__init__() + super(ParallelFreqAwareEmbeddingBagTablewiseSpiltCache, self).__init__() self.rank = dist.get_rank() self.world_size = dist.get_world_size() self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list] @@ -109,26 +240,32 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module): batch_size = (offsets.shape[0]) // self.global_tables_num local_output_list = [] for i, handle_table in enumerate(self.assigned_table_list): - indices_start_position = offsets[batch_size * handle_table] - if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): - # till the end special case - indices_end_position = indices.shape[0] - else: - indices_end_position = offsets[batch_size * (handle_table + 1)] - - local_indices = indices[indices_start_position:indices_end_position] - \ - self.global_tables_offsets[handle_table] - if self.include_last_offset: - local_offsets = offsets[batch_size * handle_table:batch_size * - (handle_table + 1) + 1] - offsets[batch_size * (handle_table)] - else: - local_offsets = offsets[batch_size * handle_table:batch_size * - (handle_table + 1)] - offsets[batch_size * (handle_table)] - local_per_sample_weights = None - if per_sample_weights != None: - local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position] - local_output_list.append(self.freq_aware_embedding_bag_list[i](local_indices, local_offsets, - local_per_sample_weights)) + with record_function("(tablewise) prepare indices and offsets"): + with record_function("part 1"): + indices_start_position = offsets[batch_size * handle_table] + if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): + # till the end special case + indices_end_position = indices.shape[0] + else: + indices_end_position = offsets[batch_size * (handle_table + 1)] + with record_function("part 2"): + # local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table] + local_indices = indices.narrow(0, indices_start_position, indices_end_position + - indices_start_position).sub(self.global_tables_offsets[handle_table]) + if self.include_last_offset: + # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)] + local_offsets = offsets.narrow(0, batch_size * handle_table, + batch_size + 1).sub(offsets[batch_size * (handle_table)]) + else: + # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)] + local_offsets = offsets.narrow(0, batch_size * handle_table, + batch_size).sub(offsets[batch_size * (handle_table)]) + local_per_sample_weights = None + if per_sample_weights != None: + local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position] + with record_function("(tablewise) tablewise forward"): + local_output_list.append(self.freq_aware_embedding_bag_list[i](local_indices, local_offsets, + local_per_sample_weights)) # get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim)) local_output = torch.cat(local_output_list, 1) @@ -140,3 +277,21 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module): if shape_hook is not None: output_full = shape_hook(output_full) return output_full + + def element_size(self): + if len(self.assigned_table_list) == 0: + return 0 + return self.freq_aware_embedding_bag_list[0].cache_weight_mgr.weight.element_size() + + def print_comm_stats_(self): + cuda_to_cpu_elem_num = 0 + cpu_to_cuda_elem_num = 0 + for freq_aware_embedding_bag in self.freq_aware_embedding_bag_list: + cuda_to_cpu_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel + cpu_to_cuda_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel + print( + f"CUDA->CPU num: {cuda_to_cpu_elem_num / 1e6} M elem" + ) + print( + f"CPU->CUDA num: {cpu_to_cuda_elem_num / 1e6} M elem" + ) diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index 2a398719e..5d92d7820 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -13,7 +13,7 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \ ColoTensor, ColoTensorSpec from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy, \ - ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig + ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelFreqAwareEmbeddingBagTablewiseSpiltCache from typing import List NUM_EMBED, EMBED_DIM = 10, 8 @@ -209,9 +209,10 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): # initialize weight # 3 feature tables. idx: 0~5, 6~10, 11~17 - weight_table1 = torch.rand(6, 5) - weight_table2 = torch.rand(5, 5) - weight_table3 = torch.rand(7, 5) + weight_tables = torch.rand(18,5) + weight_table1 = weight_tables[0:6] + weight_table2 = weight_tables[6:11] + weight_table3 = weight_tables[11:18] embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] embedding_bag_config_list.append(TablewiseEmbeddingBagConfig( num_embeddings=6, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table1.clone().detach().cpu())) @@ -219,14 +220,20 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): num_embeddings=5, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table2.clone().detach().cpu())) embedding_bag_config_list.append(TablewiseEmbeddingBagConfig( num_embeddings=7, cuda_row_num=4, assigned_rank=1, initial_weight=weight_table3.clone().detach().cpu())) - + if rank == 0: + _weight = torch.cat([weight_table1, weight_table2],0) + else: + _weight = weight_table3 model = ParallelFreqAwareEmbeddingBagTablewise( embedding_bag_config_list, embedding_dim=5, + _weight=_weight, + include_last_offset=True, + cuda_row_num=8, + buffer_size=0, evict_strategy=EvictionStrategy.LFU, - include_last_offset=True ) - # demo explain: + # explain ''' batch feature 1 feature 2 feature 3 input0 [1,2,3] [6,7] [] @@ -244,28 +251,27 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): fake_grad = rand_grad[0:2] else : fake_grad = rand_grad[2:] - res.backward(fake_grad) optimizer.step() optimizer.zero_grad() - # check correctness on weight_table2 + # check correctness if rank == 0: - ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_table2.detach().clone(), + ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_tables.detach().clone(), include_last_offset=True, freeze=False).to(device) ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2) - ref_grad = rand_grad[:, 5:10] - ref_res = ref_model(torch.tensor([0, 1, 3, 0, 2], device=device), torch.tensor([0, 2, 3, 5], device=device)) - ref_res.backward(ref_grad) + ref_fake_grad = torch.cat(rand_grad.split(5,1),0) + ref_res = ref_model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), + torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device)) + ref_res.backward(ref_fake_grad) ref_optimizer.step() ref_optimizer.zero_grad() - - model.freq_aware_embedding_bag_list[1].cache_weight_mgr.flush() # update cpu weight - recover_weight = model.freq_aware_embedding_bag_list[1].cache_weight_mgr.weight - assert torch.allclose(recover_weight, ref_model.weight.detach().cpu() - ), f"{recover_weight - ref_model.weight.detach().cpu()}" - + + model.cache_weight_mgr.flush() + recover_weight = model.cache_weight_mgr.weight.to(device) + ref_weight = ref_model.weight.detach()[:11] + assert torch.allclose(recover_weight, ref_weight), f"{recover_weight - ref_weight}" def run_parallel_freq_aware_embed_columnwise(rank, world_size): device = torch.device('cuda', torch.cuda.current_device())