mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Feature] Split cross-entropy computation in SP (#5959)
* halfway * fix cross-PP-stage position id length diff bug * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * adapt chatglm, command-R, qwen * debug * halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * add comments * q1 index only once * remove events to simplify stream sync * simplify forward/backward logic * 2d ring forward passed * 2d ring backward passed * fixes * fix ring attn loss * 2D ring backward + llama passed * merge * update logger * fix typo * rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * remove typos * fixes * support GPT --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -25,7 +25,6 @@ from transformers.models.llama.modeling_llama import (
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import AttnMaskType
|
||||
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward
|
||||
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
@@ -58,10 +57,7 @@ class LlamaPipelineForwards:
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
# Split output only when computing cross entropy using llama_for_causal_lm_forward
|
||||
# or get_lm_forward_with_dist_cross_entropy
|
||||
# Default to True to avoid bug when calling classification forward from huggingface
|
||||
force_sp_output_gather: bool = True,
|
||||
force_sp_gather: bool = True, # Set to false only when computing cross entropy
|
||||
):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -78,8 +74,9 @@ class LlamaPipelineForwards:
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
disable_pp = stage_manager is None
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
if disable_pp or stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@@ -88,10 +85,10 @@ class LlamaPipelineForwards:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
device = hidden_states.device
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
@@ -101,8 +98,8 @@ class LlamaPipelineForwards:
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
sp_size = shard_config.sequence_parallel_size
|
||||
if sp_mode == "all_to_all" and not stage_manager.is_first_stage():
|
||||
# For generating full positions ids, as the states will be gather along the seq dim in the attention layer later.
|
||||
# Generating full positions ids for modes that gather sequence before attn
|
||||
if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()):
|
||||
seq_length *= sp_size
|
||||
|
||||
past_seen_tokens = 0
|
||||
@@ -117,7 +114,6 @@ class LlamaPipelineForwards:
|
||||
|
||||
seq_length_with_past = seq_length + past_seen_tokens
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
@@ -130,14 +126,13 @@ class LlamaPipelineForwards:
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if not stage_manager.is_first_stage() and sp_mode == "ring_attn":
|
||||
|
||||
no_split_input = disable_pp or not stage_manager.is_first_stage()
|
||||
if no_split_input and sp_mode == "ring_attn":
|
||||
_, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)
|
||||
elif shard_config.enable_flash_attention:
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||
attn_kwargs = ColoAttention.prepare_attn_kwargs(
|
||||
attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
@@ -146,15 +141,15 @@ class LlamaPipelineForwards:
|
||||
invert=(sp_mode != "ring_attn"),
|
||||
)
|
||||
else:
|
||||
attn_kwargs = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||
|
||||
# Support SP + PP
|
||||
# TODO: support padded casual cu_seqlens across stages
|
||||
if stage_manager.is_first_stage():
|
||||
# Support SP + PP. Later stages have already received the split input.
|
||||
split_input = disable_pp or stage_manager.is_first_stage()
|
||||
if split_input:
|
||||
# Ring Attention zigzag batch processing
|
||||
if sp_mode == "ring_attn":
|
||||
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
|
||||
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
|
||||
if not attention_mask.bool().all():
|
||||
hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
|
||||
attention_mask, sp_group, hidden_states, position_ids
|
||||
)
|
||||
@@ -181,8 +176,8 @@ class LlamaPipelineForwards:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1])
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
num_ckpt_layers = 0
|
||||
if self.gradient_checkpointing and self.training:
|
||||
num_ckpt_layers = end_idx - start_idx
|
||||
@@ -228,18 +223,16 @@ class LlamaPipelineForwards:
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||
hidden_states = gather_sp_output(
|
||||
hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
if (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode): # noqa
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if stage_manager.is_last_stage():
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
@@ -257,7 +250,7 @@ class LlamaPipelineForwards:
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
# always return dict for imediate stage
|
||||
# always return dict for intermediate stage
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
@@ -323,7 +316,7 @@ class LlamaPipelineForwards:
|
||||
# Split labels in a zigzag fashion too
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
if attention_mask.bool().all():
|
||||
labels = split_batch_zigzag(labels, sp_group, seq_dim=1)
|
||||
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
|
||||
else:
|
||||
# [B, max_seqlen // sp_size]
|
||||
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
|
||||
@@ -345,16 +338,17 @@ class LlamaPipelineForwards:
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
force_sp_output_gather=False,
|
||||
force_sp_gather=False,
|
||||
)
|
||||
past_key_values = None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
disable_pp = stage_manager is None
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@@ -629,263 +623,3 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
# Split output only when computing cross entropy using llama_for_causal_lm_forward
|
||||
# or get_lm_forward_with_dist_cross_entropy
|
||||
# Default to True to avoid bug when calling classification forward from huggingface
|
||||
force_sp_output_gather: bool = True,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
past_seen_tokens = 0
|
||||
seq_len = inputs_embeds.shape[1]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
|
||||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device)
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
if shard_config.enable_flash_attention:
|
||||
mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len)
|
||||
attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
inputs_embeds.dtype,
|
||||
inputs_embeds.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
invert=(sp_mode != "ring_attn"),
|
||||
)
|
||||
|
||||
else:
|
||||
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
|
||||
# Ring Attention zigzag batch processing
|
||||
if sp_mode == "ring_attn":
|
||||
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
|
||||
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
|
||||
inputs_embeds, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
|
||||
attention_mask, sp_group, inputs_embeds, position_ids
|
||||
)
|
||||
else:
|
||||
inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group)
|
||||
attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors
|
||||
|
||||
elif is_share_sp_tp(sp_mode):
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attn_kwargs,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attn_kwargs,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
# Cases that don't support parallelizing cross entropy computation along sequence
|
||||
if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
|
||||
hidden_states = gather_sp_output(
|
||||
hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
||||
)
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
def forward(
|
||||
self: LlamaForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output:
|
||||
# Special processing: Split labels in a zigzag fashion too
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
if attention_mask.bool().all():
|
||||
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
|
||||
else:
|
||||
# [B, max_seq_len // sp_size]
|
||||
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
force_sp_output_gather=False,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.pretraining_tp > 1:
|
||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
Reference in New Issue
Block a user