mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[legacy] move communication and nn to legacy and refactor logger (#4671)
* [legacy] move communication to legacy (#4640) * [legacy] refactor logger and clean up legacy codes (#4654) * [legacy] make logger independent to gpc * [legacy] make optim independent to registry * [legacy] move test engine to legacy * [legacy] move nn to legacy (#4656) * [legacy] move nn to legacy * [checkpointio] fix save hf config * [test] remove useledd rpc pp test * [legacy] fix nn init * [example] skip tutorial hybriad parallel example * [devops] test doc check * [devops] test doc check
This commit is contained in:
21
colossalai/legacy/nn/parallel/layers/__init__.py
Normal file
21
colossalai/legacy/nn/parallel/layers/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from .cache_embedding import (
|
||||
CachedEmbeddingBag,
|
||||
CachedParamMgr,
|
||||
EvictionStrategy,
|
||||
LimitBuffIndexCopyer,
|
||||
ParallelCachedEmbeddingBag,
|
||||
ParallelCachedEmbeddingBagTablewise,
|
||||
ParallelCachedEmbeddingBagTablewiseSpiltCache,
|
||||
TablewiseEmbeddingBagConfig,
|
||||
)
|
||||
from .colo_module import ColoModule
|
||||
from .embedding import ColoEmbedding
|
||||
from .linear import ColoLinear
|
||||
from .module_utils import check_colo_module, get_colo_module, init_colo_module, is_colo_module, register_colo_module
|
||||
|
||||
__all__ = [
|
||||
'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
|
||||
'ColoLinear', 'ColoEmbedding', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'CachedParamMgr',
|
||||
'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
|
||||
'ParallelCachedEmbeddingBagTablewiseSpiltCache'
|
||||
]
|
@@ -0,0 +1,13 @@
|
||||
from .cache_mgr import CachedParamMgr, EvictionStrategy
|
||||
from .cached_embedding import CachedEmbeddingBag
|
||||
from .copyer import LimitBuffIndexCopyer
|
||||
from .embedding_config import TablewiseEmbeddingBagConfig
|
||||
from .parallel_cached_embedding import ParallelCachedEmbeddingBag
|
||||
from .parallel_cached_embedding_tablewise import ParallelCachedEmbeddingBagTablewise
|
||||
from .parallel_cached_embedding_tablewise_split_cache import ParallelCachedEmbeddingBagTablewiseSpiltCache
|
||||
|
||||
__all__ = [
|
||||
'CachedParamMgr', 'LimitBuffIndexCopyer', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'EvictionStrategy',
|
||||
'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
|
||||
'ParallelCachedEmbeddingBagTablewiseSpiltCache'
|
||||
]
|
@@ -0,0 +1,37 @@
|
||||
import abc
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class BaseEmbeddingBag(abc.ABC, nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=None,
|
||||
max_norm=None,
|
||||
norm_type=2.,
|
||||
scale_grad_by_freq=False,
|
||||
sparse=False,
|
||||
mode='mean',
|
||||
include_last_offset=False,
|
||||
):
|
||||
super(BaseEmbeddingBag, self).__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
if padding_idx is not None:
|
||||
if padding_idx > 0:
|
||||
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
||||
elif padding_idx < 0:
|
||||
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
||||
padding_idx = self.num_embeddings + padding_idx
|
||||
self.padding_idx = padding_idx
|
||||
self.max_norm = max_norm
|
||||
self.norm_type = norm_type
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
self.sparse = sparse
|
||||
|
||||
# Specific to embedding bag
|
||||
self.mode = mode
|
||||
self.include_last_offset = include_last_offset
|
@@ -0,0 +1,585 @@
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from contexttimer import Timer
|
||||
from torch.profiler import record_function
|
||||
|
||||
from .copyer import LimitBuffIndexCopyer
|
||||
|
||||
|
||||
class EvictionStrategy(Enum):
|
||||
LFU = 1
|
||||
# dataset aware eviction strategy
|
||||
DATASET = 2
|
||||
|
||||
|
||||
def _wait_for_data(t, stream: Optional[torch.cuda.streams.Stream]) -> None:
|
||||
if stream is None:
|
||||
return
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
# As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html,
|
||||
# PyTorch uses the "caching allocator" for memory allocation for tensors. When a tensor is
|
||||
# freed, its memory is likely to be reused by newly constructed tensors. By default,
|
||||
# this allocator traces whether a tensor is still in use by only the CUDA stream where it
|
||||
# was created. When a tensor is used by additional CUDA streams, we need to call record_stream
|
||||
# to tell the allocator about all these streams. Otherwise, the allocator might free the
|
||||
# underlying memory of the tensor once it is no longer used by the creator stream. This is
|
||||
# a notable programming trick when we write programs using multi CUDA streams.
|
||||
cur_stream = torch.cuda.current_stream()
|
||||
assert isinstance(t, torch.Tensor)
|
||||
t.record_stream(cur_stream)
|
||||
|
||||
|
||||
class CachedParamMgr(torch.nn.Module):
|
||||
"""
|
||||
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
|
||||
CPU maintains the entire original weight.
|
||||
CUDA maintains a fraction of the weights used in the upcoming computation. The row number in CUDA is controlled by `cuda_row_num`.
|
||||
During training, GPU needs to transmit embedding rows between CPU and GPU.
|
||||
Args:
|
||||
weight (torch.Tensor): the weight of the Embedding layer.
|
||||
cuda_row_num (int, optional): the number of rows cached in CUDA memory. Defaults to 0.
|
||||
buffer_size (int, optional): the number of rows in a data transmitter buffer. Defaults to 50_000.
|
||||
pin_weight (bool, optional): use pin memory to store the cpu weight. If set `True`, the cpu memory usage will increase largely. Defaults to False.
|
||||
evict_strategy (EvictionStrategy, optional): the eviction strategy. There are two options.
|
||||
`EvictionStrategy.LFU`: use the least frequently used cache.
|
||||
`EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume.
|
||||
Defaults to EvictionStrategy.DATASET.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: torch.Tensor,
|
||||
cuda_row_num: int = 0,
|
||||
buffer_size: int = 0,
|
||||
pin_weight: bool = True,
|
||||
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
|
||||
async_copy: bool = False,
|
||||
) -> None:
|
||||
super(CachedParamMgr, self).__init__()
|
||||
self.buffer_size = buffer_size
|
||||
self.num_embeddings, self.embedding_dim = weight.shape
|
||||
self.cuda_row_num = cuda_row_num
|
||||
self._cuda_available_row_num = self.cuda_row_num
|
||||
self.pin_weight = pin_weight
|
||||
self.elem_size_in_byte = weight.element_size()
|
||||
|
||||
# weight configure
|
||||
self._init_weight(weight)
|
||||
|
||||
# Perf log
|
||||
self.num_hits_history = []
|
||||
self.num_miss_history = []
|
||||
self.num_write_back_history = []
|
||||
|
||||
self._evict_strategy = evict_strategy
|
||||
|
||||
self._async_copy = async_copy
|
||||
|
||||
if self._async_copy:
|
||||
self._memcpy_stream = torch.cuda.Stream()
|
||||
|
||||
print('use async copy')
|
||||
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
# cache_row_idx -> frequency, freq of the cache rows.
|
||||
# classic lfu cache. evict the minimal freq value row in cuda cache.
|
||||
self.register_buffer("freq_cnter",
|
||||
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
|
||||
dtype=torch.long).fill_(sys.maxsize),
|
||||
persistent=False)
|
||||
self._elapsed_dict = {}
|
||||
self._show_cache_miss = True
|
||||
self._reset_comm_stats()
|
||||
|
||||
def _reset_comm_stats(self):
|
||||
for k in self._elapsed_dict.keys():
|
||||
self._elapsed_dict[k] = 0
|
||||
|
||||
self._cpu_to_cuda_numel = 0
|
||||
self._cuda_to_cpu_numel = 0
|
||||
if self._show_cache_miss:
|
||||
self._cache_miss = 0
|
||||
self._total_cache = 0
|
||||
|
||||
@contextmanager
|
||||
def timer(self, name):
|
||||
with Timer() as t:
|
||||
yield
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if name not in self._elapsed_dict.keys():
|
||||
self._elapsed_dict[name] = 0
|
||||
self._elapsed_dict[name] += t.elapsed
|
||||
|
||||
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
|
||||
"""_find_evict_gpu_idxs
|
||||
Find the gpu idxs to be evicted, according to their freq.
|
||||
Args:
|
||||
evict_num (int): how many rows has to be evicted
|
||||
Returns:
|
||||
torch.Tensor: a list tensor (1D), contains the gpu_row_idxs.
|
||||
"""
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
# find the minimal evict_num freq entries in cached_idx_map
|
||||
_, evict_gpu_row_idxs = torch.topk(self.freq_cnter, evict_num, largest=False)
|
||||
return evict_gpu_row_idxs
|
||||
elif self._evict_strategy == EvictionStrategy.DATASET:
|
||||
# cached_idx_map itself implies the priority of eviction.
|
||||
# The value of self.cached_idx_map represents cpu_row_idx.
|
||||
# The larger it is, the less frequently it will appear in the dataset,
|
||||
# and the higher its eviction priority will be.
|
||||
_, evict_gpu_row_idxs = torch.topk(self.cached_idx_map, evict_num, largest=True)
|
||||
return evict_gpu_row_idxs
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def _init_weight(self, weight):
|
||||
if self.cuda_row_num > 0:
|
||||
# Enable cache with introducing auxiliary data structures
|
||||
self.cuda_cached_weight = torch.nn.Parameter(
|
||||
torch.zeros(self.cuda_row_num,
|
||||
self.embedding_dim,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=weight.dtype))
|
||||
|
||||
# pin memory cpu for higher CPU-GPU copy bandwidth
|
||||
self.weight = weight.pin_memory() if self.pin_weight else weight
|
||||
# map original id to new id with respect to frequency
|
||||
# id -> cpu_row_idx
|
||||
self.register_buffer(
|
||||
"idx_map",
|
||||
torch.arange(self.num_embeddings, dtype=torch.long, device=torch.cuda.current_device()),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
# cached_idx_map: gpu_row_idx -> cpu_row_idx
|
||||
self.register_buffer("cached_idx_map",
|
||||
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
|
||||
dtype=torch.long).fill_(-1),
|
||||
persistent=False)
|
||||
|
||||
# cpu_row_id -> gpu_row_idx.
|
||||
# gpu_row_idx as -1 means cpu_row_id not in CUDA.
|
||||
self.register_buffer("inverted_cached_idx",
|
||||
torch.zeros(self.num_embeddings, device=torch.cuda.current_device(),
|
||||
dtype=torch.long).fill_(-1),
|
||||
persistent=False)
|
||||
|
||||
self.evict_backlist = torch.tensor([], device=torch.cuda.current_device())
|
||||
|
||||
# index copy buffer size should less than 10% of cuda weight.
|
||||
if self.buffer_size > 0:
|
||||
self.limit_buff_index_copyer = LimitBuffIndexCopyer(self.buffer_size)
|
||||
|
||||
else:
|
||||
# Disable cache so that FreqCacheEmbedding is compatible with vanilla EmbeddingBag
|
||||
# self.weight = torch.nn.Parameter(weight)
|
||||
# self.cuda_cached_weight = self.weight
|
||||
raise NotImplementedError()
|
||||
|
||||
def cpu_weight_data(self, row_idx: int) -> torch.Tensor:
|
||||
"""
|
||||
access a row of CPU weight.
|
||||
Args:
|
||||
row_idx (int): the idx of rows
|
||||
Returns:
|
||||
torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D.
|
||||
"""
|
||||
|
||||
return self.weight.data.view(-1).narrow(0,
|
||||
int(row_idx) * self.embedding_dim,
|
||||
self.embedding_dim).view(1, self.embedding_dim)
|
||||
|
||||
@property
|
||||
def cuda_available_row_num(self):
|
||||
return self._cuda_available_row_num
|
||||
|
||||
@torch.no_grad()
|
||||
def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7):
|
||||
"""reorder
|
||||
reorder the weight according to ids' frequency in dataset before training.
|
||||
Execute only once before training, also known as warmup phase.
|
||||
|
||||
Note:
|
||||
If you would like to use the DATASET as the eviction strategy, you must call this function.
|
||||
Note:
|
||||
If you are use the LFU as the eviction strategy, you can skip this function. If you still use this function. It will initialize
|
||||
The frequency in LFU cache using the dataset statistics.
|
||||
Args:
|
||||
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
|
||||
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
||||
"""
|
||||
# reorder phase: reorder the cpu weight according to their freq stats in the target dataset.
|
||||
# reorder only works for DATASET eviction strategy.
|
||||
|
||||
if ids_freq_mapping is not None and not isinstance(ids_freq_mapping, torch.Tensor):
|
||||
ids_freq_mapping = torch.tensor(ids_freq_mapping)
|
||||
|
||||
if self._evict_strategy == EvictionStrategy.DATASET:
|
||||
if ids_freq_mapping is not None:
|
||||
tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
|
||||
sorted_idx = torch.argsort(tmp_idx)
|
||||
self.idx_map.data.copy_(sorted_idx)
|
||||
|
||||
# warmup phase: copy #preload_row_num rows from cpu to gpu.
|
||||
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
|
||||
if preload_row_num > 0:
|
||||
with Timer() as timer:
|
||||
# extract rows from cpu weight
|
||||
if self._evict_strategy == EvictionStrategy.LFU and ids_freq_mapping is not None:
|
||||
freq_value, preload_cpu_ids = torch.topk(ids_freq_mapping, preload_row_num, dim=0, largest=True)
|
||||
preload_cuda_row_idxs = torch.arange(preload_row_num).cuda()
|
||||
else:
|
||||
preload_cpu_ids = torch.arange(preload_row_num)
|
||||
preload_cuda_row_idxs = preload_cpu_ids.cuda()
|
||||
if self.buffer_size > 0:
|
||||
self.limit_buff_index_copyer.index_copy(0,
|
||||
src_index=preload_cpu_ids,
|
||||
tgt_index=preload_cuda_row_idxs,
|
||||
src=self.weight.view(self.num_embeddings, -1),
|
||||
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
||||
else:
|
||||
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda()
|
||||
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs,
|
||||
preload_rows)
|
||||
|
||||
# update auxiliary info
|
||||
self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.cuda()
|
||||
self.inverted_cached_idx[preload_cpu_ids] = preload_cuda_row_idxs
|
||||
self._cuda_available_row_num -= preload_row_num
|
||||
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
# if the ids_freq_mapping is not None, we initialize the embedding row's freq value in LFU as its freq in dataset.
|
||||
if ids_freq_mapping is None:
|
||||
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0)
|
||||
else:
|
||||
self.freq_cnter[preload_cuda_row_idxs] = freq_value.cuda()
|
||||
|
||||
print(f'Cache warmup finished cost {timer.elapsed} sec.')
|
||||
|
||||
def flush(self):
|
||||
"""flush all CUDA rows to CPU.
|
||||
The function is usually called after training finished.
|
||||
"""
|
||||
slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1)
|
||||
row_ids = self.cached_idx_map[slots]
|
||||
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu()
|
||||
self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows)
|
||||
self.cached_idx_map.index_fill_(0, slots, -1)
|
||||
self.inverted_cached_idx.index_fill_(0, row_ids, -1)
|
||||
self._cuda_available_row_num += slots.numel()
|
||||
|
||||
if self._show_cache_miss:
|
||||
self._cache_miss = 0
|
||||
self._total_cache = 0
|
||||
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
self.freq_cnter.fill_(sys.maxsize)
|
||||
assert self._cuda_available_row_num == self.cuda_row_num
|
||||
assert torch.all(self.inverted_cached_idx == -1).item()
|
||||
assert torch.all(self.cached_idx_map == -1).item()
|
||||
|
||||
def print_comm_stats(self):
|
||||
if self._cuda_to_cpu_numel > 0 and "3_evict_out" in self._elapsed_dict:
|
||||
elapsed = self._elapsed_dict["3_evict_out"]
|
||||
print(
|
||||
f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cuda_to_cpu_numel / 1e6} M elem"
|
||||
)
|
||||
print(f'cuda_to_cpu_elapse {elapsed} sec')
|
||||
if self._cpu_to_cuda_numel > 0 and "5_evict_in" in self._elapsed_dict:
|
||||
elapsed = self._elapsed_dict["5_evict_in"]
|
||||
print(
|
||||
f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cpu_to_cuda_numel / 1e6} M elem"
|
||||
)
|
||||
print(f'cpu_to_cuda_elapse {elapsed} sec')
|
||||
|
||||
for k, v in self._elapsed_dict.items():
|
||||
print(f'{k}: {v}')
|
||||
|
||||
print(f'cache miss ratio {self._cache_miss / self._total_cache}')
|
||||
|
||||
@torch.no_grad()
|
||||
def _id_to_cached_cuda_id(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
convert ids to indices in self.cuda_cached_weight.
|
||||
Implemented with parallel operations on GPU.
|
||||
Args:
|
||||
ids (torch.Tensor): ids from the dataset
|
||||
Returns:
|
||||
torch.Tensor: contains indices in self.cuda_cached_weight
|
||||
"""
|
||||
ids = self.idx_map.index_select(0, ids.view(-1))
|
||||
ret = self.inverted_cached_idx.index_select(0, ids)
|
||||
return ret
|
||||
|
||||
@torch.no_grad()
|
||||
def prepare_ids(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
move the cpu embedding rows w.r.t. ids into CUDA memory
|
||||
Args:
|
||||
ids (torch.Tensor): the ids to be computed
|
||||
Returns:
|
||||
torch.Tensor: indices on the cuda_cached_weight.
|
||||
"""
|
||||
torch.cuda.synchronize()
|
||||
with self.timer("cache_op") as gtimer:
|
||||
# identify cpu rows to cache
|
||||
with self.timer("1_identify_cpu_row_idxs") as timer:
|
||||
with record_function("(cache) get unique indices"):
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
cpu_row_idxs, repeat_times = torch.unique(ids, return_counts=True)
|
||||
else:
|
||||
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True)
|
||||
|
||||
assert len(cpu_row_idxs) <= self.cuda_row_num, \
|
||||
f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \
|
||||
f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \
|
||||
f"Please increase cuda_row_num or decrease the training batch size."
|
||||
self.evict_backlist = cpu_row_idxs
|
||||
tmp = torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)
|
||||
comm_cpu_row_idxs = cpu_row_idxs[tmp]
|
||||
|
||||
if self._show_cache_miss:
|
||||
self._cache_miss += torch.sum(repeat_times[tmp])
|
||||
self._total_cache += ids.numel()
|
||||
|
||||
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
|
||||
self.num_miss_history.append(len(comm_cpu_row_idxs))
|
||||
self.num_write_back_history.append(0)
|
||||
|
||||
# move sure the cuda rows will not be evicted!
|
||||
with record_function("(cache) prepare_rows_on_cuda"):
|
||||
with self.timer("prepare_rows_on_cuda") as timer:
|
||||
self._prepare_rows_on_cuda(comm_cpu_row_idxs)
|
||||
|
||||
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
|
||||
|
||||
with self.timer("6_update_cache") as timer:
|
||||
with record_function("6_update_cache"):
|
||||
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
||||
|
||||
# update for LFU.
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs]
|
||||
self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times)
|
||||
|
||||
return gpu_row_idxs
|
||||
|
||||
def _row_in_cuda(self, row_id: int) -> bool:
|
||||
return self.inverted_cached_idx[row_id] != -1
|
||||
|
||||
@torch.no_grad()
|
||||
def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None:
|
||||
"""prepare rows in cpu_row_idxs on CUDA memory
|
||||
Args:
|
||||
cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA
|
||||
"""
|
||||
evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num
|
||||
|
||||
cpu_row_idxs_copy = cpu_row_idxs.cpu()
|
||||
|
||||
# move evict in rows to gpu
|
||||
if self._async_copy:
|
||||
if self.buffer_size == 0:
|
||||
evict_in_rows_gpu = self.weight.view(self.num_embeddings,
|
||||
-1).index_select(0, cpu_row_idxs_copy).pin_memory()
|
||||
with torch.cuda.stream(self._memcpy_stream):
|
||||
evict_in_rows_gpu = evict_in_rows_gpu.to(torch.cuda.current_device(), non_blocking=True)
|
||||
else:
|
||||
raise NotImplemented
|
||||
|
||||
if evict_num > 0:
|
||||
with self.timer("2_identify_cuda_row_idxs") as timer:
|
||||
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
||||
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
||||
if self._evict_strategy == EvictionStrategy.DATASET:
|
||||
# mask method.
|
||||
# set cached_idx_map[invalid_idxs] to -2.
|
||||
# so those idxs will be sorted to end, therefore not being chosen as victim
|
||||
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
||||
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
||||
|
||||
with self.timer("2_1_find_evict_gpu_idxs") as timer:
|
||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||
|
||||
# move evict out rows to cpu
|
||||
if self._async_copy:
|
||||
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||
-1).index_select(0, evict_gpu_row_idxs)
|
||||
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
|
||||
with torch.cuda.stream(None):
|
||||
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||
|
||||
elif self._evict_strategy == EvictionStrategy.LFU:
|
||||
with self.timer("2_1_backup_freqs") as timer:
|
||||
backup_freqs = self.freq_cnter[invalid_idxs].clone()
|
||||
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize)
|
||||
|
||||
with self.timer("2_2_find_evict_gpu_idxs") as timer:
|
||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||
|
||||
if self._async_copy:
|
||||
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||
-1).index_select(0, evict_gpu_row_idxs)
|
||||
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
|
||||
with torch.cuda.stream(None):
|
||||
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
||||
|
||||
with self.timer("2_3_revert_freqs") as timer:
|
||||
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs)
|
||||
|
||||
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
||||
|
||||
with self.timer("3_evict_out") as timer:
|
||||
if self.buffer_size > 0:
|
||||
self.limit_buff_index_copyer.index_copy(0,
|
||||
src_index=evict_gpu_row_idxs,
|
||||
tgt_index=evict_info.cpu(),
|
||||
src=self.cuda_cached_weight.view(self.cuda_row_num, -1),
|
||||
tgt=self.weight.view(self.num_embeddings, -1))
|
||||
else:
|
||||
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
|
||||
# TODO async gpu -> cpu
|
||||
if self._async_copy:
|
||||
_wait_for_data(evict_out_rows_cpu, None)
|
||||
else:
|
||||
with self.timer("3_1_evict_out_index_select") as timer:
|
||||
evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||
-1).index_select(0, evict_gpu_row_idxs)
|
||||
with self.timer("3_2_evict_out_gpu_to_cpu_copy") as timer:
|
||||
evict_out_rows_cpu = evict_out_rows_cpu.cpu()
|
||||
|
||||
with self.timer("3_2_evict_out_cpu_copy") as timer:
|
||||
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu)
|
||||
|
||||
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
|
||||
self.inverted_cached_idx.index_fill_(0, evict_info, -1)
|
||||
# self.freq_cnter.index_fill(0, evict_gpu_row_idxs, sys.maxsize) # unnecessary
|
||||
self._cuda_available_row_num += evict_num
|
||||
|
||||
weight_size = evict_gpu_row_idxs.numel() * self.embedding_dim
|
||||
self._cuda_to_cpu_numel += weight_size
|
||||
# print(f"evict embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
|
||||
|
||||
# slots of cuda weight to evict in
|
||||
with self.timer("4_identify_cuda_slot") as timer:
|
||||
slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[:cpu_row_idxs.numel()]
|
||||
|
||||
# TODO wait for optimize
|
||||
with self.timer("5_evict_in") as timer:
|
||||
# Here also allocate extra memory on CUDA. #cpu_row_idxs
|
||||
if self.buffer_size > 0:
|
||||
self.limit_buff_index_copyer.index_copy(0,
|
||||
src_index=cpu_row_idxs_copy,
|
||||
tgt_index=slots,
|
||||
src=self.weight.view(self.num_embeddings, -1),
|
||||
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
||||
else:
|
||||
if self._async_copy:
|
||||
_wait_for_data(evict_in_rows_gpu, self._memcpy_stream)
|
||||
else:
|
||||
with self.timer("5_1_evict_in_index_select") as timer:
|
||||
# narrow index select to a subset of self.weight
|
||||
# tmp = torch.narrow(self.weight.view(self.num_embeddings, -1), 0, min(cpu_row_idxs).cpu(), max(cpu_row_idxs) - min(cpu_row_idxs) + 1)
|
||||
# evict_in_rows_gpu = tmp.index_select(0, cpu_row_idxs_copy - min(cpu_row_idxs).cpu())
|
||||
evict_in_rows_gpu = self.weight.view(self.num_embeddings,
|
||||
-1).index_select(0, cpu_row_idxs_copy).pin_memory()
|
||||
|
||||
with self.timer("5_2_evict_in_gpu_to_cpu_copy") as timer:
|
||||
evict_in_rows_gpu = evict_in_rows_gpu.cuda()
|
||||
|
||||
with self.timer("5_3_evict_in_index_copy") as timer:
|
||||
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu)
|
||||
|
||||
with self.timer("6_update_cache") as timer:
|
||||
self.cached_idx_map[slots] = cpu_row_idxs
|
||||
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slots)
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
self.freq_cnter.index_fill_(0, slots, 0)
|
||||
self._cuda_available_row_num -= cpu_row_idxs.numel()
|
||||
|
||||
weight_size = cpu_row_idxs.numel() * self.embedding_dim
|
||||
self._cpu_to_cuda_numel += weight_size
|
||||
# print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
|
||||
|
||||
def _find_free_cuda_row(self) -> int:
|
||||
if self._cuda_available_row_num == 0:
|
||||
return -1
|
||||
candidates = torch.nonzero(self.cached_idx_map == -1).squeeze(1)
|
||||
return candidates[0].item()
|
||||
|
||||
def _evict(self) -> int:
|
||||
"""
|
||||
deprecated
|
||||
evict one row from cuda to cpu.
|
||||
Returns:
|
||||
(int) : the slot id be evicted.
|
||||
"""
|
||||
mask = torch.logical_or(torch.isin(self.cached_idx_map, self.evict_backlist), self.cached_idx_map == -1)
|
||||
buf = self.cached_idx_map[mask].clone()
|
||||
idx = torch.nonzero(mask).squeeze(1)
|
||||
self.cached_idx_map.index_fill_(0, idx, -1)
|
||||
max_row, max_cpu_row_idx = torch.max(self.cached_idx_map, dim=0)
|
||||
max_gpu_row_idx = self.cached_idx_map[max_cpu_row_idx]
|
||||
|
||||
if max_gpu_row_idx == -1:
|
||||
raise RuntimeError("Can not evict a row")
|
||||
|
||||
max_gpu_row_idx = max_gpu_row_idx.item()
|
||||
max_offset = self.inverted_cached_idx[max_gpu_row_idx]
|
||||
# recover
|
||||
self.cached_idx_map.index_copy_(0, idx, buf)
|
||||
|
||||
with Timer() as timer:
|
||||
cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, max_offset * self.embedding_dim,
|
||||
self.embedding_dim).view(1, self.embedding_dim)
|
||||
self.cpu_weight_data(max_gpu_row_idx).data.copy_(cuda_tensor)
|
||||
|
||||
# update inverted_cached_idx, min_slot_id is evicted from cuda
|
||||
self.cached_idx_map[max_cpu_row_idx] = -1
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
self.freq_cnter[max_cpu_row_idx] = sys.maxsize
|
||||
self.inverted_cached_idx[max_gpu_row_idx] = -1
|
||||
|
||||
self._cuda_available_row_num += 1
|
||||
|
||||
self._cuda_to_cpu_numel += self.embedding_dim
|
||||
# self.num_write_back_history[-1] += 1
|
||||
return max_cpu_row_idx
|
||||
|
||||
@torch.no_grad()
|
||||
def _admit(self, row_id: int):
|
||||
"""
|
||||
deprecated
|
||||
move in row_id to CUDA
|
||||
Args:
|
||||
row_id (int): the id of row to be moved in
|
||||
"""
|
||||
# find a free slot in partial cuda weight
|
||||
slot_id = self._find_free_cuda_row()
|
||||
|
||||
if slot_id == -1:
|
||||
# evict one row
|
||||
slot_id = self._evict()
|
||||
slot_offset = slot_id
|
||||
# copy payload from cpu to cuda
|
||||
with Timer() as timer:
|
||||
cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, slot_offset * self.embedding_dim,
|
||||
self.embedding_dim).view(1, self.embedding_dim)
|
||||
cuda_tensor.data.copy_(self.cpu_weight_data(row_id))
|
||||
|
||||
# update the inverted_cached_idx
|
||||
self.cached_idx_map[slot_id] = row_id
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
self.freq_cnter[slot_id] = 0
|
||||
self.inverted_cached_idx[row_id] = slot_offset
|
||||
|
||||
self._cuda_available_row_num -= 1
|
||||
|
||||
self._cpu_to_cuda_numel += self.embedding_dim
|
@@ -0,0 +1,158 @@
|
||||
from typing import Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .base_embedding import BaseEmbeddingBag
|
||||
from .cache_mgr import CachedParamMgr, EvictionStrategy
|
||||
|
||||
|
||||
class CachedEmbeddingBag(BaseEmbeddingBag):
|
||||
"""CachedEmbeddingBag
|
||||
|
||||
Cached Embedding. Apply a GPU-based software cache approaches to dynamically manage the embedding table in the CPU and GPU memory space.
|
||||
It can leverage the id's frequency statistics of the target dataset, by passing a frequency list to param `ids_freq_mapping`.
|
||||
You can also apply a naive LFU cache eviction strategy by setting `evict_strategy` as EvictionStrategy.LFU.
|
||||
|
||||
Args:
|
||||
num_embeddings (int): size of the dictionary of embeddings
|
||||
embedding_dim (int): the size of each embedding vector
|
||||
padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed “pad”. For a newly constructed EmbeddingBag, the embedding vector at padding_idx will default to all zeros, but can be updated to another value to be used as the padding vector. Note that the embedding vector at padding_idx is excluded from the reduction.
|
||||
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm
|
||||
norm_type (str, optional): The p of the p-norm to compute for the max_norm option. Defaults to 2.
|
||||
scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default False. Note: this option is not supported when mode="max". Defaults to False.
|
||||
sparse (bool, optional): if True, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for more details regarding sparse gradients. Note: this option is not supported when mode="max".. Defaults to False.
|
||||
_weight (torch.Tensor, optional): an embedding weight tensor. Concatenate multiple tables in a embedding bag as a single one. Defaults to None.
|
||||
mode (str, optional): "sum", "mean" or "max". Specifies the way to reduce the bag. "sum" computes the weighted sum, taking per_sample_weights into consideration. "mean" computes the average of the values in the bag, "max" computes the max value over each bag. Default: "mean". Defaults to 'mean'.
|
||||
include_last_offset (bool, optional): if True, offsets has one additional element, where the last element is equivalent to the size of indices. This matches the CSR format.. Defaults to False.
|
||||
dtype (torch.dtype, optional): data type of the cpu weight initialization. Defaults to None meaning float32.
|
||||
device (torch.device, optional): device type to the cpu weight. Defaults to None meaning cpu.
|
||||
cache_ratio (float, float): cache ratio of the #cuda_weight_row / #cpu_weight_row
|
||||
ids_freq_mapping (Union[List, torch.Tensor], optional): the frequency of each embedding vector occurs in dataset. Defaults to None.
|
||||
warmup_ratio (float, optional): the ratio of cuda cache is warmuped with. Defaults to 0.7.
|
||||
buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, the buffer is not used. Defaults to 0.
|
||||
pin_weight (bool, optional): pin the cpu weight. Defaults to False.
|
||||
evict_strategy (EvictionStrategy, optional): evict strategy of the software cache. Defaults to EvictionStrategy.DATASET.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
max_norm: float = None,
|
||||
norm_type: float = 2.,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
_weight: Optional[torch.Tensor] = None,
|
||||
mode: str = 'mean',
|
||||
include_last_offset: bool = False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
cache_ratio: float = 0.01,
|
||||
ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None,
|
||||
warmup_ratio: float = 0.7,
|
||||
buffer_size: int = 0,
|
||||
pin_weight: bool = False,
|
||||
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
|
||||
super(CachedEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
|
||||
scale_grad_by_freq, sparse, mode, include_last_offset)
|
||||
|
||||
assert cache_ratio <= 1.0, f"cache ratio {cache_ratio} must less than 1.0"
|
||||
self.evict_strategy = evict_strategy
|
||||
if _weight is None:
|
||||
_weight = self._weight_alloc(dtype, device)
|
||||
cuda_row_num = int(num_embeddings * cache_ratio)
|
||||
# configure weight & cache
|
||||
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight)
|
||||
self.cache_op = True
|
||||
|
||||
def set_cache_mgr_async_copy(self, flag):
|
||||
self.cache_weight_mgr._async_copy = flag
|
||||
|
||||
def _weight_alloc(self, dtype, device):
|
||||
weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device)
|
||||
with torch.no_grad():
|
||||
weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)
|
||||
if self.padding_idx is not None:
|
||||
weight[self.padding_idx].fill_(0)
|
||||
return weight
|
||||
|
||||
def _preprocess(self,
|
||||
weight,
|
||||
cuda_row_num: int,
|
||||
ids_freq_mapping: Optional[List[int]] = None,
|
||||
warmup_ratio=0.7,
|
||||
buffer_size=50_000,
|
||||
pin_weight=False):
|
||||
"""
|
||||
Called after initialized.
|
||||
Reorder the weight rows according to the ids_freq_mapping.
|
||||
Then, let the weights of the Module be managed by a CachedParamMgr.
|
||||
|
||||
Args:
|
||||
cuda_row_num (int): number of rows can be hosted in CUDA memory
|
||||
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
|
||||
warmup_ratio (float): the amount of rows preloaded in cuda cache
|
||||
"""
|
||||
self.cache_weight_mgr = CachedParamMgr(weight,
|
||||
cuda_row_num,
|
||||
buffer_size,
|
||||
pin_weight,
|
||||
evict_strategy=self.evict_strategy)
|
||||
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
||||
|
||||
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None):
|
||||
if self.cache_op:
|
||||
with torch.no_grad():
|
||||
input = self.cache_weight_mgr.prepare_ids(input)
|
||||
|
||||
embeddings = F.embedding_bag(input.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||
per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||
if shape_hook is not None:
|
||||
embeddings = shape_hook(embeddings)
|
||||
return embeddings
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.cache_weight_mgr.weight
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
|
||||
yield 'weight', self.cache_weight_mgr.cuda_cached_weight
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
yield self.cache_weight_mgr.cuda_cached_weight
|
||||
|
||||
def set_cache_op(self, cache_op: bool = True):
|
||||
self.cache_op = cache_op
|
||||
|
||||
|
||||
############################# Perf Log ###################################
|
||||
|
||||
@property
|
||||
def num_hits_history(self):
|
||||
return self.cache_weight_mgr.num_hits_history
|
||||
|
||||
@property
|
||||
def num_miss_history(self):
|
||||
return self.cache_weight_mgr.num_miss_history
|
||||
|
||||
@property
|
||||
def num_write_back_history(self):
|
||||
return self.cache_weight_mgr.num_write_back_history
|
||||
|
||||
@property
|
||||
def swap_in_bandwidth(self):
|
||||
if self.cache_weight_mgr._cpu_to_cuda_numel > 0:
|
||||
return self.cache_weight_mgr._cpu_to_cuda_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \
|
||||
self.cache_weight_mgr._cpu_to_cuda_elapse
|
||||
else:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def swap_out_bandwidth(self):
|
||||
if self.cache_weight_mgr._cuda_to_cpu_numel > 0:
|
||||
return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \
|
||||
self.cache_weight_mgr._cuda_to_cpu_elapse
|
||||
return 0
|
@@ -0,0 +1,49 @@
|
||||
import torch
|
||||
from torch import LongTensor
|
||||
|
||||
|
||||
class LimitBuffIndexCopyer(object):
|
||||
"""LimitBuffIndexCopyer
|
||||
Index Copy using limited temp buffer on CUDA.
|
||||
|
||||
Args:
|
||||
size (int): buffer size
|
||||
"""
|
||||
|
||||
def __init__(self, size: int) -> None:
|
||||
self._buff_size = size
|
||||
|
||||
@torch.no_grad()
|
||||
def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor):
|
||||
"""copy
|
||||
src tensor[src_index] -(index_select)-> tmp -(index_copy_)-> tgt tensor [tgt_index]
|
||||
The valid rows in the src tensor are continuous, while rows in tgt tensor is scattered.
|
||||
|
||||
Args:
|
||||
dim (int): dimension along which to index
|
||||
src_index (int): indices of src tensor to select from
|
||||
tgt_index (int): indices of tgt tensor to select from
|
||||
src (torch.Tensor): the tensor containing values to copy
|
||||
tgt (torch.Tensor): the tensor to be copied
|
||||
"""
|
||||
# tgt.index_copy_(dim, index, src)
|
||||
assert dim == 0, "only support index_copy on dim 0"
|
||||
assert tgt.dim() == 2
|
||||
assert src.dim() == 2
|
||||
tgt_device = tgt.device
|
||||
src_device = src.device
|
||||
|
||||
assert src_index.numel() == tgt_index.numel()
|
||||
dim_size = src_index.numel()
|
||||
src_index = src_index.to(src_device)
|
||||
for begin_pos in range(0, dim_size, self._buff_size):
|
||||
cur_len = min(self._buff_size, dim_size - begin_pos)
|
||||
src_idx_piece = src_index.narrow(0, begin_pos, cur_len)
|
||||
if src_device.type == 'cpu' and tgt_device.type == 'cuda':
|
||||
cpu_tmp_buffer = src.index_select(dim, src_idx_piece).pin_memory()
|
||||
tmp_buffer = torch.empty_like(cpu_tmp_buffer, device=tgt_device)
|
||||
tmp_buffer.copy_(cpu_tmp_buffer)
|
||||
else:
|
||||
tmp_buffer = src.index_select(dim, src_idx_piece).to(tgt_device)
|
||||
tgt_idx_piece = tgt_index.narrow(0, begin_pos, cur_len)
|
||||
tgt.index_copy_(dim, tgt_idx_piece, tmp_buffer)
|
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
|
||||
|
||||
class TablewiseEmbeddingBagConfig:
|
||||
'''
|
||||
example:
|
||||
def prepare_tablewise_config(args, cache_ratio, ...):
|
||||
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
|
||||
...
|
||||
return embedding_bag_config_list
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
cuda_row_num: int,
|
||||
assigned_rank: int = 0,
|
||||
buffer_size=50_000,
|
||||
ids_freq_mapping=None,
|
||||
initial_weight: torch.tensor = None,
|
||||
name: str = ""):
|
||||
self.num_embeddings = num_embeddings
|
||||
self.cuda_row_num = cuda_row_num
|
||||
self.assigned_rank = assigned_rank
|
||||
self.buffer_size = buffer_size
|
||||
self.ids_freq_mapping = ids_freq_mapping
|
||||
self.initial_weight = initial_weight
|
||||
self.name = name
|
@@ -0,0 +1,142 @@
|
||||
from typing import Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.legacy.nn._ops._utils import dual_all_to_all
|
||||
from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ComputePattern, ProcessGroup, ShardSpec
|
||||
|
||||
from .cache_mgr import CachedParamMgr, EvictionStrategy
|
||||
from .cached_embedding import CachedEmbeddingBag
|
||||
|
||||
|
||||
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
||||
if world_size == 1:
|
||||
return 0, embedding_dim, True
|
||||
|
||||
assert embedding_dim >= world_size, \
|
||||
f"Embedding dimension {embedding_dim} must be larger than the world size " \
|
||||
f"{world_size} of the process group"
|
||||
chunk_size = embedding_dim // world_size
|
||||
threshold = embedding_dim % world_size
|
||||
# if embedding dim is divisible by world size
|
||||
if threshold == 0:
|
||||
return rank * chunk_size, (rank + 1) * chunk_size, True
|
||||
|
||||
# align with the split strategy of torch.tensor_split
|
||||
size_list = [chunk_size + 1 if i < threshold else chunk_size for i in range(world_size)]
|
||||
offset = sum(size_list[:rank])
|
||||
return offset, offset + size_list[rank], False
|
||||
|
||||
|
||||
class ParallelCachedEmbeddingBag(CachedEmbeddingBag):
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=None,
|
||||
max_norm=None,
|
||||
norm_type=2.,
|
||||
scale_grad_by_freq=False,
|
||||
sparse=False,
|
||||
_weight=None,
|
||||
mode='mean',
|
||||
include_last_offset=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
cache_ratio=0.01,
|
||||
ids_freq_mapping=None,
|
||||
warmup_ratio=0.7,
|
||||
buffer_size=50_000,
|
||||
pin_weight=False,
|
||||
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET):
|
||||
self.rank = torch.distributed.get_rank()
|
||||
self.world_size = torch.distributed.get_world_size()
|
||||
|
||||
self.partition_start_index, self.partition_end_index, divisible = get_partition(
|
||||
embedding_dim, self.rank, self.world_size)
|
||||
self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index
|
||||
|
||||
super(ParallelCachedEmbeddingBag,
|
||||
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
||||
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
|
||||
warmup_ratio, buffer_size, pin_weight, evict_strategy)
|
||||
self.cache_op = True
|
||||
|
||||
def _weight_alloc(self, dtype, device):
|
||||
weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype)
|
||||
with torch.no_grad():
|
||||
weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)
|
||||
if self.padding_idx is not None:
|
||||
weight[self.padding_idx].fill_(0)
|
||||
colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size),
|
||||
dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]),
|
||||
compute_attr=ComputePattern.TP1D)
|
||||
return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
indices,
|
||||
offsets=None,
|
||||
per_sample_weights=None,
|
||||
shape_hook=None,
|
||||
scatter_dim=0,
|
||||
gather_dim=-1,
|
||||
):
|
||||
if self.cache_op:
|
||||
with torch.no_grad():
|
||||
indices = self.cache_weight_mgr.prepare_ids(indices)
|
||||
output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||
per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||
if shape_hook is not None:
|
||||
output_shard = shape_hook(output_shard)
|
||||
output_full = dual_all_to_all(output_shard,
|
||||
self.weight.get_process_group(),
|
||||
scatter_dim=scatter_dim,
|
||||
gather_dim=gather_dim)
|
||||
return output_full
|
||||
|
||||
def set_cache_op(self, cache_op: bool = True):
|
||||
self.cache_op = cache_op
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
embedding: torch.Tensor,
|
||||
freeze: bool = True,
|
||||
padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2.,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
mode: str = 'mean',
|
||||
include_last_offset: bool = False,
|
||||
cuda_row_num: int = 100_000,
|
||||
ids_freq_mapping: Optional[List[int]] = None,
|
||||
warmup_ratio: float = 0.7,
|
||||
buffer_size: int = 0,
|
||||
) -> 'ParallelCachedEmbeddingBag':
|
||||
rows, cols = embedding.shape
|
||||
embedding_bag = cls(rows,
|
||||
cols,
|
||||
padding_idx,
|
||||
max_norm,
|
||||
norm_type,
|
||||
scale_grad_by_freq,
|
||||
sparse,
|
||||
embedding,
|
||||
mode,
|
||||
include_last_offset,
|
||||
cuda_row_num=cuda_row_num,
|
||||
ids_freq_mapping=ids_freq_mapping,
|
||||
warmup_ratio=warmup_ratio,
|
||||
buffer_size=buffer_size)
|
||||
embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze
|
||||
return embedding_bag
|
||||
|
||||
def print_comm_stats_(self):
|
||||
self.cache_weight_mgr.print_comm_stats()
|
||||
|
||||
def element_size(self):
|
||||
return self.weight.element_size()
|
@@ -0,0 +1,199 @@
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise
|
||||
from colossalai.tensor import ProcessGroup
|
||||
|
||||
from .cache_mgr import EvictionStrategy
|
||||
from .cached_embedding import CachedEmbeddingBag
|
||||
from .embedding_config import TablewiseEmbeddingBagConfig
|
||||
|
||||
|
||||
class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag):
|
||||
"""
|
||||
all tables assigned to this class instance are managed by a single CachedEmbeddingBag.
|
||||
Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig],
|
||||
embedding_dim: int,
|
||||
padding_idx=None,
|
||||
max_norm=None,
|
||||
norm_type=2.,
|
||||
scale_grad_by_freq=False,
|
||||
sparse=False,
|
||||
_weight=None,
|
||||
mode='mean',
|
||||
include_last_offset=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
cache_ratio=0.01,
|
||||
warmup_ratio=0.7,
|
||||
buffer_size=50_000,
|
||||
pin_weight=False,
|
||||
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list]
|
||||
self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list]
|
||||
self.global_tables_num = len(embedding_bag_config_list)
|
||||
self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda()
|
||||
self.assigned_table_list: List[int] = []
|
||||
self.pg = ProcessGroup(tp_degree=self.world_size)
|
||||
self.num_embeddings = 0
|
||||
for i, rank in enumerate(self.rank_of_tables):
|
||||
if rank == self.rank:
|
||||
self.assigned_table_list.append(i)
|
||||
self.num_embeddings += self.global_table_num_embeddings_list[i]
|
||||
self.include_last_offset = include_last_offset
|
||||
|
||||
ids_freq_mapping = []
|
||||
for config in embedding_bag_config_list:
|
||||
if config.assigned_rank == self.rank:
|
||||
if config.ids_freq_mapping != None:
|
||||
ids_freq_mapping.extend(config.ids_freq_mapping)
|
||||
else:
|
||||
ids_freq_mapping = None
|
||||
break
|
||||
self.cache_ratio = cache_ratio
|
||||
# table-associate cache
|
||||
cuda_row_num = int(cache_ratio * self.num_embeddings)
|
||||
super(ParallelCachedEmbeddingBagTablewise,
|
||||
self).__init__(self.num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
||||
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
|
||||
warmup_ratio, buffer_size, pin_weight, evict_strategy)
|
||||
|
||||
# for assigned tables reconnection:
|
||||
self.idx_offset_list = []
|
||||
offset_cumsum = 0
|
||||
for table_i, table_num_embeddings in enumerate(self.global_table_num_embeddings_list):
|
||||
if self.rank_of_tables[table_i] == self.rank:
|
||||
self.idx_offset_list.append(offset_cumsum)
|
||||
else:
|
||||
offset_cumsum += table_num_embeddings
|
||||
|
||||
# prepare list shape for all_to_all output
|
||||
self.embedding_dim_per_rank = [0 for i in range(self.world_size)]
|
||||
for rank in self.rank_of_tables:
|
||||
self.embedding_dim_per_rank[rank] += embedding_dim
|
||||
|
||||
self.cache_op = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
indices: torch.Tensor,
|
||||
offsets: torch.Tensor = None,
|
||||
per_sample_weights=None,
|
||||
shape_hook=None,
|
||||
already_split_along_rank=True,
|
||||
):
|
||||
if not already_split_along_rank:
|
||||
# not recommanded. it takes time.
|
||||
batch_size = (offsets.shape[0]) // self.global_tables_num
|
||||
local_indices, local_offsets, local_per_sample_weights = self.split_along_rank(
|
||||
batch_size, indices, offsets, per_sample_weights)
|
||||
else:
|
||||
# recommanded.
|
||||
batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
|
||||
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
|
||||
if self.cache_op:
|
||||
with torch.no_grad():
|
||||
indices = self.cache_weight_mgr.prepare_ids(local_indices)
|
||||
local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
|
||||
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||
local_per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||
local_output = torch.cat(local_output.split(batch_size), 1)
|
||||
remains = batch_size % self.world_size
|
||||
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
|
||||
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
|
||||
if shape_hook is not None:
|
||||
output_full = shape_hook(output_full)
|
||||
return output_full
|
||||
|
||||
def split_along_rank(self,
|
||||
batch_size,
|
||||
indices: torch.Tensor,
|
||||
offsets: torch.Tensor = None,
|
||||
per_sample_weights=None):
|
||||
'''
|
||||
if input indices and offsets haven't been splitted along assigned rank, this function will do it.
|
||||
it takes time. please consider splitting data during batch loading.
|
||||
'''
|
||||
local_indices_list: List(torch.Tensor) = []
|
||||
local_offsets_list: List(torch.Tensor) = []
|
||||
if per_sample_weights != None:
|
||||
local_per_sample_weights_list: List(torch.Tensor) = []
|
||||
|
||||
offset_pre_end = 0 # local_offsets trick
|
||||
for i, handle_table in enumerate(self.assigned_table_list):
|
||||
indices_start_position = offsets[batch_size * handle_table]
|
||||
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
|
||||
# till-the-end special case
|
||||
indices_end_position = indices.shape[0]
|
||||
else:
|
||||
indices_end_position = offsets[batch_size * (handle_table + 1)]
|
||||
# alternative approach: reduce malloc
|
||||
'''
|
||||
# 1. local_indices_list:
|
||||
local_indices = indices.narrow(0, indices_start_position, indices_end_position - indices_start_position)
|
||||
torch.sub(local_indices, self.idx_offset_list[i], out=local_indices)
|
||||
local_indices_list.append(local_indices)
|
||||
# 2. local_offsets_list:
|
||||
if i + 1 == len(self.assigned_table_list):
|
||||
# till-the-end special case
|
||||
if not self.include_last_offset:
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size)
|
||||
else:
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1)
|
||||
torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets)
|
||||
local_offsets_list.append(local_offsets)
|
||||
else:
|
||||
temp_holder = offsets[batch_size * handle_table].item()
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size)
|
||||
torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets)
|
||||
offset_pre_end = offsets[batch_size * (handle_table + 1)] + offset_pre_end - temp_holder
|
||||
local_offsets_list.append(local_offsets)
|
||||
'''
|
||||
# 1. local_indices_list:
|
||||
local_indices_list.append(
|
||||
indices.narrow(0, indices_start_position,
|
||||
indices_end_position - indices_start_position).sub(self.idx_offset_list[i]))
|
||||
# 2. local_offsets_list:
|
||||
if i + 1 == len(self.assigned_table_list):
|
||||
# till-the-end special case
|
||||
if not self.include_last_offset:
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table,
|
||||
batch_size).add(offset_pre_end - offsets[batch_size *
|
||||
(handle_table)])
|
||||
else:
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size +
|
||||
1).add(offset_pre_end - offsets[batch_size * (handle_table)])
|
||||
local_offsets_list.append(local_offsets)
|
||||
else:
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size +
|
||||
1).add(offset_pre_end - offsets[batch_size * (handle_table)])
|
||||
offset_pre_end = local_offsets[-1]
|
||||
local_offsets_list.append(local_offsets[:-1])
|
||||
# 3. local_per_sample_weights_list:
|
||||
if per_sample_weights != None:
|
||||
local_per_sample_weights_list.append(per_sample_weights[indices_start_position:indices_end_position])
|
||||
local_indices = torch.cat(local_indices_list, 0)
|
||||
local_offsets = torch.cat(local_offsets_list, 0)
|
||||
local_per_sample_weights = None
|
||||
if per_sample_weights != None:
|
||||
local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0)
|
||||
return local_indices, local_offsets, local_per_sample_weights
|
||||
|
||||
def set_cache_op(self, cache_op: bool = True):
|
||||
self.cache_op = cache_op
|
||||
|
||||
def print_comm_stats_(self):
|
||||
self.cache_weight_mgr.print_comm_stats()
|
||||
|
||||
def element_size(self):
|
||||
return self.weight.element_size()
|
@@ -0,0 +1,138 @@
|
||||
import abc
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.profiler import record_function
|
||||
|
||||
from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise
|
||||
from colossalai.tensor import ProcessGroup
|
||||
|
||||
from .cache_mgr import EvictionStrategy
|
||||
from .cached_embedding import CachedEmbeddingBag
|
||||
from .embedding_config import TablewiseEmbeddingBagConfig
|
||||
|
||||
|
||||
class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
|
||||
"""
|
||||
every table assigned to this class instance is managed by a CachedEmbeddingBag.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig],
|
||||
embedding_dim: int,
|
||||
padding_idx=None,
|
||||
max_norm=None,
|
||||
norm_type=2.,
|
||||
scale_grad_by_freq=False,
|
||||
sparse=False,
|
||||
mode='mean',
|
||||
include_last_offset=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
warmup_ratio=0.7,
|
||||
pin_weight=False,
|
||||
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
|
||||
super(ParallelCachedEmbeddingBagTablewiseSpiltCache, self).__init__()
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list]
|
||||
self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list]
|
||||
self.global_tables_num = len(embedding_bag_config_list)
|
||||
self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda()
|
||||
|
||||
self.assigned_table_list: List[int] = []
|
||||
for i, rank in enumerate(self.rank_of_tables):
|
||||
if rank == self.rank:
|
||||
self.assigned_table_list.append(i)
|
||||
self.include_last_offset = include_last_offset
|
||||
self.pg = ProcessGroup(tp_degree=self.world_size)
|
||||
|
||||
# prepare CachedEmbeddingBag list
|
||||
|
||||
self.cached_embedding_bag_list: nn.ModuleList = nn.ModuleList()
|
||||
for config in embedding_bag_config_list:
|
||||
if config.assigned_rank != self.rank:
|
||||
continue
|
||||
self.cached_embedding_bag_list.append(
|
||||
CachedEmbeddingBag(num_embeddings=config.num_embeddings,
|
||||
embedding_dim=embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse,
|
||||
_weight=config.initial_weight,
|
||||
mode=mode,
|
||||
include_last_offset=include_last_offset,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
cuda_row_num=config.cuda_row_num,
|
||||
ids_freq_mapping=config.ids_freq_mapping,
|
||||
warmup_ratio=warmup_ratio,
|
||||
buffer_size=config.buffer_size,
|
||||
pin_weight=pin_weight,
|
||||
evict_strategy=evict_strategy))
|
||||
|
||||
# prepare list shape for all_to_all output
|
||||
self.embedding_dim_per_rank = [0 for i in range(self.world_size)]
|
||||
for rank in self.rank_of_tables:
|
||||
self.embedding_dim_per_rank[rank] += embedding_dim
|
||||
|
||||
def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None):
|
||||
# determine indices to handle
|
||||
batch_size = (offsets.shape[0]) // self.global_tables_num
|
||||
local_output_list = []
|
||||
for i, handle_table in enumerate(self.assigned_table_list):
|
||||
with record_function("(tablewise) prepare indices and offsets"):
|
||||
with record_function("part 1"):
|
||||
indices_start_position = offsets[batch_size * handle_table]
|
||||
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
|
||||
# till the end special case
|
||||
indices_end_position = indices.shape[0]
|
||||
else:
|
||||
indices_end_position = offsets[batch_size * (handle_table + 1)]
|
||||
with record_function("part 2"):
|
||||
# local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table]
|
||||
local_indices = indices.narrow(0, indices_start_position, indices_end_position -
|
||||
indices_start_position).sub(self.global_tables_offsets[handle_table])
|
||||
if self.include_last_offset:
|
||||
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)]
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table,
|
||||
batch_size + 1).sub(offsets[batch_size * (handle_table)])
|
||||
else:
|
||||
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)]
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table,
|
||||
batch_size).sub(offsets[batch_size * (handle_table)])
|
||||
local_per_sample_weights = None
|
||||
if per_sample_weights != None:
|
||||
local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position]
|
||||
with record_function("(tablewise) tablewise forward"):
|
||||
local_output_list.append(self.cached_embedding_bag_list[i](local_indices, local_offsets,
|
||||
local_per_sample_weights))
|
||||
|
||||
# get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim))
|
||||
local_output = torch.cat(local_output_list, 1)
|
||||
# then concatenate those local_output on the second dimension.
|
||||
# use all_to_all
|
||||
remains = batch_size % self.world_size
|
||||
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
|
||||
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
|
||||
if shape_hook is not None:
|
||||
output_full = shape_hook(output_full)
|
||||
return output_full
|
||||
|
||||
def element_size(self):
|
||||
if len(self.assigned_table_list) == 0:
|
||||
return 0
|
||||
return self.cached_embedding_bag_list[0].cache_weight_mgr.weight.element_size()
|
||||
|
||||
def print_comm_stats_(self):
|
||||
cuda_to_cpu_elem_num = 0
|
||||
cpu_to_cuda_elem_num = 0
|
||||
for cached_embedding_bag in self.cached_embedding_bag_list:
|
||||
cuda_to_cpu_elem_num += cached_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel
|
||||
cpu_to_cuda_elem_num += cached_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel
|
||||
print(f"CUDA->CPU num: {cuda_to_cpu_elem_num / 1e6} M elem")
|
||||
print(f"CPU->CUDA num: {cpu_to_cuda_elem_num / 1e6} M elem")
|
47
colossalai/legacy/nn/parallel/layers/colo_module.py
Normal file
47
colossalai/legacy/nn/parallel/layers/colo_module.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from colossalai.tensor import ComputePattern
|
||||
from colossalai.tensor.distspec import _DistSpec
|
||||
|
||||
|
||||
class ColoModule(object):
|
||||
|
||||
def __init__(self):
|
||||
self._shard_params: List[str] = []
|
||||
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
|
||||
|
||||
def _register_shard_params(self, params: List[str]):
|
||||
self._shard_params = params
|
||||
|
||||
def _register_allowed_patterns(self,
|
||||
compute_pattern: ComputePattern,
|
||||
dist_specs: Dict[str, _DistSpec],
|
||||
mode='default'):
|
||||
assert list(
|
||||
dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.'
|
||||
if not compute_pattern in self._allowed_patterns:
|
||||
self._allowed_patterns[compute_pattern] = {}
|
||||
self._allowed_patterns[compute_pattern][mode] = dist_specs
|
||||
|
||||
def _set_default(self, compute_pattern: ComputePattern, target_mode):
|
||||
self._allowed_patterns[compute_pattern]['default'] = self._allowed_patterns[compute_pattern][target_mode]
|
||||
|
||||
def has_compute_pattern(self, compute_pattern: ComputePattern):
|
||||
return compute_pattern in self._allowed_patterns
|
||||
|
||||
def get_dist_specs(self, compute_pattern: ComputePattern):
|
||||
assert self.has_compute_pattern(compute_pattern)
|
||||
return self._allowed_patterns[compute_pattern]
|
||||
|
||||
def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode='default'):
|
||||
return compute_pattern in self._allowed_patterns and mode in self._allowed_patterns[compute_pattern]
|
||||
|
||||
def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode='default'):
|
||||
assert self.has_compute_pattern_with_mode(compute_pattern, mode)
|
||||
return self._allowed_patterns[compute_pattern][mode]
|
||||
|
||||
def get_param_names(self):
|
||||
return self._shard_params
|
||||
|
||||
def register(self, compute_pattern, pg):
|
||||
raise NotImplementedError
|
37
colossalai/legacy/nn/parallel/layers/embedding.py
Normal file
37
colossalai/legacy/nn/parallel/layers/embedding.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
|
||||
|
||||
from .colo_module import ColoModule
|
||||
|
||||
|
||||
class ColoEmbedding(ColoModule):
|
||||
|
||||
def __init__(self):
|
||||
super(ColoEmbedding, self).__init__()
|
||||
self._register_shard_params(['weight'])
|
||||
|
||||
def register(self, compute_pattern, pg: ProcessGroup):
|
||||
if not compute_pattern in self._allowed_patterns:
|
||||
if ComputePattern.TP1D == compute_pattern:
|
||||
self._set_TP1D(pg)
|
||||
|
||||
def _set_TP1D(self, pg: ProcessGroup):
|
||||
# TP1D Row Linear
|
||||
_compute_pattern = ComputePattern.TP1D
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': ShardSpec([0], [pg.tp_world_size()]),
|
||||
},
|
||||
mode='row',
|
||||
)
|
||||
|
||||
# TP1D Col Linear
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': ShardSpec([-1], [pg.tp_world_size()]),
|
||||
},
|
||||
mode='col',
|
||||
)
|
||||
|
||||
self._set_default(compute_pattern=_compute_pattern, target_mode='row')
|
39
colossalai/legacy/nn/parallel/layers/linear.py
Normal file
39
colossalai/legacy/nn/parallel/layers/linear.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
|
||||
|
||||
from .colo_module import ColoModule
|
||||
|
||||
|
||||
class ColoLinear(ColoModule):
|
||||
|
||||
def __init__(self):
|
||||
super(ColoLinear, self).__init__()
|
||||
self._register_shard_params(['weight', 'bias'])
|
||||
|
||||
def register(self, compute_pattern, pg: ProcessGroup):
|
||||
if not compute_pattern in self._allowed_patterns:
|
||||
if ComputePattern.TP1D == compute_pattern:
|
||||
self._set_TP1D(pg)
|
||||
|
||||
def _set_TP1D(self, pg):
|
||||
# TP1D Row Linear
|
||||
_compute_pattern = ComputePattern.TP1D
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': ShardSpec([-1], [pg.tp_world_size()]),
|
||||
'bias': None
|
||||
},
|
||||
mode='row',
|
||||
)
|
||||
|
||||
# TP1D Col Linear
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': ShardSpec([0], [pg.tp_world_size()]),
|
||||
'bias': ShardSpec([0], [pg.tp_world_size()])
|
||||
},
|
||||
mode='col',
|
||||
)
|
||||
|
||||
self._set_default(compute_pattern=_compute_pattern, target_mode='row')
|
115
colossalai/legacy/nn/parallel/layers/module_utils.py
Normal file
115
colossalai/legacy/nn/parallel/layers/module_utils.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup, distspec
|
||||
|
||||
from . import ColoModule
|
||||
|
||||
_COLOSSAL_MODULES: Dict[type, ColoModule] = {}
|
||||
|
||||
|
||||
def register_colo_module(module_type: type, colo_module: ColoModule):
|
||||
global _COLOSSAL_MODULES
|
||||
_COLOSSAL_MODULES[module_type] = colo_module
|
||||
|
||||
|
||||
def is_colo_module(module: torch.nn.Module):
|
||||
global _COLOSSAL_MODULES
|
||||
for module_type in _COLOSSAL_MODULES.keys():
|
||||
if isinstance(module, module_type):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_colo_module(module: torch.nn.Module):
|
||||
global _COLOSSAL_MODULES
|
||||
if is_colo_module(module):
|
||||
for module_type, colo_module in _COLOSSAL_MODULES.items():
|
||||
if isinstance(module, module_type):
|
||||
return colo_module
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True):
|
||||
if is_colo_module(module):
|
||||
colo_module = get_colo_module(module)
|
||||
param_names = colo_module.get_param_names()
|
||||
compute_pattern = None
|
||||
for param_name in param_names:
|
||||
param = module.get_parameter(param_name)
|
||||
if not isinstance(param, ColoParameter):
|
||||
raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.')
|
||||
if param.has_compute_spec():
|
||||
cur_compute_pattern = param.compute_spec.compute_pattern
|
||||
if compute_pattern is None:
|
||||
compute_pattern = cur_compute_pattern
|
||||
else:
|
||||
if cur_compute_pattern != compute_pattern:
|
||||
raise Exception(
|
||||
f'Invalid ColoParameter spec: Params in {module} have different compute_pattern.')
|
||||
else:
|
||||
continue
|
||||
|
||||
if compute_pattern is not None:
|
||||
colo_module.register(compute_pattern, pg)
|
||||
if not colo_module.has_compute_pattern(compute_pattern):
|
||||
raise Exception(
|
||||
f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.')
|
||||
|
||||
match_specs = False
|
||||
allowed_specs = colo_module.get_dist_specs(compute_pattern)
|
||||
for _, param_specs in allowed_specs.items():
|
||||
cur_match = True
|
||||
for param_name, dist_spec in param_specs.items():
|
||||
param = module.get_parameter(param_name)
|
||||
if param.has_compute_spec():
|
||||
if dist_spec != param.dist_spec:
|
||||
cur_match = False
|
||||
break
|
||||
else:
|
||||
if dist_spec is not None:
|
||||
cur_match = False
|
||||
break
|
||||
if cur_match == True:
|
||||
match_specs = True
|
||||
break
|
||||
if match_specs == False:
|
||||
raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.')
|
||||
if recursive == True:
|
||||
for submodule in module.children():
|
||||
check_colo_module(submodule, pg=pg, recursive=True)
|
||||
|
||||
|
||||
def init_colo_module(module: torch.nn.Module,
|
||||
compute_spec: ComputeSpec,
|
||||
pg: ProcessGroup,
|
||||
recursive=True,
|
||||
mode='default'):
|
||||
compute_pattern = compute_spec.compute_pattern
|
||||
if is_colo_module(module):
|
||||
# for each param
|
||||
# set its process_group, dist_spec and compute_spec
|
||||
colo_module = get_colo_module(module)
|
||||
colo_module.register(compute_pattern, pg)
|
||||
if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode):
|
||||
raise NotImplementedError
|
||||
# a set for modules which update at least one param in the init process.
|
||||
# these modules need to be checked whether all params still match one of the valid compute pattern.
|
||||
modules_update_param = {module}
|
||||
for param_name, dist_spec in colo_module.get_dist_specs_with_mode(compute_pattern, mode=mode).items():
|
||||
if dist_spec is None:
|
||||
continue
|
||||
param = module.get_parameter(param_name)
|
||||
if isinstance(param, ColoParameter):
|
||||
param.set_process_group(pg)
|
||||
param.set_dist_spec(dist_spec)
|
||||
param.compute_spec = compute_spec
|
||||
for mod in param.shared_param_modules:
|
||||
modules_update_param.add(mod)
|
||||
for mod in modules_update_param:
|
||||
check_colo_module(mod, pg, recursive=False)
|
||||
if recursive == True:
|
||||
for submodule in module.children():
|
||||
init_colo_module(submodule, compute_spec, pg=pg, recursive=True, mode=mode)
|
Reference in New Issue
Block a user