[shardformer] vit/llama/t5 ignore the sequence parallelism flag and some fix. (#4498)

* [shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

* fix

fix

fix

fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* activate checks
This commit is contained in:
flybird11111
2023-08-24 15:50:02 +08:00
committed by GitHub
parent e04436a82a
commit 3353e55c80
7 changed files with 46 additions and 21 deletions

View File

@@ -1,3 +1,4 @@
import warnings
from typing import Callable, Dict, List, Union
import torch.nn as nn
@@ -32,6 +33,10 @@ class ViTPolicy(Policy):
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={},
param_replacement=[],