mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-28 14:27:20 +00:00
feat(model): support ollama as an optional llm & embedding proxy (#1475)
Signed-off-by: shanhaikang.shk <shanhaikang.shk@oceanbase.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
parent
0f8188b152
commit
744b3e4933
@ -100,3 +100,6 @@ ignore_missing_imports = True
|
|||||||
|
|
||||||
[mypy-rich.*]
|
[mypy-rich.*]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[mypy-ollama.*]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
@ -69,6 +69,7 @@ LLM_MODEL_CONFIG = {
|
|||||||
"yi_proxyllm": "yi_proxyllm",
|
"yi_proxyllm": "yi_proxyllm",
|
||||||
# https://platform.moonshot.cn/docs/
|
# https://platform.moonshot.cn/docs/
|
||||||
"moonshot_proxyllm": "moonshot_proxyllm",
|
"moonshot_proxyllm": "moonshot_proxyllm",
|
||||||
|
"ollama_proxyllm": "ollama_proxyllm",
|
||||||
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
|
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
|
||||||
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
|
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
|
||||||
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
|
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
|
||||||
@ -200,6 +201,7 @@ EMBEDDING_MODEL_CONFIG = {
|
|||||||
"proxy_azure": "proxy_azure",
|
"proxy_azure": "proxy_azure",
|
||||||
# Common HTTP embedding model
|
# Common HTTP embedding model
|
||||||
"proxy_http_openapi": "proxy_http_openapi",
|
"proxy_http_openapi": "proxy_http_openapi",
|
||||||
|
"proxy_ollama": "proxy_ollama",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,6 +50,16 @@ class EmbeddingLoader:
|
|||||||
if proxy_param.proxy_backend:
|
if proxy_param.proxy_backend:
|
||||||
openapi_param["model_name"] = proxy_param.proxy_backend
|
openapi_param["model_name"] = proxy_param.proxy_backend
|
||||||
return OpenAPIEmbeddings(**openapi_param)
|
return OpenAPIEmbeddings(**openapi_param)
|
||||||
|
elif model_name in ["proxy_ollama"]:
|
||||||
|
from dbgpt.rag.embedding import OllamaEmbeddings
|
||||||
|
|
||||||
|
proxy_param = cast(ProxyEmbeddingParameters, param)
|
||||||
|
ollama_param = {}
|
||||||
|
if proxy_param.proxy_server_url:
|
||||||
|
ollama_param["api_url"] = proxy_param.proxy_server_url
|
||||||
|
if proxy_param.proxy_backend:
|
||||||
|
ollama_param["model_name"] = proxy_param.proxy_backend
|
||||||
|
return OllamaEmbeddings(**ollama_param)
|
||||||
else:
|
else:
|
||||||
from dbgpt.rag.embedding import HuggingFaceEmbeddings
|
from dbgpt.rag.embedding import HuggingFaceEmbeddings
|
||||||
|
|
||||||
|
@ -114,6 +114,23 @@ class TongyiProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
|||||||
return tongyi_generate_stream
|
return tongyi_generate_stream
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaLLMModelAdapter(ProxyLLMModelAdapter):
|
||||||
|
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||||
|
return lower_model_name_or_path == "ollama_proxyllm"
|
||||||
|
|
||||||
|
def get_llm_client_class(
|
||||||
|
self, params: ProxyModelParameters
|
||||||
|
) -> Type[ProxyLLMClient]:
|
||||||
|
from dbgpt.model.proxy.llms.ollama import OllamaLLMClient
|
||||||
|
|
||||||
|
return OllamaLLMClient
|
||||||
|
|
||||||
|
def get_generate_stream_function(self, model, model_path: str):
|
||||||
|
from dbgpt.model.proxy.llms.ollama import ollama_generate_stream
|
||||||
|
|
||||||
|
return ollama_generate_stream
|
||||||
|
|
||||||
|
|
||||||
class ZhipuProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
class ZhipuProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
||||||
support_system_message = False
|
support_system_message = False
|
||||||
|
|
||||||
@ -279,6 +296,7 @@ class MoonshotProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
|||||||
|
|
||||||
register_model_adapter(OpenAIProxyLLMModelAdapter)
|
register_model_adapter(OpenAIProxyLLMModelAdapter)
|
||||||
register_model_adapter(TongyiProxyLLMModelAdapter)
|
register_model_adapter(TongyiProxyLLMModelAdapter)
|
||||||
|
register_model_adapter(OllamaLLMModelAdapter)
|
||||||
register_model_adapter(ZhipuProxyLLMModelAdapter)
|
register_model_adapter(ZhipuProxyLLMModelAdapter)
|
||||||
register_model_adapter(WenxinProxyLLMModelAdapter)
|
register_model_adapter(WenxinProxyLLMModelAdapter)
|
||||||
register_model_adapter(GeminiProxyLLMModelAdapter)
|
register_model_adapter(GeminiProxyLLMModelAdapter)
|
||||||
|
@ -556,7 +556,7 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
|
|||||||
|
|
||||||
|
|
||||||
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
||||||
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi",
|
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama",
|
||||||
}
|
}
|
||||||
|
|
||||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
||||||
|
@ -11,6 +11,7 @@ def __lazy_import(name):
|
|||||||
"ZhipuLLMClient": "dbgpt.model.proxy.llms.zhipu",
|
"ZhipuLLMClient": "dbgpt.model.proxy.llms.zhipu",
|
||||||
"YiLLMClient": "dbgpt.model.proxy.llms.yi",
|
"YiLLMClient": "dbgpt.model.proxy.llms.yi",
|
||||||
"MoonshotLLMClient": "dbgpt.model.proxy.llms.moonshot",
|
"MoonshotLLMClient": "dbgpt.model.proxy.llms.moonshot",
|
||||||
|
"OllamaLLMClient": "dbgpt.model.proxy.llms.ollama",
|
||||||
}
|
}
|
||||||
|
|
||||||
if name in module_path:
|
if name in module_path:
|
||||||
@ -33,4 +34,5 @@ __all__ = [
|
|||||||
"SparkLLMClient",
|
"SparkLLMClient",
|
||||||
"YiLLMClient",
|
"YiLLMClient",
|
||||||
"MoonshotLLMClient",
|
"MoonshotLLMClient",
|
||||||
|
"OllamaLLMClient",
|
||||||
]
|
]
|
||||||
|
101
dbgpt/model/proxy/llms/ollama.py
Normal file
101
dbgpt/model/proxy/llms/ollama.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import logging
|
||||||
|
from concurrent.futures import Executor
|
||||||
|
from typing import Iterator, Optional
|
||||||
|
|
||||||
|
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
|
||||||
|
from dbgpt.model.parameter import ProxyModelParameters
|
||||||
|
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||||
|
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def ollama_generate_stream(
|
||||||
|
model: ProxyModel, tokenizer, params, device, context_len=4096
|
||||||
|
):
|
||||||
|
client: OllamaLLMClient = model.proxy_llm_client
|
||||||
|
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
|
||||||
|
request = ModelRequest.build_request(
|
||||||
|
client.default_model,
|
||||||
|
messages=params["messages"],
|
||||||
|
temperature=params.get("temperature"),
|
||||||
|
context=context,
|
||||||
|
max_new_tokens=params.get("max_new_tokens"),
|
||||||
|
)
|
||||||
|
for r in client.sync_generate_stream(request):
|
||||||
|
yield r
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaLLMClient(ProxyLLMClient):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
host: Optional[str] = None,
|
||||||
|
model_alias: Optional[str] = "ollama_proxyllm",
|
||||||
|
context_length: Optional[int] = 4096,
|
||||||
|
executor: Optional[Executor] = None,
|
||||||
|
):
|
||||||
|
if not model:
|
||||||
|
model = "llama2"
|
||||||
|
if not host:
|
||||||
|
host = "http://localhost:11434"
|
||||||
|
self._model = model
|
||||||
|
self._host = host
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
model_names=[model, model_alias],
|
||||||
|
context_length=context_length,
|
||||||
|
executor=executor,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new_client(
|
||||||
|
cls,
|
||||||
|
model_params: ProxyModelParameters,
|
||||||
|
default_executor: Optional[Executor] = None,
|
||||||
|
) -> "OllamaLLMClient":
|
||||||
|
return cls(
|
||||||
|
model=model_params.proxyllm_backend,
|
||||||
|
host=model_params.proxy_server_url,
|
||||||
|
model_alias=model_params.model_name,
|
||||||
|
context_length=model_params.max_context_size,
|
||||||
|
executor=default_executor,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_model(self) -> str:
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
def sync_generate_stream(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
message_converter: Optional[MessageConverter] = None,
|
||||||
|
) -> Iterator[ModelOutput]:
|
||||||
|
try:
|
||||||
|
import ollama
|
||||||
|
from ollama import Client
|
||||||
|
except ImportError as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import python package: ollama "
|
||||||
|
"Please install ollama by command `pip install ollama"
|
||||||
|
) from e
|
||||||
|
request = self.local_covert_message(request, message_converter)
|
||||||
|
messages = request.to_common_messages()
|
||||||
|
|
||||||
|
model = request.model or self._model
|
||||||
|
client = Client(self._host)
|
||||||
|
try:
|
||||||
|
stream = client.chat(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
content = ""
|
||||||
|
for chunk in stream:
|
||||||
|
content = content + chunk["message"]["content"]
|
||||||
|
yield ModelOutput(text=content, error_code=0)
|
||||||
|
except ollama.ResponseError as e:
|
||||||
|
return ModelOutput(
|
||||||
|
text=f"**Ollama Response Error, Please CheckErrorInfo.**: {e}",
|
||||||
|
error_code=-1,
|
||||||
|
)
|
@ -12,6 +12,7 @@ from .embeddings import ( # noqa: F401
|
|||||||
HuggingFaceInferenceAPIEmbeddings,
|
HuggingFaceInferenceAPIEmbeddings,
|
||||||
HuggingFaceInstructEmbeddings,
|
HuggingFaceInstructEmbeddings,
|
||||||
JinaEmbeddings,
|
JinaEmbeddings,
|
||||||
|
OllamaEmbeddings,
|
||||||
OpenAPIEmbeddings,
|
OpenAPIEmbeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,6 +24,7 @@ __ALL__ = [
|
|||||||
"HuggingFaceInstructEmbeddings",
|
"HuggingFaceInstructEmbeddings",
|
||||||
"JinaEmbeddings",
|
"JinaEmbeddings",
|
||||||
"OpenAPIEmbeddings",
|
"OpenAPIEmbeddings",
|
||||||
|
"OllamaEmbeddings",
|
||||||
"DefaultEmbeddingFactory",
|
"DefaultEmbeddingFactory",
|
||||||
"EmbeddingFactory",
|
"EmbeddingFactory",
|
||||||
"WrappedEmbeddingFactory",
|
"WrappedEmbeddingFactory",
|
||||||
|
@ -736,3 +736,94 @@ class OpenAPIEmbeddings(BaseModel, Embeddings):
|
|||||||
"""Asynchronous Embed query text."""
|
"""Asynchronous Embed query text."""
|
||||||
embeddings = await self.aembed_documents([text])
|
embeddings = await self.aembed_documents([text])
|
||||||
return embeddings[0]
|
return embeddings[0]
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaEmbeddings(BaseModel, Embeddings):
|
||||||
|
"""Ollama proxy embeddings.
|
||||||
|
|
||||||
|
This class is used to get embeddings for a list of texts using the Ollama API.
|
||||||
|
It requires a proxy server url `api_url` and a model name `model_name`.
|
||||||
|
The default model name is "llama2".
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_url: str = Field(
|
||||||
|
default="http://localhost:11434",
|
||||||
|
description="The URL of the embeddings API.",
|
||||||
|
)
|
||||||
|
model_name: str = Field(
|
||||||
|
default="llama2", description="The name of the model to use."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
"""Initialize the OllamaEmbeddings."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Get the embeddings for a list of texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts (Documents): A list of texts to get embeddings for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedded texts as List[List[float]], where each inner List[float]
|
||||||
|
corresponds to a single input text.
|
||||||
|
"""
|
||||||
|
return [self.embed_query(text) for text in texts]
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Compute query embeddings using a OpenAPI embedding model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embeddings for the text.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import ollama
|
||||||
|
from ollama import Client
|
||||||
|
except ImportError as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import python package: ollama "
|
||||||
|
"Please install ollama by command `pip install ollama"
|
||||||
|
) from e
|
||||||
|
try:
|
||||||
|
return (
|
||||||
|
Client(self.api_url).embeddings(model=self.model_name, prompt=text)
|
||||||
|
)["embedding"]
|
||||||
|
except ollama.ResponseError as e:
|
||||||
|
raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}")
|
||||||
|
|
||||||
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Asynchronous Embed search docs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: A list of texts to get embeddings for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[float]]: Embedded texts as List[List[float]], where each inner
|
||||||
|
List[float] corresponds to a single input text.
|
||||||
|
"""
|
||||||
|
embeddings = []
|
||||||
|
for text in texts:
|
||||||
|
embedding = await self.aembed_query(text)
|
||||||
|
embeddings.append(embedding)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
|
"""Asynchronous Embed query text."""
|
||||||
|
try:
|
||||||
|
import ollama
|
||||||
|
from ollama import AsyncClient
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"The ollama python package is not installed. "
|
||||||
|
"Please install it with `pip install ollama`"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
embedding = await AsyncClient(host=self.api_url).embeddings(
|
||||||
|
model=self.model_name, prompt=text
|
||||||
|
)
|
||||||
|
return embedding["embedding"]
|
||||||
|
except ollama.ResponseError as e:
|
||||||
|
raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}")
|
||||||
|
1
setup.py
1
setup.py
@ -658,6 +658,7 @@ def default_requires():
|
|||||||
"dashscope",
|
"dashscope",
|
||||||
"chardet",
|
"chardet",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
"ollama",
|
||||||
]
|
]
|
||||||
setup_spec.extras["default"] += setup_spec.extras["framework"]
|
setup_spec.extras["default"] += setup_spec.extras["framework"]
|
||||||
setup_spec.extras["default"] += setup_spec.extras["rag"]
|
setup_spec.extras["default"] += setup_spec.extras["rag"]
|
||||||
|
Loading…
Reference in New Issue
Block a user