[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -3,7 +3,6 @@ from typing import Dict, List, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.fx import GraphModule
from torch.fx.graph import Graph
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
@@ -14,27 +13,32 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pas
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
class ModuleWrapper(nn.Module):
'''
"""
This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict
into the forward function.
'''
"""
def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]):
'''
def __init__(
self,
module: ColoGraphModule,
sharding_spec_dict: Dict[int, List[ShardingSpec]],
origin_spec_dict: Dict[int, ShardingSpec],
comm_actions_dict: Dict[int, Dict[str, CommAction]],
):
"""
Args:
module: the original module
sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node.
origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor.
comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor.
'''
"""
super(ModuleWrapper, self).__init__()
self.module = module
self.sharding_spec_dict = sharding_spec_dict
@@ -42,67 +46,68 @@ class ModuleWrapper(nn.Module):
self.comm_actions_dict = comm_actions_dict
def forward(self, *args, **kwargs):
return self.module(*args,
sharding_spec_convert_dict=self.sharding_spec_dict,
origin_node_sharding_spec_dict=self.origin_spec_dict,
comm_actions_dict=self.comm_actions_dict,
**kwargs)
return self.module(
*args,
sharding_spec_convert_dict=self.sharding_spec_dict,
origin_node_sharding_spec_dict=self.origin_spec_dict,
comm_actions_dict=self.comm_actions_dict,
**kwargs,
)
def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable):
'''
"""
This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func.
'''
"""
# TODO: implement this function
pass
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
'''
"""
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
from the alpha_beta_dict. These two values will be used to estimate the communication cost.
'''
"""
# TODO: implement this function
pass
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str,
shard_option: str):
'''
def build_strategy_constructor(
graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str, shard_option: str
):
"""
This method is used to build the strategy_constructor for the given graph.
After this method, each node in the graph will have a strategies_vector which
is constructed by the related node handler.
'''
if solver_preference == 'standard':
"""
if solver_preference == "standard":
solver_preference = SolverPerference.STANDARD
elif solver_preference == 'tp':
elif solver_preference == "tp":
solver_preference = SolverPerference.TP
elif solver_preference == 'dp':
elif solver_preference == "dp":
solver_preference = SolverPerference.DP
else:
raise ValueError(f'Invalid solver_preference: {solver_preference}')
raise ValueError(f"Invalid solver_preference: {solver_preference}")
if dataloader_option == 'replicated':
if dataloader_option == "replicated":
dataloader_option = DataloaderOption.REPLICATED
elif dataloader_option == 'distributed':
elif dataloader_option == "distributed":
dataloader_option = DataloaderOption.DISTRIBUTED
else:
raise ValueError(f'Invalid dataloader_option: {dataloader_option}')
raise ValueError(f"Invalid dataloader_option: {dataloader_option}")
if shard_option == 'standard':
if shard_option == "standard":
shard_option = ShardOption.STANDARD
elif shard_option == 'shard':
elif shard_option == "shard":
shard_option = ShardOption.SHARD
elif shard_option == 'shard_last_axis':
elif shard_option == "shard_last_axis":
shard_option = ShardOption.SHARD_LAST_AXIS
elif shard_option == 'full_shard':
elif shard_option == "full_shard":
shard_option = ShardOption.FULL_SHARD
else:
raise ValueError(f'Invalid shard_option: {shard_option}')
raise ValueError(f"Invalid shard_option: {shard_option}")
solver_options = SolverOptions(solver_perference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option)
solver_options = SolverOptions(
solver_perference=solver_preference, dataloader_option=dataloader_option, shard_option=shard_option
)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
@@ -110,10 +115,10 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_pre
def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
'''
"""
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
'''
"""
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# graph_analyser = GraphAnalyser(gm)
@@ -127,23 +132,23 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
return solution
def transform_to_sharded_model(gm: ColoGraphModule,
meta_args: Dict,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor,
overlap: bool = False):
'''
def transform_to_sharded_model(
gm: ColoGraphModule,
meta_args: Dict,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor,
overlap: bool = False,
):
"""
This method is used to transform the original graph to the sharded graph.
The model parameters will be sharded according to the solution and the grad hooks
will be added to the sharded graph using the runtime_preparation_pass.
The communication node will be added into the graph using the runtime_apply_pass.
'''
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm,
solution,
device_mesh,
strategies_constructor,
overlap=overlap)
"""
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
gm, solution, device_mesh, strategies_constructor, overlap=overlap
)
gm = runtime_apply_pass(gm)
shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict)
gm.recompile()
@@ -152,12 +157,14 @@ def transform_to_sharded_model(gm: ColoGraphModule,
return gm, sharding_spec_dicts
def initialize_device_mesh(world_size: int = -1,
physical_devices: List[int] = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None):
'''
def initialize_device_mesh(
world_size: int = -1,
physical_devices: List[int] = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None,
):
"""
This method is used to initialize the device mesh.
Args:
@@ -170,7 +177,7 @@ def initialize_device_mesh(world_size: int = -1,
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
'''
"""
# if world_size is not set, use the world size from torch.distributed
if world_size == -1:
world_size = dist.get_world_size()
@@ -201,27 +208,31 @@ def initialize_device_mesh(world_size: int = -1,
# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id)
device_mesh = DeviceMesh(physical_mesh_id=physical_mesh,
logical_mesh_id=logical_mesh_id,
mesh_alpha=mesh_alpha,
mesh_beta=mesh_beta,
init_process_group=True)
device_mesh = DeviceMesh(
physical_mesh_id=physical_mesh,
logical_mesh_id=logical_mesh_id,
mesh_alpha=mesh_alpha,
mesh_beta=mesh_beta,
init_process_group=True,
)
return device_mesh
def initialize_model(model: nn.Module,
meta_args: Dict[str, torch.Tensor],
device_mesh: DeviceMesh,
memory_budget: float = -1.0,
overlap: bool = False,
solver_preference: str = 'standard',
dataloader_option: str = 'replicated',
shard_option: str = 'standard',
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solution_path: str = None,
return_solution: bool = False):
'''
def initialize_model(
model: nn.Module,
meta_args: Dict[str, torch.Tensor],
device_mesh: DeviceMesh,
memory_budget: float = -1.0,
overlap: bool = False,
solver_preference: str = "standard",
dataloader_option: str = "replicated",
shard_option: str = "standard",
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solution_path: str = None,
return_solution: bool = False,
):
"""
This method is used to initialize the sharded model which could be used as normal pytorch model.
Args:
@@ -246,7 +257,7 @@ def initialize_model(model: nn.Module,
return_solution(optional): if the return_solution is True, the solution will be returned. The returned
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies.
'''
"""
tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True)
graph = tracer.trace(root=model, meta_args=meta_args)
@@ -256,11 +267,13 @@ def initialize_model(model: nn.Module,
shape_prop_pass(gm, *meta_args.values())
gm.recompile()
strategies_constructor = build_strategy_constructor(graph,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option)
strategies_constructor = build_strategy_constructor(
graph,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option,
)
if load_solver_solution:
solution = torch.load(solution_path)
else:
@@ -268,8 +281,9 @@ def initialize_model(model: nn.Module,
if save_solver_solution:
torch.save(solution, solution_path)
gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor,
overlap)
gm, sharding_spec_dicts = transform_to_sharded_model(
gm, meta_args, solution, device_mesh, strategies_constructor, overlap
)
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
@@ -277,28 +291,30 @@ def initialize_model(model: nn.Module,
solution_to_return = []
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
for index, node in enumerate(nodes):
solution_to_return.append(f'{node.name} {node.strategies_vector[solution[index]].name}')
solution_to_return.append(f"{node.name} {node.strategies_vector[solution[index]].name}")
return model_to_return, solution_to_return
else:
return model_to_return
def autoparallelize(model: nn.Module,
meta_args: Dict[str, torch.Tensor] = None,
data_loader: torch.utils.data.DataLoader = None,
data_process_func: callable = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None,
solver_preference: str = 'standard',
dataloader_option: str = 'replicated',
shard_option: str = 'standard',
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solver_solution_path: str = None,
return_solution: bool = False,
memory_budget: float = -1.0):
'''
def autoparallelize(
model: nn.Module,
meta_args: Dict[str, torch.Tensor] = None,
data_loader: torch.utils.data.DataLoader = None,
data_process_func: callable = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None,
solver_preference: str = "standard",
dataloader_option: str = "replicated",
shard_option: str = "standard",
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solver_solution_path: str = None,
return_solution: bool = False,
memory_budget: float = -1.0,
):
"""
This method is used to initialize the device mesh, extract the meta_args, and
use them to create a sharded model.
@@ -329,24 +345,26 @@ def autoparallelize(model: nn.Module,
return_solution(optional): if the return_solution is True, the solution will be returned.
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
'''
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict,
logical_mesh_shape=logical_mesh_shape,
logical_mesh_id=logical_mesh_id)
"""
device_mesh = initialize_device_mesh(
alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape, logical_mesh_id=logical_mesh_id
)
if meta_args is None:
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
rst_to_unpack = initialize_model(model,
meta_args,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option,
save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution,
solution_path=solver_solution_path,
return_solution=return_solution,
memory_budget=memory_budget)
rst_to_unpack = initialize_model(
model,
meta_args,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option,
save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution,
solution_path=solver_solution_path,
return_solution=return_solution,
memory_budget=memory_budget,
)
if return_solution:
model, solution = rst_to_unpack