mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[FAW] cpu caching operations (#1520)
This commit is contained in:
@@ -83,15 +83,16 @@ def test_reorder_with_freq():
|
||||
chunkid.append(idx // chunk_size)
|
||||
offset_in_chunk.append(idx % chunk_size)
|
||||
|
||||
chunkid = torch.tensor(chunkid, dtype=torch.long, device=torch.cuda.current_device())
|
||||
offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=torch.cuda.current_device())
|
||||
dev = torch.device('cuda')
|
||||
chunkid = torch.tensor(chunkid, dtype=torch.long, device=dev)
|
||||
offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev)
|
||||
|
||||
weight = torch.rand(num_embed, 2)
|
||||
mgr = CachedParamMgr(weight, num_chunk)
|
||||
mgr = CachedParamMgr(weight, num_chunk, use_cpu_caching=dev.type == 'cpu')
|
||||
|
||||
mgr.reorder(idx_map)
|
||||
|
||||
indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=torch.cuda.current_device()))
|
||||
indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=dev))
|
||||
mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode='floor')
|
||||
mgr_offsets = torch.remainder(indices, chunk_size)
|
||||
assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}"
|
||||
@@ -280,6 +281,6 @@ def test_parallel_freq_aware_embed(world_size):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_freq_aware_embed(True)
|
||||
test_freq_aware_embed(True)
|
||||
# test_parallel_freq_aware_embed(2)
|
||||
test_lfu_strategy(False)
|
||||
# test_lfu_strategy(False)
|
||||
|
Reference in New Issue
Block a user