From df37c0d086dcdfa1ed0af6c9fb31cdb32c441574 Mon Sep 17 00:00:00 2001 From: Jerron Lim Date: Mon, 29 Jul 2024 02:25:58 +0800 Subject: [PATCH] partners[ollama]: Support base_url for ChatOllama (#24719) Add a class attribute `base_url` for ChatOllama to allow users to choose a different URL to connect to. Fixes #24555 --- .../ollama/langchain_ollama/chat_models.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index 775d8a51324..0b353f74ace 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -17,7 +17,6 @@ from typing import ( ) from uuid import uuid4 -import ollama from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) @@ -40,7 +39,7 @@ from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_tool -from ollama import AsyncClient, Message, Options +from ollama import AsyncClient, Client, Message, Options def _get_usage_metadata_from_generation_info( @@ -316,6 +315,9 @@ class ChatOllama(BaseChatModel): keep_alive: Optional[Union[int, str]] = None """How long the model will stay loaded into memory.""" + base_url: Optional[str] = None + """Base url the model is hosted under.""" + @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling Ollama.""" @@ -442,7 +444,7 @@ class ChatOllama(BaseChatModel): params["options"]["stop"] = stop if "tools" in kwargs: - yield await AsyncClient().chat( + yield await AsyncClient(host=self.base_url).chat( model=params["model"], messages=ollama_messages, stream=False, @@ -452,7 +454,7 @@ class ChatOllama(BaseChatModel): tools=kwargs["tools"], ) # type:ignore else: - async for part in await AsyncClient().chat( + async for part in await AsyncClient(host=self.base_url).chat( model=params["model"], messages=ollama_messages, stream=True, @@ -480,7 +482,7 @@ class ChatOllama(BaseChatModel): params["options"]["stop"] = stop if "tools" in kwargs: - yield ollama.chat( + yield Client(host=self.base_url).chat( model=params["model"], messages=ollama_messages, stream=False, @@ -490,7 +492,7 @@ class ChatOllama(BaseChatModel): tools=kwargs["tools"], ) else: - yield from ollama.chat( + yield from Client(host=self.base_url).chat( model=params["model"], messages=ollama_messages, stream=True,