mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
fix precommit
This commit is contained in:
@@ -3,22 +3,12 @@ import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
CohereForCausalLM,
|
||||
CohereModel,
|
||||
StaticCache,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.models.cohere.modeling_cohere import CohereForCausalLM, CohereModel, StaticCache, repeat_kv
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
@@ -343,10 +333,9 @@ class CommandPipelineForwards:
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_command_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
|
||||
from transformers.models.cohere.modeling_cohere import CohereAttention, apply_rotary_pos_emb
|
||||
from transformers.models.cohere.modeling_cohere import repeat_kv
|
||||
|
||||
from transformers.models.cohere.modeling_cohere import CohereAttention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
def forward(
|
||||
self: CohereAttention,
|
||||
@@ -728,7 +717,6 @@ def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
|
Reference in New Issue
Block a user