mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[Feature] Support LLaMA-3 CPT and ST (#5619)
* support LLaMA-3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Run pre-commit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
2
applications/Colossal-LLaMA/colossal_llama/__init__.py
Normal file
2
applications/Colossal-LLaMA/colossal_llama/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
@@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
@@ -0,0 +1,106 @@
|
||||
# Copyright 2023 lm-sys@FastChat
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import dataclasses
|
||||
from enum import Enum, auto
|
||||
from typing import List
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
ADD_BOS_EOS_TOKEN = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
sep_style: SeparatorStyle
|
||||
seps: List[str]
|
||||
|
||||
def clear(self):
|
||||
self.messages = []
|
||||
|
||||
def get_prompt(self, length: int = None):
|
||||
if length is None:
|
||||
length = len(self.messages)
|
||||
|
||||
if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
|
||||
ret = self.system
|
||||
for role, message in self.messages[0:length]:
|
||||
if message:
|
||||
ret += role + ": " + self.seps[0] + message + self.seps[1]
|
||||
else:
|
||||
ret += role + ": " + self.seps[0]
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
def save_prompt(self):
|
||||
if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
|
||||
ret = self.system
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ": " + self.seps[0] + message + self.seps[1] + "\n"
|
||||
else:
|
||||
ret += role + ": " + self.seps[0]
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
seps=self.seps,
|
||||
)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"system": self.system,
|
||||
"roles": self.roles,
|
||||
"messages": self.messages,
|
||||
"offset": self.offset,
|
||||
"seps": self.seps,
|
||||
}
|
||||
|
||||
|
||||
LLaMA2_Conv = Conversation(
|
||||
system="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
|
||||
seps=["<s>", "</s>"],
|
||||
)
|
||||
|
||||
LLaMA3_Conv = Conversation(
|
||||
system="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
|
||||
seps=["<|begin_of_text|>", "<|end_of_text|>"],
|
||||
)
|
||||
|
||||
default_conversation = LLaMA3_Conv
|
171
applications/Colossal-LLaMA/colossal_llama/dataset/loader.py
Normal file
171
applications/Colossal-LLaMA/colossal_llama/dataset/loader.py
Normal file
@@ -0,0 +1,171 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterator, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset as HFDataset
|
||||
from datasets import dataset_dict, load_from_disk
|
||||
from torch.utils.data import ConcatDataset, Dataset, DistributedSampler
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||
PathType = Union[str, os.PathLike]
|
||||
|
||||
|
||||
def load_tokenized_dataset(
|
||||
dataset_paths: Union[PathType, List[PathType]], mode: str = "train"
|
||||
) -> Optional[DatasetType]:
|
||||
"""
|
||||
Load pre-tokenized dataset.
|
||||
Each instance of dataset is a dictionary with
|
||||
`{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
|
||||
"""
|
||||
mode_map = {"train": "train", "dev": "validation", "test": "test"}
|
||||
assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
|
||||
|
||||
if isinstance(dataset_paths, (str, os.PathLike)):
|
||||
dataset_paths = [dataset_paths]
|
||||
|
||||
datasets = [] # `List[datasets.dataset_dict.Dataset]`
|
||||
for ds_path in dataset_paths:
|
||||
ds_path = os.path.abspath(ds_path)
|
||||
assert os.path.exists(ds_path), f"Not existed file path {ds_path}"
|
||||
ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False)
|
||||
if isinstance(ds_dict, HFDataset):
|
||||
datasets.append(ds_dict)
|
||||
else:
|
||||
if mode_map[mode] in ds_dict:
|
||||
datasets.append(ds_dict[mode_map[mode]])
|
||||
if len(datasets) == 0:
|
||||
return None
|
||||
if len(datasets) == 1:
|
||||
return datasets.pop()
|
||||
return ConcatDataset(datasets=datasets)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSupervisedDataset(object):
|
||||
"""
|
||||
Collate instances for supervised dataset.
|
||||
Each instance is a tokenized dictionary with fields
|
||||
`input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizer
|
||||
max_length: int = 4096
|
||||
ignore_index: int = -100
|
||||
padding: str = "max_length"
|
||||
|
||||
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
instances (`Sequence[Dict[str, List[int]]]`):
|
||||
Mini-batch samples, each sample is stored in an individual dictionary.
|
||||
|
||||
Returns:
|
||||
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
|
||||
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
|
||||
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
|
||||
`labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
|
||||
"""
|
||||
assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
|
||||
f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
|
||||
f"but now `{self.tokenizer.pad_token_id}`"
|
||||
)
|
||||
|
||||
# `List[torch.Tensor]`
|
||||
batch_input_ids = [
|
||||
torch.LongTensor(instance["input_ids"][: self.max_length])
|
||||
if len(instance["input_ids"]) > self.max_length
|
||||
else torch.LongTensor(instance["input_ids"])
|
||||
for instance in instances
|
||||
]
|
||||
batch_labels = [
|
||||
torch.LongTensor(instance["labels"][: self.max_length])
|
||||
if len(instance["labels"]) > self.max_length
|
||||
else torch.LongTensor(instance["labels"])
|
||||
for instance in instances
|
||||
]
|
||||
|
||||
if self.tokenizer.padding_side == "right":
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=batch_input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id,
|
||||
) # (bsz, max_len)
|
||||
labels = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=batch_labels,
|
||||
batch_first=True,
|
||||
padding_value=self.ignore_index,
|
||||
) # (bsz, max_len)
|
||||
if self.padding == "max_length":
|
||||
# pad to max
|
||||
to_pad = self.max_length - input_ids.size(1)
|
||||
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
|
||||
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
|
||||
elif self.tokenizer.padding_side == "left":
|
||||
reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
|
||||
reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=reversed_input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id,
|
||||
) # (bsz, max_len)
|
||||
input_ids = torch.flip(reversed_input_ids, dims=(1,)) # (bsz, max_len)
|
||||
reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels]
|
||||
reversed_labels = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=reversed_labels,
|
||||
batch_first=True,
|
||||
padding_value=self.ignore_index,
|
||||
) # (bsz, max_len)
|
||||
labels = torch.flip(reversed_labels, dims=(1,)) # (bsz, max_len)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, "
|
||||
f"but now `{self.tokenizer.padding_side}`"
|
||||
)
|
||||
|
||||
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
|
||||
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
||||
|
||||
|
||||
class StatefulDistributedSampler(DistributedSampler):
|
||||
"""
|
||||
Stateful distributed sampler for multi-stage training.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: DatasetType,
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
dataset=dataset,
|
||||
num_replicas=num_replicas,
|
||||
rank=rank,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
self.start_index = 0
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
iterator = super().__iter__()
|
||||
indices = list(iterator)
|
||||
indices = indices[self.start_index :]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples - self.start_index
|
||||
|
||||
def set_start_index(self, start_index: int) -> None:
|
||||
self.start_index = start_index
|
@@ -0,0 +1,301 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Splicing multiple pre-tokenized sequence data points
|
||||
"""
|
||||
|
||||
import bisect
|
||||
import random
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
||||
|
||||
from datasets import dataset_dict
|
||||
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .conversation import Conversation, default_conversation
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||
|
||||
|
||||
def supervised_tokenize_pretrain(
|
||||
data_point: Dict[str, str], tokenizer: LlamaTokenizer, ignore_index: int = None, max_length: int = 4096
|
||||
) -> Dict[str, Union[int, str, List[int]]]:
|
||||
"""
|
||||
A tokenization function to tokenize an original pretraining data point as following:
|
||||
{"source": "", "target": "Beijing, the capital of the People's Republic of China, ...", "category": "geography"}
|
||||
"""
|
||||
assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, (
|
||||
"Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, "
|
||||
"add <bos> and <eos> manually later"
|
||||
)
|
||||
if ignore_index is None:
|
||||
ignore_index = IGNORE_INDEX
|
||||
|
||||
source_text = data_point["source"] # `str`
|
||||
target_text = data_point["target"] # `str`
|
||||
is_null_source = len(source_text) == 0
|
||||
|
||||
source_text = tokenizer.bos_token + source_text
|
||||
target_text += tokenizer.eos_token
|
||||
sequence_text = source_text + target_text
|
||||
|
||||
tokenized = tokenizer([source_text, sequence_text])["input_ids"]
|
||||
sequence_input_ids = tokenized[1]
|
||||
sequence_labels = deepcopy(sequence_input_ids)
|
||||
|
||||
source_length = len(tokenized[0])
|
||||
if not is_null_source:
|
||||
sequence_labels[:source_length] = [ignore_index for _ in range(source_length)]
|
||||
|
||||
# sequence truncation.
|
||||
if len(sequence_input_ids) > max_length:
|
||||
sequence_input_ids = sequence_input_ids[:max_length]
|
||||
sequence_labels = sequence_labels[:max_length]
|
||||
|
||||
return dict(
|
||||
input_ids=sequence_input_ids,
|
||||
labels=sequence_labels,
|
||||
seq_length=len(sequence_input_ids),
|
||||
seq_category=data_point["category"],
|
||||
)
|
||||
|
||||
|
||||
def supervised_tokenize_sft(
|
||||
data_point: Dict[str, str],
|
||||
tokenizer: AutoTokenizer,
|
||||
conversation_template: Conversation = default_conversation,
|
||||
ignore_index: int = None,
|
||||
max_length: int = 4096,
|
||||
) -> Dict[str, Union[int, str, List[int]]]:
|
||||
"""
|
||||
A tokenization function to tokenize an original supervised data point as following:
|
||||
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
|
||||
"""
|
||||
assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, (
|
||||
"Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, "
|
||||
"add <bos> and <eos> manually later"
|
||||
)
|
||||
|
||||
assert (
|
||||
tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1]
|
||||
), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`."
|
||||
|
||||
if ignore_index is None:
|
||||
ignore_index = IGNORE_INDEX
|
||||
|
||||
messages = data_point["messages"]
|
||||
template = deepcopy(conversation_template)
|
||||
template.messages = []
|
||||
|
||||
for mess in messages:
|
||||
from_str = mess["from"]
|
||||
if from_str.lower() == "human":
|
||||
from_str = template.roles[0]
|
||||
elif from_str.lower() == "assistant":
|
||||
from_str = template.roles[1]
|
||||
else:
|
||||
raise ValueError(f"Unsupported role {from_str.lower()}")
|
||||
|
||||
template.append_message(from_str, mess["content"])
|
||||
|
||||
if len(template.messages) % 2 != 0:
|
||||
template.messages = template.messages[0:-1]
|
||||
|
||||
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
|
||||
turns = [i for i in range(1, len(messages) // 2 + 1)]
|
||||
target_turn_index = bisect.bisect_right(
|
||||
turns,
|
||||
max_length - 1,
|
||||
key=lambda x: len(tokenizer([template.get_prompt(2 * x)], add_special_tokens=False)["input_ids"][0]),
|
||||
)
|
||||
|
||||
# The tokenized length for first turn already exceeds `max_length - 1`.
|
||||
if target_turn_index - 1 < 0:
|
||||
return dict(
|
||||
input_ids=None,
|
||||
labels=None,
|
||||
inputs_decode=None,
|
||||
labels_decode=None,
|
||||
seq_length=None,
|
||||
seq_category=None,
|
||||
)
|
||||
|
||||
target_turn = turns[target_turn_index - 1]
|
||||
prompt = template.get_prompt(2 * target_turn)
|
||||
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
|
||||
|
||||
template.messages = template.messages[0 : 2 * target_turn]
|
||||
|
||||
starts = []
|
||||
ends = []
|
||||
gpt_bos = False if template.messages[0][0] == template.roles[0] else True
|
||||
gpt_eos = False if template.messages[0][0] == template.roles[0] else True
|
||||
|
||||
for i, token_id in enumerate(tokenized):
|
||||
if token_id == tokenizer.bos_token_id:
|
||||
if gpt_bos:
|
||||
starts.append(i)
|
||||
gpt_bos = not gpt_bos
|
||||
elif token_id == tokenizer.eos_token_id:
|
||||
if gpt_eos:
|
||||
ends.append(i)
|
||||
gpt_eos = not gpt_eos
|
||||
|
||||
if len(starts) != target_turn or len(ends) != target_turn:
|
||||
logger.info(
|
||||
"Please check whether the tokenizer add additional `bos_token` and `eos_token`.\n\nOr the original message contains `bos_token` or `eos_token`."
|
||||
)
|
||||
return dict(
|
||||
input_ids=None,
|
||||
labels=None,
|
||||
inputs_decode=None,
|
||||
labels_decode=None,
|
||||
seq_length=None,
|
||||
seq_category=None,
|
||||
)
|
||||
|
||||
tokenized = [tokenizer.bos_token_id] + tokenized
|
||||
labels = [ignore_index] * len(tokenized)
|
||||
for start, end in zip(starts, ends):
|
||||
labels[start + 1 : end + 2] = tokenized[start + 1 : end + 2]
|
||||
|
||||
labels_decode = deepcopy(labels)
|
||||
for i, z in enumerate(labels_decode):
|
||||
if z == ignore_index:
|
||||
labels_decode[i] = tokenizer.unk_token_id
|
||||
|
||||
# `inputs_decode` and `labels_decode` can be used to check whether the tokenization method is true.
|
||||
return dict(
|
||||
input_ids=tokenized,
|
||||
labels=labels,
|
||||
inputs_decode=tokenizer.decode(tokenized),
|
||||
labels_decode=tokenizer.decode(labels_decode),
|
||||
seq_length=len(tokenized),
|
||||
seq_category=data_point["category"] if "category" in data_point else "None",
|
||||
)
|
||||
|
||||
|
||||
class ClosedToConstantLengthSplicedDataset(IterableDataset):
|
||||
"""
|
||||
Define an iterable dataset that returns a (close to) constant length data point spliced from multiple
|
||||
original independent (pre-tokenized) data points.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: DSType,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_length: int = 4096,
|
||||
num_packed_sequences: int = 8,
|
||||
fetch_sequence_func: Callable[[Any], Tuple[List[int], List[int]]] = None,
|
||||
input_ids_field: str = "input_ids",
|
||||
labels_field: str = "labels",
|
||||
infinite: bool = False,
|
||||
shuffle: bool = True,
|
||||
error_strict: bool = False,
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.dataset = dataset
|
||||
self.max_length = max_length
|
||||
self.infinite = infinite
|
||||
self.max_buffer_size = max_length * num_packed_sequences # e.g., 4096 * 16
|
||||
self.shuffle = shuffle
|
||||
|
||||
# Callable[[Dict[str, Any]], Tuple[List[int], List[int]]],
|
||||
# A function that fetch sequence input_ids and labels from the original data point
|
||||
if fetch_sequence_func is None:
|
||||
self.fetch_sequence_func = lambda data_point: (data_point[input_ids_field], data_point[labels_field])
|
||||
else:
|
||||
self.fetch_sequence_func = fetch_sequence_func
|
||||
self.input_ids_field = input_ids_field
|
||||
self.labels_field = labels_field
|
||||
|
||||
self.error_strict = error_strict
|
||||
self.current_size = 0 # `int`, current packed data size.
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataset)
|
||||
|
||||
def __iter__(self) -> Iterable[Dict[str, List[int]]]:
|
||||
iterator = iter(self.dataset)
|
||||
more_data_points = True
|
||||
while more_data_points is True:
|
||||
buffer, buffer_len = [], 0
|
||||
while True:
|
||||
# ending condition.
|
||||
if buffer_len >= self.max_buffer_size:
|
||||
break
|
||||
try:
|
||||
# `Tuple[List[int], List[int]]`
|
||||
seq_input_ids, seq_labels = self.fetch_sequence_func(next(iterator))
|
||||
buffer.append({self.input_ids_field: seq_input_ids, self.labels_field: seq_labels})
|
||||
buffer_len += len(buffer[-1][self.input_ids_field])
|
||||
except StopIteration:
|
||||
if self.infinite is True:
|
||||
iterator = iter(self.dataset)
|
||||
warnings.warn("The dataset reached end and the iterator is reset to the start.")
|
||||
else:
|
||||
more_data_points = False
|
||||
break
|
||||
examples = [] # `List[Dict[str, List[int]]]`, save buffered spliced data points.
|
||||
spliced_input_ids, spliced_labels = [], [] # `List[int]`, `List[int]`
|
||||
for i, data_point in enumerate(buffer):
|
||||
# TODO(2023-09-18) check errors for each unspliced tokenized data point
|
||||
seq_input_ids = data_point[self.input_ids_field]
|
||||
seq_labels = data_point[self.labels_field]
|
||||
# Handle special case:
|
||||
# If the length of an original data point (i.e., input_ids length of a data point before splicing)
|
||||
# exceeds `max_length`, truncate it.
|
||||
if len(seq_input_ids) > self.max_length:
|
||||
truncated_seq_input_ids = seq_input_ids[: self.max_length]
|
||||
truncated_label_ids = seq_labels[: self.max_length]
|
||||
if set(truncated_label_ids) == {IGNORE_INDEX}:
|
||||
if self.error_strict is True:
|
||||
raise ValueError(
|
||||
f"Find an out-of-bounds length({len(seq_input_ids)}) data point "
|
||||
f"with all label values as {IGNORE_INDEX}."
|
||||
)
|
||||
else:
|
||||
warnings.warn(f"Filter an error truncated data point (labels all {IGNORE_INDEX})")
|
||||
continue # Skip the current error data point.
|
||||
spliced_data_point = {
|
||||
self.input_ids_field: truncated_seq_input_ids,
|
||||
self.labels_field: truncated_label_ids,
|
||||
}
|
||||
examples.append(spliced_data_point)
|
||||
warnings.warn("Find a data point to be truncated.")
|
||||
continue
|
||||
|
||||
# Pre action judgment.
|
||||
if len(spliced_input_ids) + len(seq_input_ids) > self.max_length:
|
||||
spliced_data_point = {
|
||||
self.input_ids_field: spliced_input_ids,
|
||||
self.labels_field: spliced_labels,
|
||||
} # `Dict[str, List[int]]`
|
||||
# Update.
|
||||
spliced_input_ids, spliced_labels = [], []
|
||||
spliced_input_ids.extend(seq_input_ids)
|
||||
spliced_labels.extend(seq_labels)
|
||||
examples.append(spliced_data_point)
|
||||
else:
|
||||
spliced_input_ids.extend(seq_input_ids)
|
||||
spliced_labels.extend(seq_labels)
|
||||
# For residual spliced data point at the end of the data set
|
||||
if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
|
||||
examples.append({self.input_ids_field: spliced_input_ids, self.labels_field: spliced_labels})
|
||||
if self.shuffle:
|
||||
random.shuffle(examples)
|
||||
for spliced_data_point in examples:
|
||||
# TODO(2023-09-18): check errors for each spliced tokenized data point.
|
||||
self.current_size += 1
|
||||
yield spliced_data_point
|
110
applications/Colossal-LLaMA/colossal_llama/model/init_model.py
Normal file
110
applications/Colossal-LLaMA/colossal_llama/model/init_model.py
Normal file
@@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Initialize new model with updated tokenizer by calculating the mean values from original model
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--source_model_and_tokenizer_path",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="Source path of model & tokenizer",
|
||||
)
|
||||
parser.add_argument("--target_tokenizer_path", type=str, required=True, default=None, help="Target tokenizer path")
|
||||
parser.add_argument("--target_model_path", type=str, required=True, default=None, help="Target model path")
|
||||
args = parser.parse_args()
|
||||
|
||||
source_tokenizer = LlamaTokenizer.from_pretrained(args.source_model_and_tokenizer_path)
|
||||
source_tokenizer.add_bos_token = False
|
||||
source_tokenizer.add_eos_token = False
|
||||
if source_tokenizer.pad_token is None:
|
||||
source_tokenizer.pad_token = source_tokenizer.unk_token
|
||||
source_vocab = source_tokenizer.get_vocab()
|
||||
|
||||
target_tokenizer = LlamaTokenizer.from_pretrained(args.target_tokenizer_path)
|
||||
target_tokenizer.add_bos_token = False
|
||||
target_tokenizer.add_eos_token = False
|
||||
if target_tokenizer.pad_token is None:
|
||||
target_tokenizer.pad_token = target_tokenizer.unk_token
|
||||
target_vocab = target_tokenizer.get_vocab()
|
||||
target_inverted_vocab = {v: k for k, v in target_vocab.items()}
|
||||
|
||||
assert len(target_vocab) > len(
|
||||
source_vocab
|
||||
), f"Target vocab size({len(target_vocab)}) must be greater than source vocab size({len(source_vocab)})"
|
||||
|
||||
gpu_device = torch.device("cuda:0")
|
||||
cpu_device = torch.device("cpu")
|
||||
|
||||
source_model = LlamaForCausalLM.from_pretrained(args.source_model_and_tokenizer_path)
|
||||
source_model.eval()
|
||||
source_model = source_model.to(gpu_device)
|
||||
|
||||
source_input_embeddings = source_model.get_input_embeddings()
|
||||
assert isinstance(source_input_embeddings, torch.nn.Embedding)
|
||||
assert source_input_embeddings.weight.shape[0] == len(source_vocab)
|
||||
source_input_embeddings.eval()
|
||||
|
||||
source_output_embeddings = source_model.get_output_embeddings()
|
||||
assert isinstance(source_output_embeddings, torch.nn.Linear)
|
||||
assert source_output_embeddings.bias is None
|
||||
assert source_output_embeddings.weight.shape[0] == len(source_vocab)
|
||||
source_output_embeddings.eval()
|
||||
|
||||
input_embeddings = source_input_embeddings.weight.cpu().detach().numpy()
|
||||
output_embeddings = source_output_embeddings.weight.cpu().detach().numpy()
|
||||
for i in range(len(source_vocab), len(target_vocab)):
|
||||
if i % 500 == 0:
|
||||
logger.info(f"processing {i}/{len(target_vocab)} target tokens")
|
||||
target_token = target_inverted_vocab[i]
|
||||
target_to_source_token_ids = torch.LongTensor(source_tokenizer([target_token])["input_ids"][0])
|
||||
target_to_source_token_ids = target_to_source_token_ids.to(gpu_device)
|
||||
|
||||
target_to_source_input_embedding = (
|
||||
source_input_embeddings.weight[target_to_source_token_ids]
|
||||
.mean(dim=0)
|
||||
.unsqueeze(dim=0)
|
||||
.cpu()
|
||||
.detach()
|
||||
.numpy()
|
||||
)
|
||||
target_to_source_output_embedding = (
|
||||
source_output_embeddings.weight[target_to_source_token_ids]
|
||||
.mean(dim=0)
|
||||
.unsqueeze(dim=0)
|
||||
.cpu()
|
||||
.detach()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
input_embeddings = np.concatenate((input_embeddings, target_to_source_input_embedding), axis=0)
|
||||
output_embeddings = np.concatenate((output_embeddings, target_to_source_output_embedding), axis=0)
|
||||
|
||||
source_model = source_model.to(cpu_device)
|
||||
assert isinstance(source_model, LlamaForCausalLM)
|
||||
|
||||
# expand
|
||||
source_model.resize_token_embeddings(new_num_tokens=len(target_vocab))
|
||||
source_model.model.embed_tokens.weight.data = torch.Tensor(input_embeddings)
|
||||
source_model.lm_head.weight.data = torch.Tensor(output_embeddings)
|
||||
|
||||
source_model = source_model.half()
|
||||
source_model.save_pretrained(save_directory=args.target_model_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Initialize new tokenizer for continual pre-training
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def expand_vocab_tokenizer(
|
||||
source_tokenizer_dir: Union[str, os.PathLike], target_tokenizer_dir: Union[str, os.PathLike], new_tokens: List[str]
|
||||
) -> None:
|
||||
"""Expand tokenizer for continue pre-training."""
|
||||
if os.path.exists(target_tokenizer_dir):
|
||||
raise RuntimeError(f"Find existed directory {target_tokenizer_dir}")
|
||||
|
||||
source_tokenizer = LlamaTokenizer.from_pretrained(source_tokenizer_dir)
|
||||
logger.info(source_tokenizer)
|
||||
source_sp_processor = source_tokenizer.sp_model
|
||||
source_spm = sp_pb2_model.ModelProto()
|
||||
source_spm.ParseFromString(source_sp_processor.serialized_model_proto())
|
||||
|
||||
logger.info(f"Source tokenizer size: {len(source_sp_processor)}")
|
||||
|
||||
# Add new tokens to source tokenizer.
|
||||
source_spm_tokens = set([p.piece for p in source_spm.pieces])
|
||||
for piece in new_tokens:
|
||||
assert isinstance(piece, str), f"Invalid token({piece}) type {type(piece)}"
|
||||
if piece in source_spm_tokens:
|
||||
# Skip existed token.
|
||||
continue
|
||||
new_p = sp_pb2_model.ModelProto().SentencePiece()
|
||||
new_p.piece = piece
|
||||
new_p.score = 0
|
||||
source_spm.pieces.append(new_p)
|
||||
logger.info(f"Expand vocab from {len(source_spm_tokens)} to {len(source_spm.pieces)}")
|
||||
|
||||
# Save
|
||||
os.makedirs(target_tokenizer_dir)
|
||||
target_tokenizer_model_path = os.path.join(target_tokenizer_dir, "tokenizer.model")
|
||||
with open(file=target_tokenizer_model_path, mode="wb") as fp:
|
||||
fp.write(source_spm.SerializeToString())
|
||||
|
||||
target_tokenizer = LlamaTokenizer(vocab_file=target_tokenizer_model_path)
|
||||
target_tokenizer.save_pretrained(save_directory=target_tokenizer_dir)
|
||||
logger.info(f"Successfully save expand tokenizer to {target_tokenizer_dir}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--source_tokenizer_dir", type=str, required=True, default=None, help="Source tokenizer directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_tokenizer_dir", type=str, required=True, default=None, help="Target tokenizer directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expand_tokens_file",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="Path of the file containing tokens to be extended",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
expand_tokens = []
|
||||
with open(file=args.expand_tokens_file, mode="r", encoding="utf-8") as fp_reader:
|
||||
for line in fp_reader:
|
||||
item = json.loads(line)
|
||||
# e.g., {"piece": "你好"}
|
||||
token = item["piece"]
|
||||
if token in expand_tokens:
|
||||
continue
|
||||
expand_tokens.append(token)
|
||||
expand_tokens.sort(key=lambda t: len(t), reverse=False)
|
||||
|
||||
expand_vocab_tokenizer(
|
||||
source_tokenizer_dir=args.source_tokenizer_dir,
|
||||
target_tokenizer_dir=args.target_tokenizer_dir,
|
||||
new_tokens=expand_tokens,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
88
applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py
Normal file
88
applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Helper functions for IO
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
|
||||
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
|
||||
"""
|
||||
Load file in JSON format
|
||||
"""
|
||||
with open(file=file_path, mode="r", encoding="utf-8") as fp:
|
||||
return json.load(fp)
|
||||
|
||||
|
||||
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
|
||||
"""
|
||||
Save as JSON format
|
||||
"""
|
||||
with open(file=file_path, mode="w", encoding="utf-8") as fp:
|
||||
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
save_dir: Union[str, os.PathLike],
|
||||
booster: Booster,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
epoch: int,
|
||||
step: int,
|
||||
batch_size: int,
|
||||
coordinator: DistCoordinator,
|
||||
) -> None:
|
||||
"""
|
||||
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
||||
"""
|
||||
|
||||
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
|
||||
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
|
||||
|
||||
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
|
||||
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"sample_start_index": step * batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
load_dir: Union[str, os.PathLike],
|
||||
booster: Booster,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
||||
"""
|
||||
|
||||
# Update booster params states.
|
||||
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
|
||||
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
|
||||
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
|
||||
|
||||
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
|
||||
return (
|
||||
running_states["epoch"],
|
||||
running_states["step"],
|
||||
running_states["sample_start_index"],
|
||||
)
|
@@ -0,0 +1,352 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
if get_accelerator().name == "cuda":
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
|
||||
from flash_attn.ops.rms_norm import rms_norm
|
||||
|
||||
def _prepare_decoder_attention_mask(
|
||||
self: LlamaModel,
|
||||
attention_mask: torch.BoolTensor,
|
||||
input_shape: torch.Size,
|
||||
inputs_embeds: torch.Tensor,
|
||||
past_key_values_length: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Decoder attetion mask
|
||||
"""
|
||||
if past_key_values_length > 0 and attention_mask is not None:
|
||||
attention_mask = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(input_shape[0], past_key_values_length),
|
||||
fill_value=True,
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
),
|
||||
attention_mask,
|
||||
),
|
||||
dim=-1,
|
||||
) # (bsz, past_key_values_length + q_len)
|
||||
if attention_mask is not None and torch.all(attention_mask):
|
||||
return None # Faster
|
||||
return attention_mask
|
||||
|
||||
def attention_forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
|
||||
"""
|
||||
if output_attentions:
|
||||
logger.warning(
|
||||
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
|
||||
"return `None` instead."
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
q_slicing, kv_slicing = (
|
||||
dim // self.config.pretraining_tp
|
||||
for dim in (
|
||||
self.num_heads * self.head_dim,
|
||||
self.num_key_value_heads * self.head_dim,
|
||||
)
|
||||
) # `Tuple[int, int]`
|
||||
q_slices, k_slices, v_slices = (
|
||||
proj.weight.split(slicing, dim=0)
|
||||
for proj, slicing in (
|
||||
(self.q_proj, q_slicing),
|
||||
(self.k_proj, kv_slicing),
|
||||
(self.v_proj, kv_slicing),
|
||||
)
|
||||
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
|
||||
q, k, v = (
|
||||
torch.cat(
|
||||
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
|
||||
dim=-1,
|
||||
)
|
||||
for slices in (q_slices, k_slices, v_slices)
|
||||
)
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
else:
|
||||
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
|
||||
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k, v = (
|
||||
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
|
||||
for states, num_heads in (
|
||||
(q, self.num_heads),
|
||||
(k, self.num_key_value_heads),
|
||||
(v, self.num_key_value_heads),
|
||||
)
|
||||
)
|
||||
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
|
||||
past_kv_len = 0
|
||||
if past_key_value is not None:
|
||||
# if `past_key_value` is not None, `kv_len` > `q_len`.
|
||||
past_kv_len = past_key_value[0].shape[-2]
|
||||
kv_len += past_kv_len
|
||||
|
||||
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
|
||||
cos, sin = self.rotary_emb(v, seq_len=kv_len)
|
||||
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
k = torch.cat([past_key_value[0], k], dim=2)
|
||||
v = torch.cat([past_key_value[1], v], dim=2)
|
||||
|
||||
past_key_value = (k, v) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
|
||||
key_padding_mask = attention_mask
|
||||
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
|
||||
q, k, v = (states.transpose(1, 2) for states in (q, k, v))
|
||||
|
||||
if past_kv_len > 0:
|
||||
q = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(bsz, past_kv_len, self.num_heads, self.head_dim),
|
||||
fill_value=0.0,
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
),
|
||||
q,
|
||||
),
|
||||
dim=1,
|
||||
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
|
||||
if key_padding_mask is None:
|
||||
# (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
|
||||
output = rearrange(
|
||||
output, pattern="... h d -> ... (h d)"
|
||||
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
else:
|
||||
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
|
||||
kv, _, cu_kv_lens, max_kv_len = unpad_input(
|
||||
hidden_states=torch.stack(tensors=(k, v), dim=2),
|
||||
attention_mask=key_padding_mask,
|
||||
)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q=q,
|
||||
kv=kv,
|
||||
cu_seqlens_q=cu_q_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
max_seqlen_q=max_q_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = pad_input(
|
||||
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
|
||||
indices=indices,
|
||||
batch=bsz,
|
||||
seqlen=past_kv_len + q_len,
|
||||
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
|
||||
if past_kv_len > 0:
|
||||
# Strip off the zero query outputs.
|
||||
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
|
||||
output = self.o_proj(output) # (bsz, q_len, hidden_size)
|
||||
return output, None, past_key_value
|
||||
|
||||
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Formard function for RMS Norm
|
||||
"""
|
||||
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
|
||||
|
||||
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
module.forward = MethodType(attention_forward, module)
|
||||
if isinstance(module, LlamaModel):
|
||||
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
|
||||
if isinstance(module, LlamaRMSNorm):
|
||||
module.forward = MethodType(rms_norm_forward, module)
|
||||
|
||||
elif get_accelerator().name == "npu":
|
||||
import torch_npu
|
||||
|
||||
class NPULlamaAttention(LlamaAttention):
|
||||
use_flash: bool = True
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
self._softmax_scale = 1 / math.sqrt(self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if not self.use_flash:
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
else:
|
||||
attn_output, *_ = torch_npu.npu_fusion_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
self.num_heads,
|
||||
"BNSD",
|
||||
atten_mask=attention_mask.bool(),
|
||||
scale=self._softmax_scale,
|
||||
padding_mask=None,
|
||||
pre_tockens=65535,
|
||||
next_tockens=0,
|
||||
keep_prob=1.0,
|
||||
inner_precise=0,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum(
|
||||
[F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
)
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
class NPURMSNorm(LlamaRMSNorm):
|
||||
def forward(self, hidden_states):
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
||||
|
||||
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
module.__class__ = NPULlamaAttention
|
||||
module.setup()
|
||||
if isinstance(module, LlamaRMSNorm):
|
||||
module.__class__ = NPURMSNorm
|
18
applications/Colossal-LLaMA/colossal_llama/utils/froze.py
Normal file
18
applications/Colossal-LLaMA/colossal_llama/utils/froze.py
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from transformers.models.llama import LlamaForCausalLM
|
||||
|
||||
|
||||
def freeze_non_embeds_parameters(model: LlamaForCausalLM) -> None:
|
||||
"""Freeze all parameters except embeddings."""
|
||||
for name, params in model.named_parameters():
|
||||
if "embed_tokens" not in name and "lm_head" not in name:
|
||||
params.requires_grad = False
|
||||
else:
|
||||
params.requires_grad = True
|
||||
|
||||
|
||||
def unfreeze_parameters(model: LlamaForCausalLM) -> None:
|
||||
for name, params in model.named_parameters():
|
||||
params.requires_grad = False
|
@@ -0,0 +1,72 @@
|
||||
# Copyright 2023 The Hugging Face team
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def unwrap(model):
|
||||
if hasattr(model, "module"):
|
||||
return model.unwrap()
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def neftune_post_forward_hook(module, input, output):
|
||||
"""
|
||||
Implements the NEFTune forward pass for the model using forward hooks. Note this works only for torch.nn.Embedding
|
||||
layers. This method is slightly adapted from the original source code that can be found here:
|
||||
https://github.com/neelsjain/NEFTune Simply add it to your model as follows:
|
||||
```python
|
||||
model = ...
|
||||
model.embed_tokens.neftune_noise_alpha = 0.1
|
||||
model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
|
||||
```
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The embedding module where the hook is attached. Note that you need to set `module.neftune_noise_alpha` to
|
||||
the desired noise alpha value.
|
||||
input (`torch.Tensor`):
|
||||
The input tensor to the model.
|
||||
output (`torch.Tensor`):
|
||||
The output tensor of the model (i.e. the embeddings).
|
||||
"""
|
||||
if module.training:
|
||||
dims = torch.tensor(output.size(1) * output.size(2))
|
||||
mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
|
||||
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
|
||||
return output
|
||||
|
||||
|
||||
def activate_neftune(model, neftune_noise_alpha=0.1):
|
||||
r"""
|
||||
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
|
||||
https://arxiv.org/abs/2310.05914
|
||||
"""
|
||||
embeddings = unwrap(model).get_input_embeddings()
|
||||
|
||||
embeddings.neftune_noise_alpha = neftune_noise_alpha
|
||||
hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
|
||||
neftune_hook_handle = hook_handle
|
||||
|
||||
return model, neftune_hook_handle
|
||||
|
||||
|
||||
def deactivate_neftune(model, neftune_hook_handle):
|
||||
"""
|
||||
Deactivates the neftune method. Make sure to call `_activate_neftune` first.
|
||||
"""
|
||||
embeddings = unwrap(model).get_input_embeddings()
|
||||
|
||||
neftune_hook_handle.remove()
|
||||
del embeddings.neftune_noise_alpha
|
@@ -0,0 +1,252 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.generation.utils import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_prompt_template(
|
||||
input_query: str,
|
||||
history: List[Dict] = None,
|
||||
roles: list = ["", "Human", "Assistant"],
|
||||
) -> str:
|
||||
"""
|
||||
Generates a prompt template for chat models based on input and history.
|
||||
|
||||
Args:
|
||||
input_query (str): User's current input query.
|
||||
history (List[Dict], optional): List of past conversations, each a dict with 'role' and 'message'.
|
||||
roles (list): Specifies the roles in the conversation, defaults to ["", "Human", "Assistant"].
|
||||
|
||||
Returns:
|
||||
str: A formatted prompt including the input query and history.
|
||||
"""
|
||||
prompt = ""
|
||||
if history is None:
|
||||
new_history = []
|
||||
else:
|
||||
new_history = deepcopy(history)
|
||||
|
||||
new_history.append({"role": roles[1], "message": input_query.strip()})
|
||||
new_history.append({"role": roles[2], "message": None})
|
||||
|
||||
for _, item in enumerate(new_history):
|
||||
role = item.get("role")
|
||||
message = item.get("message")
|
||||
if role == roles[0]:
|
||||
prompt += f"<s>{message}\n\n"
|
||||
else:
|
||||
if message:
|
||||
prompt += f"{role}: <s>{message}</s>"
|
||||
else:
|
||||
prompt += f"{role}: <s>"
|
||||
return prompt
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def streaming_chat(
|
||||
model: Any,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
input_query: str,
|
||||
history: List[Dict] = None,
|
||||
roles: list = ["", "Human", "Assistant"],
|
||||
past_key_values: Tuple[Tuple[torch.FloatTensor, Any], Any] = None,
|
||||
temperature: float = 0.8,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 50,
|
||||
do_sample: bool = True,
|
||||
length_penalty: float = 1.2,
|
||||
max_new_tokens: int = 512,
|
||||
logits_processor: LogitsProcessorList = None,
|
||||
return_past_key_values: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Streaming chat responses generation with a given model and tokenizer.
|
||||
|
||||
Args:
|
||||
model (Any): The language model to generate responses.
|
||||
tokenizer (PreTrainedTokenizer): Tokenizer compatible with the model, used for encoding inputs and decoding responses.
|
||||
input_query (str): The current user input to respond to.
|
||||
history (List[Dict], optional): A list of past conversations, where each conversation is a dictionary with keys 'role' and 'message'.
|
||||
roles (list): Roles involved in the conversation, defaults to ["", "Human", "Assistant"].
|
||||
past_key_values (Tuple[Tuple[torch.FloatTensor, Any], Any], optional): Past key values for incremental decoding.
|
||||
temperature (float): The temperature value for token sampling, defaults to 0.8.
|
||||
top_p (float): Nucleus sampling probability threshold, defaults to 0.95.
|
||||
top_k (int): Top-K filtering threshold, defaults to 50.
|
||||
do_sample (bool): Whether to sample responses, defaults to True.
|
||||
length_penalty (float): Penalty for response length, defaults to 1.2.
|
||||
max_new_tokens (int): Maximum number of new tokens to generate, defaults to 512.
|
||||
logits_processor (LogitsProcessorList, optional): Custom logits processors, defaults to None.
|
||||
return_past_key_values (bool): Whether to return past key values for further incremental decoding, defaults to False.
|
||||
**kwargs: Additional keyword arguments for generation.
|
||||
|
||||
Yields:
|
||||
Tuple[str, List[Dict], Optional[Tuple[Tuple[torch.FloatTensor, Any], Any]]]: A tuple containing the generated response, updated history, and
|
||||
optionally the updated past key values if `return_past_key_values` is True.
|
||||
|
||||
Ensures padding is on the left side for the tokenizer.
|
||||
"""
|
||||
assert tokenizer.padding_side == "left", "Current generation only supports left padding."
|
||||
if history is None:
|
||||
history = []
|
||||
if logits_processor is None:
|
||||
logits_processor = LogitsProcessorList()
|
||||
|
||||
generation_kwargs = {
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"do_sample": do_sample,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"length_penalty": length_penalty,
|
||||
"use_cache": True,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
prompt_str = get_prompt_template(input_query, history=history, roles=roles)
|
||||
|
||||
eos_token_id = [tokenizer.eos_token_id]
|
||||
inputs = tokenizer(prompt_str, return_tensors="pt").to(model.device)
|
||||
history.append({"role": roles[1], "message": input_query.strip()})
|
||||
history.append({"role": roles[2], "message": None})
|
||||
|
||||
for outputs in stream_generate(
|
||||
model,
|
||||
**inputs,
|
||||
past_key_values=past_key_values,
|
||||
eos_token_id=eos_token_id,
|
||||
return_past_key_values=return_past_key_values,
|
||||
**generation_kwargs,
|
||||
):
|
||||
if return_past_key_values:
|
||||
outputs, past_key_values = outputs
|
||||
|
||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1]
|
||||
response = tokenizer.decode(outputs)
|
||||
|
||||
history[-1]["message"] = response.strip()
|
||||
if return_past_key_values:
|
||||
yield response, history, past_key_values
|
||||
else:
|
||||
yield response, history
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def stream_generate(
|
||||
model: Any,
|
||||
input_ids: torch.Tensor,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
return_past_key_values: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generates sequences of token ids using the specified model and generation parameters.
|
||||
Adapted from https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py
|
||||
|
||||
Args:
|
||||
model (Any): The model used for generating sequences of token ids.
|
||||
input_ids (torch.Tensor): The sequence used as a prompt for the generation or as model inputs to the encoder.
|
||||
generation_config (Optional[GenerationConfig]): The generation configuration to be used as base parametrization for the generation call.
|
||||
logits_processor (Optional[LogitsProcessorList]): Custom logits processors that complement the default logits processors built from arguments
|
||||
and generation config.
|
||||
stopping_criteria (Optional[StoppingCriteriaList]): Custom stopping criteria that complement the default stopping criteria built from arguments
|
||||
and a generation config.
|
||||
prefix_allowed_tokens_fn (Optional[Callable[[int, torch.Tensor], List[int]]]): Function to constrain token generation.
|
||||
return_past_key_values (bool): Whether to return past key values for further incremental decoding, defaults to False.
|
||||
**kwargs: Additional parameters for model generation.
|
||||
|
||||
Yields:
|
||||
torch.Tensor: The generated token IDs, updated after each generation step.
|
||||
Optional[Tuple[Tuple[torch.FloatTensor, Any], Any]]: The past key values, returned if `return_past_key_values` is True, defaults to False.
|
||||
"""
|
||||
input_ids_len = input_ids.size(1)
|
||||
|
||||
if generation_config is None:
|
||||
generation_config = model.generation_config
|
||||
generation_config = deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(**kwargs)
|
||||
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||
|
||||
if generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_len
|
||||
|
||||
if input_ids_len >= generation_config.max_length:
|
||||
input_ids_string = "decoder_input_ids" if model.config.is_encoder_decoder else "input_ids"
|
||||
logger.warning(
|
||||
f"Input length of {input_ids_string} is {input_ids_len}, but `max_length` is set to"
|
||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||
" increasing `max_new_tokens`."
|
||||
)
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
# prepare distribution pre_processing samplers
|
||||
logits_processor = model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_len,
|
||||
encoder_input_ids=input_ids,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
# prepare stopping criteria
|
||||
stopping_criteria = model._get_stopping_criteria(
|
||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||
)
|
||||
|
||||
logits_warper = model._get_logits_warper(generation_config)
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
scores = None
|
||||
|
||||
while True:
|
||||
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
# forward pass to get next token
|
||||
outputs = model(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
|
||||
# NOTE: this is correct only in left padding mode
|
||||
# pre-process distribution
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# sample
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
if generation_config.do_sample:
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
else:
|
||||
next_tokens = torch.argmax(probs, dim=-1)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
model_kwargs = model._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
|
||||
)
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||
)
|
||||
|
||||
if return_past_key_values:
|
||||
yield input_ids, outputs.past_key_values
|
||||
else:
|
||||
yield input_ids
|
||||
# stop when each sentence is finished, or if exceed the maximum length
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||
break
|
Reference in New Issue
Block a user