mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 10:30:03 +00:00
[NFC] polish doc style for ColoTensor (#1457)
This commit is contained in:
@@ -99,9 +99,11 @@ class CachedParamMgr(torch.nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7):
|
||||
"""reorder the weight according to ids' frequency in dataset before training.
|
||||
"""reorder
|
||||
reorder the weight according to ids' frequency in dataset before training.
|
||||
Also Build the IndexMappingTable, aka index_mapping_table.
|
||||
Execute only once before training.
|
||||
|
||||
Args:
|
||||
ids_freq_mapping (List[int]): a list, idx is id number, value is freq. if None no reorder
|
||||
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
||||
|
@@ -16,8 +16,9 @@ class LimitBuffIndexCopyer(object):
|
||||
@torch.no_grad()
|
||||
def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor):
|
||||
"""copy
|
||||
src tensor[src_index] -(index_select)-> tmp -()-> tgt tensor [tgt_index]
|
||||
The valid part in src is continous, while in tgt is scatter.
|
||||
src tensor[src_index] -(index_select)-> tmp -(index_copy_)-> tgt tensor [tgt_index]
|
||||
The valid rows in the src tensor are continous, while rows in tgt tensor is scattered.
|
||||
|
||||
Args:
|
||||
dim (int): dimension along which to index
|
||||
src_index (int): indices of src tensor to select from
|
||||
|
@@ -57,6 +57,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||
Called after initialized.
|
||||
Reorder the weight rows according to the ids_freq_mapping.
|
||||
Then, let the weights of the Module be managed by a CachedParamMgr.
|
||||
|
||||
Args:
|
||||
cuda_row_num (int): number of rows can be hosted in CUDA memory
|
||||
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
|
||||
|
Reference in New Issue
Block a user