mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 05:31:40 +00:00
feat(datasource):add oceanbase support (#1622)
Co-authored-by: csunny <cfqsunny@163.com> Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
@@ -90,9 +90,9 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
) from exc
|
||||
|
||||
kwargs["client"] = sentence_transformers.SentenceTransformer(
|
||||
kwargs.get("model_name"),
|
||||
kwargs.get("model_name") or DEFAULT_MODEL_NAME,
|
||||
cache_folder=kwargs.get("cache_folder"),
|
||||
**kwargs.get("model_kwargs"),
|
||||
**(kwargs.get("model_kwargs") or {}),
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
|
||||
@@ -16,7 +17,7 @@ class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = "BAAI/bge-reranker-base"
|
||||
max_length: Optional[int] = None
|
||||
max_length: int
|
||||
"""Max length for input sequences. Longer sequences will be truncated. If None, max
|
||||
length of the model will be used"""
|
||||
"""Model name to use."""
|
||||
@@ -33,9 +34,9 @@ class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
)
|
||||
|
||||
kwargs["client"] = CrossEncoder(
|
||||
kwargs.get("model_name"),
|
||||
kwargs.get("model_name", "BAAI/bge-reranker-base"),
|
||||
max_length=kwargs.get("max_length"),
|
||||
**kwargs.get("model_kwargs"),
|
||||
**(kwargs.get("model_kwargs") or {}),
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -54,7 +55,9 @@ class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
query_content_pairs = [[query, candidate] for candidate in candidates]
|
||||
_model = cast(CrossEncoder, self.client)
|
||||
rank_scores = _model.predict(sentences=query_content_pairs)
|
||||
return rank_scores.tolist()
|
||||
if isinstance(rank_scores, np.ndarray):
|
||||
rank_scores = rank_scores.tolist()
|
||||
return rank_scores
|
||||
|
||||
|
||||
class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
|
Reference in New Issue
Block a user