chore: bump requirement.txt for guanaco support

This commit is contained in:
LBYPatrick 2023-06-16 11:45:42 +08:00
parent 3ccd939fe3
commit 1580aaa8eb
4 changed files with 14 additions and 6 deletions

View File

@ -43,6 +43,9 @@ LLM_MODEL_CONFIG = {
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"), "falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
# TODO Support baichuan-7b
#"baichuan-7b" : os.path.join(MODEL_PATH, "baichuan-7b"),
"gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"), "gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"),
"proxyllm": "proxyllm", "proxyllm": "proxyllm",
} }

View File

@ -32,9 +32,9 @@ class BaseLLMAdaper:
return True return True
def loader(self, model_path: str, from_pretrained_kwargs: dict): def loader(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False,trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs model_path, low_cpu_mem_usage=True, trust_remote_code=True, **from_pretrained_kwargs
) )
return model, tokenizer return model, tokenizer
@ -57,7 +57,7 @@ def get_llm_model_adapter(model_path: str) -> BaseLLMAdaper:
raise ValueError(f"Invalid model adapter for {model_path}") raise ValueError(f"Invalid model adapter for {model_path}")
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4? # TODO support cpu? for practice we support gpt4all or chatglm-6b-int4?
class VicunaLLMAdapater(BaseLLMAdaper): class VicunaLLMAdapater(BaseLLMAdaper):

View File

@ -5,6 +5,7 @@ import asyncio
import json import json
import os import os
import sys import sys
import traceback
import uvicorn import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import BackgroundTasks, FastAPI, Request
@ -89,10 +90,14 @@ class ModelWorker:
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0} ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
yield json.dumps(ret).encode() + b"\0" yield json.dumps(ret).encode() + b"\0"
except Exception as e: except Exception as e:
msg = "{}: {}".format(str(e),traceback.format_exc())
ret = { ret = {
"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", "text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {msg}",
"error_code": 0, "error_code": 0,
} }
yield json.dumps(ret).encode() + b"\0" yield json.dumps(ret).encode() + b"\0"
def get_embeddings(self, prompt): def get_embeddings(self, prompt):

View File

@ -4,7 +4,7 @@ aiohttp==3.8.4
aiosignal==1.3.1 aiosignal==1.3.1
async-timeout==4.0.2 async-timeout==4.0.2
attrs==22.2.0 attrs==22.2.0
bitsandbytes==0.37.0 bitsandbytes==0.39.0
cchardet==2.1.7 cchardet==2.1.7
chardet==5.1.0 chardet==5.1.0
contourpy==1.0.7 contourpy==1.0.7
@ -27,7 +27,7 @@ python-dateutil==2.8.2
pyyaml==6.0 pyyaml==6.0
tokenizers==0.13.2 tokenizers==0.13.2
tqdm==4.64.1 tqdm==4.64.1
transformers==4.28.0 transformers==4.30.0
timm==0.6.13 timm==0.6.13
spacy==3.5.3 spacy==3.5.3
webdataset==0.2.48 webdataset==0.2.48