mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-12 17:53:03 +00:00
[coati] add chatglm model (#4539)
* update configuration of chatglm and add support in coati * add unit test & update chatglm default config & fix bos index issue * remove chatglm due to oom * add dataset pkg in requirement-text * fix parameter issue in test_models * add ref in tokenize & rm unnessary parts * separate source & target tokenization in chatglm * add unit test to chatglm * fix test dataset issue * update truncation of chatglm * fix Colossalai version * fix colossal ai version in test
This commit is contained in:
@@ -9,13 +9,15 @@ from coati.models.bloom import BLOOMActor
|
||||
from coati.models.gpt import GPTActor
|
||||
from coati.models.llama import LlamaActor
|
||||
from coati.models.opt import OPTActor
|
||||
from coati.models.chatglm import ChatGLMActor
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, AutoModel
|
||||
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
from transformers.trainer import get_scheduler
|
||||
|
||||
@@ -58,6 +60,8 @@ def train(args):
|
||||
model = LlamaActor(pretrained=args.pretrain,
|
||||
lora_rank=args.lora_rank,
|
||||
checkpoint=args.grad_checkpoint)
|
||||
elif args.model == 'chatglm':
|
||||
model = ChatGLMActor(pretrained=args.pretrain)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
@@ -81,6 +85,9 @@ def train(args):
|
||||
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
elif args.model == 'chatglm':
|
||||
tokenizer = ChatGLMTokenizer.from_pretrained(
|
||||
"THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
@@ -99,7 +106,6 @@ def train(args):
|
||||
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
|
||||
else:
|
||||
optim = Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# configure dataset
|
||||
@@ -185,7 +191,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--strategy',
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
|
||||
default='colossalai_zero2')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama', 'chatglm'], default='bloom')
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--dataset', type=str, default=None)
|
||||
|
||||
Reference in New Issue
Block a user