feat(model): support ollama as an optional llm & embedding proxy (#1475)

Signed-off-by: shanhaikang.shk <shanhaikang.shk@oceanbase.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
GITHUBear 2024-04-28 18:36:45 +08:00 committed by GitHub
parent 0f8188b152
commit 744b3e4933
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 231 additions and 1 deletions

View File

@ -100,3 +100,6 @@ ignore_missing_imports = True
[mypy-rich.*]
ignore_missing_imports = True
[mypy-ollama.*]
ignore_missing_imports = True

View File

@ -69,6 +69,7 @@ LLM_MODEL_CONFIG = {
"yi_proxyllm": "yi_proxyllm",
# https://platform.moonshot.cn/docs/
"moonshot_proxyllm": "moonshot_proxyllm",
"ollama_proxyllm": "ollama_proxyllm",
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
@ -200,6 +201,7 @@ EMBEDDING_MODEL_CONFIG = {
"proxy_azure": "proxy_azure",
# Common HTTP embedding model
"proxy_http_openapi": "proxy_http_openapi",
"proxy_ollama": "proxy_ollama",
}

View File

@ -50,6 +50,16 @@ class EmbeddingLoader:
if proxy_param.proxy_backend:
openapi_param["model_name"] = proxy_param.proxy_backend
return OpenAPIEmbeddings(**openapi_param)
elif model_name in ["proxy_ollama"]:
from dbgpt.rag.embedding import OllamaEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
ollama_param = {}
if proxy_param.proxy_server_url:
ollama_param["api_url"] = proxy_param.proxy_server_url
if proxy_param.proxy_backend:
ollama_param["model_name"] = proxy_param.proxy_backend
return OllamaEmbeddings(**ollama_param)
else:
from dbgpt.rag.embedding import HuggingFaceEmbeddings

View File

@ -114,6 +114,23 @@ class TongyiProxyLLMModelAdapter(ProxyLLMModelAdapter):
return tongyi_generate_stream
class OllamaLLMModelAdapter(ProxyLLMModelAdapter):
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path == "ollama_proxyllm"
def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
from dbgpt.model.proxy.llms.ollama import OllamaLLMClient
return OllamaLLMClient
def get_generate_stream_function(self, model, model_path: str):
from dbgpt.model.proxy.llms.ollama import ollama_generate_stream
return ollama_generate_stream
class ZhipuProxyLLMModelAdapter(ProxyLLMModelAdapter):
support_system_message = False
@ -279,6 +296,7 @@ class MoonshotProxyLLMModelAdapter(ProxyLLMModelAdapter):
register_model_adapter(OpenAIProxyLLMModelAdapter)
register_model_adapter(TongyiProxyLLMModelAdapter)
register_model_adapter(OllamaLLMModelAdapter)
register_model_adapter(ZhipuProxyLLMModelAdapter)
register_model_adapter(WenxinProxyLLMModelAdapter)
register_model_adapter(GeminiProxyLLMModelAdapter)

View File

@ -556,7 +556,7 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi",
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama",
}
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}

View File

@ -11,6 +11,7 @@ def __lazy_import(name):
"ZhipuLLMClient": "dbgpt.model.proxy.llms.zhipu",
"YiLLMClient": "dbgpt.model.proxy.llms.yi",
"MoonshotLLMClient": "dbgpt.model.proxy.llms.moonshot",
"OllamaLLMClient": "dbgpt.model.proxy.llms.ollama",
}
if name in module_path:
@ -33,4 +34,5 @@ __all__ = [
"SparkLLMClient",
"YiLLMClient",
"MoonshotLLMClient",
"OllamaLLMClient",
]

View File

