initial commit: add colossal llama 2 (#4784)

This commit is contained in:
Tong Li
2023-09-24 23:12:26 +08:00
committed by GitHub
parent 4146f1c0ce
commit 74aa7d964a
19 changed files with 2162 additions and 2 deletions

View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,219 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import os
import random
from dataclasses import dataclass
from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
import torch
from datasets import dataset_dict, load_from_disk
from datasets import Dataset as HFDataset
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
from transformers.tokenization_utils import PreTrainedTokenizer
import torch.nn.functional as F
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
PathType = Union[str, os.PathLike]
def load_tokenized_dataset(
dataset_paths: Union[PathType, List[PathType]], mode: str = "train"
) -> Optional[DatasetType]:
"""
Load pre-tokenized dataset.
Each instance of dataset is a dictionary with
`{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
"""
mode_map = {"train": "train", "dev": "validation", "test": "test"}
assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
if isinstance(dataset_paths, (str, os.PathLike)):
dataset_paths = [dataset_paths]
datasets = [] # `List[datasets.dataset_dict.Dataset]`
for ds_path in dataset_paths:
ds_path = os.path.abspath(ds_path)
assert os.path.exists(ds_path), f"Not existed file path {ds_path}"
ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False)
if isinstance(ds_dict, HFDataset):
datasets.append(ds_dict)
else:
if mode_map[mode] in ds_dict:
datasets.append(ds_dict[mode_map[mode]])
if len(datasets) == 0:
return None
if len(datasets) == 1:
return datasets.pop()
return ConcatDataset(datasets=datasets)
@dataclass
class DataCollatorForSupervisedDataset(object):
"""
Collate instances for supervised dataset.
Each instance is a tokenized dictionary with fields
`input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
"""
tokenizer: PreTrainedTokenizer
max_length: int = 4096
ignore_index: int = -100
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
"""
Args:
instances (`Sequence[Dict[str, List[int]]]`):
Mini-batch samples, each sample is stored in an individual dictionary.
Returns:
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
`labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
"""
assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
f"but now `{self.tokenizer.pad_token_id}`"
)
# `List[torch.Tensor]`
batch_input_ids = [
torch.LongTensor(instance["input_ids"][: self.max_length])
if len(instance["input_ids"]) > self.max_length
else torch.LongTensor(instance["input_ids"])
for instance in instances
]
batch_labels = [
torch.LongTensor(instance["labels"][: self.max_length])
if len(instance["labels"]) > self.max_length
else torch.LongTensor(instance["labels"])
for instance in instances
]
if self.tokenizer.padding_side == "right":
input_ids = torch.nn.utils.rnn.pad_sequence(
sequences=batch_input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
) # (bsz, max_len)
labels = torch.nn.utils.rnn.pad_sequence(
sequences=batch_labels,
batch_first=True,
padding_value=self.ignore_index,
) # (bsz, max_len)
# pad to max
to_pad = self.max_length - input_ids.size(1)
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
elif self.tokenizer.padding_side == "left":
reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
sequences=reversed_input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
) # (bsz, max_len)
input_ids = torch.flip(reversed_input_ids, dims=(1,)) # (bsz, max_len)
reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels]
reversed_labels = torch.nn.utils.rnn.pad_sequence(
sequences=reversed_labels,
batch_first=True,
padding_value=self.ignore_index,
) # (bsz, max_len)
labels = torch.flip(reversed_labels, dims=(1,)) # (bsz, max_len)
else:
raise RuntimeError(
f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, "
f"but now `{self.tokenizer.padding_side}`"
)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
class StatefulDistributedSampler(DistributedSampler):
"""
Stateful distributed sampler for multi-stage training.
"""
def __init__(
self,
dataset: DatasetType,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
super().__init__(
dataset=dataset,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
)
self.start_index = 0
def __iter__(self) -> Iterator:
iterator = super().__iter__()
indices = list(iterator)
indices = indices[self.start_index :]
return iter(indices)
def __len__(self) -> int:
return self.num_samples - self.start_index
def set_start_index(self, start_index: int) -> None:
self.start_index = start_index
def setup_distributed_dataloader(
dataset: DatasetType,
batch_size: int = 1,
shuffle: bool = False,
seed: int = 1024,
drop_last: bool = False,
pin_memory: bool = False,
num_workers: int = 0,
collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
process_group: Optional[ProcessGroup] = None,
**kwargs,
) -> DataLoader:
"""
Setup dataloader for distributed training.
"""
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
sampler = StatefulDistributedSampler(
dataset=dataset,
num_replicas=process_group.size(),
rank=process_group.rank(),
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
)
# Deterministic dataloader
def seed_worker(worker_id: int) -> None:
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset=dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
worker_init_fn=seed_worker,
**_kwargs,
)

