diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index eee05f3e1..da21e4ac8 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -27,7 +27,7 @@ LLM_MODEL_CONFIG = { VECTOR_SEARCH_TOP_K = 3 LLM_MODEL = "vicuna-13b" LIMIT_MODEL_CONCURRENCY = 5 -MAX_POSITION_EMBEDDINGS = 2048 +MAX_POSITION_EMBEDDINGS = 4096 VICUNA_MODEL_SERVER = "http://192.168.31.114:8000" diff --git a/pilot/model/inference.py b/pilot/model/inference.py index 66766b3b3..532be9c33 100644 --- a/pilot/model/inference.py +++ b/pilot/model/inference.py @@ -5,13 +5,13 @@ import torch @torch.inference_mode() def generate_stream(model, tokenizer, params, device, - context_len=2048, stream_interval=2): + context_len=4096, stream_interval=2): """Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """ prompt = params["prompt"] l_prompt = len(prompt) temperature = float(params.get("temperature", 1.0)) - max_new_tokens = int(params.get("max_new_tokens", 256)) + max_new_tokens = int(params.get("max_new_tokens", 2048)) stop_str = params.get("stop", None) input_ids = tokenizer(prompt).input_ids diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 1ca3cee20..5c09d7b85 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -301,8 +301,8 @@ def build_single_model_ui(): max_output_tokens = gr.Slider( minimum=0, - maximum=1024, - value=512, + maximum=4096, + value=2048, step=64, interactive=True, label="最大输出Token数", diff --git a/requirements.txt b/requirements.txt index 3e86cd311..d3bcf1bec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -51,4 +51,5 @@ wandb llama-index==0.5.27 pymysql unstructured==0.6.3 -pytesseract==0.3.10 \ No newline at end of file +pytesseract==0.3.10 +chromadb \ No newline at end of file