feat(model): Support deploy rerank model (#1522)

This commit is contained in:
Fangyin Cheng
2024-05-16 14:50:16 +08:00
committed by GitHub
parent 559affe87d
commit 593e974405
29 changed files with 814 additions and 75 deletions

View File

@@ -255,6 +255,10 @@ class BaseEmbeddingModelParameters(BaseModelParameters):
def build_kwargs(self, **kwargs) -> Dict:
pass
def is_rerank_model(self) -> bool:
"""Check if the model is a rerank model"""
return False
@dataclass
class EmbeddingModelParameters(BaseEmbeddingModelParameters):
@@ -272,6 +276,19 @@ class EmbeddingModelParameters(BaseEmbeddingModelParameters):
},
)
rerank: Optional[bool] = field(
default=False, metadata={"help": "Whether the model is a rerank model"}
)
max_length: Optional[int] = field(
default=None,
metadata={
"help": "Max length for input sequences. Longer sequences will be "
"truncated. If None, max length of the model will be used, just for rerank"
" model now."
},
)
def build_kwargs(self, **kwargs) -> Dict:
model_kwargs, encode_kwargs = None, None
if self.device:
@@ -280,10 +297,16 @@ class EmbeddingModelParameters(BaseEmbeddingModelParameters):
encode_kwargs = {"normalize_embeddings": self.normalize_embeddings}
if model_kwargs:
kwargs["model_kwargs"] = model_kwargs
if self.is_rerank_model():
kwargs["max_length"] = self.max_length
if encode_kwargs:
kwargs["encode_kwargs"] = encode_kwargs
return kwargs
def is_rerank_model(self) -> bool:
"""Check if the model is a rerank model"""
return self.rerank
@dataclass
class ModelParameters(BaseModelParameters):
@@ -537,26 +560,35 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
metadata={"help": "Tto support Azure OpenAI Service custom deployment names"},
)
rerank: Optional[bool] = field(
default=False, metadata={"help": "Whether the model is a rerank model"}
)
def build_kwargs(self, **kwargs) -> Dict:
params = {
"openai_api_base": self.proxy_server_url,
"openai_api_key": self.proxy_api_key,
"openai_api_type": self.proxy_api_type if self.proxy_api_type else None,
"openai_api_version": self.proxy_api_version
if self.proxy_api_version
else None,
"openai_api_version": (
self.proxy_api_version if self.proxy_api_version else None
),
"model": self.proxy_backend,
"deployment": self.proxy_deployment
if self.proxy_deployment
else self.proxy_backend,
"deployment": (
self.proxy_deployment if self.proxy_deployment else self.proxy_backend
),
}
for k, v in kwargs:
params[k] = v
return params
def is_rerank_model(self) -> bool:
"""Check if the model is a rerank model"""
return self.rerank
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama",
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,"
"proxy_ollama,rerank_proxy_http_openapi",
}
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}