[Feature] Support LLaMA-3 CPT and ST (#5619)

* support LLaMA-3

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Run pre-commit

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Tong Li
2024-04-23 13:54:05 +08:00
committed by GitHub
parent e094933da1
commit 862fbaaa62
28 changed files with 89 additions and 87 deletions

View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,106 @@
# Copyright 2023 lm-sys@FastChat
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from enum import Enum, auto
from typing import List
class SeparatorStyle(Enum):
ADD_BOS_EOS_TOKEN = auto()
@dataclasses.dataclass
class Conversation:
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle
seps: List[str]
def clear(self):
self.messages = []
def get_prompt(self, length: int = None):
if length is None:
length = len(self.messages)
if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
ret = self.system
for role, message in self.messages[0:length]:
if message:
ret += role + ": " + self.seps[0] + message + self.seps[1]
else:
ret += role + ": " + self.seps[0]
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def save_prompt(self):
if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
ret = self.system
for role, message in self.messages:
if message:
ret += role + ": " + self.seps[0] + message + self.seps[1] + "\n"
else:
ret += role + ": " + self.seps[0]
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
seps=self.seps,
)
def dict(self):
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"seps": self.seps,
}
LLaMA2_Conv = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("Human", "Assistant"),
messages=[],
offset=0,
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
seps=["<s>", "</s>"],
)
LLaMA3_Conv = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("Human", "Assistant"),
messages=[],
offset=0,
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
seps=["<|begin_of_text|>", "<|end_of_text|>"],
)
default_conversation = LLaMA3_Conv

View File

