fix precommit

This commit is contained in:
GuangyaoZhang
2024-06-14 08:09:24 +00:00
parent 98da648a4a
commit fe2e74c03a
7 changed files with 35 additions and 86 deletions

View File

@@ -3,22 +3,12 @@ import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.models.cohere.modeling_cohere import (
CohereForCausalLM,
CohereModel,
StaticCache,
repeat_kv,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import CohereForCausalLM, CohereModel, StaticCache, repeat_kv
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
@@ -343,10 +333,9 @@ class CommandPipelineForwards:
hidden_states = outputs.get("hidden_states")
return {"hidden_states": hidden_states}
def get_command_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
from transformers.models.cohere.modeling_cohere import CohereAttention, apply_rotary_pos_emb
from transformers.models.cohere.modeling_cohere import repeat_kv
from transformers.models.cohere.modeling_cohere import CohereAttention, apply_rotary_pos_emb, repeat_kv
def forward(
self: CohereAttention,
@@ -728,7 +717,6 @@ def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions: