fix(model): Fix remote embedding model error in some case (#587)

This commit is contained in:
Aries-ckt 2023-09-14 20:20:20 +08:00 committed by GitHub
commit 465c065ef0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 105 additions and 26 deletions

View File

@ -152,6 +152,7 @@ class SystemApp(LifeCycle):
@self.app.on_event("startup")
async def startup_event():
"""ASGI app startup event handler."""
# TODO catch exception and shutdown if worker manager start failed
asyncio.create_task(self.async_after_start())
self.after_start()

View File

@ -18,9 +18,11 @@ class EmbeddingFactory(BaseComponet, ABC):
class DefaultEmbeddingFactory(EmbeddingFactory):
def __init__(self, system_app=None, model_name: str = None, **kwargs: Any) -> None:
def __init__(
self, system_app=None, default_model_name: str = None, **kwargs: Any
) -> None:
super().__init__(system_app=system_app)
self._default_model_name = model_name
self._default_model_name = default_model_name
self.kwargs = kwargs
def init_app(self, system_app):
@ -31,9 +33,13 @@ class DefaultEmbeddingFactory(EmbeddingFactory):
) -> "Embeddings":
if not model_name:
model_name = self._default_model_name
new_kwargs = {k: v for k, v in self.kwargs.items()}
new_kwargs["model_name"] = model_name
if embedding_cls:
return embedding_cls(model_name=model_name, **self.kwargs)
return embedding_cls(**new_kwargs)
else:
from langchain.embeddings import HuggingFaceEmbeddings
return HuggingFaceEmbeddings(model_name=model_name, **self.kwargs)
return HuggingFaceEmbeddings(**new_kwargs)

View File

@ -1,5 +1,5 @@
import logging
from typing import Dict, List, Type
from typing import Dict, List, Type, Optional
from pilot.configs.model_config import get_device
from pilot.model.loader import _get_model_real_path
@ -45,21 +45,12 @@ class EmbeddingsModelWorker(ModelWorker):
self, command_args: List[str] = None
) -> EmbeddingModelParameters:
param_cls = self.model_param_class()
model_args = EnvArgumentParser()
env_prefix = EnvArgumentParser.get_env_prefix(self.model_name)
model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass(
param_cls,
env_prefix=env_prefix,
command_args=command_args,
return _parse_embedding_params(
model_name=self.model_name,
model_path=self.model_path,
command_args=command_args,
param_cls=param_cls,
)
if not model_params.device:
model_params.device = get_device()
logger.info(
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
)
return model_params
def start(
self,
@ -100,3 +91,26 @@ class EmbeddingsModelWorker(ModelWorker):
logger.info(f"Receive embeddings request, model: {model}")
input: List[str] = params["input"]
return self._embeddings_impl.embed_documents(input)
def _parse_embedding_params(
model_name: str,
model_path: str,
command_args: List[str] = None,
param_cls: Optional[Type] = EmbeddingModelParameters,
):
model_args = EnvArgumentParser()
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass(
param_cls,
env_prefix=env_prefix,
command_args=command_args,
model_name=model_name,
model_path=model_path,
)
if not model_params.device:
model_params.device = get_device()
logger.info(
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
)
return model_params

View File

@ -635,6 +635,7 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
@app.on_event("startup")
async def startup_event():
# TODO catch exception and shutdown if worker manager start failed
asyncio.create_task(worker_manager.start())
@app.on_event("shutdown")

View File

@ -102,6 +102,12 @@ class WebWerverParameters(BaseParameters):
"help": "Whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. "
},
)
remote_embedding: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to enable remote embedding models. If it is True, you need to start a embedding model through `dbgpt start worker --worker_type text2vec --model_name xxx --model_path xxx`"
},
)
log_level: Optional[str] = field(
default="INFO",
metadata={

View File

@ -1,21 +1,65 @@
from typing import Any, Type, TYPE_CHECKING
from pilot.componet import SystemApp
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
import logging
from pilot.configs.model_config import get_device
from pilot.embedding_engine.embedding_factory import (
EmbeddingFactory,
DefaultEmbeddingFactory,
)
from pilot.server.base import WebWerverParameters
from pilot.utils.parameter_utils import EnvArgumentParser
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
def initialize_componets(system_app: SystemApp, embedding_model_name: str):
from pilot.model.cluster import worker_manager
logger = logging.getLogger(__name__)
def initialize_componets(
param: WebWerverParameters,
system_app: SystemApp,
embedding_model_name: str,
embedding_model_path: str,
):
from pilot.model.cluster.controller.controller import controller
system_app.register(
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
)
system_app.register_instance(controller)
_initialize_embedding_model(
param, system_app, embedding_model_name, embedding_model_path
)
def _initialize_embedding_model(
param: WebWerverParameters,
system_app: SystemApp,
embedding_model_name: str,
embedding_model_path: str,
):
from pilot.model.cluster import worker_manager
if param.remote_embedding:
logger.info("Register remote RemoteEmbeddingFactory")
system_app.register(
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
)
else:
from pilot.model.parameter import EmbeddingModelParameters
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
model_params: EmbeddingModelParameters = _parse_embedding_params(
model_name=embedding_model_name,
model_path=embedding_model_path,
param_cls=EmbeddingModelParameters,
)
kwargs = model_params.build_kwargs(model_name=embedding_model_path)
logger.info(f"Register local DefaultEmbeddingFactory with kwargs: {kwargs}")
system_app.register(
DefaultEmbeddingFactory, default_model_name=embedding_model_path, **kwargs
)
class RemoteEmbeddingFactory(EmbeddingFactory):
def __init__(

View File

@ -109,20 +109,27 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
# Before start
system_app.before_start()
print(param)
embedding_model_name = CFG.EMBEDDING_MODEL
embedding_model_path = EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
server_init(param, system_app)
model_start_listener = _create_model_start_listener(system_app)
initialize_componets(system_app, CFG.EMBEDDING_MODEL)
initialize_componets(param, system_app, embedding_model_name, embedding_model_path)
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
if not param.light:
print("Model Unified Deployment Mode!")
if not param.remote_embedding:
embedding_model_name, embedding_model_path = None, None
initialize_worker_manager_in_client(
app=app,
model_name=CFG.LLM_MODEL,
model_path=model_path,
local_port=param.port,
embedding_model_name=CFG.EMBEDDING_MODEL,
embedding_model_path=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
embedding_model_name=embedding_model_name,
embedding_model_path=embedding_model_path,
start_listener=model_start_listener,
)