From 0bb0246fb1fc8800f6bccf013351400ea4d3f768 Mon Sep 17 00:00:00 2001 From: LBYPatrick Date: Fri, 16 Jun 2023 12:23:28 +0800 Subject: [PATCH] chore: run black against modified code --- pilot/configs/model_config.py | 3 +-- pilot/model/adapter.py | 9 +++++++-- pilot/server/llmserver.py | 3 +-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index e82a459a3..f6d59e4e1 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -43,9 +43,8 @@ LLM_MODEL_CONFIG = { "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"), "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), - # TODO Support baichuan-7b - #"baichuan-7b" : os.path.join(MODEL_PATH, "baichuan-7b"), + # "baichuan-7b" : os.path.join(MODEL_PATH, "baichuan-7b"), "gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"), "proxyllm": "proxyllm", } diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 2b93c5c9c..9ea80fb7a 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -32,9 +32,14 @@ class BaseLLMAdaper: return True def loader(self, model_path: str, from_pretrained_kwargs: dict): - tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False,trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=False, trust_remote_code=True + ) model = AutoModelForCausalLM.from_pretrained( - model_path, low_cpu_mem_usage=True, trust_remote_code=True, **from_pretrained_kwargs + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, ) return model, tokenizer diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 68a3545a3..51339f322 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -90,8 +90,7 @@ class ModelWorker: ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0} yield json.dumps(ret).encode() + b"\0" except Exception as e: - - msg = "{}: {}".format(str(e),traceback.format_exc()) + msg = "{}: {}".format(str(e), traceback.format_exc()) ret = { "text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {msg}",