mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[FCE] update interface for frequency statistics in FreqCacheEmbedding (#1462)
This commit is contained in:
@@ -44,7 +44,7 @@ def synthesize_1d_sparse_feature(
|
||||
def test_cachemgr():
|
||||
model = torch.nn.EmbeddingBag(10000, 128)
|
||||
# 10 chunks, 5 in cuda
|
||||
mgr = CachedParamMgr(model.weight, 5)
|
||||
mgr = CachedParamMgr(model.weight.detach(), 5)
|
||||
assert mgr.cuda_row_num == 5
|
||||
|
||||
mgr._admit(1)
|
||||
@@ -74,8 +74,8 @@ def test_reorder_with_freq():
|
||||
chunk_size = 1
|
||||
num_chunk = 5
|
||||
|
||||
idx_map = np.random.randint(10000, size=(num_embed,))
|
||||
sorted_idx = np.flipud(np.argsort(idx_map)).tolist()
|
||||
idx_map = torch.randint(10000, size=(num_embed,))
|
||||
sorted_idx = torch.argsort(idx_map, descending=True).tolist()
|
||||
chunkid, offset_in_chunk = [], []
|
||||
for i in range(num_embed):
|
||||
idx = sorted_idx.index(i)
|
||||
@@ -231,6 +231,6 @@ def test_parallel_freq_aware_embed(world_size):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_cachemgr()
|
||||
test_cachemgr()
|
||||
# test_freq_aware_embed()
|
||||
test_parallel_freq_aware_embed(2)
|
||||
# test_parallel_freq_aware_embed(2)
|
||||
|
Reference in New Issue
Block a user