diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index 775d8a51324..0b353f74ace 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -17,7 +17,6 @@ from typing import ( ) from uuid import uuid4 -import ollama from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) @@ -40,7 +39,7 @@ from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_tool -from ollama import AsyncClient, Message, Options +from ollama import AsyncClient, Client, Message, Options def _get_usage_metadata_from_generation_info( @@ -316,6 +315,9 @@ class ChatOllama(BaseChatModel): keep_alive: Optional[Union[int, str]] = None """How long the model will stay loaded into memory.""" + base_url: Optional[str] = None + """Base url the model is hosted under.""" + @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling Ollama.""" @@ -442,7 +444,7 @@ class ChatOllama(BaseChatModel): params["options"]["stop"] = stop if "tools" in kwargs: - yield await AsyncClient().chat( + yield await AsyncClient(host=self.base_url).chat( model=params["model"], messages=ollama_messages, stream=False, @@ -452,7 +454,7 @@ class ChatOllama(BaseChatModel): tools=kwargs["tools"], ) # type:ignore else: - async for part in await AsyncClient().chat( + async for part in await AsyncClient(host=self.base_url).chat( model=params["model"], messages=ollama_messages, stream=True, @@ -480,7 +482,7 @@ class ChatOllama(BaseChatModel): params["options"]["stop"] = stop if "tools" in kwargs: - yield ollama.chat( + yield Client(host=self.base_url).chat( model=params["model"], messages=ollama_messages, stream=False, @@ -490,7 +492,7 @@ class ChatOllama(BaseChatModel): tools=kwargs["tools"], ) else: - yield from ollama.chat( + yield from Client(host=self.base_url).chat( model=params["model"], messages=ollama_messages, stream=True,