change 'xxx if xxx else None' to 'xxx or None'

This commit is contained in:
GuangyaoZhang
2024-06-18 03:32:42 +00:00
parent a83a2336e8
commit d84d68601a
7 changed files with 19 additions and 18 deletions

View File

@@ -3,13 +3,18 @@ import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import CohereForCausalLM, CohereModel, StaticCache, apply_rotary_pos_emb, repeat_kv
from transformers.models.cohere.modeling_cohere import (
CohereForCausalLM,
CohereModel,
StaticCache,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
@@ -584,6 +589,7 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
return forward
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import CohereForCausalLM
@@ -683,4 +689,4 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
attentions=outputs.attentions,
)
return forward
return forward