add overlap option (#2613)

This commit is contained in:
YuliangLiu0306
2023-02-08 15:02:31 +08:00
committed by GitHub
parent cb3d1bef62
commit 28398f1c70
2 changed files with 32 additions and 16 deletions

View File

@@ -98,16 +98,22 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
return solution
def transform_to_sharded_model(gm: ColoGraphModule, solution: List[int], device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor):
def transform_to_sharded_model(gm: ColoGraphModule,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor,
overlap: bool = False):
'''
This method is used to transform the original graph to the sharded graph.
The model parameters will be sharded according to the solution and the grad hooks
will be added to the sharded graph using the runtime_preparation_pass.
The communication node will be added into the graph using the runtime_apply_pass.
'''
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
gm, solution, device_mesh, strategies_constructor)
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm,
solution,
device_mesh,
strategies_constructor,
overlap=overlap)
gm = runtime_apply_pass(gm)
gm.recompile()
sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict)
@@ -176,6 +182,7 @@ def initialize_model(model: nn.Module,
meta_args: Dict[str, torch.Tensor],
device_mesh: DeviceMesh,
memory_budget: float = -1.0,
overlap: bool = False,
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solution_path: str = None,
@@ -189,6 +196,8 @@ def initialize_model(model: nn.Module,
device_mesh: the device mesh to execute the model.
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
overlap(optional): the overlap is used to specify whether to overlap gradient communication and
backward computing.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
@@ -211,7 +220,7 @@ def initialize_model(model: nn.Module,
if save_solver_solution:
torch.save(solution, solution_path)
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor)
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor, overlap)
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
if return_solution: