[legacy] move communication and nn to legacy and refactor logger (#4671)

* [legacy] move communication to legacy (#4640)

* [legacy] refactor logger and clean up legacy codes (#4654)

* [legacy] make logger independent to gpc

* [legacy] make optim independent to registry

* [legacy] move test engine to legacy

* [legacy] move nn to legacy (#4656)

* [legacy] move nn to legacy

* [checkpointio] fix save hf config

* [test] remove useledd rpc pp test

* [legacy] fix nn init

* [example] skip tutorial hybriad parallel example

* [devops] test doc check

* [devops] test doc check
This commit is contained in:
Hongxin Liu
2023-09-11 16:24:28 +08:00
committed by GitHub
parent 536397cc95
commit 554aa9592e
170 changed files with 781 additions and 758 deletions

View File

@@ -0,0 +1,21 @@
from .cache_embedding import (
CachedEmbeddingBag,
CachedParamMgr,
EvictionStrategy,
LimitBuffIndexCopyer,
ParallelCachedEmbeddingBag,
ParallelCachedEmbeddingBagTablewise,
ParallelCachedEmbeddingBagTablewiseSpiltCache,
TablewiseEmbeddingBagConfig,
)
from .colo_module import ColoModule
from .embedding import ColoEmbedding
from .linear import ColoLinear
from .module_utils import check_colo_module, get_colo_module, init_colo_module, is_colo_module, register_colo_module
__all__ = [
'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
'ColoLinear', 'ColoEmbedding', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'CachedParamMgr',
'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
'ParallelCachedEmbeddingBagTablewiseSpiltCache'
]

View File

@@ -0,0 +1,13 @@
from .cache_mgr import CachedParamMgr, EvictionStrategy
from .cached_embedding import CachedEmbeddingBag
from .copyer import LimitBuffIndexCopyer
from .embedding_config import TablewiseEmbeddingBagConfig
from .parallel_cached_embedding import ParallelCachedEmbeddingBag
from .parallel_cached_embedding_tablewise import ParallelCachedEmbeddingBagTablewise
from .parallel_cached_embedding_tablewise_split_cache import ParallelCachedEmbeddingBagTablewiseSpiltCache
__all__ = [
'CachedParamMgr', 'LimitBuffIndexCopyer', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'EvictionStrategy',
'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
'ParallelCachedEmbeddingBagTablewiseSpiltCache'
]

View File

@@ -0,0 +1,37 @@
import abc
import torch.nn as nn
class BaseEmbeddingBag(abc.ABC, nn.Module):
def __init__(
self,
num_embeddings,
embedding_dim,
padding_idx=None,
max_norm=None,
norm_type=2.,
scale_grad_by_freq=False,
sparse=False,
mode='mean',
include_last_offset=False,
):
super(BaseEmbeddingBag, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
elif padding_idx < 0:
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
# Specific to embedding bag
self.mode = mode
self.include_last_offset = include_last_offset

View File

@@ -0,0 +1,585 @@
import sys
from contextlib import contextmanager
from enum import Enum
from typing import List, Optional
import numpy as np
import torch
from contexttimer import Timer
from torch.profiler import record_function
from .copyer import LimitBuffIndexCopyer
class EvictionStrategy(Enum):
LFU = 1
# dataset aware eviction strategy
DATASET = 2
def _wait_for_data(t, stream: Optional[torch.cuda.streams.Stream]) -> None:
if stream is None:
return
torch.cuda.current_stream().wait_stream(stream)
# As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html,
# PyTorch uses the "caching allocator" for memory allocation for tensors. When a tensor is
# freed, its memory is likely to be reused by newly constructed tensors. By default,
# this allocator traces whether a tensor is still in use by only the CUDA stream where it
# was created. When a tensor is used by additional CUDA streams, we need to call record_stream
# to tell the allocator about all these streams. Otherwise, the allocator might free the
# underlying memory of the tensor once it is no longer used by the creator stream. This is
# a notable programming trick when we write programs using multi CUDA streams.
cur_stream = torch.cuda.current_stream()
assert isinstance(t, torch.Tensor)
t.record_stream(cur_stream)
class CachedParamMgr(torch.nn.Module):
"""
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
CPU maintains the entire original weight.
CUDA maintains a fraction of the weights used in the upcoming computation. The row number in CUDA is controlled by `cuda_row_num`.
During training, GPU needs to transmit embedding rows between CPU and GPU.
Args:
weight (torch.Tensor): the weight of the Embedding layer.
cuda_row_num (int, optional): the number of rows cached in CUDA memory. Defaults to 0.
buffer_size (int, optional): the number of rows in a data transmitter buffer. Defaults to 50_000.
pin_weight (bool, optional): use pin memory to store the cpu weight. If set `True`, the cpu memory usage will increase largely. Defaults to False.
evict_strategy (EvictionStrategy, optional): the eviction strategy. There are two options.
`EvictionStrategy.LFU`: use the least frequently used cache.
`EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume.
Defaults to EvictionStrategy.DATASET.
"""
def __init__(
self,
weight: torch.Tensor,
cuda_row_num: int = 0,
buffer_size: int = 0,
pin_weight: bool = True,
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
async_copy: bool = False,
) -> None:
super(CachedParamMgr, self).__init__()
self.buffer_size = buffer_size
self.num_embeddings, self.embedding_dim = weight.shape
self.cuda_row_num = cuda_row_num
self._cuda_available_row_num = self.cuda_row_num
self.pin_weight = pin_weight
self.elem_size_in_byte = weight.element_size()
# weight configure
self._init_weight(weight)
# Perf log
self.num_hits_history = []
self.num_miss_history = []
self.num_write_back_history = []
self._evict_strategy = evict_strategy
self._async_copy = async_copy
if self._async_copy:
self._memcpy_stream = torch.cuda.Stream()
print('use async copy')
if self._evict_strategy == EvictionStrategy.LFU:
# cache_row_idx -> frequency, freq of the cache rows.
# classic lfu cache. evict the minimal freq value row in cuda cache.
self.register_buffer("freq_cnter",
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
dtype=torch.long).fill_(sys.maxsize),
persistent=False)
self._elapsed_dict = {}
self._show_cache_miss = True
self._reset_comm_stats()
def _reset_comm_stats(self):
for k in self._elapsed_dict.keys():
self._elapsed_dict[k] = 0
self._cpu_to_cuda_numel = 0
self._cuda_to_cpu_numel = 0
if self._show_cache_miss:
self._cache_miss = 0
self._total_cache = 0
@contextmanager
def timer(self, name):
with Timer() as t:
yield
torch.cuda.synchronize()
if name not in self._elapsed_dict.keys():
self._elapsed_dict[name] = 0
self._elapsed_dict[name] += t.elapsed
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
"""_find_evict_gpu_idxs
Find the gpu idxs to be evicted, according to their freq.
Args:
evict_num (int): how many rows has to be evicted
Returns:
torch.Tensor: a list tensor (1D), contains the gpu_row_idxs.
"""
if self._evict_strategy == EvictionStrategy.LFU:
# find the minimal evict_num freq entries in cached_idx_map
_, evict_gpu_row_idxs = torch.topk(self.freq_cnter, evict_num, largest=False)
return evict_gpu_row_idxs
elif self._evict_strategy == EvictionStrategy.DATASET:
# cached_idx_map itself implies the priority of eviction.
# The value of self.cached_idx_map represents cpu_row_idx.
# The larger it is, the less frequently it will appear in the dataset,
# and the higher its eviction priority will be.
_, evict_gpu_row_idxs = torch.topk(self.cached_idx_map, evict_num, largest=True)
return evict_gpu_row_idxs
else:
raise TypeError
def _init_weight(self, weight):
if self.cuda_row_num > 0:
# Enable cache with introducing auxiliary data structures
self.cuda_cached_weight = torch.nn.Parameter(
torch.zeros(self.cuda_row_num,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=weight.dtype))
# pin memory cpu for higher CPU-GPU copy bandwidth
self.weight = weight.pin_memory() if self.pin_weight else weight
# map original id to new id with respect to frequency
# id -> cpu_row_idx
self.register_buffer(
"idx_map",
torch.arange(self.num_embeddings, dtype=torch.long, device=torch.cuda.current_device()),
persistent=False,
)
# cached_idx_map: gpu_row_idx -> cpu_row_idx
self.register_buffer("cached_idx_map",
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
dtype=torch.long).fill_(-1),
persistent=False)
# cpu_row_id -> gpu_row_idx.
# gpu_row_idx as -1 means cpu_row_id not in CUDA.
self.register_buffer("inverted_cached_idx",
torch.zeros(self.num_embeddings, device=torch.cuda.current_device(),
dtype=torch.long).fill_(-1),
persistent=False)
self.evict_backlist = torch.tensor([], device=torch.cuda.current_device())
# index copy buffer size should less than 10% of cuda weight.
if self.buffer_size > 0:
self.limit_buff_index_copyer = LimitBuffIndexCopyer(self.buffer_size)
else:
# Disable cache so that FreqCacheEmbedding is compatible with vanilla EmbeddingBag
# self.weight = torch.nn.Parameter(weight)
# self.cuda_cached_weight = self.weight
raise NotImplementedError()
def cpu_weight_data(self, row_idx: int) -> torch.Tensor:
"""
access a row of CPU weight.
Args:
row_idx (int): the idx of rows
Returns:
torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D.
"""
return self.weight.data.view(-1).narrow(0,
int(row_idx) * self.embedding_dim,
self.embedding_dim).view(1, self.embedding_dim)
@property
def cuda_available_row_num(self):
return self._cuda_available_row_num
@torch.no_grad()
def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7):
"""reorder
reorder the weight according to ids' frequency in dataset before training.
Execute only once before training, also known as warmup phase.
Note:
If you would like to use the DATASET as the eviction strategy, you must call this function.
Note:
If you are use the LFU as the eviction strategy, you can skip this function. If you still use this function. It will initialize
The frequency in LFU cache using the dataset statistics.
Args:
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
warmup_ratio (float): the amount of chunks preloaded in cuda cache
"""
# reorder phase: reorder the cpu weight according to their freq stats in the target dataset.
# reorder only works for DATASET eviction strategy.
if ids_freq_mapping is not None and not isinstance(ids_freq_mapping, torch.Tensor):
ids_freq_mapping = torch.tensor(ids_freq_mapping)
if self._evict_strategy == EvictionStrategy.DATASET:
if ids_freq_mapping is not None:
tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
sorted_idx = torch.argsort(tmp_idx)
self.idx_map.data.copy_(sorted_idx)
# warmup phase: copy #preload_row_num rows from cpu to gpu.
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
if preload_row_num > 0:
with Timer() as timer:
# extract rows from cpu weight
if self._evict_strategy == EvictionStrategy.LFU and ids_freq_mapping is not None:
freq_value, preload_cpu_ids = torch.topk(ids_freq_mapping, preload_row_num, dim=0, largest=True)
preload_cuda_row_idxs = torch.arange(preload_row_num).cuda()
else:
preload_cpu_ids = torch.arange(preload_row_num)
preload_cuda_row_idxs = preload_cpu_ids.cuda()
if self.buffer_size > 0:
self.limit_buff_index_copyer.index_copy(0,
src_index=preload_cpu_ids,
tgt_index=preload_cuda_row_idxs,
src=self.weight.view(self.num_embeddings, -1),
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
else:
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda()
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs,
preload_rows)
# update auxiliary info
self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.cuda()
self.inverted_cached_idx[preload_cpu_ids] = preload_cuda_row_idxs
self._cuda_available_row_num -= preload_row_num
if self._evict_strategy == EvictionStrategy.LFU:
# if the ids_freq_mapping is not None, we initialize the embedding row's freq value in LFU as its freq in dataset.
if ids_freq_mapping is None:
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0)
else:
self.freq_cnter[preload_cuda_row_idxs] = freq_value.cuda()
print(f'Cache warmup finished cost {timer.elapsed} sec.')
def flush(self):
"""flush all CUDA rows to CPU.
The function is usually called after training finished.
"""
slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1)
row_ids = self.cached_idx_map[slots]
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu()
self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows)
self.cached_idx_map.index_fill_(0, slots, -1)
self.inverted_cached_idx.index_fill_(0, row_ids, -1)
self._cuda_available_row_num += slots.numel()
if self._show_cache_miss:
self._cache_miss = 0
self._total_cache = 0
if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter.fill_(sys.maxsize)
assert self._cuda_available_row_num == self.cuda_row_num
assert torch.all(self.inverted_cached_idx == -1).item()
assert torch.all(self.cached_idx_map == -1).item()
def print_comm_stats(self):
if self._cuda_to_cpu_numel > 0 and "3_evict_out" in self._elapsed_dict:
elapsed = self._elapsed_dict["3_evict_out"]
print(
f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cuda_to_cpu_numel / 1e6} M elem"
)
print(f'cuda_to_cpu_elapse {elapsed} sec')
if self._cpu_to_cuda_numel > 0 and "5_evict_in" in self._elapsed_dict:
elapsed = self._elapsed_dict["5_evict_in"]
print(
f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cpu_to_cuda_numel / 1e6} M elem"
)
print(f'cpu_to_cuda_elapse {elapsed} sec')
for k, v in self._elapsed_dict.items():
print(f'{k}: {v}')
print(f'cache miss ratio {self._cache_miss / self._total_cache}')
@torch.no_grad()
def _id_to_cached_cuda_id(self, ids: torch.Tensor) -> torch.Tensor:
"""
convert ids to indices in self.cuda_cached_weight.
Implemented with parallel operations on GPU.
Args:
ids (torch.Tensor): ids from the dataset
Returns:
torch.Tensor: contains indices in self.cuda_cached_weight
"""
ids = self.idx_map.index_select(0, ids.view(-1))
ret = self.inverted_cached_idx.index_select(0, ids)
return ret
@torch.no_grad()
def prepare_ids(self, ids: torch.Tensor) -> torch.Tensor:
"""
move the cpu embedding rows w.r.t. ids into CUDA memory
Args:
ids (torch.Tensor): the ids to be computed
Returns:
torch.Tensor: indices on the cuda_cached_weight.
"""
torch.cuda.synchronize()
with self.timer("cache_op") as gtimer:
# identify cpu rows to cache
with self.timer("1_identify_cpu_row_idxs") as timer:
with record_function("(cache) get unique indices"):
if self._evict_strategy == EvictionStrategy.LFU:
cpu_row_idxs, repeat_times = torch.unique(ids, return_counts=True)
else:
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True)
assert len(cpu_row_idxs) <= self.cuda_row_num, \
f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \
f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \
f"Please increase cuda_row_num or decrease the training batch size."
self.evict_backlist = cpu_row_idxs
tmp = torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)
comm_cpu_row_idxs = cpu_row_idxs[tmp]
if self._show_cache_miss:
self._cache_miss += torch.sum(repeat_times[tmp])
self._total_cache += ids.numel()
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
self.num_miss_history.append(len(comm_cpu_row_idxs))
self.num_write_back_history.append(0)
# move sure the cuda rows will not be evicted!
with record_function("(cache) prepare_rows_on_cuda"):
with self.timer("prepare_rows_on_cuda") as timer:
self._prepare_rows_on_cuda(comm_cpu_row_idxs)
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
with self.timer("6_update_cache") as timer:
with record_function("6_update_cache"):
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
# update for LFU.
if self._evict_strategy == EvictionStrategy.LFU:
unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs]
self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times)
return gpu_row_idxs
def _row_in_cuda(self, row_id: int) -> bool:
return self.inverted_cached_idx[row_id] != -1
@torch.no_grad()
def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None:
"""prepare rows in cpu_row_idxs on CUDA memory
Args:
cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA
"""
evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num
cpu_row_idxs_copy = cpu_row_idxs.cpu()
# move evict in rows to gpu
if self._async_copy:
if self.buffer_size == 0:
evict_in_rows_gpu = self.weight.view(self.num_embeddings,
-1).index_select(0, cpu_row_idxs_copy).pin_memory()
with torch.cuda.stream(self._memcpy_stream):
evict_in_rows_gpu = evict_in_rows_gpu.to(torch.cuda.current_device(), non_blocking=True)
else:
raise NotImplemented
if evict_num > 0:
with self.timer("2_identify_cuda_row_idxs") as timer:
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
if self._evict_strategy == EvictionStrategy.DATASET:
# mask method.
# set cached_idx_map[invalid_idxs] to -2.
# so those idxs will be sorted to end, therefore not being chosen as victim
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
with self.timer("2_1_find_evict_gpu_idxs") as timer:
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
# move evict out rows to cpu
if self._async_copy:
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
-1).index_select(0, evict_gpu_row_idxs)
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
with torch.cuda.stream(None):
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
elif self._evict_strategy == EvictionStrategy.LFU:
with self.timer("2_1_backup_freqs") as timer:
backup_freqs = self.freq_cnter[invalid_idxs].clone()
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize)
with self.timer("2_2_find_evict_gpu_idxs") as timer:
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
if self._async_copy:
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
-1).index_select(0, evict_gpu_row_idxs)
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
with torch.cuda.stream(None):
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
with self.timer("2_3_revert_freqs") as timer:
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs)
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
with self.timer("3_evict_out") as timer:
if self.buffer_size > 0:
self.limit_buff_index_copyer.index_copy(0,
src_index=evict_gpu_row_idxs,
tgt_index=evict_info.cpu(),
src=self.cuda_cached_weight.view(self.cuda_row_num, -1),
tgt=self.weight.view(self.num_embeddings, -1))
else:
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
# TODO async gpu -> cpu
if self._async_copy:
_wait_for_data(evict_out_rows_cpu, None)
else:
with self.timer("3_1_evict_out_index_select") as timer:
evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num,
-1).index_select(0, evict_gpu_row_idxs)
with self.timer("3_2_evict_out_gpu_to_cpu_copy") as timer:
evict_out_rows_cpu = evict_out_rows_cpu.cpu()
with self.timer("3_2_evict_out_cpu_copy") as timer:
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu)
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
self.inverted_cached_idx.index_fill_(0, evict_info, -1)
# self.freq_cnter.index_fill(0, evict_gpu_row_idxs, sys.maxsize) # unnecessary
self._cuda_available_row_num += evict_num
weight_size = evict_gpu_row_idxs.numel() * self.embedding_dim
self._cuda_to_cpu_numel += weight_size
# print(f"evict embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
# slots of cuda weight to evict in
with self.timer("4_identify_cuda_slot") as timer:
slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[:cpu_row_idxs.numel()]
# TODO wait for optimize
with self.timer("5_evict_in") as timer:
# Here also allocate extra memory on CUDA. #cpu_row_idxs
if self.buffer_size > 0:
self.limit_buff_index_copyer.index_copy(0,
src_index=cpu_row_idxs_copy,
tgt_index=slots,
src=self.weight.view(self.num_embeddings, -1),
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
else:
if self._async_copy:
_wait_for_data(evict_in_rows_gpu, self._memcpy_stream)
else:
with self.timer("5_1_evict_in_index_select") as timer:
# narrow index select to a subset of self.weight
# tmp = torch.narrow(self.weight.view(self.num_embeddings, -1), 0, min(cpu_row_idxs).cpu(), max(cpu_row_idxs) - min(cpu_row_idxs) + 1)
# evict_in_rows_gpu = tmp.index_select(0, cpu_row_idxs_copy - min(cpu_row_idxs).cpu())
evict_in_rows_gpu = self.weight.view(self.num_embeddings,
-1).index_select(0, cpu_row_idxs_copy).pin_memory()
with self.timer("5_2_evict_in_gpu_to_cpu_copy") as timer:
evict_in_rows_gpu = evict_in_rows_gpu.cuda()
with self.timer("5_3_evict_in_index_copy") as timer:
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu)
with self.timer("6_update_cache") as timer:
self.cached_idx_map[slots] = cpu_row_idxs
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slots)
if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter.index_fill_(0, slots, 0)
self._cuda_available_row_num -= cpu_row_idxs.numel()
weight_size = cpu_row_idxs.numel() * self.embedding_dim
self._cpu_to_cuda_numel += weight_size
# print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
def _find_free_cuda_row(self) -> int:
if self._cuda_available_row_num == 0:
return -1
candidates = torch.nonzero(self.cached_idx_map == -1).squeeze(1)
return candidates[0].item()
def _evict(self) -> int:
"""
deprecated
evict one row from cuda to cpu.
Returns:
(int) : the slot id be evicted.
"""
mask = torch.logical_or(torch.isin(self.cached_idx_map, self.evict_backlist), self.cached_idx_map == -1)
buf = self.cached_idx_map[mask].clone()
idx = torch.nonzero(mask).squeeze(1)
self.cached_idx_map.index_fill_(0, idx, -1)
max_row, max_cpu_row_idx = torch.max(self.cached_idx_map, dim=0)
max_gpu_row_idx = self.cached_idx_map[max_cpu_row_idx]
if max_gpu_row_idx == -1:
raise RuntimeError("Can not evict a row")
max_gpu_row_idx = max_gpu_row_idx.item()
max_offset = self.inverted_cached_idx[max_gpu_row_idx]
# recover
self.cached_idx_map.index_copy_(0, idx, buf)
with Timer() as timer:
cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, max_offset * self.embedding_dim,
self.embedding_dim).view(1, self.embedding_dim)
self.cpu_weight_data(max_gpu_row_idx).data.copy_(cuda_tensor)
# update inverted_cached_idx, min_slot_id is evicted from cuda
self.cached_idx_map[max_cpu_row_idx] = -1
if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter[max_cpu_row_idx] = sys.maxsize
self.inverted_cached_idx[max_gpu_row_idx] = -1
self._cuda_available_row_num += 1
self._cuda_to_cpu_numel += self.embedding_dim
# self.num_write_back_history[-1] += 1
return max_cpu_row_idx
@torch.no_grad()
def _admit(self, row_id: int):
"""
deprecated
move in row_id to CUDA
Args:
row_id (int): the id of row to be moved in
"""
# find a free slot in partial cuda weight
slot_id = self._find_free_cuda_row()
if slot_id == -1:
# evict one row
slot_id = self._evict()
slot_offset = slot_id
# copy payload from cpu to cuda
with Timer() as timer:
cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, slot_offset * self.embedding_dim,
self.embedding_dim).view(1, self.embedding_dim)
cuda_tensor.data.copy_(self.cpu_weight_data(row_id))
# update the inverted_cached_idx
self.cached_idx_map[slot_id] = row_id
if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter[slot_id] = 0
self.inverted_cached_idx[row_id] = slot_offset
self._cuda_available_row_num -= 1
self._cpu_to_cuda_numel += self.embedding_dim

