mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[autoparallel] gpt2 autoparallel examples (#2267)
* [autoparallel] gpt2 autoparallel examples * polish code * polish code
This commit is contained in:
@@ -172,7 +172,8 @@ def initialize_model(model: nn.Module,
|
||||
memory_budget: float = -1.0,
|
||||
save_solver_solution: bool = False,
|
||||
load_solver_solution: bool = False,
|
||||
solution_path: str = None):
|
||||
solution_path: str = None,
|
||||
return_solution: bool = False):
|
||||
'''
|
||||
This method is used to initialize the sharded model which could be used as normal pytorch model.
|
||||
|
||||
@@ -187,6 +188,9 @@ def initialize_model(model: nn.Module,
|
||||
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
||||
from the solution_path.
|
||||
solution_path(optional): the path to save or load the solution.
|
||||
return_solution(optional): if the return_solution is True, the solution will be returned. The returned
|
||||
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
|
||||
return a series of integers, but return the best strategies.
|
||||
'''
|
||||
tracer = ColoTracer()
|
||||
|
||||
@@ -204,7 +208,14 @@ def initialize_model(model: nn.Module,
|
||||
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor)
|
||||
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
|
||||
|
||||
return model_to_return
|
||||
if return_solution:
|
||||
solution_to_return = []
|
||||
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||
for index, node in enumerate(nodes):
|
||||
solution_to_return.append(f'{node.name} {node.strategies_vector[solution[index]].name}')
|
||||
return model_to_return, solution_to_return
|
||||
else:
|
||||
return model_to_return
|
||||
|
||||
|
||||
def autoparallelize(model: nn.Module,
|
||||
@@ -216,6 +227,7 @@ def autoparallelize(model: nn.Module,
|
||||
save_solver_solution: bool = False,
|
||||
load_solver_solution: bool = False,
|
||||
solver_solution_path: str = None,
|
||||
return_solution: bool = False,
|
||||
memory_budget: float = -1.0):
|
||||
'''
|
||||
This method is used to initialize the device mesh, extract the meta_args, and
|
||||
@@ -238,18 +250,26 @@ def autoparallelize(model: nn.Module,
|
||||
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
||||
from the solution_path.
|
||||
solver_solution_path(optional): the path to save or load the solution.
|
||||
return_solution(optional): if the return_solution is True, the solution will be returned.
|
||||
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
|
||||
the memory budget will be infinity.
|
||||
'''
|
||||
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape)
|
||||
if meta_args is None:
|
||||
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
|
||||
model = initialize_model(model,
|
||||
meta_args,
|
||||
device_mesh,
|
||||
save_solver_solution=save_solver_solution,
|
||||
load_solver_solution=load_solver_solution,
|
||||
solver_solution_path=solver_solution_path,
|
||||
memory_budget=memory_budget)
|
||||
|
||||
return model
|
||||
rst_to_unpack = initialize_model(model,
|
||||
meta_args,
|
||||
device_mesh,
|
||||
save_solver_solution=save_solver_solution,
|
||||
load_solver_solution=load_solver_solution,
|
||||
solver_solution_path=solver_solution_path,
|
||||
return_solution=return_solution,
|
||||
memory_budget=memory_budget)
|
||||
|
||||
if return_solution:
|
||||
model, solution = rst_to_unpack
|
||||
return model, solution
|
||||
else:
|
||||
model = rst_to_unpack
|
||||
return model
|
||||
|
Reference in New Issue
Block a user