diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index d962057b1..b4a1f4bd8 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -506,8 +506,9 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): ) } policy.update(new_item) + # TODO: test lora bug here # enable tp, replace layer to LinearWithGradAccum - else: + elif use_zbv: # add a new item for sequence classification new_item = { LlamaForSequenceClassification: ModulePolicyDescription(