fix: lint

This commit is contained in:
csunny 2023-05-30 19:11:34 +08:00
parent 1bbed02a3c
commit 16c6986666
6 changed files with 34 additions and 26 deletions

View File

@ -85,8 +85,10 @@ class ChatGLMAdapater(BaseLLMAdaper):
class GuanacoAdapter(BaseLLMAdaper): class GuanacoAdapter(BaseLLMAdaper):
"""TODO Support guanaco""" """TODO Support guanaco"""
pass pass
class CodeGenAdapter(BaseLLMAdaper): class CodeGenAdapter(BaseLLMAdaper):
pass pass
@ -127,9 +129,11 @@ class GPT4AllAdapter(BaseLLMAdaper):
# TODO # TODO
pass pass
class ProxyllmAdapter(BaseLLMAdaper): class ProxyllmAdapter(BaseLLMAdaper):
"""The model adapter for local proxy""" """The model adapter for local proxy"""
def match(self, model_path: str): def match(self, model_path: str):
return "proxyllm" in model_path return "proxyllm" in model_path

View File

@ -109,8 +109,10 @@ class ModelLoader(metaclass=Singleton):
compress_module(model, self.device) compress_module(model, self.device)
if ( if (
self.device == "cuda" and num_gpus == 1 and not cpu_offloading (self.device == "cuda" and num_gpus == 1 and not cpu_offloading)
) or self.device == "mps" and tokenizer: or self.device == "mps"
and tokenizer
):
model.to(self.device) model.to(self.device)
if debug: if debug:

View File

@ -8,10 +8,8 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
CFG = Config() CFG = Config()
def proxyllm_generate_stream(
model, tokenizer, params, device, context_len=2048
):
def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048):
history = [] history = []
prompt = params["prompt"] prompt = params["prompt"]
@ -19,7 +17,7 @@ def proxyllm_generate_stream(
headers = { headers = {
"Authorization": "Bearer " + CFG.proxy_api_key, "Authorization": "Bearer " + CFG.proxy_api_key,
"Token": CFG.proxy_api_key "Token": CFG.proxy_api_key,
} }
messages = prompt.split(stop) messages = prompt.split(stop)
@ -30,14 +28,15 @@ def proxyllm_generate_stream(
{"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]}, {"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]},
) )
history.append( history.append(
{"role": "system", "content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1]} {
"role": "system",
"content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
}
) )
# Add user query # Add user query
query = messages[-2].split(ROLE_USER + ":")[1] query = messages[-2].split(ROLE_USER + ":")[1]
history.append( history.append({"role": "user", "content": query})
{"role": "user", "content": query}
)
payloads = { payloads = {
"model": "gpt-3.5-turbo", # just for test, remove this later "model": "gpt-3.5-turbo", # just for test, remove this later
"messages": history, "messages": history,
@ -47,14 +46,16 @@ def proxyllm_generate_stream(
print(payloads) print(payloads)
print(headers) print(headers)
res = requests.post(CFG.proxy_server_url, headers=headers, json=payloads, stream=True) res = requests.post(
CFG.proxy_server_url, headers=headers, json=payloads, stream=True
)
text = "" text = ""
print("====================================res================") print("====================================res================")
print(res) print(res)
for line in res.iter_lines(): for line in res.iter_lines():
if line: if line:
decoded_line = line.decode('utf-8') decoded_line = line.decode("utf-8")
json_line = json.loads(decoded_line) json_line = json.loads(decoded_line)
print(json_line) print(json_line)
text += json_line["choices"][0]["message"]["content"] text += json_line["choices"][0]["message"]["content"]

View File

@ -85,7 +85,7 @@ class CodeGenChatAdapter(BaseChatAdpter):
class GuanacoChatAdapter(BaseChatAdpter): class GuanacoChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco """ """Model chat adapter for Guanaco"""
def match(self, model_path: str): def match(self, model_path: str):
return "guanaco" in model_path return "guanaco" in model_path
@ -101,6 +101,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
def get_generate_stream_func(self): def get_generate_stream_func(self):
from pilot.model.proxy_llm import proxyllm_generate_stream from pilot.model.proxy_llm import proxyllm_generate_stream
return proxyllm_generate_stream return proxyllm_generate_stream

View File

@ -434,7 +434,7 @@ def http_bot(
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
""" """
if data["error_code"] == 0: if data["error_code"] == 0:
print("****************:",data) print("****************:", data)
if "vicuna" in CFG.LLM_MODEL: if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip() output = data["text"][skip_echo_len:].strip()
else: else: