From 64169f3e8f6066ed62cf0a155e56db1424123a67 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 6 Sep 2022 10:41:20 +0800 Subject: [PATCH] [embedding] polish parallel embedding tablewise (#1545) --- .../layers/cache_embedding/__init__.py | 8 +- .../layers/cache_embedding/cache_mgr.py | 16 +- .../cache_embedding/embedding_config.py | 27 +++ ...parallel_freq_aware_embedding_tablewise.py | 200 +++--------------- ...q_aware_embedding_tablewise_split_cache.py | 138 ++++++++++++ tests/test_layers/test_cache_embedding.py | 47 ++-- 6 files changed, 232 insertions(+), 204 deletions(-) create mode 100644 colossalai/nn/parallel/layers/cache_embedding/embedding_config.py create mode 100644 colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise_split_cache.py diff --git a/colossalai/nn/parallel/layers/cache_embedding/__init__.py b/colossalai/nn/parallel/layers/cache_embedding/__init__.py index 15f921968..7f1c72588 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/__init__.py +++ b/colossalai/nn/parallel/layers/cache_embedding/__init__.py @@ -2,8 +2,12 @@ from .cache_mgr import CachedParamMgr, EvictionStrategy from .copyer import LimitBuffIndexCopyer from .freq_aware_embedding import FreqAwareEmbeddingBag from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag -from .parallel_freq_aware_embedding_tablewise import ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelFreqAwareEmbeddingBagTablewiseSpiltCache +from .embedding_config import TablewiseEmbeddingBagConfig +from .parallel_freq_aware_embedding_tablewise import ParallelFreqAwareEmbeddingBagTablewise +from .parallel_freq_aware_embedding_tablewise_split_cache import ParallelFreqAwareEmbeddingBagTablewiseSpiltCache + __all__ = [ 'CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', - 'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', 'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache' + 'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', + 'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache' ] diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index ee7ce0607..fdb120134 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -293,7 +293,7 @@ class CachedParamMgr(torch.nn.Module): Returns: torch.Tensor: indices on the cuda_cached_weight. """ - with record_function("(zhg) get unique indices"): + with record_function("(pre-id) get unique indices"): ids = ids.to(self._cache_dev) cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True) @@ -303,7 +303,7 @@ class CachedParamMgr(torch.nn.Module): f"Please increase cuda_row_num or decrease the training batch size." self.evict_backlist = cpu_row_idxs - with record_function("(zhg) get cpu row idxs"): + with record_function("(pre-id) get cpu row idxs"): comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)] self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs)) @@ -311,18 +311,18 @@ class CachedParamMgr(torch.nn.Module): self.num_write_back_history.append(0) # move sure the cuda rows will not be evicted! - with record_function("(zhg) cache update"): + with record_function("(pre-id) cache update"): self._prepare_rows_on_cuda(comm_cpu_row_idxs) + self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype) - self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype) - - with record_function("(zhg) embed cpu rows idx -> cache gpu row idxs"): + with record_function("(pre-id) embed cpu rows idx -> cache gpu row idxs"): gpu_row_idxs = self._id_to_cached_cuda_id(ids) # update for LFU. if self._evict_strategy == EvictionStrategy.LFU: - unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs] - self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times) + with record_function("(pre-id) lfu cnter updates"): + unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs] + self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times) return gpu_row_idxs diff --git a/colossalai/nn/parallel/layers/cache_embedding/embedding_config.py b/colossalai/nn/parallel/layers/cache_embedding/embedding_config.py new file mode 100644 index 000000000..36e04c833 --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/embedding_config.py @@ -0,0 +1,27 @@ +import torch + + +class TablewiseEmbeddingBagConfig: + ''' + example: + def prepare_tablewise_config(args, cache_ratio, ...): + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] + ... + return embedding_bag_config_list + ''' + + def __init__(self, + num_embeddings: int, + cuda_row_num: int, + assigned_rank: int = 0, + buffer_size=50_000, + ids_freq_mapping=None, + initial_weight: torch.tensor = None, + name: str = ""): + self.num_embeddings = num_embeddings + self.cuda_row_num = cuda_row_num + self.assigned_rank = assigned_rank + self.buffer_size = buffer_size + self.ids_freq_mapping = ids_freq_mapping + self.initial_weight = initial_weight + self.name = name diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py index 9df793828..c0d72fbfc 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py @@ -1,42 +1,14 @@ import torch import torch.distributed as dist -import torch.nn as nn -from torch.profiler import record_function -from typing import List -import abc import torch.nn.functional as F + from .freq_aware_embedding import FreqAwareEmbeddingBag - -from colossalai.tensor import ProcessGroup from .cache_mgr import EvictionStrategy - +from .embedding_config import TablewiseEmbeddingBagConfig +from colossalai.tensor import ProcessGroup from colossalai.nn._ops._utils import dual_all_to_all_tablewise - -class TablewiseEmbeddingBagConfig: - ''' - example: - def prepare_tablewise_config(args, cache_ratio, ...): - embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] - ... - return embedding_bag_config_list - ''' - - def __init__(self, - num_embeddings: int, - cuda_row_num: int, - assigned_rank: int = 0, - buffer_size=50_000, - ids_freq_mapping=None, - initial_weight: torch.tensor = None, - name: str = ""): - self.num_embeddings = num_embeddings - self.cuda_row_num = cuda_row_num - self.assigned_rank = assigned_rank - self.buffer_size = buffer_size - self.ids_freq_mapping = ids_freq_mapping - self.initial_weight = initial_weight - self.name = name +from typing import List class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): @@ -44,6 +16,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): all tables assigned to this class instance are managed by a single FreqAwareEmbeddingBag. Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight. """ + def __init__(self, embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], embedding_dim: int, @@ -98,9 +71,9 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): for table_i, table_num_embeddings in enumerate(self.global_table_num_embeddings_list): if self.rank_of_tables[table_i] == self.rank: self.idx_offset_list.append(offset_cumsum) - else : + else: offset_cumsum += table_num_embeddings - + # prepare list shape for all_to_all output self.embedding_dim_per_rank = [0 for i in range(self.world_size)] for rank in self.rank_of_tables: @@ -112,8 +85,8 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): local_offsets_list: List(torch.Tensor) = [] if per_sample_weights != None: local_per_sample_weights_list: List(torch.Tensor) = [] - - offset_pre_end = 0 # local_offsets trick + + offset_pre_end = 0 # local_offsets trick for i, handle_table in enumerate(self.assigned_table_list): indices_start_position = offsets[batch_size * handle_table] if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): @@ -122,27 +95,29 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): else: indices_end_position = offsets[batch_size * (handle_table + 1)] # 1. local_indices_list: - local_indices_list.append(indices.narrow(0, indices_start_position, indices_end_position - - indices_start_position).sub(self.idx_offset_list[i])) + local_indices_list.append( + indices.narrow(0, indices_start_position, + indices_end_position - indices_start_position).sub(self.idx_offset_list[i])) # 2. local_offsets_list: if i + 1 == len(self.assigned_table_list): # till-the-end special case if not self.include_last_offset: local_offsets = offsets.narrow(0, batch_size * handle_table, - batch_size).add(offset_pre_end - offsets[batch_size * (handle_table)]) - else : - local_offsets = offsets.narrow(0, batch_size * handle_table, - batch_size + 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + batch_size).add(offset_pre_end - offsets[batch_size * + (handle_table)]) + else: + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + + 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) local_offsets_list.append(local_offsets) else: - local_offsets = offsets.narrow(0, batch_size * handle_table, - batch_size + 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + + 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) offset_pre_end = local_offsets[-1] local_offsets_list.append(local_offsets[:-1]) # 3. local_per_sample_weights_list: if per_sample_weights != None: local_per_sample_weights_list.append(per_sample_weights[indices_start_position:indices_end_position]) - + local_indices = torch.cat(local_indices_list, 0) local_offsets = torch.cat(local_offsets_list, 0) local_per_sample_weights = None @@ -150,148 +125,21 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0) with torch.no_grad(): reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices) - + local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, local_per_sample_weights, self.include_last_offset, self.padding_idx) - local_output = torch.cat(local_output.split(batch_size),1) - + local_output = torch.cat(local_output.split(batch_size), 1) + remains = batch_size % self.world_size scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank) if shape_hook is not None: output_full = shape_hook(output_full) return output_full - + def print_comm_stats_(self): self.cache_weight_mgr.print_comm_stats() def element_size(self): return self.weight.element_size() - -class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module): - """ - every table assigned to this class instance is managed by a FreqAwareEmbeddingBag. - """ - - def __init__(self, - embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], - embedding_dim: int, - padding_idx=None, - max_norm=None, - norm_type=2., - scale_grad_by_freq=False, - sparse=False, - mode='mean', - include_last_offset=False, - dtype=None, - device=None, - warmup_ratio=0.7, - pin_weight=False, - evict_strategy: EvictionStrategy = EvictionStrategy.LFU): - super(ParallelFreqAwareEmbeddingBagTablewiseSpiltCache, self).__init__() - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() - self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list] - self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list] - self.global_tables_num = len(embedding_bag_config_list) - self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda() - - self.assigned_table_list: List[int] = [] - for i, rank in enumerate(self.rank_of_tables): - if rank == self.rank: - self.assigned_table_list.append(i) - self.include_last_offset = include_last_offset - self.pg = ProcessGroup(tp_degree=self.world_size) - - # prepare FreqAwareEmbeddingBag list - - self.freq_aware_embedding_bag_list: nn.ModuleList = nn.ModuleList() - for config in embedding_bag_config_list: - if config.assigned_rank != self.rank: - continue - self.freq_aware_embedding_bag_list.append( - FreqAwareEmbeddingBag(num_embeddings=config.num_embeddings, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - _weight=config.initial_weight, - mode=mode, - include_last_offset=include_last_offset, - dtype=dtype, - device=device, - cuda_row_num=config.cuda_row_num, - ids_freq_mapping=config.ids_freq_mapping, - warmup_ratio=warmup_ratio, - buffer_size=config.buffer_size, - pin_weight=pin_weight, - evict_strategy=evict_strategy)) - - # prepare list shape for all_to_all output - self.embedding_dim_per_rank = [0 for i in range(self.world_size)] - for rank in self.rank_of_tables: - self.embedding_dim_per_rank[rank] += embedding_dim - - def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None): - # determine indices to handle - batch_size = (offsets.shape[0]) // self.global_tables_num - local_output_list = [] - for i, handle_table in enumerate(self.assigned_table_list): - with record_function("(tablewise) prepare indices and offsets"): - with record_function("part 1"): - indices_start_position = offsets[batch_size * handle_table] - if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): - # till the end special case - indices_end_position = indices.shape[0] - else: - indices_end_position = offsets[batch_size * (handle_table + 1)] - with record_function("part 2"): - # local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table] - local_indices = indices.narrow(0, indices_start_position, indices_end_position - - indices_start_position).sub(self.global_tables_offsets[handle_table]) - if self.include_last_offset: - # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)] - local_offsets = offsets.narrow(0, batch_size * handle_table, - batch_size + 1).sub(offsets[batch_size * (handle_table)]) - else: - # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)] - local_offsets = offsets.narrow(0, batch_size * handle_table, - batch_size).sub(offsets[batch_size * (handle_table)]) - local_per_sample_weights = None - if per_sample_weights != None: - local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position] - with record_function("(tablewise) tablewise forward"): - local_output_list.append(self.freq_aware_embedding_bag_list[i](local_indices, local_offsets, - local_per_sample_weights)) - - # get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim)) - local_output = torch.cat(local_output_list, 1) - # then concatenate those local_output on the second demension. - # use all_to_all - remains = batch_size % self.world_size - scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] - output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank) - if shape_hook is not None: - output_full = shape_hook(output_full) - return output_full - - def element_size(self): - if len(self.assigned_table_list) == 0: - return 0 - return self.freq_aware_embedding_bag_list[0].cache_weight_mgr.weight.element_size() - - def print_comm_stats_(self): - cuda_to_cpu_elem_num = 0 - cpu_to_cuda_elem_num = 0 - for freq_aware_embedding_bag in self.freq_aware_embedding_bag_list: - cuda_to_cpu_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel - cpu_to_cuda_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel - print( - f"CUDA->CPU num: {cuda_to_cpu_elem_num / 1e6} M elem" - ) - print( - f"CPU->CUDA num: {cpu_to_cuda_elem_num / 1e6} M elem" - ) diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise_split_cache.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise_split_cache.py new file mode 100644 index 000000000..807ab389a --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise_split_cache.py @@ -0,0 +1,138 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.profiler import record_function + +from .freq_aware_embedding import FreqAwareEmbeddingBag + +from colossalai.tensor import ProcessGroup +from colossalai.nn._ops._utils import dual_all_to_all_tablewise +from .embedding_config import TablewiseEmbeddingBagConfig +from .cache_mgr import EvictionStrategy + +from typing import List +import abc + + +class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module): + """ + every table assigned to this class instance is managed by a FreqAwareEmbeddingBag. + """ + + def __init__(self, + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], + embedding_dim: int, + padding_idx=None, + max_norm=None, + norm_type=2., + scale_grad_by_freq=False, + sparse=False, + mode='mean', + include_last_offset=False, + dtype=None, + device=None, + warmup_ratio=0.7, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU): + super(ParallelFreqAwareEmbeddingBagTablewiseSpiltCache, self).__init__() + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list] + self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list] + self.global_tables_num = len(embedding_bag_config_list) + self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda() + + self.assigned_table_list: List[int] = [] + for i, rank in enumerate(self.rank_of_tables): + if rank == self.rank: + self.assigned_table_list.append(i) + self.include_last_offset = include_last_offset + self.pg = ProcessGroup(tp_degree=self.world_size) + + # prepare FreqAwareEmbeddingBag list + + self.freq_aware_embedding_bag_list: nn.ModuleList = nn.ModuleList() + for config in embedding_bag_config_list: + if config.assigned_rank != self.rank: + continue + self.freq_aware_embedding_bag_list.append( + FreqAwareEmbeddingBag(num_embeddings=config.num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=config.initial_weight, + mode=mode, + include_last_offset=include_last_offset, + dtype=dtype, + device=device, + cuda_row_num=config.cuda_row_num, + ids_freq_mapping=config.ids_freq_mapping, + warmup_ratio=warmup_ratio, + buffer_size=config.buffer_size, + pin_weight=pin_weight, + evict_strategy=evict_strategy)) + + # prepare list shape for all_to_all output + self.embedding_dim_per_rank = [0 for i in range(self.world_size)] + for rank in self.rank_of_tables: + self.embedding_dim_per_rank[rank] += embedding_dim + + def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None): + # determine indices to handle + batch_size = (offsets.shape[0]) // self.global_tables_num + local_output_list = [] + for i, handle_table in enumerate(self.assigned_table_list): + with record_function("(tablewise) prepare indices and offsets"): + with record_function("part 1"): + indices_start_position = offsets[batch_size * handle_table] + if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): + # till the end special case + indices_end_position = indices.shape[0] + else: + indices_end_position = offsets[batch_size * (handle_table + 1)] + with record_function("part 2"): + # local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table] + local_indices = indices.narrow(0, indices_start_position, indices_end_position - + indices_start_position).sub(self.global_tables_offsets[handle_table]) + if self.include_last_offset: + # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)] + local_offsets = offsets.narrow(0, batch_size * handle_table, + batch_size + 1).sub(offsets[batch_size * (handle_table)]) + else: + # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)] + local_offsets = offsets.narrow(0, batch_size * handle_table, + batch_size).sub(offsets[batch_size * (handle_table)]) + local_per_sample_weights = None + if per_sample_weights != None: + local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position] + with record_function("(tablewise) tablewise forward"): + local_output_list.append(self.freq_aware_embedding_bag_list[i](local_indices, local_offsets, + local_per_sample_weights)) + + # get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim)) + local_output = torch.cat(local_output_list, 1) + # then concatenate those local_output on the second demension. + # use all_to_all + remains = batch_size % self.world_size + scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] + output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank) + if shape_hook is not None: + output_full = shape_hook(output_full) + return output_full + + def element_size(self): + if len(self.assigned_table_list) == 0: + return 0 + return self.freq_aware_embedding_bag_list[0].cache_weight_mgr.weight.element_size() + + def print_comm_stats_(self): + cuda_to_cpu_elem_num = 0 + cpu_to_cuda_elem_num = 0 + for freq_aware_embedding_bag in self.freq_aware_embedding_bag_list: + cuda_to_cpu_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel + cpu_to_cuda_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel + print(f"CUDA->CPU num: {cuda_to_cpu_elem_num / 1e6} M elem") + print(f"CPU->CUDA num: {cpu_to_cuda_elem_num / 1e6} M elem") diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index 5d92d7820..3f4dcb0d1 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -13,7 +13,7 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \ ColoTensor, ColoTensorSpec from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy, \ - ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelFreqAwareEmbeddingBagTablewiseSpiltCache + ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig from typing import List NUM_EMBED, EMBED_DIM = 10, 8 @@ -209,19 +209,28 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): # initialize weight # 3 feature tables. idx: 0~5, 6~10, 11~17 - weight_tables = torch.rand(18,5) + weight_tables = torch.rand(18, 5) weight_table1 = weight_tables[0:6] weight_table2 = weight_tables[6:11] weight_table3 = weight_tables[11:18] embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] - embedding_bag_config_list.append(TablewiseEmbeddingBagConfig( - num_embeddings=6, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table1.clone().detach().cpu())) - embedding_bag_config_list.append(TablewiseEmbeddingBagConfig( - num_embeddings=5, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table2.clone().detach().cpu())) - embedding_bag_config_list.append(TablewiseEmbeddingBagConfig( - num_embeddings=7, cuda_row_num=4, assigned_rank=1, initial_weight=weight_table3.clone().detach().cpu())) + embedding_bag_config_list.append( + TablewiseEmbeddingBagConfig(num_embeddings=6, + cuda_row_num=4, + assigned_rank=0, + initial_weight=weight_table1.clone().detach().cpu())) + embedding_bag_config_list.append( + TablewiseEmbeddingBagConfig(num_embeddings=5, + cuda_row_num=4, + assigned_rank=0, + initial_weight=weight_table2.clone().detach().cpu())) + embedding_bag_config_list.append( + TablewiseEmbeddingBagConfig(num_embeddings=7, + cuda_row_num=4, + assigned_rank=1, + initial_weight=weight_table3.clone().detach().cpu())) if rank == 0: - _weight = torch.cat([weight_table1, weight_table2],0) + _weight = torch.cat([weight_table1, weight_table2], 0) else: _weight = weight_table3 model = ParallelFreqAwareEmbeddingBagTablewise( @@ -249,30 +258,31 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device) if rank == 0: fake_grad = rand_grad[0:2] - else : + else: fake_grad = rand_grad[2:] res.backward(fake_grad) optimizer.step() optimizer.zero_grad() - # check correctness + # check correctness if rank == 0: ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_tables.detach().clone(), include_last_offset=True, freeze=False).to(device) ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2) - ref_fake_grad = torch.cat(rand_grad.split(5,1),0) + ref_fake_grad = torch.cat(rand_grad.split(5, 1), 0) ref_res = ref_model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device)) ref_res.backward(ref_fake_grad) ref_optimizer.step() ref_optimizer.zero_grad() - + model.cache_weight_mgr.flush() recover_weight = model.cache_weight_mgr.weight.to(device) ref_weight = ref_model.weight.detach()[:11] assert torch.allclose(recover_weight, ref_weight), f"{recover_weight - ref_weight}" + def run_parallel_freq_aware_embed_columnwise(rank, world_size): device = torch.device('cuda', torch.cuda.current_device()) @@ -289,11 +299,12 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size): coloweight.set_process_group(ProcessGroup(tp_degree=world_size)) coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D)) - model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight, - include_last_offset=True, - freeze=False, - cuda_row_num=batch_size * 2, - ) + model = ParallelFreqAwareEmbeddingBag.from_pretrained( + coloweight, + include_last_offset=True, + freeze=False, + cuda_row_num=batch_size * 2, + ) assert model.cache_weight_mgr.weight.device.type == 'cpu' assert model.cache_weight_mgr.cuda_cached_weight.requires_grad