diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 8640257b6..0ff368c70 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -85,8 +85,10 @@ class ChatGLMAdapater(BaseLLMAdaper): class GuanacoAdapter(BaseLLMAdaper): """TODO Support guanaco""" + pass + class CodeGenAdapter(BaseLLMAdaper): pass @@ -124,12 +126,14 @@ class GPT4AllAdapter(BaseLLMAdaper): return "gpt4all" in model_path def loader(self, model_path: str, from_pretrained_kwargs: dict): - # TODO + # TODO pass + class ProxyllmAdapter(BaseLLMAdaper): - + """The model adapter for local proxy""" + def match(self, model_path: str): return "proxyllm" in model_path diff --git a/pilot/model/loader.py b/pilot/model/loader.py index f08d487f3..a3d443da8 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -109,8 +109,10 @@ class ModelLoader(metaclass=Singleton): compress_module(model, self.device) if ( - self.device == "cuda" and num_gpus == 1 and not cpu_offloading - ) or self.device == "mps" and tokenizer: + (self.device == "cuda" and num_gpus == 1 and not cpu_offloading) + or self.device == "mps" + and tokenizer + ): model.to(self.device) if debug: diff --git a/pilot/model/proxy_llm.py b/pilot/model/proxy_llm.py index 09c7a82ed..3242603d3 100644 --- a/pilot/model/proxy_llm.py +++ b/pilot/model/proxy_llm.py @@ -8,19 +8,18 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER CFG = Config() -def proxyllm_generate_stream( - model, tokenizer, params, device, context_len=2048 -): +def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048): history = [] prompt = params["prompt"] stop = params.get("stop", "###") headers = { - "Authorization": "Bearer " + CFG.proxy_api_key + "Authorization": "Bearer " + CFG.proxy_api_key, + "Token": CFG.proxy_api_key, } - + messages = prompt.split(stop) # Add history conversation @@ -29,28 +28,35 @@ def proxyllm_generate_stream( {"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]}, ) history.append( - {"role": "system", "content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1]} + { + "role": "system", + "content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1], + } ) - - # Add user query + + # Add user query query = messages[-2].split(ROLE_USER + ":")[1] - history.append( - {"role": "user", "content": query} - ) + history.append({"role": "user", "content": query}) payloads = { - "model": "gpt-3.5-turbo", # just for test, remove this later - "messages": history, + "model": "gpt-3.5-turbo", # just for test, remove this later + "messages": history, "temperature": params.get("temperature"), "max_tokens": params.get("max_new_tokens"), } - res = requests.post(CFG.proxy_server_url, headers=headers, json=payloads, stream=True) + print(payloads) + print(headers) + res = requests.post( + CFG.proxy_server_url, headers=headers, json=payloads, stream=True + ) text = "" + print("====================================res================") + print(res) for line in res.iter_lines(): if line: - decoded_line = line.decode('utf-8') + decoded_line = line.decode("utf-8") json_line = json.loads(decoded_line) print(json_line) text += json_line["choices"][0]["message"]["content"] - yield text \ No newline at end of file + yield text diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 71ebcae85..17d2f95a8 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -85,13 +85,13 @@ class CodeGenChatAdapter(BaseChatAdpter): class GuanacoChatAdapter(BaseChatAdpter): - """Model chat adapter for Guanaco """ - + """Model chat adapter for Guanaco""" + def match(self, model_path: str): return "guanaco" in model_path def get_generate_stream_func(self): - # TODO + # TODO pass @@ -101,7 +101,8 @@ class ProxyllmChatAdapter(BaseChatAdpter): def get_generate_stream_func(self): from pilot.model.proxy_llm import proxyllm_generate_stream - return proxyllm_generate_stream + + return proxyllm_generate_stream register_llm_model_chat_adapter(VicunaChatAdapter) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 18c1de611..376e27852 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -37,7 +37,7 @@ class ModelWorker: self.model, self.tokenizer = self.ml.loader( num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG ) - + if not isinstance(self.model, str): if hasattr(self.model.config, "max_sequence_length"): self.context_len = self.model.config.max_sequence_length