feat(model): support ollama as an optional llm & embedding proxy (#1475)

Signed-off-by: shanhaikang.shk <shanhaikang.shk@oceanbase.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
GITHUBear
2024-04-28 18:36:45 +08:00
committed by GitHub
parent 0f8188b152
commit 744b3e4933
10 changed files with 231 additions and 1 deletions

View File

@@ -12,6 +12,7 @@ from .embeddings import ( # noqa: F401
HuggingFaceInferenceAPIEmbeddings,
HuggingFaceInstructEmbeddings,
JinaEmbeddings,
OllamaEmbeddings,
OpenAPIEmbeddings,
)
@@ -23,6 +24,7 @@ __ALL__ = [
"HuggingFaceInstructEmbeddings",
"JinaEmbeddings",
"OpenAPIEmbeddings",
"OllamaEmbeddings",
"DefaultEmbeddingFactory",
"EmbeddingFactory",
"WrappedEmbeddingFactory",

View File

@@ -736,3 +736,94 @@ class OpenAPIEmbeddings(BaseModel, Embeddings):
"""Asynchronous Embed query text."""
embeddings = await self.aembed_documents([text])
return embeddings[0]
class OllamaEmbeddings(BaseModel, Embeddings):
"""Ollama proxy embeddings.
This class is used to get embeddings for a list of texts using the Ollama API.
It requires a proxy server url `api_url` and a model name `model_name`.
The default model name is "llama2".
"""
api_url: str = Field(
default="http://localhost:11434",
description="The URL of the embeddings API.",
)
model_name: str = Field(
default="llama2", description="The name of the model to use."
)
def __init__(self, **kwargs):
"""Initialize the OllamaEmbeddings."""
super().__init__(**kwargs)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embedded texts as List[List[float]], where each inner List[float]
corresponds to a single input text.
"""
return [self.embed_query(text) for text in texts]
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a OpenAPI embedding model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
try:
import ollama
from ollama import Client
except ImportError as e:
raise ValueError(
"Could not import python package: ollama "
"Please install ollama by command `pip install ollama"
) from e
try:
return (
Client(self.api_url).embeddings(model=self.model_name, prompt=text)
)["embedding"]
except ollama.ResponseError as e:
raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}")
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs.
Args:
texts: A list of texts to get embeddings for.
Returns:
List[List[float]]: Embedded texts as List[List[float]], where each inner
List[float] corresponds to a single input text.
"""
embeddings = []
for text in texts:
embedding = await self.aembed_query(text)
embeddings.append(embedding)
return embeddings
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
try:
import ollama
from ollama import AsyncClient
except ImportError:
raise ValueError(
"The ollama python package is not installed. "
"Please install it with `pip install ollama`"
)
try:
embedding = await AsyncClient(host=self.api_url).embeddings(
model=self.model_name, prompt=text
)
return embedding["embedding"]
except ollama.ResponseError as e:
raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}")