partners[ollama]: Support base_url for ChatOllama (#24719)

Add a class attribute `base_url` for ChatOllama to allow users to choose
a different URL to connect to.

Fixes #24555
This commit is contained in:
Jerron Lim 2024-07-29 02:25:58 +08:00 committed by GitHub
parent 8964f8a710
commit df37c0d086
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,7 +17,6 @@ from typing import (
) )
from uuid import uuid4 from uuid import uuid4
import ollama
from langchain_core.callbacks import ( from langchain_core.callbacks import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
@ -40,7 +39,7 @@ from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import convert_to_openai_tool
from ollama import AsyncClient, Message, Options from ollama import AsyncClient, Client, Message, Options
def _get_usage_metadata_from_generation_info( def _get_usage_metadata_from_generation_info(
@ -316,6 +315,9 @@ class ChatOllama(BaseChatModel):
keep_alive: Optional[Union[int, str]] = None keep_alive: Optional[Union[int, str]] = None
"""How long the model will stay loaded into memory.""" """How long the model will stay loaded into memory."""
base_url: Optional[str] = None
"""Base url the model is hosted under."""
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Ollama.""" """Get the default parameters for calling Ollama."""
@ -442,7 +444,7 @@ class ChatOllama(BaseChatModel):
params["options"]["stop"] = stop params["options"]["stop"] = stop
if "tools" in kwargs: if "tools" in kwargs:
yield await AsyncClient().chat( yield await AsyncClient(host=self.base_url).chat(
model=params["model"], model=params["model"],
messages=ollama_messages, messages=ollama_messages,
stream=False, stream=False,
@ -452,7 +454,7 @@ class ChatOllama(BaseChatModel):
tools=kwargs["tools"], tools=kwargs["tools"],
) # type:ignore ) # type:ignore
else: else:
async for part in await AsyncClient().chat( async for part in await AsyncClient(host=self.base_url).chat(
model=params["model"], model=params["model"],
messages=ollama_messages, messages=ollama_messages,
stream=True, stream=True,
@ -480,7 +482,7 @@ class ChatOllama(BaseChatModel):
params["options"]["stop"] = stop params["options"]["stop"] = stop
if "tools" in kwargs: if "tools" in kwargs:
yield ollama.chat( yield Client(host=self.base_url).chat(
model=params["model"], model=params["model"],
messages=ollama_messages, messages=ollama_messages,
stream=False, stream=False,
@ -490,7 +492,7 @@ class ChatOllama(BaseChatModel):
tools=kwargs["tools"], tools=kwargs["tools"],
) )
else: else:
yield from ollama.chat( yield from Client(host=self.base_url).chat(
model=params["model"], model=params["model"],
messages=ollama_messages, messages=ollama_messages,
stream=True, stream=True,