mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -37,10 +37,9 @@ Note that the license is subject to update to a more comprehensive version. For
|
||||
|
||||
import copy
|
||||
import math
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -80,7 +79,6 @@ def default_init(cls, *args, **kwargs):
|
||||
|
||||
|
||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||
scores.zero_()
|
||||
@@ -100,7 +98,7 @@ class PrefixEncoder(torch.nn.Module):
|
||||
self.prefix_projection = config.prefix_projection
|
||||
if self.prefix_projection:
|
||||
# Use a two-layer MLP to encode the prefix
|
||||
kv_size = (config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
|
||||
kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
|
||||
self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
|
||||
self.trans = torch.nn.Sequential(
|
||||
torch.nn.Linear(kv_size, config.hidden_size),
|
||||
@@ -151,10 +149,9 @@ def split_tensor_along_last_dim(
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, dim, original_impl=False, device=None, dtype=None):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
self.dim = dim
|
||||
self.original_impl = original_impl
|
||||
@@ -174,7 +171,7 @@ class RotaryEmbedding(nn.Module):
|
||||
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
|
||||
"""
|
||||
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
||||
theta = 1.0 / (base**(torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
|
||||
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
|
||||
|
||||
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
||||
seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
|
||||
@@ -220,7 +217,6 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
||||
super().__init__()
|
||||
self.elementwise_affine = True
|
||||
@@ -236,7 +232,6 @@ class RMSNorm(torch.nn.Module):
|
||||
|
||||
|
||||
class CoreAttention(torch.nn.Module):
|
||||
|
||||
def __init__(self, config: ChatGLMConfig, layer_number):
|
||||
super(CoreAttention, self).__init__()
|
||||
|
||||
@@ -250,7 +245,7 @@ class CoreAttention(torch.nn.Module):
|
||||
|
||||
# Per attention head and per partition values.
|
||||
self.hidden_size_per_partition = projection_size
|
||||
self.hidden_size_per_attention_head = (projection_size // config.num_attention_heads)
|
||||
self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
|
||||
self.num_attention_heads_per_partition = config.num_attention_heads
|
||||
|
||||
coeff = None
|
||||
@@ -267,15 +262,15 @@ class CoreAttention(torch.nn.Module):
|
||||
if pytorch_major_version >= 2:
|
||||
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
||||
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
is_causal=True)
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer, key_layer, value_layer, is_causal=True
|
||||
)
|
||||
else:
|
||||
if attention_mask is not None:
|
||||
attention_mask = ~attention_mask
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
||||
attention_mask)
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer, key_layer, value_layer, attention_mask
|
||||
)
|
||||
context_layer = context_layer.permute(2, 0, 1, 3)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.reshape(*new_context_layer_shape)
|
||||
@@ -307,8 +302,8 @@ class CoreAttention(torch.nn.Module):
|
||||
# Raw attention scores. [b * np, sq, sk]
|
||||
matmul_result = torch.baddbmm(
|
||||
matmul_input_buffer,
|
||||
query_layer.transpose(0, 1), # [b * np, sq, hn]
|
||||
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
|
||||
query_layer.transpose(0, 1), # [b * np, sq, hn]
|
||||
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
|
||||
beta=0.0,
|
||||
alpha=(1.0 / self.norm_factor),
|
||||
)
|
||||
@@ -325,7 +320,7 @@ class CoreAttention(torch.nn.Module):
|
||||
attention_scores = attention_scores.float()
|
||||
if self.coeff is not None:
|
||||
attention_scores = attention_scores * self.coeff
|
||||
if (attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]):
|
||||
if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
|
||||
attention_mask = torch.ones(
|
||||
output_size[0],
|
||||
1,
|
||||
@@ -388,15 +383,16 @@ class SelfAttention(torch.nn.Module):
|
||||
|
||||
self.projection_size = config.kv_channels * config.num_attention_heads
|
||||
# Per attention head and per partition values.
|
||||
self.hidden_size_per_attention_head = (self.projection_size // config.num_attention_heads)
|
||||
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
||||
self.num_attention_heads_per_partition = config.num_attention_heads
|
||||
|
||||
self.multi_query_attention = config.multi_query_attention
|
||||
self.qkv_hidden_size = 3 * self.projection_size
|
||||
if self.multi_query_attention:
|
||||
self.num_multi_query_groups_per_partition = config.multi_query_group_num
|
||||
self.qkv_hidden_size = (self.projection_size +
|
||||
2 * self.hidden_size_per_attention_head * config.multi_query_group_num)
|
||||
self.qkv_hidden_size = (
|
||||
self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
|
||||
)
|
||||
self.query_key_value = nn.Linear(
|
||||
config.hidden_size,
|
||||
self.qkv_hidden_size,
|
||||
@@ -459,18 +455,27 @@ class SelfAttention(torch.nn.Module):
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
query_layer = query_layer.view(query_layer.size()[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
))
|
||||
key_layer = key_layer.view(key_layer.size()[:-1] + (
|
||||
self.num_multi_query_groups_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
))
|
||||
value_layer = value_layer.view(value_layer.size()[:-1] + (
|
||||
self.num_multi_query_groups_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
))
|
||||
query_layer = query_layer.view(
|
||||
query_layer.size()[:-1]
|
||||
+ (
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
key_layer = key_layer.view(
|
||||
key_layer.size()[:-1]
|
||||
+ (
|
||||
self.num_multi_query_groups_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
value_layer = value_layer.view(
|
||||
value_layer.size()[:-1]
|
||||
+ (
|
||||
self.num_multi_query_groups_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
else:
|
||||
new_tensor_shape = mixed_x_layer.size()[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
@@ -504,10 +509,13 @@ class SelfAttention(torch.nn.Module):
|
||||
self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
|
||||
-1,
|
||||
)
|
||||
key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
))
|
||||
key_layer = key_layer.contiguous().view(
|
||||
key_layer.size()[:2]
|
||||
+ (
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
value_layer = value_layer.unsqueeze(-2)
|
||||
value_layer = value_layer.expand(
|
||||
-1,
|
||||
@@ -516,10 +524,13 @@ class SelfAttention(torch.nn.Module):
|
||||
self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
|
||||
-1,
|
||||
)
|
||||
value_layer = value_layer.contiguous().view(value_layer.size()[:2] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
))
|
||||
value_layer = value_layer.contiguous().view(
|
||||
value_layer.size()[:2]
|
||||
+ (
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
|
||||
# ==================================
|
||||
# core attention computation
|
||||
@@ -600,7 +611,7 @@ class GLMBlock(torch.nn.Module):
|
||||
super(GLMBlock, self).__init__()
|
||||
self.layer_number = layer_number
|
||||
|
||||
self.apply_residual_connection_post_layernorm = (config.apply_residual_connection_post_layernorm)
|
||||
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
||||
|
||||
self.fp32_residual_connection = config.fp32_residual_connection
|
||||
|
||||
@@ -724,7 +735,8 @@ class GLMTransformer(torch.nn.Module):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
all_self_attentions = None
|
||||
@@ -806,7 +818,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def get_position_ids(self, input_ids, device):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
position_ids = (torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1))
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
||||
return position_ids
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
@@ -843,7 +855,6 @@ class Embedding(torch.nn.Module):
|
||||
|
||||
|
||||
class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
|
||||
def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
|
||||
super().__init__(config)
|
||||
if empty_init:
|
||||
@@ -860,8 +871,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
|
||||
# Rotary positional embeddings
|
||||
self.seq_length = config.seq_length
|
||||
rotary_dim = (config.hidden_size //
|
||||
config.num_attention_heads if config.kv_channels is None else config.kv_channels)
|
||||
rotary_dim = (
|
||||
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
||||
)
|
||||
|
||||
self.rotary_pos_emb = RotaryEmbedding(
|
||||
rotary_dim // 2,
|
||||
@@ -891,7 +903,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
return self.embedding.word_embeddings
|
||||
|
||||
def get_prompt(self, batch_size, device, dtype=torch.half):
|
||||
prefix_tokens = (self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device))
|
||||
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
||||
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|
||||
past_key_values = past_key_values.view(
|
||||
batch_size,
|
||||
@@ -917,10 +929,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
@@ -966,12 +979,16 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [
|
||||
hidden_states,
|
||||
presents,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
] if v is not None)
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
presents,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
@@ -988,7 +1005,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
|
||||
|
||||
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
|
||||
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -1009,7 +1025,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
) -> Dict[str, Any]:
|
||||
# update past_key_values
|
||||
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
||||
outputs, standardize_cache_format=standardize_cache_format)
|
||||
outputs, standardize_cache_format=standardize_cache_format
|
||||
)
|
||||
|
||||
# update attention mask
|
||||
if "attention_mask" in model_kwargs:
|
||||
@@ -1067,7 +1084,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
return_last_logit: Optional[bool] = False,
|
||||
):
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=input_ids,
|
||||
@@ -1113,8 +1130,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...],
|
||||
beam_idx: torch.LongTensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
||||
def _reorder_cache(
|
||||
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
||||
"""
|
||||
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
||||
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
||||
@@ -1122,10 +1140,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
|
||||
Output shares the same memory storage as `past`.
|
||||
"""
|
||||
return tuple((
|
||||
layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
|
||||
layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
|
||||
) for layer_past in past)
|
||||
return tuple(
|
||||
(
|
||||
layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
|
||||
layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
|
||||
)
|
||||
for layer_past in past
|
||||
)
|
||||
|
||||
def process_response(self, response):
|
||||
response = response.strip()
|
||||
@@ -1180,7 +1201,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
}
|
||||
inputs = self.build_inputs(tokenizer, query, history=history)
|
||||
outputs = self.generate(**inputs, **gen_kwargs)
|
||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
|
||||
response = tokenizer.decode(outputs)
|
||||
response = self.process_response(response)
|
||||
history = history + [(query, response)]
|
||||
@@ -1227,14 +1248,14 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
||||
inputs["attention_mask"] = attention_mask
|
||||
for outputs in self.stream_generate(
|
||||
**inputs,
|
||||
past_key_values=past_key_values,
|
||||
return_past_key_values=return_past_key_values,
|
||||
**gen_kwargs,
|
||||
**inputs,
|
||||
past_key_values=past_key_values,
|
||||
return_past_key_values=return_past_key_values,
|
||||
**gen_kwargs,
|
||||
):
|
||||
if return_past_key_values:
|
||||
outputs, past_key_values = outputs
|
||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
|
||||
response = tokenizer.decode(outputs)
|
||||
if response and response[-1] != "<EFBFBD>":
|
||||
response = self.process_response(response)
|
||||
@@ -1269,7 +1290,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
|
||||
has_default_max_length = (kwargs.get("max_length") is None and generation_config.max_length is not None)
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||
warnings.warn(
|
||||
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
||||
@@ -1278,7 +1299,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
UserWarning,
|
||||
)
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = (generation_config.max_new_tokens + input_ids_seq_length)
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
if not has_default_max_length:
|
||||
logger.warn(
|
||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||
@@ -1289,14 +1310,16 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
)
|
||||
|
||||
if input_ids_seq_length >= generation_config.max_length:
|
||||
input_ids_string = ("decoder_input_ids" if self.config.is_encoder_decoder else "input_ids")
|
||||
logger.warning(f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||
" increasing `max_new_tokens`.")
|
||||
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||
logger.warning(
|
||||
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||
" increasing `max_new_tokens`."
|
||||
)
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
logits_processor = (logits_processor if logits_processor is not None else LogitsProcessorList())
|
||||
stopping_criteria = (stopping_criteria if stopping_criteria is not None else StoppingCriteriaList())
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
logits_processor = self._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
@@ -1306,8 +1329,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
stopping_criteria = self._get_stopping_criteria(generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria)
|
||||
stopping_criteria = self._get_stopping_criteria(
|
||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||
)
|
||||
logits_warper = self._get_logits_warper(generation_config)
|
||||
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
@@ -1337,9 +1361,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
model_kwargs = self._update_model_kwargs_for_generation(outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder)
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
||||
if return_past_key_values:
|
||||
yield input_ids, outputs.past_key_values
|
||||
|
Reference in New Issue
Block a user