View File

@@ -0,0 +1,158 @@
from typing import Iterator, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from .base_embedding import BaseEmbeddingBag
from .cache_mgr import CachedParamMgr, EvictionStrategy
class CachedEmbeddingBag(BaseEmbeddingBag):
"""CachedEmbeddingBag
Cached Embedding. Apply a GPU-based software cache approaches to dynamically manage the embedding table in the CPU and GPU memory space.
It can leverage the id's frequency statistics of the target dataset, by passing a frequency list to param `ids_freq_mapping`.
You can also apply a naive LFU cache eviction strategy by setting `evict_strategy` as EvictionStrategy.LFU.
Args:
num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed “pad”. For a newly constructed EmbeddingBag, the embedding vector at padding_idx will default to all zeros, but can be updated to another value to be used as the padding vector. Note that the embedding vector at padding_idx is excluded from the reduction.
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm
norm_type (str, optional): The p of the p-norm to compute for the max_norm option. Defaults to 2.
scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default False. Note: this option is not supported when mode="max". Defaults to False.
sparse (bool, optional): if True, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for more details regarding sparse gradients. Note: this option is not supported when mode="max".. Defaults to False.
_weight (torch.Tensor, optional): an embedding weight tensor. Concatenate multiple tables in a embedding bag as a single one. Defaults to None.
mode (str, optional): "sum", "mean" or "max". Specifies the way to reduce the bag. "sum" computes the weighted sum, taking per_sample_weights into consideration. "mean" computes the average of the values in the bag, "max" computes the max value over each bag. Default: "mean". Defaults to 'mean'.
include_last_offset (bool, optional): if True, offsets has one additional element, where the last element is equivalent to the size of indices. This matches the CSR format.. Defaults to False.
dtype (torch.dtype, optional): data type of the cpu weight initialization. Defaults to None meaning float32.
device (torch.device, optional): device type to the cpu weight. Defaults to None meaning cpu.
cache_ratio (float, float): cache ratio of the #cuda_weight_row / #cpu_weight_row
ids_freq_mapping (Union[List, torch.Tensor], optional): the frequency of each embedding vector occurs in dataset. Defaults to None.
warmup_ratio (float, optional): the ratio of cuda cache is warmuped with. Defaults to 0.7.
buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, the buffer is not used. Defaults to 0.
pin_weight (bool, optional): pin the cpu weight. Defaults to False.
evict_strategy (EvictionStrategy, optional): evict strategy of the software cache. Defaults to EvictionStrategy.DATASET.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
max_norm: float = None,
norm_type: float = 2.,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[torch.Tensor] = None,
mode: str = 'mean',
include_last_offset: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
cache_ratio: float = 0.01,
ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None,
warmup_ratio: float = 0.7,
buffer_size: int = 0,
pin_weight: bool = False,
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
super(CachedEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
scale_grad_by_freq, sparse, mode, include_last_offset)
assert cache_ratio <= 1.0, f"cache ratio {cache_ratio} must less than 1.0"
self.evict_strategy = evict_strategy
if _weight is None:
_weight = self._weight_alloc(dtype, device)
cuda_row_num = int(num_embeddings * cache_ratio)
# configure weight & cache
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight)
self.cache_op = True
def set_cache_mgr_async_copy(self, flag):
self.cache_weight_mgr._async_copy = flag
def _weight_alloc(self, dtype, device):
weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device)
with torch.no_grad():
weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)
if self.padding_idx is not None:
weight[self.padding_idx].fill_(0)
return weight
def _preprocess(self,
weight,
cuda_row_num: int,
ids_freq_mapping: Optional[List[int]] = None,
warmup_ratio=0.7,
buffer_size=50_000,
pin_weight=False):
"""
Called after initialized.
Reorder the weight rows according to the ids_freq_mapping.
Then, let the weights of the Module be managed by a CachedParamMgr.
Args:
cuda_row_num (int): number of rows can be hosted in CUDA memory
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
warmup_ratio (float): the amount of rows preloaded in cuda cache
"""
self.cache_weight_mgr = CachedParamMgr(weight,
cuda_row_num,
buffer_size,
pin_weight,
evict_strategy=self.evict_strategy)
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None):
if self.cache_op:
with torch.no_grad():
input = self.cache_weight_mgr.prepare_ids(input)
embeddings = F.embedding_bag(input.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset, self.padding_idx)
if shape_hook is not None:
embeddings = shape_hook(embeddings)
return embeddings
@property
def weight(self):
return self.cache_weight_mgr.weight
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
yield 'weight', self.cache_weight_mgr.cuda_cached_weight
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
yield self.cache_weight_mgr.cuda_cached_weight
def set_cache_op(self, cache_op: bool = True):
self.cache_op = cache_op
############################# Perf Log ###################################
@property
def num_hits_history(self):
return self.cache_weight_mgr.num_hits_history
@property
def num_miss_history(self):
return self.cache_weight_mgr.num_miss_history
@property
def num_write_back_history(self):
return self.cache_weight_mgr.num_write_back_history
@property
def swap_in_bandwidth(self):
if self.cache_weight_mgr._cpu_to_cuda_numel > 0:
return self.cache_weight_mgr._cpu_to_cuda_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \
self.cache_weight_mgr._cpu_to_cuda_elapse
else:
return 0
@property
def swap_out_bandwidth(self):
if self.cache_weight_mgr._cuda_to_cpu_numel > 0:
return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \
self.cache_weight_mgr._cuda_to_cpu_elapse
return 0

View File

@@ -0,0 +1,49 @@
import torch
from torch import LongTensor
class LimitBuffIndexCopyer(object):
"""LimitBuffIndexCopyer
Index Copy using limited temp buffer on CUDA.
Args:
size (int): buffer size
"""
def __init__(self, size: int) -> None:
self._buff_size = size
@torch.no_grad()
def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor):
"""copy
src tensor[src_index] -(index_select)-> tmp -(index_copy_)-> tgt tensor [tgt_index]
The valid rows in the src tensor are continuous, while rows in tgt tensor is scattered.
Args:
dim (int): dimension along which to index
src_index (int): indices of src tensor to select from
tgt_index (int): indices of tgt tensor to select from
src (torch.Tensor): the tensor containing values to copy
tgt (torch.Tensor): the tensor to be copied
"""
# tgt.index_copy_(dim, index, src)
assert dim == 0, "only support index_copy on dim 0"
assert tgt.dim() == 2
assert src.dim() == 2
tgt_device = tgt.device
src_device = src.device
assert src_index.numel() == tgt_index.numel()
dim_size = src_index.numel()
src_index = src_index.to(src_device)
for begin_pos in range(0, dim_size, self._buff_size):
cur_len = min(self._buff_size, dim_size - begin_pos)
src_idx_piece = src_index.narrow(0, begin_pos, cur_len)
if src_device.type == 'cpu' and tgt_device.type == 'cuda':
cpu_tmp_buffer = src.index_select(dim, src_idx_piece).pin_memory()
tmp_buffer = torch.empty_like(cpu_tmp_buffer, device=tgt_device)
tmp_buffer.copy_(cpu_tmp_buffer)
else:
tmp_buffer = src.index_select(dim, src_idx_piece).to(tgt_device)
tgt_idx_piece = tgt_index.narrow(0, begin_pos, cur_len)
tgt.index_copy_(dim, tgt_idx_piece, tmp_buffer)

View File

@@ -0,0 +1,27 @@
import torch
class TablewiseEmbeddingBagConfig:
'''
example:
def prepare_tablewise_config(args, cache_ratio, ...):
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
...
return embedding_bag_config_list
'''
def __init__(self,
num_embeddings: int,
cuda_row_num: int,
assigned_rank: int = 0,
buffer_size=50_000,
ids_freq_mapping=None,
initial_weight: torch.tensor = None,
name: str = ""):
self.num_embeddings = num_embeddings
self.cuda_row_num = cuda_row_num
self.assigned_rank = assigned_rank
self.buffer_size = buffer_size
self.ids_freq_mapping = ids_freq_mapping
self.initial_weight = initial_weight
self.name = name

View File

@@ -0,0 +1,142 @@
from typing import Iterator, List, Optional, Tuple
import torch
import torch.nn.functional as F
from colossalai.legacy.nn._ops._utils import dual_all_to_all
from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ComputePattern, ProcessGroup, ShardSpec
from .cache_mgr import CachedParamMgr, EvictionStrategy
from .cached_embedding import CachedEmbeddingBag
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
if world_size == 1:
return 0, embedding_dim, True
assert embedding_dim >= world_size, \
f"Embedding dimension {embedding_dim} must be larger than the world size " \
f"{world_size} of the process group"
chunk_size = embedding_dim // world_size
threshold = embedding_dim % world_size
# if embedding dim is divisible by world size
if threshold == 0:
return rank * chunk_size, (rank + 1) * chunk_size, True
# align with the split strategy of torch.tensor_split
size_list = [chunk_size + 1 if i < threshold else chunk_size for i in range(world_size)]
offset = sum(size_list[:rank])
return offset, offset + size_list[rank], False
class ParallelCachedEmbeddingBag(CachedEmbeddingBag):
def __init__(self,
num_embeddings,
embedding_dim,
padding_idx=None,
max_norm=None,
norm_type=2.,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
mode='mean',
include_last_offset=False,
dtype=None,
device=None,
cache_ratio=0.01,
ids_freq_mapping=None,
warmup_ratio=0.7,
buffer_size=50_000,
pin_weight=False,
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET):
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
self.partition_start_index, self.partition_end_index, divisible = get_partition(
embedding_dim, self.rank, self.world_size)
self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index
super(ParallelCachedEmbeddingBag,
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
warmup_ratio, buffer_size, pin_weight, evict_strategy)
self.cache_op = True
def _weight_alloc(self, dtype, device):
weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype)
with torch.no_grad():
weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)
if self.padding_idx is not None:
weight[self.padding_idx].fill_(0)
colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size),
dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]),
compute_attr=ComputePattern.TP1D)
return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec)
def forward(
self,
indices,
offsets=None,
per_sample_weights=None,
shape_hook=None,
scatter_dim=0,
gather_dim=-1,
):
if self.cache_op:
with torch.no_grad():
indices = self.cache_weight_mgr.prepare_ids(indices)
output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset, self.padding_idx)
if shape_hook is not None:
output_shard = shape_hook(output_shard)
output_full = dual_all_to_all(output_shard,
self.weight.get_process_group(),
scatter_dim=scatter_dim,
gather_dim=gather_dim)
return output_full
def set_cache_op(self, cache_op: bool = True):
self.cache_op = cache_op
@classmethod
def from_pretrained(
cls,
embedding: torch.Tensor,
freeze: bool = True,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.,
scale_grad_by_freq: bool = False,
sparse: bool = False,
mode: str = 'mean',
include_last_offset: bool = False,
cuda_row_num: int = 100_000,
ids_freq_mapping: Optional[List[int]] = None,
warmup_ratio: float = 0.7,
buffer_size: int = 0,
) -> 'ParallelCachedEmbeddingBag':
rows, cols = embedding.shape
embedding_bag = cls(rows,
cols,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
embedding,
mode,
include_last_offset,
cuda_row_num=cuda_row_num,
ids_freq_mapping=ids_freq_mapping,
warmup_ratio=warmup_ratio,
buffer_size=buffer_size)
embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze
return embedding_bag
def print_comm_stats_(self):
self.cache_weight_mgr.print_comm_stats()
def element_size(self):
return self.weight.element_size()

View File

@@ -0,0 +1,199 @@
import time
from typing import List
import torch
import torch.distributed as dist
import torch.nn.functional as F
from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise
from colossalai.tensor import ProcessGroup
from .cache_mgr import EvictionStrategy
from .cached_embedding import CachedEmbeddingBag
from .embedding_config import TablewiseEmbeddingBagConfig
class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag):
"""
all tables assigned to this class instance are managed by a single CachedEmbeddingBag.
Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight.
"""
def __init__(self,
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig],
embedding_dim: int,
padding_idx=None,
max_norm=None,
norm_type=2.,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
mode='mean',
include_last_offset=False,
dtype=None,
device=None,
cache_ratio=0.01,
warmup_ratio=0.7,
buffer_size=50_000,
pin_weight=False,
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list]
self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list]
self.global_tables_num = len(embedding_bag_config_list)
self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda()
self.assigned_table_list: List[int] = []
self.pg = ProcessGroup(tp_degree=self.world_size)
self.num_embeddings = 0
for i, rank in enumerate(self.rank_of_tables):
if rank == self.rank:
self.assigned_table_list.append(i)
self.num_embeddings += self.global_table_num_embeddings_list[i]
self.include_last_offset = include_last_offset
ids_freq_mapping = []
for config in embedding_bag_config_list:
if config.assigned_rank == self.rank:
if config.ids_freq_mapping != None:
ids_freq_mapping.extend(config.ids_freq_mapping)
else:
ids_freq_mapping = None
break
self.cache_ratio = cache_ratio
# table-associate cache
cuda_row_num = int(cache_ratio * self.num_embeddings)
super(ParallelCachedEmbeddingBagTablewise,
self).__init__(self.num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
warmup_ratio, buffer_size, pin_weight, evict_strategy)
# for assigned tables reconnection:
self.idx_offset_list = []
offset_cumsum = 0
for table_i, table_num_embeddings in enumerate(self.global_table_num_embeddings_list):
if self.rank_of_tables[table_i] == self.rank:
self.idx_offset_list.append(offset_cumsum)
else:
offset_cumsum += table_num_embeddings
# prepare list shape for all_to_all output
self.embedding_dim_per_rank = [0 for i in range(self.world_size)]
for rank in self.rank_of_tables:
self.embedding_dim_per_rank[rank] += embedding_dim
self.cache_op = True
def forward(
self,
indices: torch.Tensor,
offsets: torch.Tensor = None,
per_sample_weights=None,
shape_hook=None,
already_split_along_rank=True,
):
if not already_split_along_rank:
# not recommanded. it takes time.
batch_size = (offsets.shape[0]) // self.global_tables_num
local_indices, local_offsets, local_per_sample_weights = self.split_along_rank(
batch_size, indices, offsets, per_sample_weights)
else:
# recommanded.
batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
if self.cache_op:
with torch.no_grad():
indices = self.cache_weight_mgr.prepare_ids(local_indices)
local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
local_per_sample_weights, self.include_last_offset, self.padding_idx)
local_output = torch.cat(local_output.split(batch_size), 1)
remains = batch_size % self.world_size
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
if shape_hook is not None:
output_full = shape_hook(output_full)
return output_full
def split_along_rank(self,
batch_size,
indices: torch.Tensor,
offsets: torch.Tensor = None,
per_sample_weights=None):
'''
if input indices and offsets haven't been splitted along assigned rank, this function will do it.
it takes time. please consider splitting data during batch loading.
'''
local_indices_list: List(torch.Tensor) = []
local_offsets_list: List(torch.Tensor) = []
if per_sample_weights != None:
local_per_sample_weights_list: List(torch.Tensor) = []
offset_pre_end = 0 # local_offsets trick
for i, handle_table in enumerate(self.assigned_table_list):
indices_start_position = offsets[batch_size * handle_table]
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
# till-the-end special case
indices_end_position = indices.shape[0]
else:
indices_end_position = offsets[batch_size * (handle_table + 1)]
# alternative approach: reduce malloc
'''
# 1. local_indices_list:
local_indices = indices.narrow(0, indices_start_position, indices_end_position - indices_start_position)
torch.sub(local_indices, self.idx_offset_list[i], out=local_indices)
local_indices_list.append(local_indices)
# 2. local_offsets_list:
if i + 1 == len(self.assigned_table_list):
# till-the-end special case
if not self.include_last_offset:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size)
else:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1)
torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets)
local_offsets_list.append(local_offsets)
else:
temp_holder = offsets[batch_size * handle_table].item()
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size)
torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets)
offset_pre_end = offsets[batch_size * (handle_table + 1)] + offset_pre_end - temp_holder
local_offsets_list.append(local_offsets)
'''
# 1. local_indices_list:
local_indices_list.append(
indices.narrow(0, indices_start_position,
indices_end_position - indices_start_position).sub(self.idx_offset_list[i]))
# 2. local_offsets_list:
if i + 1 == len(self.assigned_table_list):
# till-the-end special case
if not self.include_last_offset:
local_offsets = offsets.narrow(0, batch_size * handle_table,
batch_size).add(offset_pre_end - offsets[batch_size *
(handle_table)])
else:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size +
1).add(offset_pre_end - offsets[batch_size * (handle_table)])
local_offsets_list.append(local_offsets)
else:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size +
1).add(offset_pre_end - offsets[batch_size * (handle_table)])
offset_pre_end = local_offsets[-1]
local_offsets_list.append(local_offsets[:-1])
# 3. local_per_sample_weights_list:
if per_sample_weights != None:
local_per_sample_weights_list.append(per_sample_weights[indices_start_position:indices_end_position])
local_indices = torch.cat(local_indices_list, 0)
local_offsets = torch.cat(local_offsets_list, 0)
local_per_sample_weights = None
if per_sample_weights != None:
local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0)
return local_indices, local_offsets, local_per_sample_weights
def set_cache_op(self, cache_op: bool = True):
self.cache_op = cache_op
def print_comm_stats_(self):
self.cache_weight_mgr.print_comm_stats()
def element_size(self):
return self.weight.element_size()

View File

@@ -0,0 +1,138 @@
import abc
from typing import List
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.profiler import record_function
from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise
from colossalai.tensor import ProcessGroup
from .cache_mgr import EvictionStrategy
from .cached_embedding import CachedEmbeddingBag
from .embedding_config import TablewiseEmbeddingBagConfig
class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
"""
every table assigned to this class instance is managed by a CachedEmbeddingBag.
"""
def __init__(self,
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig],
embedding_dim: int,
padding_idx=None,
max_norm=None,
norm_type=2.,
scale_grad_by_freq=False,
sparse=False,
mode='mean',
include_last_offset=False,
dtype=None,
device=None,
warmup_ratio=0.7,
pin_weight=False,
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
super(ParallelCachedEmbeddingBagTablewiseSpiltCache, self).__init__()
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list]
self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list]
self.global_tables_num = len(embedding_bag_config_list)
self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda()
self.assigned_table_list: List[int] = []
for i, rank in enumerate(self.rank_of_tables):
if rank == self.rank:
self.assigned_table_list.append(i)
self.include_last_offset = include_last_offset
self.pg = ProcessGroup(tp_degree=self.world_size)
# prepare CachedEmbeddingBag list
self.cached_embedding_bag_list: nn.ModuleList = nn.ModuleList()
for config in embedding_bag_config_list:
if config.assigned_rank != self.rank:
continue
self.cached_embedding_bag_list.append(
CachedEmbeddingBag(num_embeddings=config.num_embeddings,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=config.initial_weight,
mode=mode,
include_last_offset=include_last_offset,
dtype=dtype,
device=device,
cuda_row_num=config.cuda_row_num,
ids_freq_mapping=config.ids_freq_mapping,
warmup_ratio=warmup_ratio,
buffer_size=config.buffer_size,
pin_weight=pin_weight,
evict_strategy=evict_strategy))
# prepare list shape for all_to_all output
self.embedding_dim_per_rank = [0 for i in range(self.world_size)]
for rank in self.rank_of_tables:
self.embedding_dim_per_rank[rank] += embedding_dim
def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None):
# determine indices to handle
batch_size = (offsets.shape[0]) // self.global_tables_num
local_output_list = []
for i, handle_table in enumerate(self.assigned_table_list):
with record_function("(tablewise) prepare indices and offsets"):
with record_function("part 1"):
indices_start_position = offsets[batch_size * handle_table]
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
# till the end special case
indices_end_position = indices.shape[0]
else:
indices_end_position = offsets[batch_size * (handle_table + 1)]
with record_function("part 2"):
# local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table]
local_indices = indices.narrow(0, indices_start_position, indices_end_position -
indices_start_position).sub(self.global_tables_offsets[handle_table])
if self.include_last_offset:
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)]
local_offsets = offsets.narrow(0, batch_size * handle_table,
batch_size + 1).sub(offsets[batch_size * (handle_table)])
else:
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)]
local_offsets = offsets.narrow(0, batch_size * handle_table,
batch_size).sub(offsets[batch_size * (handle_table)])
local_per_sample_weights = None
if per_sample_weights != None:
local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position]
with record_function("(tablewise) tablewise forward"):
local_output_list.append(self.cached_embedding_bag_list[i](local_indices, local_offsets,
local_per_sample_weights))
# get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim))
local_output = torch.cat(local_output_list, 1)
# then concatenate those local_output on the second dimension.
# use all_to_all
remains = batch_size % self.world_size
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
if shape_hook is not None:
output_full = shape_hook(output_full)
return output_full
def element_size(self):
if len(self.assigned_table_list) == 0:
return 0
return self.cached_embedding_bag_list[0].cache_weight_mgr.weight.element_size()
def print_comm_stats_(self):
cuda_to_cpu_elem_num = 0
cpu_to_cuda_elem_num = 0
for cached_embedding_bag in self.cached_embedding_bag_list:
cuda_to_cpu_elem_num += cached_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel
cpu_to_cuda_elem_num += cached_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel
print(f"CUDA->CPU num: {cuda_to_cpu_elem_num / 1e6} M elem")
print(f"CPU->CUDA num: {cpu_to_cuda_elem_num / 1e6} M elem")

View File

@@ -0,0 +1,47 @@
from typing import Dict, List
from colossalai.tensor import ComputePattern
from colossalai.tensor.distspec import _DistSpec
class ColoModule(object):
def __init__(self):
self._shard_params: List[str] = []
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
def _register_shard_params(self, params: List[str]):
self._shard_params = params
def _register_allowed_patterns(self,
compute_pattern: ComputePattern,
dist_specs: Dict[str, _DistSpec],
mode='default'):
assert list(
dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.'
if not compute_pattern in self._allowed_patterns:
self._allowed_patterns[compute_pattern] = {}
self._allowed_patterns[compute_pattern][mode] = dist_specs
def _set_default(self, compute_pattern: ComputePattern, target_mode):
self._allowed_patterns[compute_pattern]['default'] = self._allowed_patterns[compute_pattern][target_mode]
def has_compute_pattern(self, compute_pattern: ComputePattern):
return compute_pattern in self._allowed_patterns
def get_dist_specs(self, compute_pattern: ComputePattern):
assert self.has_compute_pattern(compute_pattern)
return self._allowed_patterns[compute_pattern]
def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode='default'):
return compute_pattern in self._allowed_patterns and mode in self._allowed_patterns[compute_pattern]
def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode='default'):
assert self.has_compute_pattern_with_mode(compute_pattern, mode)
return self._allowed_patterns[compute_pattern][mode]
def get_param_names(self):
return self._shard_params
def register(self, compute_pattern, pg):
raise NotImplementedError

View File

@@ -0,0 +1,37 @@
from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
from .colo_module import ColoModule
class ColoEmbedding(ColoModule):
def __init__(self):
super(ColoEmbedding, self).__init__()
self._register_shard_params(['weight'])
def register(self, compute_pattern, pg: ProcessGroup):
if not compute_pattern in self._allowed_patterns:
if ComputePattern.TP1D == compute_pattern:
self._set_TP1D(pg)
def _set_TP1D(self, pg: ProcessGroup):
# TP1D Row Linear
_compute_pattern = ComputePattern.TP1D
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': ShardSpec([0], [pg.tp_world_size()]),
},
mode='row',
)
# TP1D Col Linear
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': ShardSpec([-1], [pg.tp_world_size()]),
},
mode='col',
)
self._set_default(compute_pattern=_compute_pattern, target_mode='row')

View File

@@ -0,0 +1,39 @@
from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
from .colo_module import ColoModule
class ColoLinear(ColoModule):
def __init__(self):
super(ColoLinear, self).__init__()
self._register_shard_params(['weight', 'bias'])
def register(self, compute_pattern, pg: ProcessGroup):
if not compute_pattern in self._allowed_patterns:
if ComputePattern.TP1D == compute_pattern:
self._set_TP1D(pg)
def _set_TP1D(self, pg):
# TP1D Row Linear
_compute_pattern = ComputePattern.TP1D
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': ShardSpec([-1], [pg.tp_world_size()]),
'bias': None
},
mode='row',
)
# TP1D Col Linear
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': ShardSpec([0], [pg.tp_world_size()]),
'bias': ShardSpec([0], [pg.tp_world_size()])
},
mode='col',
)
self._set_default(compute_pattern=_compute_pattern, target_mode='row')

View File

@@ -0,0 +1,115 @@
from typing import Dict
import torch
from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup, distspec
from . import ColoModule
_COLOSSAL_MODULES: Dict[type, ColoModule] = {}
def register_colo_module(module_type: type, colo_module: ColoModule):
global _COLOSSAL_MODULES
_COLOSSAL_MODULES[module_type] = colo_module
def is_colo_module(module: torch.nn.Module):
global _COLOSSAL_MODULES
for module_type in _COLOSSAL_MODULES.keys():
if isinstance(module, module_type):
return True
return False
def get_colo_module(module: torch.nn.Module):
global _COLOSSAL_MODULES
if is_colo_module(module):
for module_type, colo_module in _COLOSSAL_MODULES.items():
if isinstance(module, module_type):
return colo_module
else:
return None
def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True):
if is_colo_module(module):
colo_module = get_colo_module(module)
param_names = colo_module.get_param_names()
compute_pattern = None
for param_name in param_names:
param = module.get_parameter(param_name)
if not isinstance(param, ColoParameter):
raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.')
if param.has_compute_spec():
cur_compute_pattern = param.compute_spec.compute_pattern
if compute_pattern is None:
compute_pattern = cur_compute_pattern
else:
if cur_compute_pattern != compute_pattern:
raise Exception(
f'Invalid ColoParameter spec: Params in {module} have different compute_pattern.')
else:
continue
if compute_pattern is not None:
colo_module.register(compute_pattern, pg)
if not colo_module.has_compute_pattern(compute_pattern):
raise Exception(
f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.')
match_specs = False
allowed_specs = colo_module.get_dist_specs(compute_pattern)
for _, param_specs in allowed_specs.items():
cur_match = True
for param_name, dist_spec in param_specs.items():
param = module.get_parameter(param_name)
if param.has_compute_spec():
if dist_spec != param.dist_spec:
cur_match = False
break
else:
if dist_spec is not None:
cur_match = False
break
if cur_match == True:
match_specs = True
break
if match_specs == False:
raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.')
if recursive == True:
for submodule in module.children():
check_colo_module(submodule, pg=pg, recursive=True)
def init_colo_module(module: torch.nn.Module,
compute_spec: ComputeSpec,
pg: ProcessGroup,
recursive=True,
mode='default'):
compute_pattern = compute_spec.compute_pattern
if is_colo_module(module):
# for each param
# set its process_group, dist_spec and compute_spec
colo_module = get_colo_module(module)
colo_module.register(compute_pattern, pg)
if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode):
raise NotImplementedError
# a set for modules which update at least one param in the init process.
# these modules need to be checked whether all params still match one of the valid compute pattern.
modules_update_param = {module}
for param_name, dist_spec in colo_module.get_dist_specs_with_mode(compute_pattern, mode=mode).items():
if dist_spec is None:
continue
param = module.get_parameter(param_name)
if isinstance(param, ColoParameter):
param.set_process_group(pg)
param.set_dist_spec(dist_spec)
param.compute_spec = compute_spec
for mod in param.shared_param_modules:
modules_update_param.add(mod)
for mod in modules_update_param:
check_colo_module(mod, pg, recursive=False)
if recursive == True:
for submodule in module.children():
init_colo_module(submodule, compute_spec, pg=pg, recursive=True, mode=mode)