1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-05-02 21:48:15 +00:00

[format] applied code formatting on changed files in pull request 4441 ()

Co-authored-by: github-actions <github-actions@github.com>
This commit is contained in:
github-actions[bot] 2023-08-16 10:47:23 +08:00 committed by GitHub
parent 5d4efdf58f
commit d20dceb9a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 41 deletions
colossalai/shardformer/policies
tests/test_pipeline

View File

@ -40,7 +40,7 @@ class ViTPolicy(Policy):
suffix="dropout",
target_module=DropoutForReplicatedInput,
)
])
])
policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={
"attention.attention.num_attention_heads":
@ -48,45 +48,45 @@ class ViTPolicy(Policy):
"attention.attention.all_head_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
])
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
])
# use flash attention
if self.shard_config.enable_flash_attention:

View File

@ -21,7 +21,7 @@ def check_stage_manager():
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()