feat(model): support ollama as an optional llm & embedding proxy (#1475)

Signed-off-by: shanhaikang.shk <shanhaikang.shk@oceanbase.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
GITHUBear
2024-04-28 18:36:45 +08:00
committed by GitHub
parent 0f8188b152
commit 744b3e4933
10 changed files with 231 additions and 1 deletions

View File

@@ -50,6 +50,16 @@ class EmbeddingLoader:
if proxy_param.proxy_backend:
openapi_param["model_name"] = proxy_param.proxy_backend
return OpenAPIEmbeddings(**openapi_param)
elif model_name in ["proxy_ollama"]:
from dbgpt.rag.embedding import OllamaEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
ollama_param = {}
if proxy_param.proxy_server_url:
ollama_param["api_url"] = proxy_param.proxy_server_url
if proxy_param.proxy_backend:
ollama_param["model_name"] = proxy_param.proxy_backend
return OllamaEmbeddings(**ollama_param)
else:
from dbgpt.rag.embedding import HuggingFaceEmbeddings

View File

@@ -114,6 +114,23 @@ class TongyiProxyLLMModelAdapter(ProxyLLMModelAdapter):
return tongyi_generate_stream
class OllamaLLMModelAdapter(ProxyLLMModelAdapter):
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path == "ollama_proxyllm"
def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
from dbgpt.model.proxy.llms.ollama import OllamaLLMClient
return OllamaLLMClient
def get_generate_stream_function(self, model, model_path: str):
from dbgpt.model.proxy.llms.ollama import ollama_generate_stream
return ollama_generate_stream
class ZhipuProxyLLMModelAdapter(ProxyLLMModelAdapter):
support_system_message = False
@@ -279,6 +296,7 @@ class MoonshotProxyLLMModelAdapter(ProxyLLMModelAdapter):
register_model_adapter(OpenAIProxyLLMModelAdapter)
register_model_adapter(TongyiProxyLLMModelAdapter)
register_model_adapter(OllamaLLMModelAdapter)
register_model_adapter(ZhipuProxyLLMModelAdapter)
register_model_adapter(WenxinProxyLLMModelAdapter)
register_model_adapter(GeminiProxyLLMModelAdapter)

View File

@@ -556,7 +556,7 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi",
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama",
}
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}

View File

@@ -11,6 +11,7 @@ def __lazy_import(name):
"ZhipuLLMClient": "dbgpt.model.proxy.llms.zhipu",
"YiLLMClient": "dbgpt.model.proxy.llms.yi",
"MoonshotLLMClient": "dbgpt.model.proxy.llms.moonshot",
"OllamaLLMClient": "dbgpt.model.proxy.llms.ollama",
}
if name in module_path:
@@ -33,4 +34,5 @@ __all__ = [
"SparkLLMClient",
"YiLLMClient",
"MoonshotLLMClient",
"OllamaLLMClient",
]

View File

@@ -0,0 +1,101 @@
import logging
from concurrent.futures import Executor
from typing import Iterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
logger = logging.getLogger(__name__)
def ollama_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=4096
):
client: OllamaLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
for r in client.sync_generate_stream(request):
yield r
class OllamaLLMClient(ProxyLLMClient):
def __init__(
self,
model: Optional[str] = None,
host: Optional[str] = None,
model_alias: Optional[str] = "ollama_proxyllm",
context_length: Optional[int] = 4096,
executor: Optional[Executor] = None,
):
if not model:
model = "llama2"
if not host:
host = "http://localhost:11434"
self._model = model
self._host = host
super().__init__(
model_names=[model, model_alias],
context_length=context_length,
executor=executor,
)
@classmethod
def new_client(
cls,
model_params: ProxyModelParameters,
default_executor: Optional[Executor] = None,
) -> "OllamaLLMClient":
return cls(
model=model_params.proxyllm_backend,
host=model_params.proxy_server_url,
model_alias=model_params.model_name,
context_length=model_params.max_context_size,
executor=default_executor,
)
@property
def default_model(self) -> str:
return self._model
def sync_generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
try:
import ollama
from ollama import Client
except ImportError as e:
raise ValueError(
"Could not import python package: ollama "
"Please install ollama by command `pip install ollama"
) from e
request = self.local_covert_message(request, message_converter)
messages = request.to_common_messages()
model = request.model or self._model
client = Client(self._host)
try:
stream = client.chat(
model=model,
messages=messages,
stream=True,
)
content = ""
for chunk in stream:
content = content + chunk["message"]["content"]
yield ModelOutput(text=content, error_code=0)
except ollama.ResponseError as e:
return ModelOutput(
text=f"**Ollama Response Error, Please CheckErrorInfo.**: {e}",
error_code=-1,
)