[shardformer] fix import (#5788)

This commit is contained in:
Hongxin Liu
2024-06-06 19:09:50 +08:00
committed by GitHub
parent 5ead00ffc5
commit 73e88a5553
2 changed files with 8 additions and 4 deletions

View File

@@ -8,6 +8,10 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.cache_utils import Cache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@@ -17,8 +21,6 @@ from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
apply_rotary_pos_emb,
repeat_kv,
)