Merge branch 'llm_proxy' into dev

# Conflicts:
#	pilot/server/webserver.py
This commit is contained in:
yhjun1026 2023-05-30 19:32:06 +08:00
commit 973bcce03c
5 changed files with 38 additions and 25 deletions

View File

@ -85,8 +85,10 @@ class ChatGLMAdapater(BaseLLMAdaper):
class GuanacoAdapter(BaseLLMAdaper): class GuanacoAdapter(BaseLLMAdaper):
"""TODO Support guanaco""" """TODO Support guanaco"""
pass pass
class CodeGenAdapter(BaseLLMAdaper): class CodeGenAdapter(BaseLLMAdaper):
pass pass
@ -124,12 +126,14 @@ class GPT4AllAdapter(BaseLLMAdaper):
return "gpt4all" in model_path return "gpt4all" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict): def loader(self, model_path: str, from_pretrained_kwargs: dict):
# TODO # TODO
pass pass
class ProxyllmAdapter(BaseLLMAdaper): class ProxyllmAdapter(BaseLLMAdaper):
"""The model adapter for local proxy""" """The model adapter for local proxy"""
def match(self, model_path: str): def match(self, model_path: str):
return "proxyllm" in model_path return "proxyllm" in model_path

View File

@ -109,8 +109,10 @@ class ModelLoader(metaclass=Singleton):
compress_module(model, self.device) compress_module(model, self.device)
if ( if (
self.device == "cuda" and num_gpus == 1 and not cpu_offloading (self.device == "cuda" and num_gpus == 1 and not cpu_offloading)
) or self.device == "mps" and tokenizer: or self.device == "mps"
and tokenizer
):
model.to(self.device) model.to(self.device)
if debug: if debug:

View File

@ -8,19 +8,18 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
CFG = Config() CFG = Config()
def proxyllm_generate_stream(
model, tokenizer, params, device, context_len=2048
):
def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048):
history = [] history = []
prompt = params["prompt"] prompt = params["prompt"]
stop = params.get("stop", "###") stop = params.get("stop", "###")
headers = { headers = {
"Authorization": "Bearer " + CFG.proxy_api_key "Authorization": "Bearer " + CFG.proxy_api_key,
"Token": CFG.proxy_api_key,
} }
messages = prompt.split(stop) messages = prompt.split(stop)
# Add history conversation # Add history conversation
@ -29,28 +28,35 @@ def proxyllm_generate_stream(
{"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]}, {"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]},
) )
history.append( history.append(
{"role": "system", "content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1]} {
"role": "system",
"content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
}
) )
# Add user query # Add user query
query = messages[-2].split(ROLE_USER + ":")[1] query = messages[-2].split(ROLE_USER + ":")[1]
history.append( history.append({"role": "user", "content": query})
{"role": "user", "content": query}
)
payloads = { payloads = {
"model": "gpt-3.5-turbo", # just for test, remove this later "model": "gpt-3.5-turbo", # just for test, remove this later
"messages": history, "messages": history,
"temperature": params.get("temperature"), "temperature": params.get("temperature"),
"max_tokens": params.get("max_new_tokens"), "max_tokens": params.get("max_new_tokens"),
} }
res = requests.post(CFG.proxy_server_url, headers=headers, json=payloads, stream=True) print(payloads)
print(headers)
res = requests.post(
CFG.proxy_server_url, headers=headers, json=payloads, stream=True
)
text = "" text = ""
print("====================================res================")
print(res)
for line in res.iter_lines(): for line in res.iter_lines():
if line: if line:
decoded_line = line.decode('utf-8') decoded_line = line.decode("utf-8")
json_line = json.loads(decoded_line) json_line = json.loads(decoded_line)
print(json_line) print(json_line)
text += json_line["choices"][0]["message"]["content"] text += json_line["choices"][0]["message"]["content"]
yield text yield text

View File

@ -85,13 +85,13 @@ class CodeGenChatAdapter(BaseChatAdpter):
class GuanacoChatAdapter(BaseChatAdpter): class GuanacoChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco """ """Model chat adapter for Guanaco"""
def match(self, model_path: str): def match(self, model_path: str):
return "guanaco" in model_path return "guanaco" in model_path
def get_generate_stream_func(self): def get_generate_stream_func(self):
# TODO # TODO
pass pass
@ -101,7 +101,8 @@ class ProxyllmChatAdapter(BaseChatAdpter):
def get_generate_stream_func(self): def get_generate_stream_func(self):
from pilot.model.proxy_llm import proxyllm_generate_stream from pilot.model.proxy_llm import proxyllm_generate_stream
return proxyllm_generate_stream
return proxyllm_generate_stream
register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(VicunaChatAdapter)

View File

@ -37,7 +37,7 @@ class ModelWorker:
self.model, self.tokenizer = self.ml.loader( self.model, self.tokenizer = self.ml.loader(
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
) )
if not isinstance(self.model, str): if not isinstance(self.model, str):
if hasattr(self.model.config, "max_sequence_length"): if hasattr(self.model.config, "max_sequence_length"):
self.context_len = self.model.config.max_sequence_length self.context_len = self.model.config.max_sequence_length