mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
fix wrong EOS token in ColossalChat
This commit is contained in:
committed by
Zian(Andy) Zheng
parent
70885d707d
commit
43ad0d9ef0
@@ -118,7 +118,7 @@ def main(args):
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == "llama":
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
|
||||
tokenizer.eos_token = "<\s>"
|
||||
tokenizer.eos_token = "</s>"
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
@@ -68,7 +68,7 @@ def train(args):
|
||||
padding_side="right",
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer.eos_token = "<\s>"
|
||||
tokenizer.eos_token = "</s>"
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
@@ -39,7 +39,7 @@ def eval(args):
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == "llama":
|
||||
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
tokenizer.eos_token = "<\s>"
|
||||
tokenizer.eos_token = "</s>"
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
@@ -125,7 +125,7 @@ def main(args):
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
|
||||
)
|
||||
tokenizer.eos_token = "<\s>"
|
||||
tokenizer.eos_token = "</s>"
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
@@ -72,7 +72,7 @@ def train(args):
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
|
||||
)
|
||||
tokenizer.eos_token = "<\s>"
|
||||
tokenizer.eos_token = "</s>"
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
@@ -75,7 +75,7 @@ def train(args):
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
|
||||
)
|
||||
tokenizer.eos_token = "<\s>"
|
||||
tokenizer.eos_token = "</s>"
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
elif args.model == "chatglm":
|
||||
tokenizer = ChatGLMTokenizer.from_pretrained(
|
||||
|
Reference in New Issue
Block a user