feat(model): support ollama as an optional llm & embedding proxy (#1475)

Signed-off-by: shanhaikang.shk <shanhaikang.shk@oceanbase.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
GITHUBear
2024-04-28 18:36:45 +08:00
committed by GitHub
parent 0f8188b152
commit 744b3e4933
10 changed files with 231 additions and 1 deletions

View File

@@ -0,0 +1,101 @@
import logging
from concurrent.futures import Executor
from typing import Iterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
logger = logging.getLogger(__name__)
def ollama_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=4096
):
client: OllamaLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
for r in client.sync_generate_stream(request):
yield r
class OllamaLLMClient(ProxyLLMClient):
def __init__(
self,
model: Optional[str] = None,
host: Optional[str] = None,
model_alias: Optional[str] = "ollama_proxyllm",
context_length: Optional[int] = 4096,
executor: Optional[Executor] = None,
):
if not model:
model = "llama2"
if not host:
host = "http://localhost:11434"
self._model = model
self._host = host
super().__init__(
model_names=[model, model_alias],
context_length=context_length,
executor=executor,
)
@classmethod
def new_client(
cls,
model_params: ProxyModelParameters,
default_executor: Optional[Executor] = None,
) -> "OllamaLLMClient":
return cls(
model=model_params.proxyllm_backend,
host=model_params.proxy_server_url,
model_alias=model_params.model_name,
context_length=model_params.max_context_size,
executor=default_executor,
)
@property
def default_model(self) -> str:
return self._model
def sync_generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
try:
import ollama
from ollama import Client
except ImportError as e:
raise ValueError(
"Could not import python package: ollama "
"Please install ollama by command `pip install ollama"
) from e
request = self.local_covert_message(request, message_converter)
messages = request.to_common_messages()
model = request.model or self._model
client = Client(self._host)
try:
stream = client.chat(
model=model,
messages=messages,
stream=True,
)
content = ""
for chunk in stream:
content = content + chunk["message"]["content"]
yield ModelOutput(text=content, error_code=0)
except ollama.ResponseError as e:
return ModelOutput(
text=f"**Ollama Response Error, Please CheckErrorInfo.**: {e}",
error_code=-1,
)