Merge remote-tracking branch 'origin/feature-xuyuan-openai-proxy' into ty_test

# Conflicts:
#	pilot/model/llm_out/proxy_llm.py
This commit is contained in:
yhjun1026
2023-06-13 15:17:06 +08:00
5 changed files with 52 additions and 24 deletions

View File

@@ -17,8 +17,9 @@ class Config(metaclass=Singleton):
def __init__(self) -> None:
"""Initialize the Config class"""
# Gradio language version: en, cn
# Gradio language version: en, zh
self.LANGUAGE = os.getenv("LANGUAGE", "en")
self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860))
self.debug_mode = False
self.skip_reprompt = False

View File

@@ -62,10 +62,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
history.append(last_user_input)
payloads = {
"model": "gpt-3.5-turbo", # just for test_py, remove this later
"model": "gpt-3.5-turbo", # just for test, remove this later
"messages": history,
"temperature": params.get("temperature"),
"max_tokens": params.get("max_new_tokens"),
"stream": True
}
res = requests.post(
@@ -75,14 +76,32 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
text = ""
for line in res.iter_lines():
if line:
decoded_line = line.decode("utf-8")
try:
json_line = json.loads(decoded_line)
print(json_line)
text += json_line["choices"][0]["message"]["content"]
yield text
except Exception as e:
text += decoded_line
yield json.loads(text)["choices"][0]["message"]["content"]
json_data = line.split(b': ', 1)[1]
decoded_line = json_data.decode("utf-8")
if decoded_line.lower() != '[DONE]'.lower():
obj = json.loads(json_data)
if obj['choices'][0]['delta'].get('content') is not None:
content = obj['choices'][0]['delta']['content']
text += content
yield text
# native result.
# payloads = {
# "model": "gpt-3.5-turbo", # just for test, remove this later
# "messages": history,
# "temperature": params.get("temperature"),
# "max_tokens": params.get("max_new_tokens"),
# }
#
# res = requests.post(
# CFG.proxy_server_url, headers=headers, json=payloads, stream=True
# )
#
# text = ""
# line = res.content
# if line:
# decoded_line = line.decode("utf-8")
# json_line = json.loads(decoded_line)
# print(json_line)
# text += json_line["choices"][0]["message"]["content"]
# yield text

View File

@@ -84,6 +84,11 @@ class ModelWorker:
return get_embeddings(self.model, self.tokenizer, prompt)
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
worker = ModelWorker(
model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE, num_gpus=1
)
app = FastAPI()
@@ -157,11 +162,4 @@ def embeddings(prompt_request: EmbeddingRequest):
if __name__ == "__main__":
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
print(model_path, DEVICE)
worker = ModelWorker(
model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE, num_gpus=1
)
uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info")

View File

@@ -658,7 +658,7 @@ def signal_handler(sig, frame):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument(
"--model-list-mode", type=str, default="once", choices=["once", "reload"]