mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-29 23:01:38 +00:00
35 lines
1.2 KiB
Python
35 lines
1.2 KiB
Python
"""Wraps the third-party language model embeddings to the common interface."""
|
|
|
|
from typing import TYPE_CHECKING, List
|
|
|
|
from dbgpt.core import Embeddings
|
|
|
|
if TYPE_CHECKING:
|
|
from langchain.embeddings.base import (
|
|
Embeddings as LangChainEmbeddings, # mypy: ignore
|
|
)
|
|
|
|
|
|
class WrappedEmbeddings(Embeddings):
|
|
"""Wraps the third-party language model embeddings to the common interface."""
|
|
|
|
def __init__(self, embeddings: "LangChainEmbeddings") -> None:
|
|
"""Create a new WrappedEmbeddings."""
|
|
self._embeddings = embeddings
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Embed search docs."""
|
|
return self._embeddings.embed_documents(texts)
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""Embed query text."""
|
|
return self._embeddings.embed_query(text)
|
|
|
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Asynchronous Embed search docs."""
|
|
return await self._embeddings.aembed_documents(texts)
|
|
|
|
async def aembed_query(self, text: str) -> List[float]:
|
|
"""Asynchronous Embed query text."""
|
|
return await self._embeddings.aembed_query(text)
|