This commit is contained in:
flybird11111
2024-06-03 11:25:18 +08:00
committed by GitHub
parent 68359ed1e1
commit 3f2be80530

View File

@@ -351,7 +351,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism and not self.shard_config.enable_sequence_parallelism:
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
LlamaForCausalLM: ModulePolicyDescription(