DB-GPT/dbgpt/model/proxy/llms/ollama.py
GITHUBear 744b3e4933
feat(model): support ollama as an optional llm & embedding proxy (#1475)
Signed-off-by: shanhaikang.shk <shanhaikang.shk@oceanbase.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
2024-04-28 18:36:45 +08:00

102 lines
3.2 KiB
Python

import logging
from concurrent.futures import Executor
from typing import Iterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
logger = logging.getLogger(__name__)
def ollama_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=4096
):
client: OllamaLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
for r in client.sync_generate_stream(request):
yield r
class OllamaLLMClient(ProxyLLMClient):
def __init__(
self,
model: Optional[str] = None,
host: Optional[str] = None,
model_alias: Optional[str] = "ollama_proxyllm",
context_length: Optional[int] = 4096,
executor: Optional[Executor] = None,
):
if not model:
model = "llama2"
if not host:
host = "http://localhost:11434"
self._model = model
self._host = host
super().__init__(
model_names=[model, model_alias],
context_length=context_length,
executor=executor,
)
@classmethod
def new_client(
cls,
model_params: ProxyModelParameters,
default_executor: Optional[Executor] = None,
) -> "OllamaLLMClient":
return cls(
model=model_params.proxyllm_backend,
host=model_params.proxy_server_url,
model_alias=model_params.model_name,
context_length=model_params.max_context_size,
executor=default_executor,
)
@property
def default_model(self) -> str:
return self._model
def sync_generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
try:
import ollama
from ollama import Client
except ImportError as e:
raise ValueError(
"Could not import python package: ollama "
"Please install ollama by command `pip install ollama"
) from e
request = self.local_covert_message(request, message_converter)
messages = request.to_common_messages()
model = request.model or self._model
client = Client(self._host)
try:
stream = client.chat(
model=model,
messages=messages,
stream=True,
)
content = ""
for chunk in stream:
content = content + chunk["message"]["content"]
yield ModelOutput(text=content, error_code=0)
except ollama.ResponseError as e:
return ModelOutput(
text=f"**Ollama Response Error, Please CheckErrorInfo.**: {e}",
error_code=-1,
)