mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 10:30:03 +00:00
[FAW] export FAW in _ops (#1438)
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
from .cache_mgr import CachedParamMgr
|
||||
from .copyer import LimitBuffIndexCopyer
|
||||
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
||||
from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag
|
||||
|
||||
__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag']
|
@@ -0,0 +1,36 @@
|
||||
import abc
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class BaseEmbeddingBag(abc.ABC, nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=None,
|
||||
max_norm=None,
|
||||
norm_type=2.,
|
||||
scale_grad_by_freq=False,
|
||||
sparse=False,
|
||||
mode='mean',
|
||||
include_last_offset=False,
|
||||
):
|
||||
super(BaseEmbeddingBag, self).__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
if padding_idx is not None:
|
||||
if padding_idx > 0:
|
||||
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
||||
elif padding_idx < 0:
|
||||
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
||||
padding_idx = self.num_embeddings + padding_idx
|
||||
self.padding_idx = padding_idx
|
||||
self.max_norm = max_norm
|
||||
self.norm_type = norm_type
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
self.sparse = sparse
|
||||
|
||||
# Specific to embedding bag
|
||||
self.mode = mode
|
||||
self.include_last_offset = include_last_offset
|
348
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
Normal file
348
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
Normal file
@@ -0,0 +1,348 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.profiler import record_function
|
||||
from typing import List, Optional
|
||||
from contexttimer import Timer
|
||||
from .copyer import LimitBuffIndexCopyer
|
||||
|
||||
|
||||
class CachedParamMgr(torch.nn.Module):
|
||||
"""
|
||||
Manage Embedding Weights in Cache on CPU and CUDA memory.
|
||||
CPU maintains entire original weight.
|
||||
CUDA maintains a fraction of weights used in the upcomming computation.
|
||||
During training, GPU needs to transmit rows between CPU and GPU.
|
||||
"""
|
||||
|
||||
def __init__(self, weight: torch.Tensor, cuda_row_num: int = 0, buffer_size: int = 50_000) -> None:
|
||||
super(CachedParamMgr, self).__init__()
|
||||
self.buffer_size = buffer_size
|
||||
self.num_embeddings, self.embedding_dim = weight.shape
|
||||
self.cuda_row_num = cuda_row_num
|
||||
self._cuda_available_row_num = self.cuda_row_num
|
||||
|
||||
self.elem_size_in_byte = weight.element_size()
|
||||
|
||||
self.cuda_cached_weight = torch.nn.Parameter(
|
||||
torch.zeros(self.cuda_row_num, self.embedding_dim, device=torch.cuda.current_device(), dtype=weight.dtype))
|
||||
|
||||
if weight.device.type == 'cuda':
|
||||
weight = weight.cpu()
|
||||
|
||||
# pin memory cpu for higher CPU-GPU copy bandwidth
|
||||
self.cpu_weight = weight.contiguous().pin_memory()
|
||||
|
||||
# map original id to new id with respect to frequency
|
||||
# id -> cpu_row_idx
|
||||
self.register_buffer(
|
||||
"idx_map",
|
||||
torch.arange(self.num_embeddings, dtype=torch.long, device=torch.cuda.current_device()),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
# cached_idx_map: gpu_row_idx -> cpu_row_idx
|
||||
self.register_buffer("cached_idx_map",
|
||||
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
|
||||
dtype=torch.long).fill_(-1),
|
||||
persistent=False)
|
||||
|
||||
# cpu_row_id -> gpu_row_idx.
|
||||
# gpu_row_idx as -1 means cpu_row_id not in CUDA.
|
||||
self.register_buffer("inverted_cached_idx",
|
||||
torch.zeros(self.num_embeddings, device=torch.cuda.current_device(),
|
||||
dtype=torch.long).fill_(-1),
|
||||
persistent=False)
|
||||
|
||||
self.evict_backlist = torch.tensor([], device=torch.cuda.current_device())
|
||||
|
||||
# index copy buffer size should less than 10% of cuda weight.
|
||||
if self.buffer_size > 0:
|
||||
self.limit_buff_index_copyer = LimitBuffIndexCopyer(self.buffer_size)
|
||||
|
||||
self.num_hits_history = []
|
||||
self.num_miss_history = []
|
||||
self.num_write_back_history = []
|
||||
self.input_id_percent_in_load_chunk = []
|
||||
self._reset_comm_stats()
|
||||
|
||||
def cpu_weight_data(self, chunk_id: int) -> torch.Tensor:
|
||||
"""
|
||||
access a chunk of CPU weight.
|
||||
|
||||
Args:
|
||||
chunk_id (int): chunk id
|
||||
|
||||
Returns:
|
||||
torch.Tensor: a piece of memory in CPU weight corresponding to chunk id's payload. The tensor is 1-D.
|
||||
"""
|
||||
|
||||
return self.cpu_weight.data.view(-1).narrow(0,
|
||||
int(chunk_id) * self.embedding_dim,
|
||||
self.embedding_dim).view(1, self.embedding_dim)
|
||||
|
||||
@property
|
||||
def cuda_available_chunk_num(self):
|
||||
return self._cuda_available_row_num
|
||||
|
||||
@torch.no_grad()
|
||||
def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7):
|
||||
"""reorder the cpu_weight according to ids' frequency in dataset before training.
|
||||
Also Build the IndexMappingTable, aka index_mapping_table.
|
||||
Execute only once before training.
|
||||
Args:
|
||||
ids_freq_mapping (List[int]): a list, idx is id number, value is freq. if None no reorder
|
||||
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
||||
"""
|
||||
if ids_freq_mapping is not None:
|
||||
tmp_idx = torch.argsort(torch.from_numpy(ids_freq_mapping).cuda(), descending=True)
|
||||
sorted_idx = torch.argsort(tmp_idx)
|
||||
self.idx_map.data.copy_(sorted_idx)
|
||||
|
||||
# TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks.
|
||||
# As cuda_cached_weight is very big. You may not have that much available memory!
|
||||
# Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda
|
||||
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
|
||||
if preload_row_num > 0:
|
||||
with Timer() as timer:
|
||||
# extract chunks from cpu weight
|
||||
preload_row_ids = torch.arange(preload_row_num)
|
||||
preload_slot_ids = preload_row_ids.cuda()
|
||||
|
||||
if self.buffer_size > 0:
|
||||
self.limit_buff_index_copyer.index_copy(0,
|
||||
src_index=preload_row_ids,
|
||||
tgt_index=preload_slot_ids,
|
||||
src=self.cpu_weight.view(self.num_embeddings, -1),
|
||||
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
||||
else:
|
||||
preload_chunks = self.cpu_weight.view(self.num_embeddings, -1).index_select(0,
|
||||
preload_row_ids).cuda()
|
||||
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_chunks)
|
||||
|
||||
# update auxiliary info
|
||||
slot_offsets = preload_slot_ids
|
||||
self.cached_idx_map[preload_slot_ids] = preload_slot_ids
|
||||
self.inverted_cached_idx[preload_slot_ids] = slot_offsets
|
||||
self._cuda_available_row_num -= preload_row_num
|
||||
print(f'Cache warmup finished cost {timer.elapsed} sec.')
|
||||
|
||||
def flush(self):
|
||||
"""flush all CUDA chunks to CPU.
|
||||
The function is usually called after training finished.
|
||||
"""
|
||||
slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1)
|
||||
chunk_ids = self.cached_idx_map[slots]
|
||||
chunks = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu()
|
||||
self.cpu_weight.view(self.num_embeddings, -1).index_copy_(0, chunk_ids.cpu(), chunks)
|
||||
self.cached_idx_map.index_fill_(0, slots, -1)
|
||||
self.inverted_cached_idx.index_fill_(0, chunk_ids, -1)
|
||||
self._cuda_available_row_num += slots.numel()
|
||||
|
||||
assert self._cuda_available_row_num == self.cuda_row_num
|
||||
assert torch.all(self.inverted_cached_idx == -1).item()
|
||||
assert torch.all(self.cached_idx_map == -1).item()
|
||||
|
||||
def print_comm_stats(self):
|
||||
if self._cuda_to_cpu_numel > 0:
|
||||
print(
|
||||
f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / self._cuda_to_cpu_elapse} MB/s {self._cuda_to_cpu_numel / 1e6} M elem"
|
||||
)
|
||||
if self._cpu_to_cuda_numel > 0:
|
||||
print(
|
||||
f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / self._cpu_to_cuda_elpase} MB/s {self._cpu_to_cuda_numel / 1e6} M elem"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _id_to_cached_cuda_id(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
convert ids to indices in self.cuda_cached_weight.
|
||||
Implemented with parallel operations on GPU.
|
||||
|
||||
Args:
|
||||
ids (torch.Tensor): ids from the dataset
|
||||
|
||||
Returns:
|
||||
torch.Tensor: contains indices in self.cuda_cached_weight
|
||||
"""
|
||||
ids = self.idx_map.index_select(0, ids.view(-1))
|
||||
ret = self.inverted_cached_idx.index_select(0, ids)
|
||||
return ret
|
||||
|
||||
@torch.no_grad()
|
||||
def prepare_ids(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
move the cpu embedding rows w.r.t. ids into CUDA memory
|
||||
|
||||
Args:
|
||||
ids (torch.Tensor): the ids to be computed
|
||||
Returns:
|
||||
torch.Tensor: indices on the cuda_cached_weight.
|
||||
"""
|
||||
with record_function("(zhg) get unique indices"):
|
||||
cpu_row_idxs = torch.unique(self.idx_map.index_select(0, ids))
|
||||
|
||||
assert len(cpu_row_idxs) <= self.cuda_row_num, \
|
||||
f"the input indices pull {len(cpu_row_idxs)} chunks, " \
|
||||
f"which is larger than the presented {self.cuda_row_num}, " \
|
||||
f"please increase cuda_row_num shrink batch size"
|
||||
self.evict_backlist = cpu_row_idxs
|
||||
|
||||
with record_function("(zhg) get cpu chunk indices"):
|
||||
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)]
|
||||
|
||||
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
|
||||
self.num_miss_history.append(len(comm_cpu_row_idxs))
|
||||
self.num_write_back_history.append(0)
|
||||
|
||||
# move sure the cuda chunk will not be evicted!
|
||||
with record_function("(zhg) cache update"):
|
||||
self._prepare_rows_on_cuda(comm_cpu_row_idxs)
|
||||
|
||||
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
|
||||
# new ids chunk_offset + offset_in_chunk
|
||||
with record_function("(zhg) embed idx -> cache chunk id"):
|
||||
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
||||
return gpu_row_idxs
|
||||
|
||||
def _reset_comm_stats(self):
|
||||
self._cpu_to_cuda_numel = 0
|
||||
self._cpu_to_cuda_elpase = 0
|
||||
self._cuda_to_cpu_elapse = 0
|
||||
self._cuda_to_cpu_numel = 0
|
||||
|
||||
def _chunk_in_cuda(self, chunk_id: int) -> bool:
|
||||
return self.inverted_cached_idx[chunk_id] != -1
|
||||
|
||||
@torch.no_grad()
|
||||
def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None:
|
||||
"""prepare rows in cpu_row_idxs on CUDA memory
|
||||
Args:
|
||||
cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA
|
||||
"""
|
||||
evict_num = cpu_row_idxs.numel() - self.cuda_available_chunk_num
|
||||
if evict_num > 0:
|
||||
with Timer() as timer:
|
||||
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
||||
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
||||
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
||||
|
||||
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
||||
evict_gpu_row_idxs = torch.argsort(self.cached_idx_map, descending=True)[:evict_num]
|
||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||
|
||||
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
||||
|
||||
if self.buffer_size > 0:
|
||||
self.limit_buff_index_copyer.index_copy(0,
|
||||
src_index=evict_gpu_row_idxs,
|
||||
tgt_index=evict_info.cpu(),
|
||||
src=self.cuda_cached_weight.view(self.cuda_row_num, -1),
|
||||
tgt=self.cpu_weight.view(self.num_embeddings, -1))
|
||||
else:
|
||||
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
|
||||
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs).cpu()
|
||||
self.cpu_weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows)
|
||||
|
||||
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
|
||||
self.inverted_cached_idx.index_fill_(0, evict_info, -1)
|
||||
self._cuda_available_row_num += evict_num
|
||||
|
||||
weight_size = evict_gpu_row_idxs.numel() * self.embedding_dim
|
||||
self._cuda_to_cpu_elapse += timer.elapsed
|
||||
self._cuda_to_cpu_numel += weight_size
|
||||
# print(f"evict embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
|
||||
|
||||
with Timer() as timer:
|
||||
slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[:cpu_row_idxs.numel()]
|
||||
# Here also allocate extra memory on CUDA. #cpu_row_idxs
|
||||
if self.buffer_size > 0:
|
||||
self.limit_buff_index_copyer.index_copy(0,
|
||||
src_index=cpu_row_idxs.cpu(),
|
||||
tgt_index=slots,
|
||||
src=self.cpu_weight.view(self.num_embeddings, -1),
|
||||
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
||||
else:
|
||||
rows = self.cpu_weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda()
|
||||
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, rows)
|
||||
slot_offsets = slots
|
||||
self.cached_idx_map[slots] = cpu_row_idxs
|
||||
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
|
||||
self._cuda_available_row_num -= cpu_row_idxs.numel()
|
||||
self._cpu_to_cuda_elpase += timer.elapsed
|
||||
weight_size = cpu_row_idxs.numel() * self.embedding_dim
|
||||
self._cpu_to_cuda_numel += weight_size
|
||||
# print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
|
||||
|
||||
def _evict(self) -> int:
|
||||
"""
|
||||
evict one chunk from cuda to cpu.
|
||||
Returns:
|
||||
(int) : the slot id be evicted.
|
||||
"""
|
||||
mask = torch.logical_or(torch.isin(self.cached_idx_map, self.evict_backlist), self.cached_idx_map == -1)
|
||||
buf = self.cached_idx_map[mask].clone()
|
||||
idx = torch.nonzero(mask).squeeze(1)
|
||||
self.cached_idx_map.index_fill_(0, idx, -1)
|
||||
max_row, max_cpu_row_idx = torch.max(self.cached_idx_map, dim=0)
|
||||
max_gpu_row_idx = self.cached_idx_map[max_cpu_row_idx]
|
||||
|
||||
if max_gpu_row_idx == -1:
|
||||
raise RuntimeError("Can not evict a row")
|
||||
|
||||
max_gpu_row_idx = max_gpu_row_idx.item()
|
||||
max_offset = self.inverted_cached_idx[max_gpu_row_idx]
|
||||
# recover
|
||||
self.cached_idx_map.index_copy_(0, idx, buf)
|
||||
|
||||
with Timer() as timer:
|
||||
cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, max_offset * self.embedding_dim,
|
||||
self.embedding_dim).view(1, self.embedding_dim)
|
||||
self.cpu_weight_data(max_gpu_row_idx).data.copy_(cuda_tensor)
|
||||
|
||||
# update inverted_cached_idx, min_slot_id is evicted from cuda
|
||||
self.cached_idx_map[max_cpu_row_idx] = -1
|
||||
|
||||
self.inverted_cached_idx[max_gpu_row_idx] = -1
|
||||
|
||||
self._cuda_available_row_num += 1
|
||||
|
||||
self._cuda_to_cpu_numel += self.embedding_dim
|
||||
self._cuda_to_cpu_elapse += timer.elapsed
|
||||
# self.num_write_back_history[-1] += 1
|
||||
return max_cpu_row_idx
|
||||
|
||||
def _find_free_cuda_row(self) -> int:
|
||||
if self._cuda_available_row_num == 0:
|
||||
return -1
|
||||
candidates = torch.nonzero(self.cached_idx_map == -1).squeeze(1)
|
||||
return candidates[0].item()
|
||||
|
||||
@torch.no_grad()
|
||||
def _admit(self, row_id: int):
|
||||
"""
|
||||
move in row_id to CUDA
|
||||
|
||||
Args:
|
||||
row_id (int): the id of row to be moved in
|
||||
"""
|
||||
# find a free slot in partial cuda weight
|
||||
slot_id = self._find_free_cuda_row()
|
||||
|
||||
if slot_id == -1:
|
||||
# evict one row
|
||||
slot_id = self._evict()
|
||||
slot_offset = slot_id
|
||||
# copy payload from cpu to cuda
|
||||
with Timer() as timer:
|
||||
cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, slot_offset * self.embedding_dim,
|
||||
self.embedding_dim).view(1, self.embedding_dim)
|
||||
cuda_tensor.data.copy_(self.cpu_weight_data(row_id))
|
||||
|
||||
# update the inverted_cached_idx
|
||||
self.cached_idx_map[slot_id] = row_id
|
||||
self.inverted_cached_idx[row_id] = slot_offset
|
||||
|
||||
self._cuda_available_row_num -= 1
|
||||
|
||||
self._cpu_to_cuda_numel += self.embedding_dim
|
||||
self._cpu_to_cuda_elpase += timer.elapsed
|
48
colossalai/nn/parallel/layers/cache_embedding/copyer.py
Normal file
48
colossalai/nn/parallel/layers/cache_embedding/copyer.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
from torch import LongTensor
|
||||
|
||||
|
||||
class LimitBuffIndexCopyer(object):
|
||||
"""LimitBuffIndexCopyer
|
||||
Index Copy using limited temp buffer on CUDA.
|
||||
|
||||
Args:
|
||||
size (int): buffer size
|
||||
"""
|
||||
|
||||
def __init__(self, size: int) -> None:
|
||||
self._buff_size = size
|
||||
|
||||
@torch.no_grad()
|
||||
def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor):
|
||||
"""copy
|
||||
src tensor[src_index] -(index_select)-> tmp -()-> tgt tensor [tgt_index]
|
||||
The valid part in src is continous, while in tgt is scatter.
|
||||
Args:
|
||||
dim (int): dimension along which to index
|
||||
src_index (int): indices of src tensor to select from
|
||||
tgt_index (int): indices of tgt tensor to select from
|
||||
src (torch.Tensor): the tensor containing values to copy
|
||||
tgt (torch.Tensor): the tensor to be copied
|
||||
"""
|
||||
# tgt.index_copy_(dim, index, src)
|
||||
assert dim == 0, "only support index_copy on dim 0"
|
||||
assert tgt.dim() == 2
|
||||
assert src.dim() == 2
|
||||
tgt_device = tgt.device
|
||||
src_device = src.device
|
||||
|
||||
assert src_index.numel() == tgt_index.numel()
|
||||
dim_size = src_index.numel()
|
||||
src_index = src_index.to(src_device)
|
||||
for begin_pos in range(0, dim_size, self._buff_size):
|
||||
cur_len = min(self._buff_size, dim_size - begin_pos)
|
||||
src_idx_piece = src_index.narrow(0, begin_pos, cur_len)
|
||||
if src_device.type == 'cpu' and tgt_device.type == 'cuda':
|
||||
cpu_tmp_buffer = src.index_select(dim, src_idx_piece).pin_memory()
|
||||
tmp_buffer = torch.empty_like(cpu_tmp_buffer, device=tgt_device)
|
||||
tmp_buffer.copy_(cpu_tmp_buffer)
|
||||
else:
|
||||
tmp_buffer = src.index_select(dim, src_idx_piece).to(tgt_device)
|
||||
tgt_idx_piece = tgt_index.narrow(0, begin_pos, cur_len)
|
||||
tgt.index_copy_(dim, tgt_idx_piece, tmp_buffer)
|
@@ -0,0 +1,83 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Optional, Iterator, Tuple
|
||||
|
||||
from .base_embedding import BaseEmbeddingBag
|
||||
from .cache_mgr import CachedParamMgr
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||
|
||||
def __init__(self, num_embeddings, embedding_dim, dtype=None, *args, **kwargs):
|
||||
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, *args, **kwargs)
|
||||
self._weight = torch.randn(self.num_embeddings, self.embedding_dim, device='cpu', dtype=dtype)
|
||||
|
||||
def preprocess(self,
|
||||
cuda_row_num: int,
|
||||
ids_freq_mapping: Optional[List[int]] = None,
|
||||
warmup_ratio=0.7,
|
||||
buffer_size=50_000):
|
||||
"""
|
||||
Called after initialized.
|
||||
Reorder the weight rows according to the ids_freq_mapping.
|
||||
Then, let the weights of the Module be managed by a CachedParamMgr.
|
||||
Args:
|
||||
cuda_row_num (int): number of rows can be hosted in CUDA memory
|
||||
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
|
||||
warmup_ratio (float): the amount of rows preloaded in cuda cache
|
||||
"""
|
||||
self.cache_weight_mgr = CachedParamMgr(self._weight, cuda_row_num, buffer_size)
|
||||
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
||||
|
||||
def forward(self, indices, offsets=None, per_sample_weights=None):
|
||||
with torch.no_grad():
|
||||
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
|
||||
|
||||
embeddings = F.embedding_bag(reorder_ids, self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||
per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||
|
||||
return embeddings
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
assert self.cache_weight_mgr is not None
|
||||
return self.cache_weight_mgr.cpu_weight.narrow(0, 0, self.num_embeddings)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
|
||||
yield 'weight', self.cache_weight_mgr.cuda_cached_weight
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
yield self.cache_weight_mgr.cuda_cached_weight
|
||||
|
||||
@property
|
||||
def num_hits_history(self):
|
||||
return self.cache_weight_mgr.num_hits_history
|
||||
|
||||
@property
|
||||
def num_miss_history(self):
|
||||
return self.cache_weight_mgr.num_miss_history
|
||||
|
||||
@property
|
||||
def num_write_back_history(self):
|
||||
return self.cache_weight_mgr.num_write_back_history
|
||||
|
||||
@property
|
||||
def swap_in_bandwidth(self):
|
||||
if self.cache_weight_mgr._cpu_to_cuda_numel > 0:
|
||||
return self.cache_weight_mgr._cpu_to_cuda_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \
|
||||
self.cache_weight_mgr._cpu_to_cuda_elpase
|
||||
else:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def swap_out_bandwidth(self):
|
||||
if self.cache_weight_mgr._cuda_to_cpu_numel > 0:
|
||||
return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \
|
||||
self.cache_weight_mgr._cuda_to_cpu_elapse
|
||||
return 0
|
||||
|
||||
@property
|
||||
def input_id_percent_in_load_chunk(self):
|
||||
return 0 # np.mean(self.cache_weight_mgr.input_id_percent_in_load_chunk) * 100
|
@@ -0,0 +1,135 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Optional, Iterator, Tuple
|
||||
|
||||
from .base_embedding import BaseEmbeddingBag
|
||||
from .cache_mgr import CachedParamMgr
|
||||
from torch.nn.parameter import Parameter
|
||||
from colossalai.nn._ops._utils import dual_all_to_all
|
||||
|
||||
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec
|
||||
|
||||
|
||||
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
||||
if world_size == 1:
|
||||
return 0, embedding_dim, True
|
||||
|
||||
assert embedding_dim >= world_size, \
|
||||
f"Embedding dimension {embedding_dim} must be larger than the world size " \
|
||||
f"{world_size} of the process group"
|
||||
chunk_size = embedding_dim // world_size
|
||||
threshold = embedding_dim % world_size
|
||||
# if embedding dim is divisible by world size
|
||||
if threshold == 0:
|
||||
return rank * chunk_size, (rank + 1) * chunk_size, True
|
||||
|
||||
# align with the split strategy of torch.tensor_split
|
||||
size_list = [chunk_size + 1 if i < threshold else chunk_size for i in range(world_size)]
|
||||
offset = sum(size_list[:rank])
|
||||
return offset, offset + size_list[rank], False
|
||||
|
||||
|
||||
class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=None,
|
||||
max_norm=None,
|
||||
norm_type=2.,
|
||||
scale_grad_by_freq=False,
|
||||
sparse=False,
|
||||
_weight=None,
|
||||
mode='mean',
|
||||
include_last_offset=False,
|
||||
dtype=None,
|
||||
debug=True):
|
||||
super(ParallelFreqAwareEmbeddingBag,
|
||||
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
||||
sparse, mode, include_last_offset)
|
||||
|
||||
self.rank = torch.distributed.get_rank()
|
||||
self.world_size = torch.distributed.get_world_size()
|
||||
self.debug = debug
|
||||
|
||||
self.partition_start_index, self.partition_end_index, divisible = get_partition(
|
||||
embedding_dim, self.rank, self.world_size)
|
||||
self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index
|
||||
|
||||
if _weight is None:
|
||||
colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size),
|
||||
dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]),
|
||||
compute_attr=ComputePattern.TP1D)
|
||||
self._weight = ColoParameter.from_torch_tensor(torch.empty(self.num_embeddings,
|
||||
self.embedding_dim_per_partition,
|
||||
device='cpu',
|
||||
dtype=dtype),
|
||||
requires_grad=True,
|
||||
spec=colo_tensor_spec)
|
||||
self.init_parameters()
|
||||
else:
|
||||
assert isinstance(_weight, ColoParameter), "initialized weight must in type of ColoParameter"
|
||||
self._weight = _weight
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.cache_weight_mgr.cpu_weight
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
|
||||
yield 'weight', self.cache_weight_mgr.cuda_cached_weight
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
yield self.cache_weight_mgr.cuda_cached_weight
|
||||
|
||||
@torch.no_grad()
|
||||
def init_parameters(self):
|
||||
self._weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)
|
||||
if self.padding_idx is not None:
|
||||
self._weight[self.padding_idx].fill_(0)
|
||||
|
||||
def preprocess(self,
|
||||
cuda_row_num: int,
|
||||
ids_freq_mapping: Optional[List[int]] = None,
|
||||
warmup_ratio: float = 0.7,
|
||||
buffer_size: int = 50_000):
|
||||
self.cache_weight_mgr = CachedParamMgr(self._weight, cuda_row_num, buffer_size=buffer_size)
|
||||
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
||||
|
||||
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1):
|
||||
with torch.no_grad():
|
||||
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
|
||||
|
||||
output_shard = F.embedding_bag(reorder_ids, self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||
per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||
|
||||
if shape_hook is not None:
|
||||
output_shard = shape_hook(output_shard)
|
||||
|
||||
output_full = dual_all_to_all(output_shard,
|
||||
self._weight.get_process_group(),
|
||||
scatter_dim=scatter_dim,
|
||||
gather_dim=gather_dim)
|
||||
return output_full
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
embedding: torch.Tensor,
|
||||
freeze: bool = True,
|
||||
padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2.,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
mode: str = 'mean',
|
||||
include_last_offset: bool = False,
|
||||
debug: bool = True,
|
||||
cuda_row_num: int = 100_000,
|
||||
ids_freq_mapping: Optional[List[int]] = None,
|
||||
warmup_ratio: float = 0.7) -> 'ParallelFreqAwareEmbeddingBag':
|
||||
rows, cols = embedding.shape
|
||||
embedding_bag = cls(rows, cols, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, embedding, mode,
|
||||
include_last_offset, debug)
|
||||
embedding_bag.preprocess(cuda_row_num, ids_freq_mapping, warmup_ratio)
|
||||
embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze
|
||||
return embedding_bag
|
Reference in New Issue
Block a user