mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Feature] Support LLaMA-3 CPT and ST (#5619)
* support LLaMA-3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Run pre-commit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
88
applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py
Normal file
88
applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Helper functions for IO
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
|
||||
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
|
||||
"""
|
||||
Load file in JSON format
|
||||
"""
|
||||
with open(file=file_path, mode="r", encoding="utf-8") as fp:
|
||||
return json.load(fp)
|
||||
|
||||
|
||||
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
|
||||
"""
|
||||
Save as JSON format
|
||||
"""
|
||||
with open(file=file_path, mode="w", encoding="utf-8") as fp:
|
||||
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
save_dir: Union[str, os.PathLike],
|
||||
booster: Booster,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
epoch: int,
|
||||
step: int,
|
||||
batch_size: int,
|
||||
coordinator: DistCoordinator,
|
||||
) -> None:
|
||||
"""
|
||||
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
||||
"""
|
||||
|
||||
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
|
||||
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
|
||||
|
||||
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
|
||||
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"sample_start_index": step * batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
load_dir: Union[str, os.PathLike],
|
||||
booster: Booster,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
||||
"""
|
||||
|
||||
# Update booster params states.
|
||||
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
|
||||
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
|
||||
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
|
||||
|
||||
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
|
||||
return (
|
||||
running_states["epoch"],
|
||||
running_states["step"],
|
||||
running_states["sample_start_index"],
|
||||
)
|
@@ -0,0 +1,352 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
if get_accelerator().name == "cuda":
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
|
||||
from flash_attn.ops.rms_norm import rms_norm
|
||||
|
||||
def _prepare_decoder_attention_mask(
|
||||
self: LlamaModel,
|
||||
attention_mask: torch.BoolTensor,
|
||||
input_shape: torch.Size,
|
||||
inputs_embeds: torch.Tensor,
|
||||
past_key_values_length: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Decoder attetion mask
|
||||
"""
|
||||
if past_key_values_length > 0 and attention_mask is not None:
|
||||
attention_mask = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(input_shape[0], past_key_values_length),
|
||||
fill_value=True,
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
),
|
||||
attention_mask,
|
||||
),
|
||||
dim=-1,
|
||||
) # (bsz, past_key_values_length + q_len)
|
||||
if attention_mask is not None and torch.all(attention_mask):
|
||||
return None # Faster
|
||||
return attention_mask
|
||||
|
||||
def attention_forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
|
||||
"""
|
||||
if output_attentions:
|
||||
logger.warning(
|
||||
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
|
||||
"return `None` instead."
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
q_slicing, kv_slicing = (
|
||||
dim // self.config.pretraining_tp
|
||||
for dim in (
|
||||
self.num_heads * self.head_dim,
|
||||
self.num_key_value_heads * self.head_dim,
|
||||
)
|
||||
) # `Tuple[int, int]`
|
||||
q_slices, k_slices, v_slices = (
|
||||
proj.weight.split(slicing, dim=0)
|
||||
for proj, slicing in (
|
||||
(self.q_proj, q_slicing),
|
||||
(self.k_proj, kv_slicing),
|
||||
(self.v_proj, kv_slicing),
|
||||
)
|
||||
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
|
||||
q, k, v = (
|
||||
torch.cat(
|
||||
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
|
||||
dim=-1,
|
||||
)
|
||||
for slices in (q_slices, k_slices, v_slices)
|
||||
)
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
else:
|
||||
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
|
||||
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k, v = (
|
||||
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
|
||||
for states, num_heads in (
|
||||
(q, self.num_heads),
|
||||
(k, self.num_key_value_heads),
|
||||
(v, self.num_key_value_heads),
|
||||
)
|
||||
)
|
||||
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
|
||||
past_kv_len = 0
|
||||
if past_key_value is not None:
|
||||
# if `past_key_value` is not None, `kv_len` > `q_len`.
|
||||
past_kv_len = past_key_value[0].shape[-2]
|
||||
kv_len += past_kv_len
|
||||
|
||||
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
|
||||
cos, sin = self.rotary_emb(v, seq_len=kv_len)
|
||||
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
k = torch.cat([past_key_value[0], k], dim=2)
|
||||
v = torch.cat([past_key_value[1], v], dim=2)
|
||||
|
||||
past_key_value = (k, v) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
|
||||
key_padding_mask = attention_mask
|
||||
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
|
||||
q, k, v = (states.transpose(1, 2) for states in (q, k, v))
|
||||
|
||||
if past_kv_len > 0:
|
||||
q = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(bsz, past_kv_len, self.num_heads, self.head_dim),
|
||||
fill_value=0.0,
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
),
|
||||
q,
|
||||
),
|
||||
dim=1,
|
||||
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
|
||||
if key_padding_mask is None:
|
||||
# (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
|
||||
output = rearrange(
|
||||
output, pattern="... h d -> ... (h d)"
|
||||
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
else:
|
||||
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
|
||||
kv, _, cu_kv_lens, max_kv_len = unpad_input(
|
||||
hidden_states=torch.stack(tensors=(k, v), dim=2),
|
||||
attention_mask=key_padding_mask,
|
||||
)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q=q,
|
||||
kv=kv,
|
||||
cu_seqlens_q=cu_q_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
max_seqlen_q=max_q_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = pad_input(
|
||||
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
|
||||
indices=indices,
|
||||
batch=bsz,
|
||||
seqlen=past_kv_len + q_len,
|
||||
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
|
||||
if past_kv_len > 0:
|
||||
# Strip off the zero query outputs.
|
||||
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
|
||||
output = self.o_proj(output) # (bsz, q_len, hidden_size)
|
||||
return output, None, past_key_value
|
||||
|
||||
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Formard function for RMS Norm
|
||||
"""
|
||||
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
|
||||
|
||||
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
module.forward = MethodType(attention_forward, module)
|
||||
if isinstance(module, LlamaModel):
|
||||
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
|
||||
if isinstance(module, LlamaRMSNorm):
|
||||
module.forward = MethodType(rms_norm_forward, module)
|
||||
|
||||
elif get_accelerator().name == "npu":
|
||||
import torch_npu
|
||||
|
||||
class NPULlamaAttention(LlamaAttention):
|
||||
use_flash: bool = True
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
self._softmax_scale = 1 / math.sqrt(self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if not self.use_flash:
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
else:
|
||||
attn_output, *_ = torch_npu.npu_fusion_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
self.num_heads,
|
||||
"BNSD",
|
||||
atten_mask=attention_mask.bool(),
|
||||
scale=self._softmax_scale,
|
||||
padding_mask=None,
|
||||
pre_tockens=65535,
|
||||
next_tockens=0,
|
||||
keep_prob=1.0,
|
||||
inner_precise=0,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum(
|
||||
[F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
)
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
class NPURMSNorm(LlamaRMSNorm):
|
||||
def forward(self, hidden_states):
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
||||
|
||||
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
module.__class__ = NPULlamaAttention
|
||||
module.setup()
|
||||
if isinstance(module, LlamaRMSNorm):
|
||||
module.__class__ = NPURMSNorm
|
18
applications/Colossal-LLaMA/colossal_llama/utils/froze.py
Normal file
18
applications/Colossal-LLaMA/colossal_llama/utils/froze.py
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from transformers.models.llama import LlamaForCausalLM
|
||||
|
||||
|
||||
def freeze_non_embeds_parameters(model: LlamaForCausalLM) -> None:
|
||||
"""Freeze all parameters except embeddings."""
|
||||
for name, params in model.named_parameters():
|
||||
if "embed_tokens" not in name and "lm_head" not in name:
|
||||
params.requires_grad = False
|
||||
else:
|
||||
params.requires_grad = True
|
||||
|
||||
|
||||
def unfreeze_parameters(model: LlamaForCausalLM) -> None:
|
||||
for name, params in model.named_parameters():
|
||||
params.requires_grad = False
|
@@ -0,0 +1,72 @@
|
||||
# Copyright 2023 The Hugging Face team
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def unwrap(model):
|
||||
if hasattr(model, "module"):
|
||||
return model.unwrap()
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def neftune_post_forward_hook(module, input, output):
|
||||
"""
|
||||
Implements the NEFTune forward pass for the model using forward hooks. Note this works only for torch.nn.Embedding
|
||||
layers. This method is slightly adapted from the original source code that can be found here:
|
||||
https://github.com/neelsjain/NEFTune Simply add it to your model as follows:
|
||||
```python
|
||||
model = ...
|
||||
model.embed_tokens.neftune_noise_alpha = 0.1
|
||||
model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
|
||||
```
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The embedding module where the hook is attached. Note that you need to set `module.neftune_noise_alpha` to
|
||||
the desired noise alpha value.
|
||||
input (`torch.Tensor`):
|
||||
The input tensor to the model.
|
||||
output (`torch.Tensor`):
|
||||
The output tensor of the model (i.e. the embeddings).
|
||||
"""
|
||||
if module.training:
|
||||
dims = torch.tensor(output.size(1) * output.size(2))
|
||||
mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
|
||||
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
|
||||
return output
|
||||
|
||||
|
||||
def activate_neftune(model, neftune_noise_alpha=0.1):
|
||||
r"""
|
||||
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
|
||||
https://arxiv.org/abs/2310.05914
|
||||
"""
|
||||
embeddings = unwrap(model).get_input_embeddings()
|
||||
|
||||
embeddings.neftune_noise_alpha = neftune_noise_alpha
|
||||
hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
|
||||
neftune_hook_handle = hook_handle
|
||||
|
||||
return model, neftune_hook_handle
|
||||
|
||||
|
||||
def deactivate_neftune(model, neftune_hook_handle):
|
||||
"""
|
||||
Deactivates the neftune method. Make sure to call `_activate_neftune` first.
|
||||
"""
|
||||
embeddings = unwrap(model).get_input_embeddings()
|
||||
|
||||
neftune_hook_handle.remove()
|
||||
del embeddings.neftune_noise_alpha
|
@@ -0,0 +1,252 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.generation.utils import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_prompt_template(
|
||||
input_query: str,
|
||||
history: List[Dict] = None,
|
||||
roles: list = ["", "Human", "Assistant"],
|
||||
) -> str:
|
||||
"""
|
||||
Generates a prompt template for chat models based on input and history.
|
||||
|
||||
Args:
|
||||
input_query (str): User's current input query.
|
||||
history (List[Dict], optional): List of past conversations, each a dict with 'role' and 'message'.
|
||||
roles (list): Specifies the roles in the conversation, defaults to ["", "Human", "Assistant"].
|
||||
|
||||
Returns:
|
||||
str: A formatted prompt including the input query and history.
|
||||
"""
|
||||
prompt = ""
|
||||
if history is None:
|
||||
new_history = []
|
||||
else:
|
||||
new_history = deepcopy(history)
|
||||
|
||||
new_history.append({"role": roles[1], "message": input_query.strip()})
|
||||
new_history.append({"role": roles[2], "message": None})
|
||||
|
||||
for _, item in enumerate(new_history):
|
||||
role = item.get("role")
|
||||
message = item.get("message")
|
||||
if role == roles[0]:
|
||||
prompt += f"<s>{message}\n\n"
|
||||
else:
|
||||
if message:
|
||||
prompt += f"{role}: <s>{message}</s>"
|
||||
else:
|
||||
prompt += f"{role}: <s>"
|
||||
return prompt
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def streaming_chat(
|
||||
model: Any,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
input_query: str,
|
||||
history: List[Dict] = None,
|
||||
roles: list = ["", "Human", "Assistant"],
|
||||
past_key_values: Tuple[Tuple[torch.FloatTensor, Any], Any] = None,
|
||||
temperature: float = 0.8,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 50,
|
||||
do_sample: bool = True,
|
||||
length_penalty: float = 1.2,
|
||||
max_new_tokens: int = 512,
|
||||
logits_processor: LogitsProcessorList = None,
|
||||
return_past_key_values: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Streaming chat responses generation with a given model and tokenizer.
|
||||
|
||||
Args:
|
||||
model (Any): The language model to generate responses.
|
||||
tokenizer (PreTrainedTokenizer): Tokenizer compatible with the model, used for encoding inputs and decoding responses.
|
||||
input_query (str): The current user input to respond to.
|
||||
history (List[Dict], optional): A list of past conversations, where each conversation is a dictionary with keys 'role' and 'message'.
|
||||
roles (list): Roles involved in the conversation, defaults to ["", "Human", "Assistant"].
|
||||
past_key_values (Tuple[Tuple[torch.FloatTensor, Any], Any], optional): Past key values for incremental decoding.
|
||||
temperature (float): The temperature value for token sampling, defaults to 0.8.
|
||||
top_p (float): Nucleus sampling probability threshold, defaults to 0.95.
|
||||
top_k (int): Top-K filtering threshold, defaults to 50.
|
||||
do_sample (bool): Whether to sample responses, defaults to True.
|
||||
length_penalty (float): Penalty for response length, defaults to 1.2.
|
||||
max_new_tokens (int): Maximum number of new tokens to generate, defaults to 512.
|
||||
logits_processor (LogitsProcessorList, optional): Custom logits processors, defaults to None.
|
||||
return_past_key_values (bool): Whether to return past key values for further incremental decoding, defaults to False.
|
||||
**kwargs: Additional keyword arguments for generation.
|
||||
|
||||
Yields:
|
||||
Tuple[str, List[Dict], Optional[Tuple[Tuple[torch.FloatTensor, Any], Any]]]: A tuple containing the generated response, updated history, and
|
||||
optionally the updated past key values if `return_past_key_values` is True.
|
||||
|
||||
Ensures padding is on the left side for the tokenizer.
|
||||
"""
|
||||
assert tokenizer.padding_side == "left", "Current generation only supports left padding."
|
||||
if history is None:
|
||||
history = []
|
||||
if logits_processor is None:
|
||||
logits_processor = LogitsProcessorList()
|
||||
|
||||
generation_kwargs = {
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"do_sample": do_sample,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"length_penalty": length_penalty,
|
||||
"use_cache": True,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
prompt_str = get_prompt_template(input_query, history=history, roles=roles)
|
||||
|
||||
eos_token_id = [tokenizer.eos_token_id]
|
||||
inputs = tokenizer(prompt_str, return_tensors="pt").to(model.device)
|
||||
history.append({"role": roles[1], "message": input_query.strip()})
|
||||
history.append({"role": roles[2], "message": None})
|
||||
|
||||
for outputs in stream_generate(
|
||||
model,
|
||||
**inputs,
|
||||
past_key_values=past_key_values,
|
||||
eos_token_id=eos_token_id,
|
||||
return_past_key_values=return_past_key_values,
|
||||
**generation_kwargs,
|
||||
):
|
||||
if return_past_key_values:
|
||||
outputs, past_key_values = outputs
|
||||
|
||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1]
|
||||
response = tokenizer.decode(outputs)
|
||||
|
||||
history[-1]["message"] = response.strip()
|
||||
if return_past_key_values:
|
||||
yield response, history, past_key_values
|
||||
else:
|
||||
yield response, history
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def stream_generate(
|
||||
model: Any,
|
||||
input_ids: torch.Tensor,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
return_past_key_values: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generates sequences of token ids using the specified model and generation parameters.
|
||||
Adapted from https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py
|
||||
|
||||
Args:
|
||||
model (Any): The model used for generating sequences of token ids.
|
||||
input_ids (torch.Tensor): The sequence used as a prompt for the generation or as model inputs to the encoder.
|
||||
generation_config (Optional[GenerationConfig]): The generation configuration to be used as base parametrization for the generation call.
|
||||
logits_processor (Optional[LogitsProcessorList]): Custom logits processors that complement the default logits processors built from arguments
|
||||
and generation config.
|
||||
stopping_criteria (Optional[StoppingCriteriaList]): Custom stopping criteria that complement the default stopping criteria built from arguments
|
||||
and a generation config.
|
||||
prefix_allowed_tokens_fn (Optional[Callable[[int, torch.Tensor], List[int]]]): Function to constrain token generation.
|
||||
return_past_key_values (bool): Whether to return past key values for further incremental decoding, defaults to False.
|
||||
**kwargs: Additional parameters for model generation.
|
||||
|
||||
Yields:
|
||||
torch.Tensor: The generated token IDs, updated after each generation step.
|
||||
Optional[Tuple[Tuple[torch.FloatTensor, Any], Any]]: The past key values, returned if `return_past_key_values` is True, defaults to False.
|
||||
"""
|
||||
input_ids_len = input_ids.size(1)
|
||||
|
||||
if generation_config is None:
|
||||
generation_config = model.generation_config
|
||||
generation_config = deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(**kwargs)
|
||||
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||
|
||||
if generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_len
|
||||
|
||||
if input_ids_len >= generation_config.max_length:
|
||||
input_ids_string = "decoder_input_ids" if model.config.is_encoder_decoder else "input_ids"
|
||||
logger.warning(
|
||||
f"Input length of {input_ids_string} is {input_ids_len}, but `max_length` is set to"
|
||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||
" increasing `max_new_tokens`."
|
||||
)
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
# prepare distribution pre_processing samplers
|
||||
logits_processor = model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_len,
|
||||
encoder_input_ids=input_ids,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
# prepare stopping criteria
|
||||
stopping_criteria = model._get_stopping_criteria(
|
||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||
)
|
||||
|
||||
logits_warper = model._get_logits_warper(generation_config)
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
scores = None
|
||||
|
||||
while True:
|
||||
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
# forward pass to get next token
|
||||
outputs = model(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
|
||||
# NOTE: this is correct only in left padding mode
|
||||
# pre-process distribution
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# sample
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
if generation_config.do_sample:
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
else:
|
||||
next_tokens = torch.argmax(probs, dim=-1)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
model_kwargs = model._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
|
||||
)
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||
)
|
||||
|
||||
if return_past_key_values:
|
||||
yield input_ids, outputs.past_key_values
|
||||
else:
|
||||
yield input_ids
|
||||
# stop when each sentence is finished, or if exceed the maximum length
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||
break
|
Reference in New Issue
Block a user