mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-14 22:23:23 +00:00
[FCE] update interface for frequency statistics in FreqCacheEmbedding (#1462)
This commit is contained in:
parent
ede326298b
commit
0aad53c62b
@ -14,12 +14,17 @@ class CachedParamMgr(torch.nn.Module):
|
|||||||
During training, GPU needs to transmit rows between CPU and GPU.
|
During training, GPU needs to transmit rows between CPU and GPU.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, weight: torch.Tensor, cuda_row_num: int = 0, buffer_size: int = 50_000) -> None:
|
def __init__(self,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
cuda_row_num: int = 0,
|
||||||
|
buffer_size: int = 50_000,
|
||||||
|
pin_weight=False) -> None:
|
||||||
super(CachedParamMgr, self).__init__()
|
super(CachedParamMgr, self).__init__()
|
||||||
self.buffer_size = buffer_size
|
self.buffer_size = buffer_size
|
||||||
self.num_embeddings, self.embedding_dim = weight.shape
|
self.num_embeddings, self.embedding_dim = weight.shape
|
||||||
self.cuda_row_num = cuda_row_num
|
self.cuda_row_num = cuda_row_num
|
||||||
self._cuda_available_row_num = self.cuda_row_num
|
self._cuda_available_row_num = self.cuda_row_num
|
||||||
|
self.pin_weight = pin_weight
|
||||||
|
|
||||||
self.elem_size_in_byte = weight.element_size()
|
self.elem_size_in_byte = weight.element_size()
|
||||||
|
|
||||||
@ -43,8 +48,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||||||
dtype=weight.dtype))
|
dtype=weight.dtype))
|
||||||
|
|
||||||
# pin memory cpu for higher CPU-GPU copy bandwidth
|
# pin memory cpu for higher CPU-GPU copy bandwidth
|
||||||
self.weight = weight.contiguous().cpu().pin_memory()
|
self.weight = weight.pin_memory() if self.pin_weight else weight
|
||||||
|
|
||||||
# map original id to new id with respect to frequency
|
# map original id to new id with respect to frequency
|
||||||
# id -> cpu_row_idx
|
# id -> cpu_row_idx
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
@ -109,7 +113,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||||||
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
||||||
"""
|
"""
|
||||||
if ids_freq_mapping is not None:
|
if ids_freq_mapping is not None:
|
||||||
tmp_idx = torch.argsort(torch.from_numpy(ids_freq_mapping).cuda(), descending=True)
|
tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
|
||||||
sorted_idx = torch.argsort(tmp_idx)
|
sorted_idx = torch.argsort(tmp_idx)
|
||||||
self.idx_map.data.copy_(sorted_idx)
|
self.idx_map.data.copy_(sorted_idx)
|
||||||
|
|
||||||
|
@ -27,20 +27,19 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|||||||
ids_freq_mapping=None,
|
ids_freq_mapping=None,
|
||||||
warmup_ratio=0.7,
|
warmup_ratio=0.7,
|
||||||
buffer_size=50_000,
|
buffer_size=50_000,
|
||||||
|
pin_weight=False,
|
||||||
):
|
):
|
||||||
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
|
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
|
||||||
scale_grad_by_freq, sparse, mode, include_last_offset)
|
scale_grad_by_freq, sparse, mode, include_last_offset)
|
||||||
|
|
||||||
if _weight is None:
|
if _weight is None:
|
||||||
_weight = self._weight_alloc(dtype, device)
|
_weight = self._weight_alloc(dtype, device)
|
||||||
else:
|
|
||||||
_weight = _weight
|
|
||||||
|
|
||||||
# configure weight & cache
|
# configure weight & cache
|
||||||
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size)
|
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight)
|
||||||
|
|
||||||
def _weight_alloc(self, dtype, device):
|
def _weight_alloc(self, dtype, device):
|
||||||
weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device, pin_memory=True)
|
weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)
|
weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)
|
||||||
if self.padding_idx is not None:
|
if self.padding_idx is not None:
|
||||||
@ -52,7 +51,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|||||||
cuda_row_num: int,
|
cuda_row_num: int,
|
||||||
ids_freq_mapping: Optional[List[int]] = None,
|
ids_freq_mapping: Optional[List[int]] = None,
|
||||||
warmup_ratio=0.7,
|
warmup_ratio=0.7,
|
||||||
buffer_size=50_000):
|
buffer_size=50_000,
|
||||||
|
pin_weight=False):
|
||||||
"""
|
"""
|
||||||
Called after initialized.
|
Called after initialized.
|
||||||
Reorder the weight rows according to the ids_freq_mapping.
|
Reorder the weight rows according to the ids_freq_mapping.
|
||||||
@ -63,17 +63,18 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|||||||
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
|
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
|
||||||
warmup_ratio (float): the amount of rows preloaded in cuda cache
|
warmup_ratio (float): the amount of rows preloaded in cuda cache
|
||||||
"""
|
"""
|
||||||
self.cache_weight_mgr = CachedParamMgr(weight, cuda_row_num, buffer_size)
|
self.cache_weight_mgr = CachedParamMgr(weight, cuda_row_num, buffer_size, pin_weight)
|
||||||
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
||||||
|
|
||||||
def forward(self, indices, offsets=None, per_sample_weights=None):
|
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
|
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
|
||||||
|
|
||||||
embeddings = F.embedding_bag(reorder_ids, self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
embeddings = F.embedding_bag(reorder_ids, self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
||||||
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||||
per_sample_weights, self.include_last_offset, self.padding_idx)
|
per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||||
|
if shape_hook is not None:
|
||||||
|
embeddings = shape_hook(embeddings)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -3,8 +3,6 @@ import torch.nn.functional as F
|
|||||||
from typing import List, Optional, Iterator, Tuple
|
from typing import List, Optional, Iterator, Tuple
|
||||||
|
|
||||||
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
||||||
from .cache_mgr import CachedParamMgr
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
from colossalai.nn._ops._utils import dual_all_to_all
|
from colossalai.nn._ops._utils import dual_all_to_all
|
||||||
|
|
||||||
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
|
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
|
||||||
@ -49,6 +47,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
|||||||
ids_freq_mapping=None,
|
ids_freq_mapping=None,
|
||||||
warmup_ratio=0.7,
|
warmup_ratio=0.7,
|
||||||
buffer_size=50_000,
|
buffer_size=50_000,
|
||||||
|
pin_weight=False,
|
||||||
):
|
):
|
||||||
self.rank = torch.distributed.get_rank()
|
self.rank = torch.distributed.get_rank()
|
||||||
self.world_size = torch.distributed.get_world_size()
|
self.world_size = torch.distributed.get_world_size()
|
||||||
@ -60,17 +59,18 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
|||||||
super(ParallelFreqAwareEmbeddingBag,
|
super(ParallelFreqAwareEmbeddingBag,
|
||||||
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
||||||
sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping,
|
sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping,
|
||||||
warmup_ratio, buffer_size)
|
warmup_ratio, buffer_size, pin_weight)
|
||||||
|
|
||||||
def _weight_alloc(self, dtype, device):
|
def _weight_alloc(self, dtype, device):
|
||||||
|
weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype)
|
||||||
|
with torch.no_grad():
|
||||||
|
weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)
|
||||||
|
if self.padding_idx is not None:
|
||||||
|
weight[self.padding_idx].fill_(0)
|
||||||
colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size),
|
colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size),
|
||||||
dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]),
|
dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]),
|
||||||
compute_attr=ComputePattern.TP1D)
|
compute_attr=ComputePattern.TP1D)
|
||||||
return ColoTensor.from_torch_tensor(torch.empty(self.num_embeddings,
|
return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec)
|
||||||
self.embedding_dim_per_partition,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype),
|
|
||||||
spec=colo_tensor_spec)
|
|
||||||
|
|
||||||
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1):
|
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -44,7 +44,7 @@ def synthesize_1d_sparse_feature(
|
|||||||
def test_cachemgr():
|
def test_cachemgr():
|
||||||
model = torch.nn.EmbeddingBag(10000, 128)
|
model = torch.nn.EmbeddingBag(10000, 128)
|
||||||
# 10 chunks, 5 in cuda
|
# 10 chunks, 5 in cuda
|
||||||
mgr = CachedParamMgr(model.weight, 5)
|
mgr = CachedParamMgr(model.weight.detach(), 5)
|
||||||
assert mgr.cuda_row_num == 5
|
assert mgr.cuda_row_num == 5
|
||||||
|
|
||||||
mgr._admit(1)
|
mgr._admit(1)
|
||||||
@ -74,8 +74,8 @@ def test_reorder_with_freq():
|
|||||||
chunk_size = 1
|
chunk_size = 1
|
||||||
num_chunk = 5
|
num_chunk = 5
|
||||||
|
|
||||||
idx_map = np.random.randint(10000, size=(num_embed,))
|
idx_map = torch.randint(10000, size=(num_embed,))
|
||||||
sorted_idx = np.flipud(np.argsort(idx_map)).tolist()
|
sorted_idx = torch.argsort(idx_map, descending=True).tolist()
|
||||||
chunkid, offset_in_chunk = [], []
|
chunkid, offset_in_chunk = [], []
|
||||||
for i in range(num_embed):
|
for i in range(num_embed):
|
||||||
idx = sorted_idx.index(i)
|
idx = sorted_idx.index(i)
|
||||||
@ -231,6 +231,6 @@ def test_parallel_freq_aware_embed(world_size):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# test_cachemgr()
|
test_cachemgr()
|
||||||
# test_freq_aware_embed()
|
# test_freq_aware_embed()
|
||||||
test_parallel_freq_aware_embed(2)
|
# test_parallel_freq_aware_embed(2)
|
||||||
|
Loading…
Reference in New Issue
Block a user