fix: lint

This commit is contained in:
csunny 2023-05-30 19:11:34 +08:00
parent 1bbed02a3c
commit 16c6986666
6 changed files with 34 additions and 26 deletions

View File

@ -85,8 +85,10 @@ class ChatGLMAdapater(BaseLLMAdaper):
class GuanacoAdapter(BaseLLMAdaper):
"""TODO Support guanaco"""
pass
class CodeGenAdapter(BaseLLMAdaper):
pass
@ -127,9 +129,11 @@ class GPT4AllAdapter(BaseLLMAdaper):
# TODO
pass
class ProxyllmAdapter(BaseLLMAdaper):
"""The model adapter for local proxy"""
def match(self, model_path: str):
return "proxyllm" in model_path

View File

@ -109,8 +109,10 @@ class ModelLoader(metaclass=Singleton):
compress_module(model, self.device)
if (
self.device == "cuda" and num_gpus == 1 and not cpu_offloading
) or self.device == "mps" and tokenizer:
(self.device == "cuda" and num_gpus == 1 and not cpu_offloading)
or self.device == "mps"
and tokenizer
):
model.to(self.device)
if debug:

View File

@ -8,10 +8,8 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
CFG = Config()
def proxyllm_generate_stream(
model, tokenizer, params, device, context_len=2048
):
def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048):
history = []
prompt = params["prompt"]
@ -19,7 +17,7 @@ def proxyllm_generate_stream(
headers = {
"Authorization": "Bearer " + CFG.proxy_api_key,
"Token": CFG.proxy_api_key
"Token": CFG.proxy_api_key,
}
messages = prompt.split(stop)
@ -30,16 +28,17 @@ def proxyllm_generate_stream(
{"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]},
)
history.append(
{"role": "system", "content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1]}
{
"role": "system",
"content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
}
)
# Add user query
query = messages[-2].split(ROLE_USER + ":")[1]
history.append(
{"role": "user", "content": query}
)
history.append({"role": "user", "content": query})
payloads = {
"model": "gpt-3.5-turbo", # just for test, remove this later
"model": "gpt-3.5-turbo", # just for test, remove this later
"messages": history,
"temperature": params.get("temperature"),
"max_tokens": params.get("max_new_tokens"),
@ -47,14 +46,16 @@ def proxyllm_generate_stream(
print(payloads)
print(headers)
res = requests.post(CFG.proxy_server_url, headers=headers, json=payloads, stream=True)
res = requests.post(
CFG.proxy_server_url, headers=headers, json=payloads, stream=True
)
text = ""
print("====================================res================")
print(res)
for line in res.iter_lines():
if line:
decoded_line = line.decode('utf-8')
decoded_line = line.decode("utf-8")
json_line = json.loads(decoded_line)
print(json_line)
text += json_line["choices"][0]["message"]["content"]

View File

@ -85,7 +85,7 @@ class CodeGenChatAdapter(BaseChatAdpter):
class GuanacoChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco """
"""Model chat adapter for Guanaco"""
def match(self, model_path: str):
return "guanaco" in model_path
@ -101,6 +101,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
def get_generate_stream_func(self):
from pilot.model.proxy_llm import proxyllm_generate_stream
return proxyllm_generate_stream

View File

@ -434,7 +434,7 @@ def http_bot(
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
print("****************:",data)
print("****************:", data)
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else: