mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
change 'xxx if xxx else None' to 'xxx or None'
This commit is contained in:
@@ -3,13 +3,18 @@ import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.models.cohere.modeling_cohere import CohereForCausalLM, CohereModel, StaticCache, apply_rotary_pos_emb, repeat_kv
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
CohereForCausalLM,
|
||||
CohereModel,
|
||||
StaticCache,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
@@ -584,6 +589,7 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
from transformers import CohereForCausalLM
|
||||
|
||||
@@ -683,4 +689,4 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
return forward
|
||||
|
Reference in New Issue
Block a user