mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 09:59:38 +00:00
initial commit: add colossal llama 2 (#4784)
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
@@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
219
applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
Normal file
219
applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
Normal file
@@ -0,0 +1,219 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
|
||||
|
||||
import torch
|
||||
from datasets import dataset_dict, load_from_disk
|
||||
from datasets import Dataset as HFDataset
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
import torch.nn.functional as F
|
||||
|
||||
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||
PathType = Union[str, os.PathLike]
|
||||
|
||||
|
||||
def load_tokenized_dataset(
|
||||
dataset_paths: Union[PathType, List[PathType]], mode: str = "train"
|
||||
) -> Optional[DatasetType]:
|
||||
"""
|
||||
Load pre-tokenized dataset.
|
||||
Each instance of dataset is a dictionary with
|
||||
`{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
|
||||
"""
|
||||
mode_map = {"train": "train", "dev": "validation", "test": "test"}
|
||||
assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
|
||||
|
||||
if isinstance(dataset_paths, (str, os.PathLike)):
|
||||
dataset_paths = [dataset_paths]
|
||||
|
||||
datasets = [] # `List[datasets.dataset_dict.Dataset]`
|
||||
for ds_path in dataset_paths:
|
||||
ds_path = os.path.abspath(ds_path)
|
||||
assert os.path.exists(ds_path), f"Not existed file path {ds_path}"
|
||||
ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False)
|
||||
if isinstance(ds_dict, HFDataset):
|
||||
datasets.append(ds_dict)
|
||||
else:
|
||||
if mode_map[mode] in ds_dict:
|
||||
datasets.append(ds_dict[mode_map[mode]])
|
||||
if len(datasets) == 0:
|
||||
return None
|
||||
if len(datasets) == 1:
|
||||
return datasets.pop()
|
||||
return ConcatDataset(datasets=datasets)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSupervisedDataset(object):
|
||||
"""
|
||||
Collate instances for supervised dataset.
|
||||
Each instance is a tokenized dictionary with fields
|
||||
`input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizer
|
||||
max_length: int = 4096
|
||||
ignore_index: int = -100
|
||||
|
||||
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
instances (`Sequence[Dict[str, List[int]]]`):
|
||||
Mini-batch samples, each sample is stored in an individual dictionary.
|
||||
|
||||
Returns:
|
||||
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
|
||||
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
|
||||
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
|
||||
`labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
|
||||
"""
|
||||
assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
|
||||
f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
|
||||
f"but now `{self.tokenizer.pad_token_id}`"
|
||||
)
|
||||
|
||||
# `List[torch.Tensor]`
|
||||
batch_input_ids = [
|
||||
torch.LongTensor(instance["input_ids"][: self.max_length])
|
||||
if len(instance["input_ids"]) > self.max_length
|
||||
else torch.LongTensor(instance["input_ids"])
|
||||
for instance in instances
|
||||
]
|
||||
batch_labels = [
|
||||
torch.LongTensor(instance["labels"][: self.max_length])
|
||||
if len(instance["labels"]) > self.max_length
|
||||
else torch.LongTensor(instance["labels"])
|
||||
for instance in instances
|
||||
]
|
||||
|
||||
if self.tokenizer.padding_side == "right":
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=batch_input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id,
|
||||
) # (bsz, max_len)
|
||||
labels = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=batch_labels,
|
||||
batch_first=True,
|
||||
padding_value=self.ignore_index,
|
||||
) # (bsz, max_len)
|
||||
# pad to max
|
||||
to_pad = self.max_length - input_ids.size(1)
|
||||
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
|
||||
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
|
||||
elif self.tokenizer.padding_side == "left":
|
||||
reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
|
||||
reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=reversed_input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id,
|
||||
) # (bsz, max_len)
|
||||
input_ids = torch.flip(reversed_input_ids, dims=(1,)) # (bsz, max_len)
|
||||
reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels]
|
||||
reversed_labels = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=reversed_labels,
|
||||
batch_first=True,
|
||||
padding_value=self.ignore_index,
|
||||
) # (bsz, max_len)
|
||||
labels = torch.flip(reversed_labels, dims=(1,)) # (bsz, max_len)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, "
|
||||
f"but now `{self.tokenizer.padding_side}`"
|
||||
)
|
||||
|
||||
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
|
||||
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
||||
|
||||
|
||||
class StatefulDistributedSampler(DistributedSampler):
|
||||
"""
|
||||
Stateful distributed sampler for multi-stage training.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: DatasetType,
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
dataset=dataset,
|
||||
num_replicas=num_replicas,
|
||||
rank=rank,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
self.start_index = 0
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
iterator = super().__iter__()
|
||||
indices = list(iterator)
|
||||
indices = indices[self.start_index :]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples - self.start_index
|
||||
|
||||
def set_start_index(self, start_index: int) -> None:
|
||||
self.start_index = start_index
|
||||
|
||||
|
||||
def setup_distributed_dataloader(
|
||||
dataset: DatasetType,
|
||||
batch_size: int = 1,
|
||||
shuffle: bool = False,
|
||||
seed: int = 1024,
|
||||
drop_last: bool = False,
|
||||
pin_memory: bool = False,
|
||||
num_workers: int = 0,
|
||||
collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
**kwargs,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Setup dataloader for distributed training.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
process_group = process_group or _get_default_group()
|
||||
sampler = StatefulDistributedSampler(
|
||||
dataset=dataset,
|
||||
num_replicas=process_group.size(),
|
||||
rank=process_group.rank(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id: int) -> None:
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
drop_last=drop_last,
|
||||
worker_init_fn=seed_worker,
|
||||
**_kwargs,
|
||||
)
|
@@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Splicing multiple pre-tokenized sequence data points
|
||||
"""
|
||||
|
||||
import random
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from datasets import dataset_dict
|
||||
from typing import Any, Callable, Dict, Iterable, List, Union, Tuple
|
||||
|
||||
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||
|
||||
|
||||
def supervised_tokenize(
|
||||
data_point: Dict[str, str], tokenizer: LlamaTokenizer, ignore_index: int = None, max_length: int = 4096
|
||||
) -> Dict[str, Union[int, str, List[int]]]:
|
||||
"""
|
||||
A tokenization function to tokenize an original pretraining data point as following:
|
||||
{"source": "", "target": "Beijing, the capital of the People's Republic of China, ...", "category": "geography"}
|
||||
"""
|
||||
assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, (
|
||||
"Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, "
|
||||
"add <bos> and <eos> manually later"
|
||||
)
|
||||
if ignore_index is None:
|
||||
ignore_index = IGNORE_INDEX
|
||||
|
||||
source_text = data_point["source"] # `str`
|
||||
target_text = data_point["target"] # `str`
|
||||
is_null_source = len(source_text) == 0
|
||||
|
||||
source_text = tokenizer.bos_token + source_text
|
||||
target_text += tokenizer.eos_token
|
||||
sequence_text = source_text + target_text
|
||||
|
||||
tokenized = tokenizer([source_text, sequence_text])["input_ids"]
|
||||
sequence_input_ids = tokenized[1]
|
||||
sequence_labels = deepcopy(sequence_input_ids)
|
||||
|
||||
source_length = len(tokenized[0])
|
||||
if not is_null_source:
|
||||
sequence_labels[:source_length] = [ignore_index for _ in range(source_length)]
|
||||
|
||||
# sequence truncation.
|
||||
if len(sequence_input_ids) > max_length:
|
||||
sequence_input_ids = sequence_input_ids[:max_length]
|
||||
sequence_labels = sequence_labels[:max_length]
|
||||
|
||||
return dict(
|
||||
input_ids=sequence_input_ids,
|
||||
labels=sequence_labels,
|
||||
seq_length=len(sequence_input_ids),
|
||||
seq_category=data_point["category"],
|
||||
)
|
||||
|
||||
|
||||
class ClosedToConstantLengthSplicedDataset(IterableDataset):
|
||||
"""
|
||||
Define an iterable dataset that returns a (close to) constant length data point spliced from multiple
|
||||
original independent (pre-tokenized) data points.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: DSType,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_length: int = 4096,
|
||||
num_packed_sequences: int = 8,
|
||||
fetch_sequence_func: Callable[[Any], Tuple[List[int], List[int]]] = None,
|
||||
input_ids_field: str = "input_ids",
|
||||
labels_field: str = "labels",
|
||||
infinite: bool = False,
|
||||
shuffle: bool = True,
|
||||
error_strict: bool = False,
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.dataset = dataset
|
||||
self.max_length = max_length
|
||||
self.infinite = infinite
|
||||
self.max_buffer_size = max_length * num_packed_sequences # e.g., 4096 * 16
|
||||
self.shuffle = shuffle
|
||||
|
||||
# Callable[[Dict[str, Any]], Tuple[List[int], List[int]]],
|
||||
# A function that fetch sequence input_ids and labels from the original data point
|
||||
if fetch_sequence_func is None:
|
||||
self.fetch_sequence_func = lambda data_point: (data_point[input_ids_field], data_point[labels_field])
|
||||
else:
|
||||
self.fetch_sequence_func = fetch_sequence_func
|
||||
self.input_ids_field = input_ids_field
|
||||
self.labels_field = labels_field
|
||||
|
||||
self.error_strict = error_strict
|
||||
self.current_size = 0 # `int`, current packed data size.
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataset)
|
||||
|
||||
def __iter__(self) -> Iterable[Dict[str, List[int]]]:
|
||||
iterator = iter(self.dataset)
|
||||
more_data_points = True
|
||||
while more_data_points is True:
|
||||
buffer, buffer_len = [], 0
|
||||
while True:
|
||||
# ending condition.
|
||||
if buffer_len >= self.max_buffer_size:
|
||||
break
|
||||
try:
|
||||
# `Tuple[List[int], List[int]]`
|
||||
seq_input_ids, seq_labels = self.fetch_sequence_func(next(iterator))
|
||||
buffer.append({self.input_ids_field: seq_input_ids, self.labels_field: seq_labels})
|
||||
buffer_len += len(buffer[-1][self.input_ids_field])
|
||||
except StopIteration:
|
||||
if self.infinite is True:
|
||||
iterator = iter(self.dataset)
|
||||
warnings.warn("The dataset reached end and the iterator is reset to the start.")
|
||||
else:
|
||||
more_data_points = False
|
||||
break
|
||||
examples = [] # `List[Dict[str, List[int]]]`, save buffered spliced data points.
|
||||
spliced_input_ids, spliced_labels = [], [] # `List[int]`, `List[int]`
|
||||
for i, data_point in enumerate(buffer):
|
||||
# TODO(2023-09-18) check errors for each unspliced tokenized data point
|
||||
seq_input_ids = data_point[self.input_ids_field]
|
||||
seq_labels = data_point[self.labels_field]
|
||||
# Handle special case:
|
||||
# If the length of an original data point (i.e., input_ids length of a data point before splicing)
|
||||
# exceeds `max_length`, truncate it.
|
||||
if len(seq_input_ids) > self.max_length:
|
||||
truncated_seq_input_ids = seq_input_ids[: self.max_length]
|
||||
truncated_label_ids = seq_labels[: self.max_length]
|
||||
if set(truncated_label_ids) == {IGNORE_INDEX}:
|
||||
if self.error_strict is True:
|
||||
raise ValueError(
|
||||
f"Find an out-of-bounds length({len(seq_input_ids)}) data point "
|
||||
f"with all label values as {IGNORE_INDEX}."
|
||||
)
|
||||
else:
|
||||
warnings.warn(f"Filter an error truncated data point (labels all {IGNORE_INDEX})")
|
||||
continue # Skip the current error data point.
|
||||
spliced_data_point = {
|
||||
self.input_ids_field: truncated_seq_input_ids,
|
||||
self.labels_field: truncated_label_ids,
|
||||
}
|
||||
examples.append(spliced_data_point)
|
||||
warnings.warn("Find a data point to be truncated.")
|
||||
continue
|
||||
|
||||
# Pre action judgment.
|
||||
if len(spliced_input_ids) + len(seq_input_ids) > self.max_length:
|
||||
spliced_data_point = {
|
||||
self.input_ids_field: spliced_input_ids,
|
||||
self.labels_field: spliced_labels,
|
||||
} # `Dict[str, List[int]]`
|
||||
# Update.
|
||||
spliced_input_ids, spliced_labels = [], []
|
||||
spliced_input_ids.extend(seq_input_ids)
|
||||
spliced_labels.extend(seq_labels)
|
||||
examples.append(spliced_data_point)
|
||||
else:
|
||||
spliced_input_ids.extend(seq_input_ids)
|
||||
spliced_labels.extend(seq_labels)
|
||||
# For residual spliced data point at the end of the data set
|
||||
if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
|
||||
examples.append(
|
||||
{
|
||||
self.input_ids_field: spliced_input_ids,
|
||||
self.labels_field: spliced_labels
|
||||
}
|
||||
)
|
||||
if self.shuffle:
|
||||
random.shuffle(examples)
|
||||
for spliced_data_point in examples:
|
||||
# TODO(2023-09-18): check errors for each spliced tokenized data point.
|
||||
self.current_size += 1
|
||||
yield spliced_data_point
|
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Initialize new model with updated tokenizer by calculating the mean values from original model
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import LlamaTokenizer, LlamaForCausalLM
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--source_model_and_tokenizer_path",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="Source path of model & tokenizer",
|
||||
)
|
||||
parser.add_argument("--target_tokenizer_path", type=str, required=True, default=None, help="Target tokenizer path")
|
||||
parser.add_argument("--target_model_path", type=str, required=True, default=None, help="Target model path")
|
||||
args = parser.parse_args()
|
||||
|
||||
source_tokenizer = LlamaTokenizer.from_pretrained(args.source_model_and_tokenizer_path)
|
||||
source_tokenizer.add_bos_token = False
|
||||
source_tokenizer.add_eos_token = False
|
||||
if source_tokenizer.pad_token is None:
|
||||
source_tokenizer.pad_token = source_tokenizer.unk_token
|
||||
source_vocab = source_tokenizer.get_vocab()
|
||||
|
||||
target_tokenizer = LlamaTokenizer.from_pretrained(args.target_tokenizer_path)
|
||||
target_tokenizer.add_bos_token = False
|
||||
target_tokenizer.add_eos_token = False
|
||||
if target_tokenizer.pad_token is None:
|
||||
target_tokenizer.pad_token = target_tokenizer.unk_token
|
||||
target_vocab = target_tokenizer.get_vocab()
|
||||
target_inverted_vocab = {v: k for k, v in target_vocab.items()}
|
||||
|
||||
assert len(target_vocab) > len(
|
||||
source_vocab
|
||||
), f"Target vocab size({len(target_vocab)}) must be greater than source vocab size({len(source_vocab)})"
|
||||
|
||||
gpu_device = torch.device("cuda:0")
|
||||
cpu_device = torch.device("cpu")
|
||||
|
||||
source_model = LlamaForCausalLM.from_pretrained(args.source_model_and_tokenizer_path)
|
||||
source_model.eval()
|
||||
source_model = source_model.to(gpu_device)
|
||||
|
||||
source_input_embeddings = source_model.get_input_embeddings()
|
||||
assert isinstance(source_input_embeddings, torch.nn.Embedding)
|
||||
assert source_input_embeddings.weight.shape[0] == len(source_vocab)
|
||||
source_input_embeddings.eval()
|
||||
|
||||
source_output_embeddings = source_model.get_output_embeddings()
|
||||
assert isinstance(source_output_embeddings, torch.nn.Linear)
|
||||
assert source_output_embeddings.bias is None
|
||||
assert source_output_embeddings.weight.shape[0] == len(source_vocab)
|
||||
source_output_embeddings.eval()
|
||||
|
||||
input_embeddings = source_input_embeddings.weight.cpu().detach().numpy()
|
||||
output_embeddings = source_output_embeddings.weight.cpu().detach().numpy()
|
||||
for i in range(len(source_vocab), len(target_vocab)):
|
||||
if i % 500 == 0:
|
||||
logger.info(f"processing {i}/{len(target_vocab)} target tokens")
|
||||
target_token = target_inverted_vocab[i]
|
||||
target_to_source_token_ids = torch.LongTensor(source_tokenizer([target_token])["input_ids"][0])
|
||||
target_to_source_token_ids = target_to_source_token_ids.to(gpu_device)
|
||||
|
||||
target_to_source_input_embedding = (
|
||||
source_input_embeddings.weight[target_to_source_token_ids]
|
||||
.mean(dim=0)
|
||||
.unsqueeze(dim=0)
|
||||
.cpu()
|
||||
.detach()
|
||||
.numpy()
|
||||
)
|
||||
target_to_source_output_embedding = (
|
||||
source_output_embeddings.weight[target_to_source_token_ids]
|
||||
.mean(dim=0)
|
||||
.unsqueeze(dim=0)
|
||||
.cpu()
|
||||
.detach()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
input_embeddings = np.concatenate((input_embeddings, target_to_source_input_embedding), axis=0)
|
||||
output_embeddings = np.concatenate((output_embeddings, target_to_source_output_embedding), axis=0)
|
||||
|
||||
source_model = source_model.to(cpu_device)
|
||||
assert isinstance(source_model, LlamaForCausalLM)
|
||||
|
||||
# expand
|
||||
source_model.resize_token_embeddings(new_num_tokens=len(target_vocab))
|
||||
source_model.model.embed_tokens.weight.data = torch.Tensor(input_embeddings)
|
||||
source_model.lm_head.weight.data = torch.Tensor(output_embeddings)
|
||||
|
||||
source_model = source_model.half()
|
||||
source_model.save_pretrained(save_directory=args.target_model_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Initialize new tokenizer for continual pre-training
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
from typing import List, Union
|
||||
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def expand_vocab_tokenizer(
|
||||
source_tokenizer_dir: Union[str, os.PathLike], target_tokenizer_dir: Union[str, os.PathLike], new_tokens: List[str]
|
||||
) -> None:
|
||||
"""Expand tokenizer for continue pre-training."""
|
||||
if os.path.exists(target_tokenizer_dir):
|
||||
raise RuntimeError(f"Find existed directory {target_tokenizer_dir}")
|
||||
|
||||
source_tokenizer = LlamaTokenizer.from_pretrained(source_tokenizer_dir)
|
||||
logger.info(source_tokenizer)
|
||||
source_sp_processor = source_tokenizer.sp_model
|
||||
source_spm = sp_pb2_model.ModelProto()
|
||||
source_spm.ParseFromString(source_sp_processor.serialized_model_proto())
|
||||
|
||||
logger.info(f"Source tokenizer size: {len(source_sp_processor)}")
|
||||
|
||||
# Add new tokens to source tokenizer.
|
||||
source_spm_tokens = set([p.piece for p in source_spm.pieces])
|
||||
for piece in new_tokens:
|
||||
assert isinstance(piece, str), f"Invalid token({piece}) type {type(piece)}"
|
||||
if piece in source_spm_tokens:
|
||||
# Skip existed token.
|
||||
continue
|
||||
new_p = sp_pb2_model.ModelProto().SentencePiece()
|
||||
new_p.piece = piece
|
||||
new_p.score = 0
|
||||
source_spm.pieces.append(new_p)
|
||||
logger.info(f"Expand vocab from {len(source_spm_tokens)} to {len(source_spm.pieces)}")
|
||||
|
||||
# Save
|
||||
os.makedirs(target_tokenizer_dir)
|
||||
target_tokenizer_model_path = os.path.join(target_tokenizer_dir, "tokenizer.model")
|
||||
with open(file=target_tokenizer_model_path, mode="wb") as fp:
|
||||
fp.write(source_spm.SerializeToString())
|
||||
|
||||
target_tokenizer = LlamaTokenizer(vocab_file=target_tokenizer_model_path)
|
||||
target_tokenizer.save_pretrained(save_directory=target_tokenizer_dir)
|
||||
logger.info(f"Successfully save expand tokenizer to {target_tokenizer_dir}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--source_tokenizer_dir", type=str, required=True, default=None, help="Source tokenizer directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_tokenizer_dir", type=str, required=True, default=None, help="Target tokenizer directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expand_tokens_file",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="Path of the file containing tokens to be extended",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
expand_tokens = []
|
||||
with open(file=args.expand_tokens_file, mode="r", encoding="utf-8") as fp_reader:
|
||||
for line in fp_reader:
|
||||
item = json.loads(line)
|
||||
# e.g., {"piece": "你好"}
|
||||
token = item["piece"]
|
||||
if token in expand_tokens:
|
||||
continue
|
||||
expand_tokens.append(token)
|
||||
expand_tokens.sort(key=lambda t: len(t), reverse=False)
|
||||
|
||||
expand_vocab_tokenizer(
|
||||
source_tokenizer_dir=args.source_tokenizer_dir,
|
||||
target_tokenizer_dir=args.target_tokenizer_dir,
|
||||
new_tokens=expand_tokens,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Helper functions for IO
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
|
||||
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
|
||||
"""
|
||||
Load file in JSON format
|
||||
"""
|
||||
with open(file=file_path, mode="r", encoding="utf-8") as fp:
|
||||
return json.load(fp)
|
||||
|
||||
|
||||
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
|
||||
"""
|
||||
Save as JSON format
|
||||
"""
|
||||
with open(file=file_path, mode="w", encoding="utf-8") as fp:
|
||||
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
save_dir: Union[str, os.PathLike],
|
||||
booster: Booster,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
epoch: int,
|
||||
step: int,
|
||||
batch_size: int,
|
||||
coordinator: DistCoordinator,
|
||||
) -> None:
|
||||
"""
|
||||
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
||||
"""
|
||||
|
||||
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
|
||||
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
|
||||
|
||||
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
|
||||
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"sample_start_index": step * batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
load_dir: Union[str, os.PathLike],
|
||||
booster: Booster,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
||||
"""
|
||||
|
||||
# Update booster params states.
|
||||
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
|
||||
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
|
||||
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
|
||||
|
||||
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
|
||||
return (
|
||||
running_states["epoch"],
|
||||
running_states["step"],
|
||||
running_states["sample_start_index"],
|
||||
)
|
@@ -0,0 +1,216 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaRMSNorm,
|
||||
LlamaAttention,
|
||||
LlamaModel,
|
||||
LlamaForCausalLM,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_func,
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
)
|
||||
from flash_attn.ops.rms_norm import rms_norm
|
||||
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def _prepare_decoder_attention_mask(
|
||||
self: LlamaModel,
|
||||
attention_mask: torch.BoolTensor,
|
||||
input_shape: torch.Size,
|
||||
inputs_embeds: torch.Tensor,
|
||||
past_key_values_length: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Decoder attetion mask
|
||||
"""
|
||||
if past_key_values_length > 0 and attention_mask is not None:
|
||||
attention_mask = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(input_shape[0], past_key_values_length),
|
||||
fill_value=True,
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
),
|
||||
attention_mask,
|
||||
),
|
||||
dim=-1,
|
||||
) # (bsz, past_key_values_length + q_len)
|
||||
if attention_mask is not None and torch.all(attention_mask):
|
||||
return None # Faster
|
||||
return attention_mask
|
||||
|
||||
|
||||
def attention_forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
|
||||
"""
|
||||
if output_attentions:
|
||||
logger.warning(
|
||||
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
|
||||
"return `None` instead."
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
q_slicing, kv_slicing = (
|
||||
dim // self.config.pretraining_tp
|
||||
for dim in (
|
||||
self.num_heads * self.head_dim,
|
||||
self.num_key_value_heads * self.head_dim,
|
||||
)
|
||||
) # `Tuple[int, int]`
|
||||
q_slices, k_slices, v_slices = (
|
||||
proj.weight.split(slicing, dim=0)
|
||||
for proj, slicing in (
|
||||
(self.q_proj, q_slicing),
|
||||
(self.k_proj, kv_slicing),
|
||||
(self.v_proj, kv_slicing),
|
||||
)
|
||||
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
|
||||
q, k, v = (
|
||||
torch.cat(
|
||||
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
|
||||
dim=-1,
|
||||
)
|
||||
for slices in (q_slices, k_slices, v_slices)
|
||||
)
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
else:
|
||||
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
|
||||
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k, v = (
|
||||
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
|
||||
for states, num_heads in (
|
||||
(q, self.num_heads),
|
||||
(k, self.num_key_value_heads),
|
||||
(v, self.num_key_value_heads),
|
||||
)
|
||||
)
|
||||
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
|
||||
past_kv_len = 0
|
||||
if past_key_value is not None:
|
||||
# if `past_key_value` is not None, `kv_len` > `q_len`.
|
||||
past_kv_len = past_key_value[0].shape[-2]
|
||||
kv_len += past_kv_len
|
||||
|
||||
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
|
||||
cos, sin = self.rotary_emb(v, seq_len=kv_len)
|
||||
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
k = torch.cat([past_key_value[0], k], dim=2)
|
||||
v = torch.cat([past_key_value[1], v], dim=2)
|
||||
|
||||
past_key_value = (k, v) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
|
||||
key_padding_mask = attention_mask
|
||||
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
|
||||
q, k, v = (states.transpose(1, 2) for states in (q, k, v))
|
||||
|
||||
if past_kv_len > 0:
|
||||
q = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(bsz, past_kv_len, self.num_heads, self.head_dim),
|
||||
fill_value=0.0,
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
),
|
||||
q,
|
||||
),
|
||||
dim=1,
|
||||
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
|
||||
if key_padding_mask is None:
|
||||
# (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
|
||||
output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
else:
|
||||
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
|
||||
kv, _, cu_kv_lens, max_kv_len = unpad_input(
|
||||
hidden_states=torch.stack(tensors=(k, v), dim=2),
|
||||
attention_mask=key_padding_mask,
|
||||
)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q=q,
|
||||
kv=kv,
|
||||
cu_seqlens_q=cu_q_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
max_seqlen_q=max_q_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = pad_input(
|
||||
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
|
||||
indices=indices,
|
||||
batch=bsz,
|
||||
seqlen=past_kv_len + q_len,
|
||||
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
|
||||
if past_kv_len > 0:
|
||||
# Strip off the zero query outputs.
|
||||
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
|
||||
output = self.o_proj(output) # (bsz, q_len, hidden_size)
|
||||
return output, None, past_key_value
|
||||
|
||||
|
||||
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Formard function for RMS Norm
|
||||
"""
|
||||
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
|
||||
|
||||
|
||||
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
module.forward = MethodType(attention_forward, module)
|
||||
if isinstance(module, LlamaModel):
|
||||
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
|
||||
if isinstance(module, LlamaRMSNorm):
|
||||
module.forward = MethodType(rms_norm_forward, module)
|
18
applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py
Normal file
18
applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from transformers.models.llama import LlamaForCausalLM
|
||||
|
||||
|
||||
def freeze_non_embeds_parameters(model: LlamaForCausalLM) -> None:
|
||||
"""Freeze all parameters except embeddings."""
|
||||
for name, params in model.named_parameters():
|
||||
if "embed_tokens" not in name and "lm_head" not in name:
|
||||
params.requires_grad = False
|
||||
else:
|
||||
params.requires_grad = True
|
||||
|
||||
|
||||
def unfreeze_parameters(model: LlamaForCausalLM) -> None:
|
||||
for name, params in model.named_parameters():
|
||||
params.requires_grad = False
|
Reference in New Issue
Block a user