mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 20:53:48 +00:00
feat(model): support ollama as an optional llm & embedding proxy (#1475)
Signed-off-by: shanhaikang.shk <shanhaikang.shk@oceanbase.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from .embeddings import ( # noqa: F401
|
||||
HuggingFaceInferenceAPIEmbeddings,
|
||||
HuggingFaceInstructEmbeddings,
|
||||
JinaEmbeddings,
|
||||
OllamaEmbeddings,
|
||||
OpenAPIEmbeddings,
|
||||
)
|
||||
|
||||
@@ -23,6 +24,7 @@ __ALL__ = [
|
||||
"HuggingFaceInstructEmbeddings",
|
||||
"JinaEmbeddings",
|
||||
"OpenAPIEmbeddings",
|
||||
"OllamaEmbeddings",
|
||||
"DefaultEmbeddingFactory",
|
||||
"EmbeddingFactory",
|
||||
"WrappedEmbeddingFactory",
|
||||
|
@@ -736,3 +736,94 @@ class OpenAPIEmbeddings(BaseModel, Embeddings):
|
||||
"""Asynchronous Embed query text."""
|
||||
embeddings = await self.aembed_documents([text])
|
||||
return embeddings[0]
|
||||
|
||||
|
||||
class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
"""Ollama proxy embeddings.
|
||||
|
||||
This class is used to get embeddings for a list of texts using the Ollama API.
|
||||
It requires a proxy server url `api_url` and a model name `model_name`.
|
||||
The default model name is "llama2".
|
||||
"""
|
||||
|
||||
api_url: str = Field(
|
||||
default="http://localhost:11434",
|
||||
description="The URL of the embeddings API.",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="llama2", description="The name of the model to use."
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the OllamaEmbeddings."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Get the embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
texts (Documents): A list of texts to get embeddings for.
|
||||
|
||||
Returns:
|
||||
Embedded texts as List[List[float]], where each inner List[float]
|
||||
corresponds to a single input text.
|
||||
"""
|
||||
return [self.embed_query(text) for text in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a OpenAPI embedding model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
try:
|
||||
import ollama
|
||||
from ollama import Client
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"Could not import python package: ollama "
|
||||
"Please install ollama by command `pip install ollama"
|
||||
) from e
|
||||
try:
|
||||
return (
|
||||
Client(self.api_url).embeddings(model=self.model_name, prompt=text)
|
||||
)["embedding"]
|
||||
except ollama.ResponseError as e:
|
||||
raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}")
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous Embed search docs.
|
||||
|
||||
Args:
|
||||
texts: A list of texts to get embeddings for.
|
||||
|
||||
Returns:
|
||||
List[List[float]]: Embedded texts as List[List[float]], where each inner
|
||||
List[float] corresponds to a single input text.
|
||||
"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embedding = await self.aembed_query(text)
|
||||
embeddings.append(embedding)
|
||||
return embeddings
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
try:
|
||||
import ollama
|
||||
from ollama import AsyncClient
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The ollama python package is not installed. "
|
||||
"Please install it with `pip install ollama`"
|
||||
)
|
||||
try:
|
||||
embedding = await AsyncClient(host=self.api_url).embeddings(
|
||||
model=self.model_name, prompt=text
|
||||
)
|
||||
return embedding["embedding"]
|
||||
except ollama.ResponseError as e:
|
||||
raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}")
|
||||
|
Reference in New Issue
Block a user