partners[ollama]: Support base_url for ChatOllama (#24719)

Add a class attribute `base_url` for ChatOllama to allow users to choose
a different URL to connect to.

Fixes #24555
This commit is contained in:
Jerron Lim 2024-07-29 02:25:58 +08:00 committed by GitHub
parent 8964f8a710
commit df37c0d086
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,7 +17,6 @@ from typing import (
)
from uuid import uuid4
import ollama
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
@ -40,7 +39,7 @@ from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from ollama import AsyncClient, Message, Options
from ollama import AsyncClient, Client, Message, Options
def _get_usage_metadata_from_generation_info(
@ -316,6 +315,9 @@ class ChatOllama(BaseChatModel):
keep_alive: Optional[Union[int, str]] = None
"""How long the model will stay loaded into memory."""
base_url: Optional[str] = None
"""Base url the model is hosted under."""
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Ollama."""
@ -442,7 +444,7 @@ class ChatOllama(BaseChatModel):
params["options"]["stop"] = stop
if "tools" in kwargs:
yield await AsyncClient().chat(
yield await AsyncClient(host=self.base_url).chat(
model=params["model"],
messages=ollama_messages,
stream=False,
@ -452,7 +454,7 @@ class ChatOllama(BaseChatModel):
tools=kwargs["tools"],
) # type:ignore
else:
async for part in await AsyncClient().chat(
async for part in await AsyncClient(host=self.base_url).chat(
model=params["model"],
messages=ollama_messages,
stream=True,
@ -480,7 +482,7 @@ class ChatOllama(BaseChatModel):
params["options"]["stop"] = stop
if "tools" in kwargs:
yield ollama.chat(
yield Client(host=self.base_url).chat(
model=params["model"],
messages=ollama_messages,
stream=False,
@ -490,7 +492,7 @@ class ChatOllama(BaseChatModel):
tools=kwargs["tools"],
)
else:
yield from ollama.chat(
yield from Client(host=self.base_url).chat(
model=params["model"],
messages=ollama_messages,
stream=True,