View File

@@ -0,0 +1,183 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Splicing multiple pre-tokenized sequence data points
"""
import random
import warnings
from copy import deepcopy
from datasets import dataset_dict
from typing import Any, Callable, Dict, Iterable, List, Union, Tuple
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
IGNORE_INDEX = -100
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
def supervised_tokenize(
data_point: Dict[str, str], tokenizer: LlamaTokenizer, ignore_index: int = None, max_length: int = 4096
) -> Dict[str, Union[int, str, List[int]]]:
"""
A tokenization function to tokenize an original pretraining data point as following:
{"source": "", "target": "Beijing, the capital of the People's Republic of China, ...", "category": "geography"}
"""
assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, (
"Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, "
"add <bos> and <eos> manually later"
)
if ignore_index is None:
ignore_index = IGNORE_INDEX
source_text = data_point["source"] # `str`
target_text = data_point["target"] # `str`
is_null_source = len(source_text) == 0
source_text = tokenizer.bos_token + source_text
target_text += tokenizer.eos_token
sequence_text = source_text + target_text
tokenized = tokenizer([source_text, sequence_text])["input_ids"]
sequence_input_ids = tokenized[1]
sequence_labels = deepcopy(sequence_input_ids)
source_length = len(tokenized[0])
if not is_null_source:
sequence_labels[:source_length] = [ignore_index for _ in range(source_length)]
# sequence truncation.
if len(sequence_input_ids) > max_length:
sequence_input_ids = sequence_input_ids[:max_length]
sequence_labels = sequence_labels[:max_length]
return dict(
input_ids=sequence_input_ids,
labels=sequence_labels,
seq_length=len(sequence_input_ids),
seq_category=data_point["category"],
)
class ClosedToConstantLengthSplicedDataset(IterableDataset):
"""
Define an iterable dataset that returns a (close to) constant length data point spliced from multiple
original independent (pre-tokenized) data points.
"""
def __init__(
self,
dataset: DSType,
tokenizer: PreTrainedTokenizer,
max_length: int = 4096,
num_packed_sequences: int = 8,
fetch_sequence_func: Callable[[Any], Tuple[List[int], List[int]]] = None,
input_ids_field: str = "input_ids",
labels_field: str = "labels",
infinite: bool = False,
shuffle: bool = True,
error_strict: bool = False,
) -> None:
self.tokenizer = tokenizer
self.dataset = dataset
self.max_length = max_length
self.infinite = infinite
self.max_buffer_size = max_length * num_packed_sequences # e.g., 4096 * 16
self.shuffle = shuffle
# Callable[[Dict[str, Any]], Tuple[List[int], List[int]]],
# A function that fetch sequence input_ids and labels from the original data point
if fetch_sequence_func is None:
self.fetch_sequence_func = lambda data_point: (data_point[input_ids_field], data_point[labels_field])
else:
self.fetch_sequence_func = fetch_sequence_func
self.input_ids_field = input_ids_field
self.labels_field = labels_field
self.error_strict = error_strict
self.current_size = 0 # `int`, current packed data size.
def __len__(self) -> int:
return len(self.dataset)
def __iter__(self) -> Iterable[Dict[str, List[int]]]:
iterator = iter(self.dataset)
more_data_points = True
while more_data_points is True:
buffer, buffer_len = [], 0
while True:
# ending condition.
if buffer_len >= self.max_buffer_size:
break
try:
# `Tuple[List[int], List[int]]`
seq_input_ids, seq_labels = self.fetch_sequence_func(next(iterator))
buffer.append({self.input_ids_field: seq_input_ids, self.labels_field: seq_labels})
buffer_len += len(buffer[-1][self.input_ids_field])
except StopIteration:
if self.infinite is True:
iterator = iter(self.dataset)
warnings.warn("The dataset reached end and the iterator is reset to the start.")
else:
more_data_points = False
break
examples = [] # `List[Dict[str, List[int]]]`, save buffered spliced data points.
spliced_input_ids, spliced_labels = [], [] # `List[int]`, `List[int]`
for i, data_point in enumerate(buffer):
# TODO(2023-09-18) check errors for each unspliced tokenized data point
seq_input_ids = data_point[self.input_ids_field]
seq_labels = data_point[self.labels_field]
# Handle special case:
# If the length of an original data point (i.e., input_ids length of a data point before splicing)
# exceeds `max_length`, truncate it.
if len(seq_input_ids) > self.max_length:
truncated_seq_input_ids = seq_input_ids[: self.max_length]
truncated_label_ids = seq_labels[: self.max_length]
if set(truncated_label_ids) == {IGNORE_INDEX}:
if self.error_strict is True:
raise ValueError(
f"Find an out-of-bounds length({len(seq_input_ids)}) data point "
f"with all label values as {IGNORE_INDEX}."
)
else:
warnings.warn(f"Filter an error truncated data point (labels all {IGNORE_INDEX})")
continue # Skip the current error data point.
spliced_data_point = {
self.input_ids_field: truncated_seq_input_ids,
self.labels_field: truncated_label_ids,
}
examples.append(spliced_data_point)
warnings.warn("Find a data point to be truncated.")
continue
# Pre action judgment.
if len(spliced_input_ids) + len(seq_input_ids) > self.max_length:
spliced_data_point = {
self.input_ids_field: spliced_input_ids,
self.labels_field: spliced_labels,
} # `Dict[str, List[int]]`
# Update.
spliced_input_ids, spliced_labels = [], []
spliced_input_ids.extend(seq_input_ids)
spliced_labels.extend(seq_labels)
examples.append(spliced_data_point)
else:
spliced_input_ids.extend(seq_input_ids)
spliced_labels.extend(seq_labels)
# For residual spliced data point at the end of the data set
if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
examples.append(
{
self.input_ids_field: spliced_input_ids,
self.labels_field: spliced_labels
}
)
if self.shuffle:
random.shuffle(examples)
for spliced_data_point in examples:
# TODO(2023-09-18): check errors for each spliced tokenized data point.
self.current_size += 1
yield spliced_data_point

View File

@@ -0,0 +1,111 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Initialize new model with updated tokenizer by calculating the mean values from original model
"""
import argparse
import numpy as np
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--source_model_and_tokenizer_path",
type=str,
required=True,
default=None,
help="Source path of model & tokenizer",
)
parser.add_argument("--target_tokenizer_path", type=str, required=True, default=None, help="Target tokenizer path")
parser.add_argument("--target_model_path", type=str, required=True, default=None, help="Target model path")
args = parser.parse_args()
source_tokenizer = LlamaTokenizer.from_pretrained(args.source_model_and_tokenizer_path)
source_tokenizer.add_bos_token = False
source_tokenizer.add_eos_token = False
if source_tokenizer.pad_token is None:
source_tokenizer.pad_token = source_tokenizer.unk_token
source_vocab = source_tokenizer.get_vocab()
target_tokenizer = LlamaTokenizer.from_pretrained(args.target_tokenizer_path)
target_tokenizer.add_bos_token = False
target_tokenizer.add_eos_token = False
if target_tokenizer.pad_token is None:
target_tokenizer.pad_token = target_tokenizer.unk_token
target_vocab = target_tokenizer.get_vocab()
target_inverted_vocab = {v: k for k, v in target_vocab.items()}
assert len(target_vocab) > len(
source_vocab
), f"Target vocab size({len(target_vocab)}) must be greater than source vocab size({len(source_vocab)})"
gpu_device = torch.device("cuda:0")
cpu_device = torch.device("cpu")
source_model = LlamaForCausalLM.from_pretrained(args.source_model_and_tokenizer_path)
source_model.eval()
source_model = source_model.to(gpu_device)
source_input_embeddings = source_model.get_input_embeddings()
assert isinstance(source_input_embeddings, torch.nn.Embedding)
assert source_input_embeddings.weight.shape[0] == len(source_vocab)
source_input_embeddings.eval()
source_output_embeddings = source_model.get_output_embeddings()
assert isinstance(source_output_embeddings, torch.nn.Linear)
assert source_output_embeddings.bias is None
assert source_output_embeddings.weight.shape[0] == len(source_vocab)
source_output_embeddings.eval()
input_embeddings = source_input_embeddings.weight.cpu().detach().numpy()
output_embeddings = source_output_embeddings.weight.cpu().detach().numpy()
for i in range(len(source_vocab), len(target_vocab)):
if i % 500 == 0:
logger.info(f"processing {i}/{len(target_vocab)} target tokens")
target_token = target_inverted_vocab[i]
target_to_source_token_ids = torch.LongTensor(source_tokenizer([target_token])["input_ids"][0])
target_to_source_token_ids = target_to_source_token_ids.to(gpu_device)
target_to_source_input_embedding = (
source_input_embeddings.weight[target_to_source_token_ids]
.mean(dim=0)
.unsqueeze(dim=0)
.cpu()
.detach()
.numpy()
)
target_to_source_output_embedding = (
source_output_embeddings.weight[target_to_source_token_ids]
.mean(dim=0)
.unsqueeze(dim=0)
.cpu()
.detach()
.numpy()
)
input_embeddings = np.concatenate((input_embeddings, target_to_source_input_embedding), axis=0)
output_embeddings = np.concatenate((output_embeddings, target_to_source_output_embedding), axis=0)
source_model = source_model.to(cpu_device)
assert isinstance(source_model, LlamaForCausalLM)
# expand
source_model.resize_token_embeddings(new_num_tokens=len(target_vocab))
source_model.model.embed_tokens.weight.data = torch.Tensor(input_embeddings)
source_model.lm_head.weight.data = torch.Tensor(output_embeddings)
source_model = source_model.half()
source_model.save_pretrained(save_directory=args.target_model_path)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,98 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
Initialize new tokenizer for continual pre-training
"""
import argparse
import os
import json
from typing import List, Union
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
from colossalai.logging import get_dist_logger
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
logger = get_dist_logger()
def expand_vocab_tokenizer(
source_tokenizer_dir: Union[str, os.PathLike], target_tokenizer_dir: Union[str, os.PathLike], new_tokens: List[str]
) -> None:
"""Expand tokenizer for continue pre-training."""
if os.path.exists(target_tokenizer_dir):
raise RuntimeError(f"Find existed directory {target_tokenizer_dir}")
source_tokenizer = LlamaTokenizer.from_pretrained(source_tokenizer_dir)
logger.info(source_tokenizer)
source_sp_processor = source_tokenizer.sp_model
source_spm = sp_pb2_model.ModelProto()
source_spm.ParseFromString(source_sp_processor.serialized_model_proto())
logger.info(f"Source tokenizer size: {len(source_sp_processor)}")
# Add new tokens to source tokenizer.
source_spm_tokens = set([p.piece for p in source_spm.pieces])
for piece in new_tokens:
assert isinstance(piece, str), f"Invalid token({piece}) type {type(piece)}"
if piece in source_spm_tokens:
# Skip existed token.
continue
new_p = sp_pb2_model.ModelProto().SentencePiece()
new_p.piece = piece
new_p.score = 0
source_spm.pieces.append(new_p)
logger.info(f"Expand vocab from {len(source_spm_tokens)} to {len(source_spm.pieces)}")
# Save
os.makedirs(target_tokenizer_dir)
target_tokenizer_model_path = os.path.join(target_tokenizer_dir, "tokenizer.model")
with open(file=target_tokenizer_model_path, mode="wb") as fp:
fp.write(source_spm.SerializeToString())
target_tokenizer = LlamaTokenizer(vocab_file=target_tokenizer_model_path)
target_tokenizer.save_pretrained(save_directory=target_tokenizer_dir)
logger.info(f"Successfully save expand tokenizer to {target_tokenizer_dir}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--source_tokenizer_dir", type=str, required=True, default=None, help="Source tokenizer directory"
)
parser.add_argument(
"--target_tokenizer_dir", type=str, required=True, default=None, help="Target tokenizer directory"
)
parser.add_argument(
"--expand_tokens_file",
type=str,
required=True,
default=None,
help="Path of the file containing tokens to be extended",
)
args = parser.parse_args()
expand_tokens = []
with open(file=args.expand_tokens_file, mode="r", encoding="utf-8") as fp_reader:
for line in fp_reader:
item = json.loads(line)
# e.g., {"piece": "你好"}
token = item["piece"]
if token in expand_tokens:
continue
expand_tokens.append(token)
expand_tokens.sort(key=lambda t: len(t), reverse=False)
expand_vocab_tokenizer(
source_tokenizer_dir=args.source_tokenizer_dir,
target_tokenizer_dir=args.target_tokenizer_dir,
new_tokens=expand_tokens,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,88 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Helper functions for IO
"""
import json
import os
from typing import Any, Dict, Tuple, Union
import torch
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
"""
Load file in JSON format
"""
with open(file=file_path, mode="r", encoding="utf-8") as fp:
return json.load(fp)
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
"""
Save as JSON format
"""
with open(file=file_path, mode="w", encoding="utf-8") as fp:
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
def save_checkpoint(
save_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
epoch: int,
step: int,
batch_size: int,
coordinator: DistCoordinator,
) -> None:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
running_states = {
"epoch": epoch,
"step": step,
"sample_start_index": step * batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
def load_checkpoint(
load_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
) -> Tuple[int, int, int]:
"""
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
# Update booster params states.
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
return (
running_states["epoch"],
running_states["step"],
running_states["sample_start_index"],
)

View File

@@ -0,0 +1,216 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from types import MethodType
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
LlamaAttention,
LlamaModel,
LlamaForCausalLM,
apply_rotary_pos_emb,
repeat_kv,
)
from colossalai.logging import get_dist_logger
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import (
flash_attn_func,
flash_attn_varlen_kvpacked_func,
)
from flash_attn.ops.rms_norm import rms_norm
logger = get_dist_logger()
def _prepare_decoder_attention_mask(
self: LlamaModel,
attention_mask: torch.BoolTensor,
input_shape: torch.Size,
inputs_embeds: torch.Tensor,
past_key_values_length: int,
) -> Optional[torch.Tensor]:
"""
Decoder attetion mask
"""
if past_key_values_length > 0 and attention_mask is not None:
attention_mask = torch.cat(
tensors=(
torch.full(
size=(input_shape[0], past_key_values_length),
fill_value=True,
dtype=attention_mask.dtype,
device=attention_mask.device,
),
attention_mask,
),
dim=-1,
) # (bsz, past_key_values_length + q_len)
if attention_mask is not None and torch.all(attention_mask):
return None # Faster
return attention_mask
def attention_forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
"""
if output_attentions:
logger.warning(
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
"return `None` instead."
)
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
q_slicing, kv_slicing = (
dim // self.config.pretraining_tp
for dim in (
self.num_heads * self.head_dim,
self.num_key_value_heads * self.head_dim,
)
) # `Tuple[int, int]`
q_slices, k_slices, v_slices = (
proj.weight.split(slicing, dim=0)
for proj, slicing in (
(self.q_proj, q_slicing),
(self.k_proj, kv_slicing),
(self.v_proj, kv_slicing),
)
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
q, k, v = (
torch.cat(
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
dim=-1,
)
for slices in (q_slices, k_slices, v_slices)
)
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
# (bsz, q_len, num_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim)
else:
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
# (bsz, q_len, num_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim)
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
q, k, v = (
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
for states, num_heads in (
(q, self.num_heads),
(k, self.num_key_value_heads),
(v, self.num_key_value_heads),
)
)
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
past_kv_len = 0
if past_key_value is not None:
# if `past_key_value` is not None, `kv_len` > `q_len`.
past_kv_len = past_key_value[0].shape[-2]
kv_len += past_kv_len
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
cos, sin = self.rotary_emb(v, seq_len=kv_len)
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
k = torch.cat([past_key_value[0], k], dim=2)
v = torch.cat([past_key_value[1], v], dim=2)
past_key_value = (k, v) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
key_padding_mask = attention_mask
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
q, k, v = (states.transpose(1, 2) for states in (q, k, v))
if past_kv_len > 0:
q = torch.cat(
tensors=(
torch.full(
size=(bsz, past_kv_len, self.num_heads, self.head_dim),
fill_value=0.0,
dtype=q.dtype,
device=q.device,
),
q,
),
dim=1,
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
if key_padding_mask is None:
# (bsz, past_kv_len + q_len, num_heads, head_dim)
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim)
else:
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
kv, _, cu_kv_lens, max_kv_len = unpad_input(
hidden_states=torch.stack(tensors=(k, v), dim=2),
attention_mask=key_padding_mask,
)
output_unpad = flash_attn_varlen_kvpacked_func(
q=q,
kv=kv,
cu_seqlens_q=cu_q_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_q_len,
max_seqlen_k=max_kv_len,
dropout_p=0.0,
softmax_scale=None,
causal=True,
)
output = pad_input(
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
indices=indices,
batch=bsz,
seqlen=past_kv_len + q_len,
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
if past_kv_len > 0:
# Strip off the zero query outputs.
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
output = self.o_proj(output) # (bsz, q_len, hidden_size)
return output, None, past_key_value
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Formard function for RMS Norm
"""
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
module.forward = MethodType(attention_forward, module)
if isinstance(module, LlamaModel):
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
if isinstance(module, LlamaRMSNorm):
module.forward = MethodType(rms_norm_forward, module)

View File

@@ -0,0 +1,18 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from transformers.models.llama import LlamaForCausalLM
def freeze_non_embeds_parameters(model: LlamaForCausalLM) -> None:
"""Freeze all parameters except embeddings."""
for name, params in model.named_parameters():
if "embed_tokens" not in name and "lm_head" not in name:
params.requires_grad = False
else:
params.requires_grad = True
def unfreeze_parameters(model: LlamaForCausalLM) -> None:
for name, params in model.named_parameters():
params.requires_grad = False