fix: lint

This commit is contained in:
csunny 2023-05-30 19:11:34 +08:00
parent 1bbed02a3c
commit 16c6986666
6 changed files with 34 additions and 26 deletions

View File

@ -85,8 +85,10 @@ class ChatGLMAdapater(BaseLLMAdaper):
class GuanacoAdapter(BaseLLMAdaper): class GuanacoAdapter(BaseLLMAdaper):
"""TODO Support guanaco""" """TODO Support guanaco"""
pass pass
class CodeGenAdapter(BaseLLMAdaper): class CodeGenAdapter(BaseLLMAdaper):
pass pass
@ -124,12 +126,14 @@ class GPT4AllAdapter(BaseLLMAdaper):
return "gpt4all" in model_path return "gpt4all" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict): def loader(self, model_path: str, from_pretrained_kwargs: dict):
# TODO # TODO
pass pass
class ProxyllmAdapter(BaseLLMAdaper): class ProxyllmAdapter(BaseLLMAdaper):
"""The model adapter for local proxy""" """The model adapter for local proxy"""
def match(self, model_path: str): def match(self, model_path: str):
return "proxyllm" in model_path return "proxyllm" in model_path

View File

@ -109,8 +109,10 @@ class ModelLoader(metaclass=Singleton):
compress_module(model, self.device) compress_module(model, self.device)
if ( if (
self.device == "cuda" and num_gpus == 1 and not cpu_offloading (self.device == "cuda" and num_gpus == 1 and not cpu_offloading)
) or self.device == "mps" and tokenizer: or self.device == "mps"
and tokenizer
):
model.to(self.device) model.to(self.device)
if debug: if debug:

View File

@ -8,10 +8,8 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
CFG = Config() CFG = Config()
def proxyllm_generate_stream(
model, tokenizer, params, device, context_len=2048
):
def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048):
history = [] history = []
prompt = params["prompt"] prompt = params["prompt"]
@ -19,9 +17,9 @@ def proxyllm_generate_stream(
headers = { headers = {
"Authorization": "Bearer " + CFG.proxy_api_key, "Authorization": "Bearer " + CFG.proxy_api_key,
"Token": CFG.proxy_api_key "Token": CFG.proxy_api_key,
} }
messages = prompt.split(stop) messages = prompt.split(stop)
# Add history conversation # Add history conversation
@ -30,32 +28,35 @@ def proxyllm_generate_stream(
{"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]}, {"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]},
) )
history.append( history.append(
{"role": "system", "content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1]} {
"role": "system",
"content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
}
) )
# Add user query # Add user query
query = messages[-2].split(ROLE_USER + ":")[1] query = messages[-2].split(ROLE_USER + ":")[1]
history.append( history.append({"role": "user", "content": query})
{"role": "user", "content": query}
)
payloads = { payloads = {
"model": "gpt-3.5-turbo", # just for test, remove this later "model": "gpt-3.5-turbo", # just for test, remove this later
"messages": history, "messages": history,
"temperature": params.get("temperature"), "temperature": params.get("temperature"),
"max_tokens": params.get("max_new_tokens"), "max_tokens": params.get("max_new_tokens"),
} }
print(payloads) print(payloads)
print(headers) print(headers)
res = requests.post(CFG.proxy_server_url, headers=headers, json=payloads, stream=True) res = requests.post(
CFG.proxy_server_url, headers=headers, json=payloads, stream=True
)
text = "" text = ""
print("====================================res================") print("====================================res================")
print(res) print(res)
for line in res.iter_lines(): for line in res.iter_lines():
if line: if line:
decoded_line = line.decode('utf-8') decoded_line = line.decode("utf-8")
json_line = json.loads(decoded_line) json_line = json.loads(decoded_line)
print(json_line) print(json_line)
text += json_line["choices"][0]["message"]["content"] text += json_line["choices"][0]["message"]["content"]
yield text yield text

View File

@ -85,13 +85,13 @@ class CodeGenChatAdapter(BaseChatAdpter):
class GuanacoChatAdapter(BaseChatAdpter): class GuanacoChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco """ """Model chat adapter for Guanaco"""
def match(self, model_path: str): def match(self, model_path: str):
return "guanaco" in model_path return "guanaco" in model_path
def get_generate_stream_func(self): def get_generate_stream_func(self):
# TODO # TODO
pass pass
@ -101,7 +101,8 @@ class ProxyllmChatAdapter(BaseChatAdpter):
def get_generate_stream_func(self): def get_generate_stream_func(self):
from pilot.model.proxy_llm import proxyllm_generate_stream from pilot.model.proxy_llm import proxyllm_generate_stream
return proxyllm_generate_stream
return proxyllm_generate_stream
register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(VicunaChatAdapter)

View File

@ -37,7 +37,7 @@ class ModelWorker:
self.model, self.tokenizer = self.ml.loader( self.model, self.tokenizer = self.ml.loader(
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
) )
if not isinstance(self.model, str): if not isinstance(self.model, str):
if hasattr(self.model.config, "max_sequence_length"): if hasattr(self.model.config, "max_sequence_length"):
self.context_len = self.model.config.max_sequence_length self.context_len = self.model.config.max_sequence_length

View File

@ -434,7 +434,7 @@ def http_bot(
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
""" """
if data["error_code"] == 0: if data["error_code"] == 0:
print("****************:",data) print("****************:", data)
if "vicuna" in CFG.LLM_MODEL: if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip() output = data["text"][skip_echo_len:].strip()
else: else: