mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[gemini] fix argument naming during chunk configuration searching
This commit is contained in:
@@ -52,7 +52,7 @@ def exam_gpt_fwd_bwd(
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gather
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
@@ -113,7 +113,7 @@ def exam_gpt_inference(
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gather
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
|
@@ -56,7 +56,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
|
||||
assert len(step_list) == 4
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gather
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
|
@@ -51,7 +51,7 @@ def exam_grad_clipping(placement_policy, model_name: str):
|
||||
p.data.copy_(torch_p.data)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = False
|
||||
if placement_policy != 'cuda':
|
||||
|
@@ -34,7 +34,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
|
||||
def multi_chunk_init(model: torch.nn.Module, placement_policy: str):
|
||||
world_size = dist.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = False
|
||||
if placement_policy != 'cuda':
|
||||
|
@@ -73,7 +73,7 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
|
||||
p.data.copy_(torch_p.data)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = False
|
||||
if placement_policy != 'cuda':
|
||||
@@ -130,7 +130,7 @@ def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
p.data.copy_(torch_p.data)
|
||||
|
||||
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
|
||||
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_m=1)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
|
@@ -30,9 +30,9 @@ def exam_search_chunk_size():
|
||||
model = model_builder()
|
||||
init_1d_row_spec(model, pg_tp)
|
||||
config_dict, *_ = search_chunk_configuration(model,
|
||||
search_range_mb=1,
|
||||
search_interval_byte=16,
|
||||
min_chunk_size_mb=0,
|
||||
search_range_m=1,
|
||||
search_interval=16,
|
||||
min_chunk_size_m=0,
|
||||
filter_exlarge_params=True)
|
||||
|
||||
for key in config_dict:
|
||||
@@ -54,9 +54,9 @@ def exam_search_strict_ddp():
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
ddp_model = model_builder()
|
||||
re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model,
|
||||
search_range_mb=1,
|
||||
search_interval_byte=16,
|
||||
min_chunk_size_mb=0,
|
||||
search_range_m=1,
|
||||
search_interval=16,
|
||||
min_chunk_size_m=0,
|
||||
filter_exlarge_params=True,
|
||||
strict_ddp_flag=False)
|
||||
# get the chunk configuration over sharded ddp models
|
||||
@@ -64,9 +64,9 @@ def exam_search_strict_ddp():
|
||||
default_dist_spec=default_shard_spec):
|
||||
sharded_ddp_model = model_builder()
|
||||
sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model,
|
||||
search_range_mb=1,
|
||||
search_interval_byte=16,
|
||||
min_chunk_size_mb=0,
|
||||
search_range_m=1,
|
||||
search_interval=16,
|
||||
min_chunk_size_m=0,
|
||||
filter_exlarge_params=True,
|
||||
strict_ddp_flag=True)
|
||||
assert re_dict == sh_dict
|
||||
@@ -91,8 +91,8 @@ def exam_chunk_manager():
|
||||
chunk_manager = init_chunk_manager(sharded_ddp_model,
|
||||
get_current_device(),
|
||||
hidden_dim=16,
|
||||
search_range_mb=1,
|
||||
min_chunk_size_mb=0,
|
||||
search_range_m=1,
|
||||
min_chunk_size_m=0,
|
||||
filter_exlarge_params=True,
|
||||
strict_ddp_flag=True)
|
||||
config_dict = chunk_manager.dp_degree_chunk_size_dict
|
||||
|
@@ -35,7 +35,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
@@ -67,7 +67,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
|
||||
torch_model = model_builder() # get a different model
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||
|
||||
|
@@ -22,7 +22,7 @@ def exam_state_dict(placement_policy, model_name: str):
|
||||
|
||||
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
|
||||
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
@@ -38,6 +38,7 @@ def exam_state_dict(placement_policy, model_name: str):
|
||||
assert key in zero_dict, f"{key} not in ZeRO dictionary."
|
||||
assert torch.equal(value, zero_dict[key]), f"{key} not equal."
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
@@ -27,7 +27,7 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
|
||||
torch_model = model_builder() # get a different model
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||
|
||||
|
Reference in New Issue
Block a user