@@ -0,0 +1,171 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from dataclasses import dataclass
from typing import Dict, Iterator, List, Optional, Sequence, Union
import torch
import torch.nn.functional as F
from datasets import Dataset as HFDataset
from datasets import dataset_dict, load_from_disk
from torch.utils.data import ConcatDataset, Dataset, DistributedSampler
from transformers.tokenization_utils import PreTrainedTokenizer
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
PathType = Union[str, os.PathLike]
def load_tokenized_dataset(
dataset_paths: Union[PathType, List[PathType]], mode: str = "train"
) -> Optional[DatasetType]:
"""
Load pre-tokenized dataset.
Each instance of dataset is a dictionary with
`{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
"""
mode_map = {"train": "train", "dev": "validation", "test": "test"}
assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
if isinstance(dataset_paths, (str, os.PathLike)):
dataset_paths = [dataset_paths]
datasets = [] # `List[datasets.dataset_dict.Dataset]`
for ds_path in dataset_paths:
ds_path = os.path.abspath(ds_path)
assert os.path.exists(ds_path), f"Not existed file path {ds_path}"
ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False)
if isinstance(ds_dict, HFDataset):
datasets.append(ds_dict)
else:
if mode_map[mode] in ds_dict:
datasets.append(ds_dict[mode_map[mode]])
if len(datasets) == 0:
return None
if len(datasets) == 1:
return datasets.pop()
return ConcatDataset(datasets=datasets)
@dataclass
class DataCollatorForSupervisedDataset(object):
"""
Collate instances for supervised dataset.
Each instance is a tokenized dictionary with fields
`input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
"""
tokenizer: PreTrainedTokenizer
max_length: int = 4096
ignore_index: int = -100
padding: str = "max_length"
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
"""
Args:
instances (`Sequence[Dict[str, List[int]]]`):
Mini-batch samples, each sample is stored in an individual dictionary.
Returns:
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
`labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
"""
assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
f"but now `{self.tokenizer.pad_token_id}`"
)
# `List[torch.Tensor]`
batch_input_ids = [
torch.LongTensor(instance["input_ids"][: self.max_length])
if len(instance["input_ids"]) > self.max_length
else torch.LongTensor(instance["input_ids"])
for instance in instances
]
batch_labels = [
torch.LongTensor(instance["labels"][: self.max_length])
if len(instance["labels"]) > self.max_length
else torch.LongTensor(instance["labels"])
for instance in instances
]
if self.tokenizer.padding_side == "right":
input_ids = torch.nn.utils.rnn.pad_sequence(
sequences=batch_input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
) # (bsz, max_len)
labels = torch.nn.utils.rnn.pad_sequence(
sequences=batch_labels,
batch_first=True,
padding_value=self.ignore_index,
) # (bsz, max_len)
if self.padding == "max_length":
# pad to max
to_pad = self.max_length - input_ids.size(1)
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
elif self.tokenizer.padding_side == "left":
reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
sequences=reversed_input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
) # (bsz, max_len)
input_ids = torch.flip(reversed_input_ids, dims=(1,)) # (bsz, max_len)
reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels]
reversed_labels = torch.nn.utils.rnn.pad_sequence(
sequences=reversed_labels,
batch_first=True,
padding_value=self.ignore_index,
) # (bsz, max_len)
labels = torch.flip(reversed_labels, dims=(1,)) # (bsz, max_len)
else:
raise RuntimeError(
f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, "
f"but now `{self.tokenizer.padding_side}`"
)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
class StatefulDistributedSampler(DistributedSampler):
"""
Stateful distributed sampler for multi-stage training.
"""
def __init__(
self,
dataset: DatasetType,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
super().__init__(
dataset=dataset,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
)
self.start_index = 0
def __iter__(self) -> Iterator:
iterator = super().__iter__()
indices = list(iterator)
indices = indices[self.start_index :]
return iter(indices)
def __len__(self) -> int:
return self.num_samples - self.start_index
def set_start_index(self, start_index: int) -> None:
self.start_index = start_index

View File

@@ -0,0 +1,301 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Splicing multiple pre-tokenized sequence data points
"""
import bisect
import random
import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
from datasets import dataset_dict
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
from transformers import AutoTokenizer
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
from colossalai.logging import get_dist_logger
from .conversation import Conversation, default_conversation
logger = get_dist_logger()
IGNORE_INDEX = -100
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
def supervised_tokenize_pretrain(
data_point: Dict[str, str], tokenizer: LlamaTokenizer, ignore_index: int = None, max_length: int = 4096
) -> Dict[str, Union[int, str, List[int]]]:
"""
A tokenization function to tokenize an original pretraining data point as following:
{"source": "", "target": "Beijing, the capital of the People's Republic of China, ...", "category": "geography"}
"""
assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, (
"Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, "
"add <bos> and <eos> manually later"
)
if ignore_index is None:
ignore_index = IGNORE_INDEX
source_text = data_point["source"] # `str`
target_text = data_point["target"] # `str`
is_null_source = len(source_text) == 0
source_text = tokenizer.bos_token + source_text
target_text += tokenizer.eos_token
sequence_text = source_text + target_text
tokenized = tokenizer([source_text, sequence_text])["input_ids"]
sequence_input_ids = tokenized[1]
sequence_labels = deepcopy(sequence_input_ids)
source_length = len(tokenized[0])
if not is_null_source:
sequence_labels[:source_length] = [ignore_index for _ in range(source_length)]
# sequence truncation.
if len(sequence_input_ids) > max_length:
sequence_input_ids = sequence_input_ids[:max_length]
sequence_labels = sequence_labels[:max_length]
return dict(
input_ids=sequence_input_ids,
labels=sequence_labels,
seq_length=len(sequence_input_ids),
seq_category=data_point["category"],
)
def supervised_tokenize_sft(
data_point: Dict[str, str],
tokenizer: AutoTokenizer,
conversation_template: Conversation = default_conversation,
ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
A tokenization function to tokenize an original supervised data point as following:
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
"""
assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, (
"Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, "
"add <bos> and <eos> manually later"
)
assert (
tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1]
), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`."
if ignore_index is None:
ignore_index = IGNORE_INDEX
messages = data_point["messages"]
template = deepcopy(conversation_template)
template.messages = []
for mess in messages:
from_str = mess["from"]
if from_str.lower() == "human":
from_str = template.roles[0]
elif from_str.lower() == "assistant":
from_str = template.roles[1]
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
template.append_message(from_str, mess["content"])
if len(template.messages) % 2 != 0:
template.messages = template.messages[0:-1]
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
turns = [i for i in range(1, len(messages) // 2 + 1)]
target_turn_index = bisect.bisect_right(
turns,
max_length - 1,
key=lambda x: len(tokenizer([template.get_prompt(2 * x)], add_special_tokens=False)["input_ids"][0]),
)
# The tokenized length for first turn already exceeds `max_length - 1`.
if target_turn_index - 1 < 0:
return dict(
input_ids=None,
labels=None,
inputs_decode=None,
labels_decode=None,
seq_length=None,
seq_category=None,
)
target_turn = turns[target_turn_index - 1]
prompt = template.get_prompt(2 * target_turn)
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
template.messages = template.messages[0 : 2 * target_turn]
starts = []
ends = []
gpt_bos = False if template.messages[0][0] == template.roles[0] else True
gpt_eos = False if template.messages[0][0] == template.roles[0] else True
for i, token_id in enumerate(tokenized):
if token_id == tokenizer.bos_token_id:
if gpt_bos:
starts.append(i)
gpt_bos = not gpt_bos
elif token_id == tokenizer.eos_token_id:
if gpt_eos:
ends.append(i)
gpt_eos = not gpt_eos
if len(starts) != target_turn or len(ends) != target_turn:
logger.info(
"Please check whether the tokenizer add additional `bos_token` and `eos_token`.\n\nOr the original message contains `bos_token` or `eos_token`."
)
return dict(
input_ids=None,
labels=None,
inputs_decode=None,
labels_decode=None,
seq_length=None,
seq_category=None,
)
tokenized = [tokenizer.bos_token_id] + tokenized
labels = [ignore_index] * len(tokenized)
for start, end in zip(starts, ends):
labels[start + 1 : end + 2] = tokenized[start + 1 : end + 2]
labels_decode = deepcopy(labels)
for i, z in enumerate(labels_decode):
if z == ignore_index:
labels_decode[i] = tokenizer.unk_token_id
# `inputs_decode` and `labels_decode` can be used to check whether the tokenization method is true.
return dict(
input_ids=tokenized,
labels=labels,
inputs_decode=tokenizer.decode(tokenized),
labels_decode=tokenizer.decode(labels_decode),
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
)
class ClosedToConstantLengthSplicedDataset(IterableDataset):
"""
Define an iterable dataset that returns a (close to) constant length data point spliced from multiple
original independent (pre-tokenized) data points.
"""
def __init__(
self,
dataset: DSType,
tokenizer: PreTrainedTokenizer,
max_length: int = 4096,
num_packed_sequences: int = 8,
fetch_sequence_func: Callable[[Any], Tuple[List[int], List[int]]] = None,
input_ids_field: str = "input_ids",
labels_field: str = "labels",
infinite: bool = False,
shuffle: bool = True,
error_strict: bool = False,
) -> None:
self.tokenizer = tokenizer
self.dataset = dataset
self.max_length = max_length
self.infinite = infinite
self.max_buffer_size = max_length * num_packed_sequences # e.g., 4096 * 16
self.shuffle = shuffle
# Callable[[Dict[str, Any]], Tuple[List[int], List[int]]],
# A function that fetch sequence input_ids and labels from the original data point
if fetch_sequence_func is None:
self.fetch_sequence_func = lambda data_point: (data_point[input_ids_field], data_point[labels_field])
else:
self.fetch_sequence_func = fetch_sequence_func
self.input_ids_field = input_ids_field
self.labels_field = labels_field
self.error_strict = error_strict
self.current_size = 0 # `int`, current packed data size.
def __len__(self) -> int:
return len(self.dataset)
def __iter__(self) -> Iterable[Dict[str, List[int]]]:
iterator = iter(self.dataset)
more_data_points = True
while more_data_points is True:
buffer, buffer_len = [], 0
while True:
# ending condition.
if buffer_len >= self.max_buffer_size:
break
try:
# `Tuple[List[int], List[int]]`
seq_input_ids, seq_labels = self.fetch_sequence_func(next(iterator))
buffer.append({self.input_ids_field: seq_input_ids, self.labels_field: seq_labels})
buffer_len += len(buffer[-1][self.input_ids_field])
except StopIteration:
if self.infinite is True:
iterator = iter(self.dataset)
warnings.warn("The dataset reached end and the iterator is reset to the start.")
else:
more_data_points = False
break
examples = [] # `List[Dict[str, List[int]]]`, save buffered spliced data points.
spliced_input_ids, spliced_labels = [], [] # `List[int]`, `List[int]`
for i, data_point in enumerate(buffer):
# TODO(2023-09-18) check errors for each unspliced tokenized data point
seq_input_ids = data_point[self.input_ids_field]
seq_labels = data_point[self.labels_field]
# Handle special case:
# If the length of an original data point (i.e., input_ids length of a data point before splicing)
# exceeds `max_length`, truncate it.
if len(seq_input_ids) > self.max_length:
truncated_seq_input_ids = seq_input_ids[: self.max_length]
truncated_label_ids = seq_labels[: self.max_length]
if set(truncated_label_ids) == {IGNORE_INDEX}:
if self.error_strict is True:
raise ValueError(
f"Find an out-of-bounds length({len(seq_input_ids)}) data point "
f"with all label values as {IGNORE_INDEX}."
)
else:
warnings.warn(f"Filter an error truncated data point (labels all {IGNORE_INDEX})")
continue # Skip the current error data point.
spliced_data_point = {
self.input_ids_field: truncated_seq_input_ids,
self.labels_field: truncated_label_ids,
}
examples.append(spliced_data_point)
warnings.warn("Find a data point to be truncated.")
continue
# Pre action judgment.
if len(spliced_input_ids) + len(seq_input_ids) > self.max_length:
spliced_data_point = {
self.input_ids_field: spliced_input_ids,
self.labels_field: spliced_labels,
} # `Dict[str, List[int]]`
# Update.
spliced_input_ids, spliced_labels = [], []
spliced_input_ids.extend(seq_input_ids)
spliced_labels.extend(seq_labels)
examples.append(spliced_data_point)
else:
spliced_input_ids.extend(seq_input_ids)
spliced_labels.extend(seq_labels)
# For residual spliced data point at the end of the data set
if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
examples.append({self.input_ids_field: spliced_input_ids, self.labels_field: spliced_labels})
if self.shuffle:
random.shuffle(examples)
for spliced_data_point in examples:
# TODO(2023-09-18): check errors for each spliced tokenized data point.
self.current_size += 1
yield spliced_data_point

View File

@@ -0,0 +1,110 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Initialize new model with updated tokenizer by calculating the mean values from original model
"""
import argparse
import numpy as np
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--source_model_and_tokenizer_path",
type=str,
required=True,
default=None,
help="Source path of model & tokenizer",
)
parser.add_argument("--target_tokenizer_path", type=str, required=True, default=None, help="Target tokenizer path")
parser.add_argument("--target_model_path", type=str, required=True, default=None, help="Target model path")
args = parser.parse_args()
source_tokenizer = LlamaTokenizer.from_pretrained(args.source_model_and_tokenizer_path)
source_tokenizer.add_bos_token = False
source_tokenizer.add_eos_token = False
if source_tokenizer.pad_token is None:
source_tokenizer.pad_token = source_tokenizer.unk_token
source_vocab = source_tokenizer.get_vocab()
target_tokenizer = LlamaTokenizer.from_pretrained(args.target_tokenizer_path)
target_tokenizer.add_bos_token = False
target_tokenizer.add_eos_token = False
if target_tokenizer.pad_token is None:
target_tokenizer.pad_token = target_tokenizer.unk_token
target_vocab = target_tokenizer.get_vocab()
target_inverted_vocab = {v: k for k, v in target_vocab.items()}
assert len(target_vocab) > len(
source_vocab
), f"Target vocab size({len(target_vocab)}) must be greater than source vocab size({len(source_vocab)})"
gpu_device = torch.device("cuda:0")
cpu_device = torch.device("cpu")
source_model = LlamaForCausalLM.from_pretrained(args.source_model_and_tokenizer_path)
source_model.eval()
source_model = source_model.to(gpu_device)
source_input_embeddings = source_model.get_input_embeddings()
assert isinstance(source_input_embeddings, torch.nn.Embedding)
assert source_input_embeddings.weight.shape[0] == len(source_vocab)
source_input_embeddings.eval()
source_output_embeddings = source_model.get_output_embeddings()
assert isinstance(source_output_embeddings, torch.nn.Linear)
assert source_output_embeddings.bias is None
assert source_output_embeddings.weight.shape[0] == len(source_vocab)
source_output_embeddings.eval()
input_embeddings = source_input_embeddings.weight.cpu().detach().numpy()
output_embeddings = source_output_embeddings.weight.cpu().detach().numpy()
for i in range(len(source_vocab), len(target_vocab)):
if i % 500 == 0:
logger.info(f"processing {i}/{len(target_vocab)} target tokens")
target_token = target_inverted_vocab[i]
target_to_source_token_ids = torch.LongTensor(source_tokenizer([target_token])["input_ids"][0])
target_to_source_token_ids = target_to_source_token_ids.to(gpu_device)
target_to_source_input_embedding = (
source_input_embeddings.weight[target_to_source_token_ids]
.mean(dim=0)
.unsqueeze(dim=0)
.cpu()
.detach()
.numpy()
)
target_to_source_output_embedding = (
source_output_embeddings.weight[target_to_source_token_ids]
.mean(dim=0)
.unsqueeze(dim=0)
.cpu()
.detach()
.numpy()
)
input_embeddings = np.concatenate((input_embeddings, target_to_source_input_embedding), axis=0)
output_embeddings = np.concatenate((output_embeddings, target_to_source_output_embedding), axis=0)
source_model = source_model.to(cpu_device)
assert isinstance(source_model, LlamaForCausalLM)
# expand
source_model.resize_token_embeddings(new_num_tokens=len(target_vocab))
source_model.model.embed_tokens.weight.data = torch.Tensor(input_embeddings)
source_model.lm_head.weight.data = torch.Tensor(output_embeddings)
source_model = source_model.half()
source_model.save_pretrained(save_directory=args.target_model_path)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,98 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
Initialize new tokenizer for continual pre-training
"""
import argparse
import json
import os
from typing import List, Union
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from colossalai.logging import get_dist_logger
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
logger = get_dist_logger()
def expand_vocab_tokenizer(
source_tokenizer_dir: Union[str, os.PathLike], target_tokenizer_dir: Union[str, os.PathLike], new_tokens: List[str]
) -> None:
"""Expand tokenizer for continue pre-training."""
if os.path.exists(target_tokenizer_dir):
raise RuntimeError(f"Find existed directory {target_tokenizer_dir}")
source_tokenizer = LlamaTokenizer.from_pretrained(source_tokenizer_dir)
logger.info(source_tokenizer)
source_sp_processor = source_tokenizer.sp_model
source_spm = sp_pb2_model.ModelProto()
source_spm.ParseFromString(source_sp_processor.serialized_model_proto())
logger.info(f"Source tokenizer size: {len(source_sp_processor)}")
# Add new tokens to source tokenizer.
source_spm_tokens = set([p.piece for p in source_spm.pieces])
for piece in new_tokens:
assert isinstance(piece, str), f"Invalid token({piece}) type {type(piece)}"
if piece in source_spm_tokens:
# Skip existed token.
continue
new_p = sp_pb2_model.ModelProto().SentencePiece()
new_p.piece = piece
new_p.score = 0
source_spm.pieces.append(new_p)
logger.info(f"Expand vocab from {len(source_spm_tokens)} to {len(source_spm.pieces)}")
# Save
os.makedirs(target_tokenizer_dir)
target_tokenizer_model_path = os.path.join(target_tokenizer_dir, "tokenizer.model")
with open(file=target_tokenizer_model_path, mode="wb") as fp:
fp.write(source_spm.SerializeToString())
target_tokenizer = LlamaTokenizer(vocab_file=target_tokenizer_model_path)
target_tokenizer.save_pretrained(save_directory=target_tokenizer_dir)
logger.info(f"Successfully save expand tokenizer to {target_tokenizer_dir}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--source_tokenizer_dir", type=str, required=True, default=None, help="Source tokenizer directory"
)
parser.add_argument(
"--target_tokenizer_dir", type=str, required=True, default=None, help="Target tokenizer directory"
)
parser.add_argument(
"--expand_tokens_file",
type=str,
required=True,
default=None,
help="Path of the file containing tokens to be extended",
)
args = parser.parse_args()
expand_tokens = []
with open(file=args.expand_tokens_file, mode="r", encoding="utf-8") as fp_reader:
for line in fp_reader:
item = json.loads(line)
# e.g., {"piece": "你好"}
token = item["piece"]
if token in expand_tokens:
continue
expand_tokens.append(token)
expand_tokens.sort(key=lambda t: len(t), reverse=False)
expand_vocab_tokenizer(
source_tokenizer_dir=args.source_tokenizer_dir,
target_tokenizer_dir=args.target_tokenizer_dir,
new_tokens=expand_tokens,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,88 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Helper functions for IO
"""
import json
import os
from typing import Any, Dict, Tuple, Union
import torch
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
"""
Load file in JSON format
"""
with open(file=file_path, mode="r", encoding="utf-8") as fp:
return json.load(fp)
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
"""
Save as JSON format
"""
with open(file=file_path, mode="w", encoding="utf-8") as fp:
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
def save_checkpoint(
save_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
epoch: int,
step: int,
batch_size: int,
coordinator: DistCoordinator,
) -> None:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
running_states = {
"epoch": epoch,
"step": step,
"sample_start_index": step * batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
def load_checkpoint(
load_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
) -> Tuple[int, int, int]:
"""
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
# Update booster params states.
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
return (
running_states["epoch"],
running_states["step"],
running_states["sample_start_index"],
)

View File

@@ -0,0 +1,352 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import math
from types import MethodType
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
apply_rotary_pos_emb,
repeat_kv,
)
from colossalai.accelerator import get_accelerator
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
if get_accelerator().name == "cuda":
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
from flash_attn.ops.rms_norm import rms_norm
def _prepare_decoder_attention_mask(
self: LlamaModel,
attention_mask: torch.BoolTensor,
input_shape: torch.Size,
inputs_embeds: torch.Tensor,
past_key_values_length: int,
) -> Optional[torch.Tensor]:
"""
Decoder attetion mask
"""
if past_key_values_length > 0 and attention_mask is not None:
attention_mask = torch.cat(
tensors=(
torch.full(
size=(input_shape[0], past_key_values_length),
fill_value=True,
dtype=attention_mask.dtype,
device=attention_mask.device,
),
attention_mask,
),
dim=-1,
) # (bsz, past_key_values_length + q_len)
if attention_mask is not None and torch.all(attention_mask):
return None # Faster
return attention_mask
def attention_forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
"""
if output_attentions:
logger.warning(
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
"return `None` instead."
)
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
q_slicing, kv_slicing = (
dim // self.config.pretraining_tp
for dim in (
self.num_heads * self.head_dim,
self.num_key_value_heads * self.head_dim,
)
) # `Tuple[int, int]`
q_slices, k_slices, v_slices = (
proj.weight.split(slicing, dim=0)
for proj, slicing in (
(self.q_proj, q_slicing),
(self.k_proj, kv_slicing),
(self.v_proj, kv_slicing),
)
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
q, k, v = (
torch.cat(
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
dim=-1,
)
for slices in (q_slices, k_slices, v_slices)
)
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
# (bsz, q_len, num_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim)
else:
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
# (bsz, q_len, num_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim)
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
q, k, v = (
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
for states, num_heads in (
(q, self.num_heads),
(k, self.num_key_value_heads),
(v, self.num_key_value_heads),
)
)
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
past_kv_len = 0
if past_key_value is not None:
# if `past_key_value` is not None, `kv_len` > `q_len`.
past_kv_len = past_key_value[0].shape[-2]
kv_len += past_kv_len
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
cos, sin = self.rotary_emb(v, seq_len=kv_len)
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
k = torch.cat([past_key_value[0], k], dim=2)
v = torch.cat([past_key_value[1], v], dim=2)
past_key_value = (k, v) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
key_padding_mask = attention_mask
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
q, k, v = (states.transpose(1, 2) for states in (q, k, v))
if past_kv_len > 0:
q = torch.cat(
tensors=(
torch.full(
size=(bsz, past_kv_len, self.num_heads, self.head_dim),
fill_value=0.0,
dtype=q.dtype,
device=q.device,
),
q,
),
dim=1,
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
if key_padding_mask is None:
# (bsz, past_kv_len + q_len, num_heads, head_dim)
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
output = rearrange(
output, pattern="... h d -> ... (h d)"
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
else:
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
kv, _, cu_kv_lens, max_kv_len = unpad_input(
hidden_states=torch.stack(tensors=(k, v), dim=2),
attention_mask=key_padding_mask,
)
output_unpad = flash_attn_varlen_kvpacked_func(
q=q,
kv=kv,
cu_seqlens_q=cu_q_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_q_len,
max_seqlen_k=max_kv_len,
dropout_p=0.0,
softmax_scale=None,
causal=True,
)
output = pad_input(
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
indices=indices,
batch=bsz,
seqlen=past_kv_len + q_len,
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
if past_kv_len > 0:
# Strip off the zero query outputs.
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
output = self.o_proj(output) # (bsz, q_len, hidden_size)
return output, None, past_key_value
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Formard function for RMS Norm
"""
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
module.forward = MethodType(attention_forward, module)
if isinstance(module, LlamaModel):
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
if isinstance(module, LlamaRMSNorm):
module.forward = MethodType(rms_norm_forward, module)
elif get_accelerator().name == "npu":
import torch_npu
class NPULlamaAttention(LlamaAttention):
use_flash: bool = True
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.setup()
def setup(self):
self._softmax_scale = 1 / math.sqrt(self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if not self.use_flash:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
else:
attn_output, *_ = torch_npu.npu_fusion_attention(
query_states,
key_states,
value_states,
self.num_heads,
"BNSD",
atten_mask=attention_mask.bool(),
scale=self._softmax_scale,
padding_mask=None,
pre_tockens=65535,
next_tockens=0,
keep_prob=1.0,
inner_precise=0,
)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum(
[F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]
)
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class NPURMSNorm(LlamaRMSNorm):
def forward(self, hidden_states):
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
module.__class__ = NPULlamaAttention
module.setup()
if isinstance(module, LlamaRMSNorm):
module.__class__ = NPURMSNorm

View File

@@ -0,0 +1,18 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from transformers.models.llama import LlamaForCausalLM
def freeze_non_embeds_parameters(model: LlamaForCausalLM) -> None:
"""Freeze all parameters except embeddings."""
for name, params in model.named_parameters():
if "embed_tokens" not in name and "lm_head" not in name:
params.requires_grad = False
else:
params.requires_grad = True
def unfreeze_parameters(model: LlamaForCausalLM) -> None:
for name, params in model.named_parameters():
params.requires_grad = False

View File

@@ -0,0 +1,72 @@
# Copyright 2023 The Hugging Face team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def unwrap(model):
if hasattr(model, "module"):
return model.unwrap()
else:
return model
def neftune_post_forward_hook(module, input, output):
"""
Implements the NEFTune forward pass for the model using forward hooks. Note this works only for torch.nn.Embedding
layers. This method is slightly adapted from the original source code that can be found here:
https://github.com/neelsjain/NEFTune Simply add it to your model as follows:
```python
model = ...
model.embed_tokens.neftune_noise_alpha = 0.1
model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
```
Args:
module (`torch.nn.Module`):
The embedding module where the hook is attached. Note that you need to set `module.neftune_noise_alpha` to
the desired noise alpha value.
input (`torch.Tensor`):
The input tensor to the model.
output (`torch.Tensor`):
The output tensor of the model (i.e. the embeddings).
"""
if module.training:
dims = torch.tensor(output.size(1) * output.size(2))
mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output
def activate_neftune(model, neftune_noise_alpha=0.1):
r"""
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
https://arxiv.org/abs/2310.05914
"""
embeddings = unwrap(model).get_input_embeddings()
embeddings.neftune_noise_alpha = neftune_noise_alpha
hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
neftune_hook_handle = hook_handle
return model, neftune_hook_handle
def deactivate_neftune(model, neftune_hook_handle):
"""
Deactivates the neftune method. Make sure to call `_activate_neftune` first.
"""
embeddings = unwrap(model).get_input_embeddings()
neftune_hook_handle.remove()
del embeddings.neftune_noise_alpha

View File

@@ -0,0 +1,252 @@
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import PreTrainedTokenizer
from transformers.generation.utils import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from transformers.utils import logging
logger = logging.get_logger(__name__)
def get_prompt_template(
input_query: str,
history: List[Dict] = None,
roles: list = ["", "Human", "Assistant"],
) -> str:
"""
Generates a prompt template for chat models based on input and history.
Args:
input_query (str): User's current input query.
history (List[Dict], optional): List of past conversations, each a dict with 'role' and 'message'.
roles (list): Specifies the roles in the conversation, defaults to ["", "Human", "Assistant"].
Returns:
str: A formatted prompt including the input query and history.
"""
prompt = ""
if history is None:
new_history = []
else:
new_history = deepcopy(history)
new_history.append({"role": roles[1], "message": input_query.strip()})
new_history.append({"role": roles[2], "message": None})
for _, item in enumerate(new_history):
role = item.get("role")
message = item.get("message")
if role == roles[0]:
prompt += f"<s>{message}\n\n"
else:
if message:
prompt += f"{role}: <s>{message}</s>"
else:
prompt += f"{role}: <s>"
return prompt
@torch.inference_mode()
def streaming_chat(
model: Any,
tokenizer: PreTrainedTokenizer,
input_query: str,
history: List[Dict] = None,
roles: list = ["", "Human", "Assistant"],
past_key_values: Tuple[Tuple[torch.FloatTensor, Any], Any] = None,
temperature: float = 0.8,
top_p: float = 0.95,
top_k: int = 50,
do_sample: bool = True,
length_penalty: float = 1.2,
max_new_tokens: int = 512,
logits_processor: LogitsProcessorList = None,
return_past_key_values: bool = False,
**kwargs,
):
"""
Streaming chat responses generation with a given model and tokenizer.
Args:
model (Any): The language model to generate responses.
tokenizer (PreTrainedTokenizer): Tokenizer compatible with the model, used for encoding inputs and decoding responses.
input_query (str): The current user input to respond to.
history (List[Dict], optional): A list of past conversations, where each conversation is a dictionary with keys 'role' and 'message'.
roles (list): Roles involved in the conversation, defaults to ["", "Human", "Assistant"].
past_key_values (Tuple[Tuple[torch.FloatTensor, Any], Any], optional): Past key values for incremental decoding.
temperature (float): The temperature value for token sampling, defaults to 0.8.
top_p (float): Nucleus sampling probability threshold, defaults to 0.95.
top_k (int): Top-K filtering threshold, defaults to 50.
do_sample (bool): Whether to sample responses, defaults to True.
length_penalty (float): Penalty for response length, defaults to 1.2.
max_new_tokens (int): Maximum number of new tokens to generate, defaults to 512.
logits_processor (LogitsProcessorList, optional): Custom logits processors, defaults to None.
return_past_key_values (bool): Whether to return past key values for further incremental decoding, defaults to False.
**kwargs: Additional keyword arguments for generation.
Yields:
Tuple[str, List[Dict], Optional[Tuple[Tuple[torch.FloatTensor, Any], Any]]]: A tuple containing the generated response, updated history, and
optionally the updated past key values if `return_past_key_values` is True.
Ensures padding is on the left side for the tokenizer.
"""
assert tokenizer.padding_side == "left", "Current generation only supports left padding."
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
generation_kwargs = {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"do_sample": do_sample,
"max_new_tokens": max_new_tokens,
"length_penalty": length_penalty,
"use_cache": True,
**kwargs,
}
prompt_str = get_prompt_template(input_query, history=history, roles=roles)
eos_token_id = [tokenizer.eos_token_id]
inputs = tokenizer(prompt_str, return_tensors="pt").to(model.device)
history.append({"role": roles[1], "message": input_query.strip()})
history.append({"role": roles[2], "message": None})
for outputs in stream_generate(
model,
**inputs,
past_key_values=past_key_values,
eos_token_id=eos_token_id,
return_past_key_values=return_past_key_values,
**generation_kwargs,
):
if return_past_key_values:
outputs, past_key_values = outputs
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1]
response = tokenizer.decode(outputs)
history[-1]["message"] = response.strip()
if return_past_key_values:
yield response, history, past_key_values
else:
yield response, history
@torch.inference_mode()
def stream_generate(
model: Any,
input_ids: torch.Tensor,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
return_past_key_values: bool = False,
**kwargs,
):
"""
Generates sequences of token ids using the specified model and generation parameters.
Adapted from https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py
Args:
model (Any): The model used for generating sequences of token ids.
input_ids (torch.Tensor): The sequence used as a prompt for the generation or as model inputs to the encoder.
generation_config (Optional[GenerationConfig]): The generation configuration to be used as base parametrization for the generation call.
logits_processor (Optional[LogitsProcessorList]): Custom logits processors that complement the default logits processors built from arguments
and generation config.
stopping_criteria (Optional[StoppingCriteriaList]): Custom stopping criteria that complement the default stopping criteria built from arguments
and a generation config.
prefix_allowed_tokens_fn (Optional[Callable[[int, torch.Tensor], List[int]]]): Function to constrain token generation.
return_past_key_values (bool): Whether to return past key values for further incremental decoding, defaults to False.
**kwargs: Additional parameters for model generation.
Yields:
torch.Tensor: The generated token IDs, updated after each generation step.
Optional[Tuple[Tuple[torch.FloatTensor, Any], Any]]: The past key values, returned if `return_past_key_values` is True, defaults to False.
"""
input_ids_len = input_ids.size(1)
if generation_config is None:
generation_config = model.generation_config
generation_config = deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
if generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_len
if input_ids_len >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if model.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_len}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
# prepare distribution pre_processing samplers
logits_processor = model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_len,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
# prepare stopping criteria
stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
logits_warper = model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True:
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = model(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
# NOTE: this is correct only in left padding mode
# pre-process distribution
next_token_logits = outputs.logits[:, -1, :]
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
if generation_config.do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
)
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
if return_past_key_values:
yield input_ids, outputs.past_key_values
else:
yield input_ids
# stop when each sentence is finished, or if exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
break