[FAW] cpu caching operations (#1520)

This commit is contained in:
Jiarui Fang
2022-08-30 14:50:02 +08:00
committed by GitHub
parent 481aecb05a
commit 9a9ef65313
4 changed files with 86 additions and 64 deletions

View File

@@ -83,15 +83,16 @@ def test_reorder_with_freq():
chunkid.append(idx // chunk_size)
offset_in_chunk.append(idx % chunk_size)
chunkid = torch.tensor(chunkid, dtype=torch.long, device=torch.cuda.current_device())
offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=torch.cuda.current_device())
dev = torch.device('cuda')
chunkid = torch.tensor(chunkid, dtype=torch.long, device=dev)
offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev)
weight = torch.rand(num_embed, 2)
mgr = CachedParamMgr(weight, num_chunk)
mgr = CachedParamMgr(weight, num_chunk, use_cpu_caching=dev.type == 'cpu')
mgr.reorder(idx_map)
indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=torch.cuda.current_device()))
indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=dev))
mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode='floor')
mgr_offsets = torch.remainder(indices, chunk_size)
assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}"
@@ -280,6 +281,6 @@ def test_parallel_freq_aware_embed(world_size):
if __name__ == '__main__':
# test_freq_aware_embed(True)
test_freq_aware_embed(True)
# test_parallel_freq_aware_embed(2)
test_lfu_strategy(False)
# test_lfu_strategy(False)