feat(datasource):add oceanbase support (#1622)

Co-authored-by: csunny <cfqsunny@163.com>
Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
明天
2024-06-13 15:13:50 +08:00
committed by GitHub
parent 58d08780d6
commit 0541d1494c
37 changed files with 117 additions and 36 deletions

View File

@@ -90,9 +90,9 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
) from exc
kwargs["client"] = sentence_transformers.SentenceTransformer(
kwargs.get("model_name"),
kwargs.get("model_name") or DEFAULT_MODEL_NAME,
cache_folder=kwargs.get("cache_folder"),
**kwargs.get("model_kwargs"),
**(kwargs.get("model_kwargs") or {}),
)
super().__init__(**kwargs)

View File

@@ -3,6 +3,7 @@
from typing import Any, Dict, List, Optional, cast
import aiohttp
import numpy as np
import requests
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
@@ -16,7 +17,7 @@ class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
client: Any #: :meta private:
model_name: str = "BAAI/bge-reranker-base"
max_length: Optional[int] = None
max_length: int
"""Max length for input sequences. Longer sequences will be truncated. If None, max
length of the model will be used"""
"""Model name to use."""
@@ -33,9 +34,9 @@ class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
)
kwargs["client"] = CrossEncoder(
kwargs.get("model_name"),
kwargs.get("model_name", "BAAI/bge-reranker-base"),
max_length=kwargs.get("max_length"),
**kwargs.get("model_kwargs"),
**(kwargs.get("model_kwargs") or {}),
)
super().__init__(**kwargs)
@@ -54,7 +55,9 @@ class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
query_content_pairs = [[query, candidate] for candidate in candidates]
_model = cast(CrossEncoder, self.client)
rank_scores = _model.predict(sentences=query_content_pairs)
return rank_scores.tolist()
if isinstance(rank_scores, np.ndarray):
rank_scores = rank_scores.tolist()
return rank_scores
class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):