mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -9,8 +9,19 @@ from colossalai.legacy.tensor.distspec import ShardSpec
|
||||
|
||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||
ELEMENTWISE_FUNC_OP = [
|
||||
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
|
||||
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
|
||||
torch.add,
|
||||
operator.add,
|
||||
torch.abs,
|
||||
torch.cos,
|
||||
torch.exp,
|
||||
torch.mul,
|
||||
operator.mul,
|
||||
operator.floordiv,
|
||||
operator.truediv,
|
||||
operator.neg,
|
||||
torch.multiply,
|
||||
torch.nn.functional.relu,
|
||||
torch.nn.functional.dropout,
|
||||
]
|
||||
|
||||
|
||||
@@ -72,7 +83,7 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
|
||||
# traverse the graph to look for consecutive linear layers
|
||||
is_linear_module = False
|
||||
|
||||
if node.op == 'call_module':
|
||||
if node.op == "call_module":
|
||||
# look for the linear layer
|
||||
module = node.graph.owning_module.get_submodule(node.target)
|
||||
if isinstance(module, nn.Linear):
|
||||
@@ -82,31 +93,31 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
|
||||
# it means the first linear has been found and the current module
|
||||
# is the second linear
|
||||
# set the current linear module to be row-sharded
|
||||
annotation_record['row'] = module
|
||||
annotation_record["row"] = module
|
||||
|
||||
for shard_type, module in annotation_record.items():
|
||||
# add row sharding spec
|
||||
if shard_type == 'row':
|
||||
if shard_type == "row":
|
||||
dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size])
|
||||
comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
setattr(module.weight, 'pg', process_group)
|
||||
setattr(module.weight, 'dist_spec', dist_spec)
|
||||
setattr(module.weight, 'comp_spec', comp_spec)
|
||||
elif shard_type == 'col':
|
||||
setattr(module.weight, "pg", process_group)
|
||||
setattr(module.weight, "dist_spec", dist_spec)
|
||||
setattr(module.weight, "comp_spec", comp_spec)
|
||||
elif shard_type == "col":
|
||||
weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
|
||||
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
weight_comp_spec.output_replicate = False
|
||||
setattr(module.weight, 'pg', process_group)
|
||||
setattr(module.weight, 'dist_spec', weight_dist_spec)
|
||||
setattr(module.weight, 'comp_spec', weight_comp_spec)
|
||||
setattr(module.weight, "pg", process_group)
|
||||
setattr(module.weight, "dist_spec", weight_dist_spec)
|
||||
setattr(module.weight, "comp_spec", weight_comp_spec)
|
||||
|
||||
if module.bias is not None:
|
||||
bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
|
||||
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
bias_comp_spec.output_replicate = False
|
||||
setattr(module.bias, 'pg', process_group)
|
||||
setattr(module.bias, 'dist_spec', bias_dist_spec)
|
||||
setattr(module.bias, 'comp_spec', bias_comp_spec)
|
||||
setattr(module.bias, "pg", process_group)
|
||||
setattr(module.bias, "dist_spec", bias_dist_spec)
|
||||
setattr(module.bias, "comp_spec", bias_comp_spec)
|
||||
start_tracking = False
|
||||
annotation_record.clear()
|
||||
else:
|
||||
@@ -114,16 +125,16 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
|
||||
# it means the current layer is the first linear
|
||||
# set the linear layer to be col-sharded
|
||||
start_tracking = True
|
||||
annotation_record['col'] = module
|
||||
annotation_record["col"] = module
|
||||
|
||||
if start_tracking and not is_linear_module:
|
||||
# check against the white list
|
||||
# if non-element wise op is found, we reset the tracking
|
||||
if node.op == 'call_module':
|
||||
if node.op == "call_module":
|
||||
module = node.graph.owning_module.get_submodule(node.target)
|
||||
if module.__class__ not in ELEMENTWISE_MODULE_OP:
|
||||
start_tracking = False
|
||||
elif node.op == 'call_function' or node.op == 'call_method':
|
||||
elif node.op == "call_function" or node.op == "call_method":
|
||||
if node.target not in ELEMENTWISE_FUNC_OP:
|
||||
start_tracking = False
|
||||
elif len(node.users.keys()) > 1:
|
||||
|
Reference in New Issue
Block a user