mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[Feature] Support LLaMA-3 CPT and ST (#5619)
* support LLaMA-3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Run pre-commit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Initialize new tokenizer for continual pre-training
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def expand_vocab_tokenizer(
|
||||
source_tokenizer_dir: Union[str, os.PathLike], target_tokenizer_dir: Union[str, os.PathLike], new_tokens: List[str]
|
||||
) -> None:
|
||||
"""Expand tokenizer for continue pre-training."""
|
||||
if os.path.exists(target_tokenizer_dir):
|
||||
raise RuntimeError(f"Find existed directory {target_tokenizer_dir}")
|
||||
|
||||
source_tokenizer = LlamaTokenizer.from_pretrained(source_tokenizer_dir)
|
||||
logger.info(source_tokenizer)
|
||||
source_sp_processor = source_tokenizer.sp_model
|
||||
source_spm = sp_pb2_model.ModelProto()
|
||||
source_spm.ParseFromString(source_sp_processor.serialized_model_proto())
|
||||
|
||||
logger.info(f"Source tokenizer size: {len(source_sp_processor)}")
|
||||
|
||||
# Add new tokens to source tokenizer.
|
||||
source_spm_tokens = set([p.piece for p in source_spm.pieces])
|
||||
for piece in new_tokens:
|
||||
assert isinstance(piece, str), f"Invalid token({piece}) type {type(piece)}"
|
||||
if piece in source_spm_tokens:
|
||||
# Skip existed token.
|
||||
continue
|
||||
new_p = sp_pb2_model.ModelProto().SentencePiece()
|
||||
new_p.piece = piece
|
||||
new_p.score = 0
|
||||
source_spm.pieces.append(new_p)
|
||||
logger.info(f"Expand vocab from {len(source_spm_tokens)} to {len(source_spm.pieces)}")
|
||||
|
||||
# Save
|
||||
os.makedirs(target_tokenizer_dir)
|
||||
target_tokenizer_model_path = os.path.join(target_tokenizer_dir, "tokenizer.model")
|
||||
with open(file=target_tokenizer_model_path, mode="wb") as fp:
|
||||
fp.write(source_spm.SerializeToString())
|
||||
|
||||
target_tokenizer = LlamaTokenizer(vocab_file=target_tokenizer_model_path)
|
||||
target_tokenizer.save_pretrained(save_directory=target_tokenizer_dir)
|
||||
logger.info(f"Successfully save expand tokenizer to {target_tokenizer_dir}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--source_tokenizer_dir", type=str, required=True, default=None, help="Source tokenizer directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_tokenizer_dir", type=str, required=True, default=None, help="Target tokenizer directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expand_tokens_file",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="Path of the file containing tokens to be extended",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
expand_tokens = []
|
||||
with open(file=args.expand_tokens_file, mode="r", encoding="utf-8") as fp_reader:
|
||||
for line in fp_reader:
|
||||
item = json.loads(line)
|
||||
# e.g., {"piece": "你好"}
|
||||
token = item["piece"]
|
||||
if token in expand_tokens:
|
||||
continue
|
||||
expand_tokens.append(token)
|
||||
expand_tokens.sort(key=lambda t: len(t), reverse=False)
|
||||
|
||||
expand_vocab_tokenizer(
|
||||
source_tokenizer_dir=args.source_tokenizer_dir,
|
||||
target_tokenizer_dir=args.target_tokenizer_dir,
|
||||
new_tokens=expand_tokens,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user