@ -0,0 +1,101 @@
import logging
from concurrent.futures import Executor
from typing import Iterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
logger = logging.getLogger(__name__)
def ollama_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=4096
):
client: OllamaLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
for r in client.sync_generate_stream(request):
yield r
class OllamaLLMClient(ProxyLLMClient):
def __init__(
self,
model: Optional[str] = None,
host: Optional[str] = None,
model_alias: Optional[str] = "ollama_proxyllm",
context_length: Optional[int] = 4096,
executor: Optional[Executor] = None,
):
if not model:
model = "llama2"
if not host:
host = "http://localhost:11434"
self._model = model
self._host = host
super().__init__(
model_names=[model, model_alias],
context_length=context_length,
executor=executor,
)
@classmethod
def new_client(
cls,
model_params: ProxyModelParameters,
default_executor: Optional[Executor] = None,
) -> "OllamaLLMClient":
return cls(
model=model_params.proxyllm_backend,
host=model_params.proxy_server_url,
model_alias=model_params.model_name,
context_length=model_params.max_context_size,
executor=default_executor,
)
@property
def default_model(self) -> str:
return self._model
def sync_generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
try:
import ollama
from ollama import Client
except ImportError as e:
raise ValueError(
"Could not import python package: ollama "
"Please install ollama by command `pip install ollama"
) from e
request = self.local_covert_message(request, message_converter)
messages = request.to_common_messages()
model = request.model or self._model
client = Client(self._host)
try:
stream = client.chat(
model=model,
messages=messages,
stream=True,
)
content = ""
for chunk in stream:
content = content + chunk["message"]["content"]
yield ModelOutput(text=content, error_code=0)
except ollama.ResponseError as e:
return ModelOutput(
text=f"**Ollama Response Error, Please CheckErrorInfo.**: {e}",
error_code=-1,
)

View File

@ -12,6 +12,7 @@ from .embeddings import ( # noqa: F401
HuggingFaceInferenceAPIEmbeddings,
HuggingFaceInstructEmbeddings,
JinaEmbeddings,
OllamaEmbeddings,
OpenAPIEmbeddings,
)
@ -23,6 +24,7 @@ __ALL__ = [
"HuggingFaceInstructEmbeddings",
"JinaEmbeddings",
"OpenAPIEmbeddings",
"OllamaEmbeddings",
"DefaultEmbeddingFactory",
"EmbeddingFactory",
"WrappedEmbeddingFactory",

View File

@ -736,3 +736,94 @@ class OpenAPIEmbeddings(BaseModel, Embeddings):
"""Asynchronous Embed query text."""
embeddings = await self.aembed_documents([text])
return embeddings[0]
class OllamaEmbeddings(BaseModel, Embeddings):
"""Ollama proxy embeddings.
This class is used to get embeddings for a list of texts using the Ollama API.
It requires a proxy server url `api_url` and a model name `model_name`.
The default model name is "llama2".
"""
api_url: str = Field(
default="http://localhost:11434",
description="The URL of the embeddings API.",
)
model_name: str = Field(
default="llama2", description="The name of the model to use."
)
def __init__(self, **kwargs):
"""Initialize the OllamaEmbeddings."""
super().__init__(**kwargs)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embedded texts as List[List[float]], where each inner List[float]
corresponds to a single input text.
"""
return [self.embed_query(text) for text in texts]
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a OpenAPI embedding model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
try:
import ollama
from ollama import Client
except ImportError as e:
raise ValueError(
"Could not import python package: ollama "
"Please install ollama by command `pip install ollama"
) from e
try:
return (
Client(self.api_url).embeddings(model=self.model_name, prompt=text)
)["embedding"]
except ollama.ResponseError as e:
raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}")
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs.
Args:
texts: A list of texts to get embeddings for.
Returns:
List[List[float]]: Embedded texts as List[List[float]], where each inner
List[float] corresponds to a single input text.
"""
embeddings = []
for text in texts:
embedding = await self.aembed_query(text)
embeddings.append(embedding)
return embeddings
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
try:
import ollama
from ollama import AsyncClient
except ImportError:
raise ValueError(
"The ollama python package is not installed. "
"Please install it with `pip install ollama`"
)
try:
embedding = await AsyncClient(host=self.api_url).embeddings(
model=self.model_name, prompt=text
)
return embedding["embedding"]
except ollama.ResponseError as e:
raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}")

View File

@ -658,6 +658,7 @@ def default_requires():
"dashscope",
"chardet",
"sentencepiece",
"ollama",
]
setup_spec.extras["default"] += setup_spec.extras["framework"]
setup_spec.extras["default"] += setup_spec.extras["rag"]