mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 11:08:50 +00:00
[embedding] polish parallel embedding tablewise (#1545)
This commit is contained in:
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
|
||||
|
||||
class TablewiseEmbeddingBagConfig:
|
||||
'''
|
||||
example:
|
||||
def prepare_tablewise_config(args, cache_ratio, ...):
|
||||
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
|
||||
...
|
||||
return embedding_bag_config_list
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
cuda_row_num: int,
|
||||
assigned_rank: int = 0,
|
||||
buffer_size=50_000,
|
||||
ids_freq_mapping=None,
|
||||
initial_weight: torch.tensor = None,
|
||||
name: str = ""):
|
||||
self.num_embeddings = num_embeddings
|
||||
self.cuda_row_num = cuda_row_num
|
||||
self.assigned_rank = assigned_rank
|
||||
self.buffer_size = buffer_size
|
||||
self.ids_freq_mapping = ids_freq_mapping
|
||||
self.initial_weight = initial_weight
|
||||
self.name = name
|
Reference in New Issue
Block a user