diff --git a/.env.template b/.env.template index 2fb5ff649..88d780ef8 100644 --- a/.env.template +++ b/.env.template @@ -7,6 +7,10 @@ ## For example, to disable coding related features, uncomment the next line # DISABLED_COMMAND_CATEGORIES= +#*******************************************************************# +#** Webserver Port **# +#*******************************************************************# +WEB_SERVER_PORT=7860 #*******************************************************************# #*** LLM PROVIDER ***# @@ -17,6 +21,7 @@ #*******************************************************************# #** LLM MODELS **# #*******************************************************************# +# LLM_MODEL, see /pilot/configs/model_config.LLM_MODEL_CONFIG LLM_MODEL=vicuna-13b MODEL_SERVER=http://127.0.0.1:8000 LIMIT_MODEL_CONCURRENCY=5 @@ -98,15 +103,20 @@ VECTOR_STORE_TYPE=Chroma #MILVUS_SECURE= +#*******************************************************************# +#** WebServer Language Support **# +#*******************************************************************# LANGUAGE=en #LANGUAGE=zh #*******************************************************************# -# ** PROXY_SERVER +# ** PROXY_SERVER (openai interface | chatGPT proxy service), use chatGPT as your LLM. +# ** if your server can visit openai, please set PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions +# ** else if you have a chatgpt proxy server, you can set PROXY_SERVER_URL={your-proxy-serverip:port/xxx} #*******************************************************************# -PROXY_API_KEY= -PROXY_SERVER_URL=http://127.0.0.1:3000/proxy_address +PROXY_API_KEY={your-openai-sk} +PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions #*******************************************************************# diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 334dc8459..450cb6901 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -17,8 +17,9 @@ class Config(metaclass=Singleton): def __init__(self) -> None: """Initialize the Config class""" - # Gradio language version: en, cn + # Gradio language version: en, zh self.LANGUAGE = os.getenv("LANGUAGE", "en") + self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860)) self.debug_mode = False self.skip_reprompt = False diff --git a/pilot/model/llm_out/proxy_llm.py b/pilot/model/llm_out/proxy_llm.py index ef40d45dc..6dd1bfc2b 100644 --- a/pilot/model/llm_out/proxy_llm.py +++ b/pilot/model/llm_out/proxy_llm.py @@ -62,10 +62,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) history.append(last_user_input) payloads = { - "model": "gpt-3.5-turbo", # just for test_py, remove this later + "model": "gpt-3.5-turbo", # just for test, remove this later "messages": history, "temperature": params.get("temperature"), "max_tokens": params.get("max_new_tokens"), + "stream": True } res = requests.post( @@ -75,14 +76,32 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) text = "" for line in res.iter_lines(): if line: - decoded_line = line.decode("utf-8") - try: - json_line = json.loads(decoded_line) - print(json_line) - text += json_line["choices"][0]["message"]["content"] - yield text - except Exception as e: - text += decoded_line - yield json.loads(text)["choices"][0]["message"]["content"] - - + json_data = line.split(b': ', 1)[1] + decoded_line = json_data.decode("utf-8") + if decoded_line.lower() != '[DONE]'.lower(): + obj = json.loads(json_data) + if obj['choices'][0]['delta'].get('content') is not None: + content = obj['choices'][0]['delta']['content'] + text += content + yield text + + # native result. + # payloads = { + # "model": "gpt-3.5-turbo", # just for test, remove this later + # "messages": history, + # "temperature": params.get("temperature"), + # "max_tokens": params.get("max_new_tokens"), + # } + # + # res = requests.post( + # CFG.proxy_server_url, headers=headers, json=payloads, stream=True + # ) + # + # text = "" + # line = res.content + # if line: + # decoded_line = line.decode("utf-8") + # json_line = json.loads(decoded_line) + # print(json_line) + # text += json_line["choices"][0]["message"]["content"] + # yield text \ No newline at end of file diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index d2730e0d5..a1dba135b 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -84,6 +84,11 @@ class ModelWorker: return get_embeddings(self.model, self.tokenizer, prompt) +model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] +worker = ModelWorker( + model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE, num_gpus=1 +) + app = FastAPI() @@ -157,11 +162,4 @@ def embeddings(prompt_request: EmbeddingRequest): if __name__ == "__main__": - model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] - print(model_path, DEVICE) - - worker = ModelWorker( - model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE, num_gpus=1 - ) - uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info") diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 7f09339ae..ceb52e4d0 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -658,7 +658,7 @@ def signal_handler(sig, frame): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") - parser.add_argument("--port", type=int) + parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT) parser.add_argument("--concurrency-count", type=int, default=10) parser.add_argument( "--model-list-mode", type=str, default="once", choices=["once", "reload"]