[embedding] cache_embedding small improvement (#1564)

This commit is contained in:
CsRic 2022-09-08 16:41:19 +08:00 committed by GitHub
parent 10dd8226b1
commit a389ac4ec9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 13 deletions

View File

@ -178,7 +178,7 @@ class CachedParamMgr(torch.nn.Module):
"""reorder """reorder
reorder the weight according to ids' frequency in dataset before training. reorder the weight according to ids' frequency in dataset before training.
Execute only once before training, also known as warmup phase. Execute only once before training, also known as warmup phase.
Note: Note:
If you would like to use the DATASET as the eviction strategy, you must call this function. If you would like to use the DATASET as the eviction strategy, you must call this function.
@ -304,7 +304,8 @@ class CachedParamMgr(torch.nn.Module):
self.evict_backlist = cpu_row_idxs self.evict_backlist = cpu_row_idxs
with record_function("(pre-id) get cpu row idxs"): with record_function("(pre-id) get cpu row idxs"):
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)] comm_cpu_row_idxs = cpu_row_idxs[torch.isin(
cpu_row_idxs, self.cached_idx_map, assume_unique=True, invert=True)]
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs)) self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
self.num_miss_history.append(len(comm_cpu_row_idxs)) self.num_miss_history.append(len(comm_cpu_row_idxs))
@ -345,7 +346,7 @@ class CachedParamMgr(torch.nn.Module):
evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num
if evict_num > 0: if evict_num > 0:
with Timer() as timer: with Timer() as timer:
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist, assume_unique=True)
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
if self._evict_strategy == EvictionStrategy.DATASET: if self._evict_strategy == EvictionStrategy.DATASET:
# mask method. # mask method.

View File

@ -75,7 +75,6 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1): def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1):
with torch.no_grad(): with torch.no_grad():
reorder_ids = self.cache_weight_mgr.prepare_ids(indices) reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
output_shard = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, output_shard = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset, self.padding_idx) per_sample_weights, self.include_last_offset, self.padding_idx)
@ -124,6 +123,6 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
def print_comm_stats_(self): def print_comm_stats_(self):
self.cache_weight_mgr.print_comm_stats() self.cache_weight_mgr.print_comm_stats()
def element_size(self): def element_size(self):
return self.weight.element_size() return self.weight.element_size()

View File

@ -87,6 +87,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
local_per_sample_weights_list: List(torch.Tensor) = [] local_per_sample_weights_list: List(torch.Tensor) = []
offset_pre_end = 0 # local_offsets trick offset_pre_end = 0 # local_offsets trick
for i, handle_table in enumerate(self.assigned_table_list): for i, handle_table in enumerate(self.assigned_table_list):
indices_start_position = offsets[batch_size * handle_table] indices_start_position = offsets[batch_size * handle_table]
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
@ -94,6 +95,28 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
indices_end_position = indices.shape[0] indices_end_position = indices.shape[0]
else: else:
indices_end_position = offsets[batch_size * (handle_table + 1)] indices_end_position = offsets[batch_size * (handle_table + 1)]
# alternative approach: reduce malloc
'''
# 1. local_indices_list:
local_indices = indices.narrow(0, indices_start_position, indices_end_position - indices_start_position)
torch.sub(local_indices, self.idx_offset_list[i], out=local_indices)
local_indices_list.append(local_indices)
# 2. local_offsets_list:
if i + 1 == len(self.assigned_table_list):
# till-the-end special case
if not self.include_last_offset:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size)
else:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1)
torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets)
local_offsets_list.append(local_offsets)
else:
temp_holder = offsets[batch_size * handle_table].item()
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size)
torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets)
offset_pre_end = offsets[batch_size * (handle_table + 1)] + offset_pre_end - temp_holder
local_offsets_list.append(local_offsets)
'''
# 1. local_indices_list: # 1. local_indices_list:
local_indices_list.append( local_indices_list.append(
indices.narrow(0, indices_start_position, indices.narrow(0, indices_start_position,
@ -103,21 +126,20 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
# till-the-end special case # till-the-end special case
if not self.include_last_offset: if not self.include_last_offset:
local_offsets = offsets.narrow(0, batch_size * handle_table, local_offsets = offsets.narrow(0, batch_size * handle_table,
batch_size).add(offset_pre_end - offsets[batch_size * batch_size).add(offset_pre_end - offsets[batch_size
(handle_table)]) * (handle_table)])
else: else:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size
1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + 1).add(offset_pre_end - offsets[batch_size * (handle_table)])
local_offsets_list.append(local_offsets) local_offsets_list.append(local_offsets)
else: else:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size
1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + 1).add(offset_pre_end - offsets[batch_size * (handle_table)])
offset_pre_end = local_offsets[-1] offset_pre_end = local_offsets[-1]
local_offsets_list.append(local_offsets[:-1]) local_offsets_list.append(local_offsets[:-1])
# 3. local_per_sample_weights_list: # 3. local_per_sample_weights_list:
if per_sample_weights != None: if per_sample_weights != None:
local_per_sample_weights_list.append(per_sample_weights[indices_start_position:indices_end_position]) local_per_sample_weights_list.append(per_sample_weights[indices_start_position:indices_end_position])
local_indices = torch.cat(local_indices_list, 0) local_indices = torch.cat(local_indices_list, 0)
local_offsets = torch.cat(local_offsets_list, 0) local_offsets = torch.cat(local_offsets_list, 0)
local_per_sample_weights = None local_per_sample_weights = None