[autoparallel] gpt2 autoparallel examples (#2267)

* [autoparallel] gpt2 autoparallel examples

* polish code

* polish code
This commit is contained in:
YuliangLiu0306
2023-01-03 14:23:33 +08:00
committed by GitHub
parent 8b045b3c1f
commit 4b29112ab2
5 changed files with 440 additions and 10 deletions

View File

@@ -172,7 +172,8 @@ def initialize_model(model: nn.Module,
memory_budget: float = -1.0,
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solution_path: str = None):
solution_path: str = None,
return_solution: bool = False):
'''
This method is used to initialize the sharded model which could be used as normal pytorch model.
@@ -187,6 +188,9 @@ def initialize_model(model: nn.Module,
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
from the solution_path.
solution_path(optional): the path to save or load the solution.
return_solution(optional): if the return_solution is True, the solution will be returned. The returned
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies.
'''
tracer = ColoTracer()
@@ -204,7 +208,14 @@ def initialize_model(model: nn.Module,
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor)
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
return model_to_return
if return_solution:
solution_to_return = []
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
for index, node in enumerate(nodes):
solution_to_return.append(f'{node.name} {node.strategies_vector[solution[index]].name}')
return model_to_return, solution_to_return
else:
return model_to_return
def autoparallelize(model: nn.Module,
@@ -216,6 +227,7 @@ def autoparallelize(model: nn.Module,
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solver_solution_path: str = None,
return_solution: bool = False,
memory_budget: float = -1.0):
'''
This method is used to initialize the device mesh, extract the meta_args, and
@@ -238,18 +250,26 @@ def autoparallelize(model: nn.Module,
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
from the solution_path.
solver_solution_path(optional): the path to save or load the solution.
return_solution(optional): if the return_solution is True, the solution will be returned.
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
'''
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape)
if meta_args is None:
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
model = initialize_model(model,
meta_args,
device_mesh,
save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution,
solver_solution_path=solver_solution_path,
memory_budget=memory_budget)
return model
rst_to_unpack = initialize_model(model,
meta_args,
device_mesh,
save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution,
solver_solution_path=solver_solution_path,
return_solution=return_solution,
memory_budget=memory_budget)
if return_solution:
model, solution = rst_to_unpack
return model, solution
else:
model = rst_to_unpack
return model