[NFC] polish doc style for ColoTensor (#1457)

This commit is contained in:
Jiarui Fang
2022-08-16 09:21:05 +08:00
committed by GitHub
parent 0dbd61c29b
commit a1476ea882
9 changed files with 197 additions and 48 deletions

View File

@@ -99,9 +99,11 @@ class CachedParamMgr(torch.nn.Module):
@torch.no_grad()
def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7):
"""reorder the weight according to ids' frequency in dataset before training.
"""reorder
reorder the weight according to ids' frequency in dataset before training.
Also Build the IndexMappingTable, aka index_mapping_table.
Execute only once before training.
Args:
ids_freq_mapping (List[int]): a list, idx is id number, value is freq. if None no reorder
warmup_ratio (float): the amount of chunks preloaded in cuda cache

View File

@@ -16,8 +16,9 @@ class LimitBuffIndexCopyer(object):
@torch.no_grad()
def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor):
"""copy
src tensor[src_index] -(index_select)-> tmp -()-> tgt tensor [tgt_index]
The valid part in src is continous, while in tgt is scatter.
src tensor[src_index] -(index_select)-> tmp -(index_copy_)-> tgt tensor [tgt_index]
The valid rows in the src tensor are continous, while rows in tgt tensor is scattered.
Args:
dim (int): dimension along which to index
src_index (int): indices of src tensor to select from

View File

@@ -57,6 +57,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
Called after initialized.
Reorder the weight rows according to the ids_freq_mapping.
Then, let the weights of the Module be managed by a CachedParamMgr.
Args:
cuda_row_num (int): number of rows can be hosted in CUDA memory
ids_freq_mapping (List[int]): a list, idx is id number, value is freq