llms: add cpu support

This commit is contained in:
csunny 2023-05-21 16:05:53 +08:00
parent f52c7523b5
commit 89970bd71c
3 changed files with 16 additions and 5 deletions

View File

@ -9,6 +9,8 @@ from transformers import (
AutoModel
)
from pilot.configs.model_config import DEVICE
class BaseLLMAdaper:
"""The Base class for multi model, in our project.
We will support those model, which performance resemble ChatGPT """
@ -61,13 +63,20 @@ class ChatGLMAdapater(BaseLLMAdaper):
"""LLM Adatpter for THUDM/chatglm-6b"""
def match(self, model_path: str):
return "chatglm" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(
model_path, trust_remote_code=True, **from_pretrained_kwargs
).half().cuda()
return model, tokenizer
if DEVICE != "cuda":
model = AutoModel.from_pretrained(
model_path, trust_remote_code=True, **from_pretrained_kwargs
).float()
return model, tokenizer
else:
model = AutoModel.from_pretrained(
model_path, trust_remote_code=True, **from_pretrained_kwargs
).half().cuda()
return model, tokenizer
class CodeGenAdapter(BaseLLMAdaper):
pass

View File

@ -155,6 +155,7 @@ if __name__ == "__main__":
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
print(model_path, DEVICE)
worker = ModelWorker(
model_path=model_path,
model_name=CFG.LLM_MODEL,

View File

@ -42,6 +42,7 @@ tenacity==8.2.2
peft
pycocoevalcap
sentence-transformers
cpm_kernels
umap-learn
notebook
gradio==3.23