mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[autoparallel] add shard option (#2696)
* [autoparallel] add shard option * polish
This commit is contained in:
@@ -8,14 +8,9 @@ from torch.fx.graph import Graph
|
||||
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (
|
||||
CostGraph,
|
||||
GraphAnalyser,
|
||||
Solver,
|
||||
SolverOptions,
|
||||
StrategiesConstructor,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
@@ -69,13 +64,43 @@ def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[f
|
||||
pass
|
||||
|
||||
|
||||
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
|
||||
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str,
|
||||
shard_option: str):
|
||||
'''
|
||||
This method is used to build the strategy_constructor for the given graph.
|
||||
After this method, each node in the graph will have a strategies_vector which
|
||||
is constructed by the related node handler.
|
||||
'''
|
||||
solver_options = SolverOptions()
|
||||
if solver_preference == 'standard':
|
||||
solver_preference = SolverPerference.STANDARD
|
||||
elif solver_preference == 'tp':
|
||||
solver_preference = SolverPerference.TP
|
||||
elif solver_preference == 'dp':
|
||||
solver_preference = SolverPerference.DP
|
||||
else:
|
||||
raise ValueError(f'Invalid solver_preference: {solver_preference}')
|
||||
|
||||
if dataloader_option == 'replicated':
|
||||
dataloader_option = DataloaderOption.REPLICATED
|
||||
elif dataloader_option == 'distributed':
|
||||
dataloader_option = DataloaderOption.DISTRIBUTED
|
||||
else:
|
||||
raise ValueError(f'Invalid dataloader_option: {dataloader_option}')
|
||||
|
||||
if shard_option == 'standard':
|
||||
shard_option = ShardOption.STANDARD
|
||||
elif shard_option == 'shard':
|
||||
shard_option = ShardOption.SHARD
|
||||
elif shard_option == 'shard_last_axis':
|
||||
shard_option = ShardOption.SHARD_LAST_AXIS
|
||||
elif shard_option == 'full_shard':
|
||||
shard_option = ShardOption.FULL_SHARD
|
||||
else:
|
||||
raise ValueError(f'Invalid shard_option: {shard_option}')
|
||||
|
||||
solver_options = SolverOptions(solver_perference=solver_preference,
|
||||
dataloader_option=dataloader_option,
|
||||
shard_option=shard_option)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
@@ -183,6 +208,9 @@ def initialize_model(model: nn.Module,
|
||||
device_mesh: DeviceMesh,
|
||||
memory_budget: float = -1.0,
|
||||
overlap: bool = False,
|
||||
solver_preference: str = 'standard',
|
||||
dataloader_option: str = 'replicated',
|
||||
shard_option: str = 'standard',
|
||||
save_solver_solution: bool = False,
|
||||
load_solver_solution: bool = False,
|
||||
solution_path: str = None,
|
||||
@@ -198,6 +226,12 @@ def initialize_model(model: nn.Module,
|
||||
the memory budget will be infinity.
|
||||
overlap(optional): the overlap is used to specify whether to overlap gradient communication and
|
||||
backward computing.
|
||||
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
|
||||
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
|
||||
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
|
||||
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
|
||||
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
|
||||
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
|
||||
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
|
||||
to the solution_path.
|
||||
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
||||
@@ -212,7 +246,12 @@ def initialize_model(model: nn.Module,
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
strategies_constructor = build_strategy_constructor(graph, device_mesh)
|
||||
|
||||
strategies_constructor = build_strategy_constructor(graph,
|
||||
device_mesh,
|
||||
solver_preference=solver_preference,
|
||||
dataloader_option=dataloader_option,
|
||||
shard_option=shard_option)
|
||||
if load_solver_solution:
|
||||
solution = torch.load(solution_path)
|
||||
else:
|
||||
@@ -240,6 +279,9 @@ def autoparallelize(model: nn.Module,
|
||||
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
|
||||
logical_mesh_shape: Tuple[int] = None,
|
||||
logical_mesh_id: torch.Tensor = None,
|
||||
solver_preference: str = 'standard',
|
||||
dataloader_option: str = 'replicated',
|
||||
shard_option: str = 'standard',
|
||||
save_solver_solution: bool = False,
|
||||
load_solver_solution: bool = False,
|
||||
solver_solution_path: str = None,
|
||||
@@ -262,6 +304,12 @@ def autoparallelize(model: nn.Module,
|
||||
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
|
||||
generated by search_best_logical_mesh_shape function.
|
||||
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
|
||||
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
|
||||
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
|
||||
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
|
||||
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
|
||||
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
|
||||
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
|
||||
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
|
||||
to the solution_path.
|
||||
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
||||
@@ -280,6 +328,8 @@ def autoparallelize(model: nn.Module,
|
||||
rst_to_unpack = initialize_model(model,
|
||||
meta_args,
|
||||
device_mesh,
|
||||
solver_preference=solver_preference,
|
||||
dataloader_option=dataloader_option,
|
||||
save_solver_solution=save_solver_solution,
|
||||
load_solver_solution=load_solver_solution,
|
||||
solution_path=solver_solution_path,
|
||||
|
Reference in New Issue
Block a user