mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[gemini] fix argument naming during chunk configuration searching
This commit is contained in:
@@ -35,7 +35,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
@@ -67,7 +67,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
|
||||
torch_model = model_builder() # get a different model
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||
|
||||
|
Reference in New Issue
Block a user