[Shardformer] add assert for num of attention heads divisible by tp_size (#5670)

* add assert for num of attention heads divisible by tp_size

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Wang Binluo
2024-04-29 05:47:47 -05:00
committed by GitHub
parent 6af6d6fc9f
commit d3f34ee8cc
13 changed files with 48 additions and 0 deletions

View File

@@ -44,6 +44,9 @@ class ViTPolicy(Policy):
warnings.warn("Vit doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
policy[ViTEmbeddings] = ModulePolicyDescription(
attribute_replacement={},
param_replacement=[],