From 1bbed02a3c641f0653faf056a35ec04cb4bc84f8 Mon Sep 17 00:00:00 2001 From: csunny Date: Tue, 30 May 2023 18:55:53 +0800 Subject: [PATCH 1/2] fix: token --- pilot/model/proxy_llm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pilot/model/proxy_llm.py b/pilot/model/proxy_llm.py index 09c7a82ed..78425bd83 100644 --- a/pilot/model/proxy_llm.py +++ b/pilot/model/proxy_llm.py @@ -18,7 +18,8 @@ def proxyllm_generate_stream( stop = params.get("stop", "###") headers = { - "Authorization": "Bearer " + CFG.proxy_api_key + "Authorization": "Bearer " + CFG.proxy_api_key, + "Token": CFG.proxy_api_key } messages = prompt.split(stop) @@ -44,9 +45,13 @@ def proxyllm_generate_stream( "max_tokens": params.get("max_new_tokens"), } + print(payloads) + print(headers) res = requests.post(CFG.proxy_server_url, headers=headers, json=payloads, stream=True) text = "" + print("====================================res================") + print(res) for line in res.iter_lines(): if line: decoded_line = line.decode('utf-8') From 16c6986666d7141ee1707d59a9a0ddafcf2e5bd3 Mon Sep 17 00:00:00 2001 From: csunny Date: Tue, 30 May 2023 19:11:34 +0800 Subject: [PATCH 2/2] fix: lint --- pilot/model/adapter.py | 8 ++++++-- pilot/model/loader.py | 6 ++++-- pilot/model/proxy_llm.py | 33 +++++++++++++++++---------------- pilot/server/chat_adapter.py | 9 +++++---- pilot/server/llmserver.py | 2 +- pilot/server/webserver.py | 2 +- 6 files changed, 34 insertions(+), 26 deletions(-) diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 8640257b6..0ff368c70 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -85,8 +85,10 @@ class ChatGLMAdapater(BaseLLMAdaper): class GuanacoAdapter(BaseLLMAdaper): """TODO Support guanaco""" + pass + class CodeGenAdapter(BaseLLMAdaper): pass @@ -124,12 +126,14 @@ class GPT4AllAdapter(BaseLLMAdaper): return "gpt4all" in model_path def loader(self, model_path: str, from_pretrained_kwargs: dict): - # TODO + # TODO pass + class ProxyllmAdapter(BaseLLMAdaper): - + """The model adapter for local proxy""" + def match(self, model_path: str): return "proxyllm" in model_path diff --git a/pilot/model/loader.py b/pilot/model/loader.py index f08d487f3..a3d443da8 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -109,8 +109,10 @@ class ModelLoader(metaclass=Singleton): compress_module(model, self.device) if ( - self.device == "cuda" and num_gpus == 1 and not cpu_offloading - ) or self.device == "mps" and tokenizer: + (self.device == "cuda" and num_gpus == 1 and not cpu_offloading) + or self.device == "mps" + and tokenizer + ): model.to(self.device) if debug: diff --git a/pilot/model/proxy_llm.py b/pilot/model/proxy_llm.py index 78425bd83..3242603d3 100644 --- a/pilot/model/proxy_llm.py +++ b/pilot/model/proxy_llm.py @@ -8,10 +8,8 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER CFG = Config() -def proxyllm_generate_stream( - model, tokenizer, params, device, context_len=2048 -): +def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048): history = [] prompt = params["prompt"] @@ -19,9 +17,9 @@ def proxyllm_generate_stream( headers = { "Authorization": "Bearer " + CFG.proxy_api_key, - "Token": CFG.proxy_api_key + "Token": CFG.proxy_api_key, } - + messages = prompt.split(stop) # Add history conversation @@ -30,32 +28,35 @@ def proxyllm_generate_stream( {"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]}, ) history.append( - {"role": "system", "content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1]} + { + "role": "system", + "content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1], + } ) - - # Add user query + + # Add user query query = messages[-2].split(ROLE_USER + ":")[1] - history.append( - {"role": "user", "content": query} - ) + history.append({"role": "user", "content": query}) payloads = { - "model": "gpt-3.5-turbo", # just for test, remove this later - "messages": history, + "model": "gpt-3.5-turbo", # just for test, remove this later + "messages": history, "temperature": params.get("temperature"), "max_tokens": params.get("max_new_tokens"), } print(payloads) print(headers) - res = requests.post(CFG.proxy_server_url, headers=headers, json=payloads, stream=True) + res = requests.post( + CFG.proxy_server_url, headers=headers, json=payloads, stream=True + ) text = "" print("====================================res================") print(res) for line in res.iter_lines(): if line: - decoded_line = line.decode('utf-8') + decoded_line = line.decode("utf-8") json_line = json.loads(decoded_line) print(json_line) text += json_line["choices"][0]["message"]["content"] - yield text \ No newline at end of file + yield text diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 71ebcae85..17d2f95a8 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -85,13 +85,13 @@ class CodeGenChatAdapter(BaseChatAdpter): class GuanacoChatAdapter(BaseChatAdpter): - """Model chat adapter for Guanaco """ - + """Model chat adapter for Guanaco""" + def match(self, model_path: str): return "guanaco" in model_path def get_generate_stream_func(self): - # TODO + # TODO pass @@ -101,7 +101,8 @@ class ProxyllmChatAdapter(BaseChatAdpter): def get_generate_stream_func(self): from pilot.model.proxy_llm import proxyllm_generate_stream - return proxyllm_generate_stream + + return proxyllm_generate_stream register_llm_model_chat_adapter(VicunaChatAdapter) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 18c1de611..376e27852 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -37,7 +37,7 @@ class ModelWorker: self.model, self.tokenizer = self.ml.loader( num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG ) - + if not isinstance(self.model, str): if hasattr(self.model.config, "max_sequence_length"): self.context_len = self.model.config.max_sequence_length diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 1d0425590..6b4013647 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -434,7 +434,7 @@ def http_bot( """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ if data["error_code"] == 0: - print("****************:",data) + print("****************:", data) if "vicuna" in CFG.LLM_MODEL: output = data["text"][skip_echo_len:].strip() else: