From 913c920ecc61285b9c47a1539bcb2cb6fabf94a0 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Wed, 15 May 2024 10:52:11 +0800 Subject: [PATCH] [Colossal-LLaMA] Fix sft issue for llama2 (#5719) * fix minor issue * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- applications/Colossal-LLaMA/prepare_sft_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/applications/Colossal-LLaMA/prepare_sft_dataset.py b/applications/Colossal-LLaMA/prepare_sft_dataset.py index be5f9bcca..a857d6c0c 100644 --- a/applications/Colossal-LLaMA/prepare_sft_dataset.py +++ b/applications/Colossal-LLaMA/prepare_sft_dataset.py @@ -10,7 +10,7 @@ import math import os from multiprocessing import cpu_count -from colossal_llama.dataset.conversation import default_conversation +from colossal_llama.dataset.conversation import LLaMA2_Conv from colossal_llama.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft from datasets import dataset_dict, load_dataset from transformers import AddedToken, AutoTokenizer @@ -78,6 +78,7 @@ def main(): # Fix split issue: https://github.com/huggingface/transformers/issues/23833 if args.llama_version == 2: tokenizer.add_tokens(AddedToken("", normalized=False, special=True), special_tokens=True) + default_conversation = LLaMA2_Conv tokenizer.add_bos_token = False tokenizer.add_eos_token = False