diff --git a/pilot/__init__.py b/pilot/__init__.py index a1531040e..b6865c44c 100644 --- a/pilot/__init__.py +++ b/pilot/__init__.py @@ -1,7 +1,12 @@ from pilot.source_embedding import (SourceEmbedding, register) +import os +import sys + +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) __all__ = [ "SourceEmbedding", "register" -] \ No newline at end of file +] diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index fa5803d3a..9afd2c01f 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -24,7 +24,7 @@ class BaseLLMAdaper: return model, tokenizer -llm_model_adapters = List[BaseLLMAdaper] = [] +llm_model_adapters: List[BaseLLMAdaper] = [] # Register llm models to adapters, by this we can use multi models. def register_llm_model_adapters(cls): diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 2b29949a3..e341cc457 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -1,14 +1,23 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import os import uvicorn import asyncio import json +import sys from typing import Optional, List from fastapi import FastAPI, Request, BackgroundTasks from fastapi.responses import StreamingResponse -from pilot.model.inference import generate_stream from pydantic import BaseModel + +global_counter = 0 +model_semaphore = None + +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) + +from pilot.model.inference import generate_stream from pilot.model.inference import generate_output, get_embeddings from pilot.model.loader import ModelLoader @@ -19,10 +28,6 @@ from pilot.configs.config import Config CFG = Config() model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] - -global_counter = 0 -model_semaphore = None - ml = ModelLoader(model_path=model_path) model, tokenizer = ml.loader(num_gpus=1, load_8bit=ISLOAD_8BIT, debug=ISDEBUG) #model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False) diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 0f19bc354..3ed989b07 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -13,6 +13,11 @@ import requests from urllib.parse import urljoin from langchain import PromptTemplate +import os +import sys + +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG from pilot.server.vectordb_qa import KnownLedgeBaseQA @@ -30,6 +35,8 @@ from pilot.prompts.generator import PromptGenerator from pilot.commands.exception_not_commands import NotCommands + + from pilot.conversation import ( default_conversation, conv_templates,