diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index d6ece0d66a8..330a3df6c0a 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -341,13 +341,22 @@ class ChatOllama(BaseChatModel): The async client to use for making requests. """ - @property - def _default_params(self) -> Dict[str, Any]: - """Get the default parameters for calling Ollama.""" - return { - "model": self.model, - "format": self.format, - "options": { + def _chat_params( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + ollama_messages = self._convert_messages_to_ollama_messages(messages) + + if self.stop is not None and stop is not None: + raise ValueError("`stop` found in both the input and default params.") + elif self.stop is not None: + stop = self.stop + + options_dict = kwargs.pop( + "options", + { "mirostat": self.mirostat, "mirostat_eta": self.mirostat_eta, "mirostat_tau": self.mirostat_tau, @@ -359,14 +368,31 @@ class ChatOllama(BaseChatModel): "repeat_penalty": self.repeat_penalty, "temperature": self.temperature, "seed": self.seed, - "stop": self.stop, + "stop": self.stop if stop is None else stop, "tfs_z": self.tfs_z, "top_k": self.top_k, "top_p": self.top_p, }, - "keep_alive": self.keep_alive, + ) + + tools = kwargs.get("tools") + default_stream = not bool(tools) + + params = { + "messages": ollama_messages, + "stream": kwargs.pop("stream", default_stream), + "model": kwargs.pop("model", self.model), + "format": kwargs.pop("format", self.format), + "options": Options(**options_dict), + "keep_alive": kwargs.pop("keep_alive", self.keep_alive), + **kwargs, } + if tools: + params["tools"] = tools + + return params + @model_validator(mode="after") def _set_clients(self) -> Self: """Set clients to use for ollama.""" @@ -464,34 +490,9 @@ class ChatOllama(BaseChatModel): stop: Optional[List[str]] = None, **kwargs: Any, ) -> AsyncIterator[Union[Mapping[str, Any], str]]: - ollama_messages = self._convert_messages_to_ollama_messages(messages) + chat_params = self._chat_params(messages, stop, **kwargs) - stop = stop if stop is not None else self.stop - - params = self._default_params - - for key in self._default_params: - if key in kwargs: - params[key] = kwargs[key] - - params["options"]["stop"] = stop - - tools = kwargs.get("tools", None) - stream = tools is None or len(tools) == 0 - - chat_params = { - "model": params["model"], - "messages": ollama_messages, - "stream": stream, - "options": Options(**params["options"]), - "keep_alive": params["keep_alive"], - "format": params["format"], - } - - if tools is not None: - chat_params["tools"] = tools - - if stream: + if chat_params["stream"]: async for part in await self._async_client.chat(**chat_params): yield part else: @@ -503,34 +504,9 @@ class ChatOllama(BaseChatModel): stop: Optional[List[str]] = None, **kwargs: Any, ) -> Iterator[Union[Mapping[str, Any], str]]: - ollama_messages = self._convert_messages_to_ollama_messages(messages) + chat_params = self._chat_params(messages, stop, **kwargs) - stop = stop if stop is not None else self.stop - - params = self._default_params - - for key in self._default_params: - if key in kwargs: - params[key] = kwargs[key] - - params["options"]["stop"] = stop - - tools = kwargs.get("tools", None) - stream = tools is None or len(tools) == 0 - - chat_params = { - "model": params["model"], - "messages": ollama_messages, - "stream": stream, - "options": Options(**params["options"]), - "keep_alive": params["keep_alive"], - "format": params["format"], - } - - if tools is not None: - chat_params["tools"] = tools - - if stream: + if chat_params["stream"]: yield from self._client.chat(**chat_params) else: yield self._client.chat(**chat_params) diff --git a/libs/partners/ollama/langchain_ollama/llms.py b/libs/partners/ollama/langchain_ollama/llms.py index 783d20104ef..264662c3a2a 100644 --- a/libs/partners/ollama/langchain_ollama/llms.py +++ b/libs/partners/ollama/langchain_ollama/llms.py @@ -126,13 +126,20 @@ class OllamaLLM(BaseLLM): The async client to use for making requests. """ - @property - def _default_params(self) -> Dict[str, Any]: - """Get the default parameters for calling Ollama.""" - return { - "model": self.model, - "format": self.format, - "options": { + def _generate_params( + self, + prompt: str, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + if self.stop is not None and stop is not None: + raise ValueError("`stop` found in both the input and default params.") + elif self.stop is not None: + stop = self.stop + + options_dict = kwargs.pop( + "options", + { "mirostat": self.mirostat, "mirostat_eta": self.mirostat_eta, "mirostat_tau": self.mirostat_tau, @@ -143,14 +150,25 @@ class OllamaLLM(BaseLLM): "repeat_last_n": self.repeat_last_n, "repeat_penalty": self.repeat_penalty, "temperature": self.temperature, - "stop": self.stop, + "stop": self.stop if stop is None else stop, "tfs_z": self.tfs_z, "top_k": self.top_k, "top_p": self.top_p, }, - "keep_alive": self.keep_alive, + ) + + params = { + "prompt": prompt, + "stream": kwargs.pop("stream", True), + "model": kwargs.pop("model", self.model), + "format": kwargs.pop("format", self.format), + "options": Options(**options_dict), + "keep_alive": kwargs.pop("keep_alive", self.keep_alive), + **kwargs, } + return params + @property def _llm_type(self) -> str: """Return type of LLM.""" @@ -179,25 +197,8 @@ class OllamaLLM(BaseLLM): stop: Optional[List[str]] = None, **kwargs: Any, ) -> AsyncIterator[Union[Mapping[str, Any], str]]: - if self.stop is not None and stop is not None: - raise ValueError("`stop` found in both the input and default params.") - elif self.stop is not None: - stop = self.stop - - params = self._default_params - - for key in self._default_params: - if key in kwargs: - params[key] = kwargs[key] - - params["options"]["stop"] = stop async for part in await self._async_client.generate( - model=params["model"], - prompt=prompt, - stream=True, - options=Options(**params["options"]), - keep_alive=params["keep_alive"], - format=params["format"], + **self._generate_params(prompt, stop=stop, **kwargs) ): # type: ignore yield part @@ -207,25 +208,8 @@ class OllamaLLM(BaseLLM): stop: Optional[List[str]] = None, **kwargs: Any, ) -> Iterator[Union[Mapping[str, Any], str]]: - if self.stop is not None and stop is not None: - raise ValueError("`stop` found in both the input and default params.") - elif self.stop is not None: - stop = self.stop - - params = self._default_params - - for key in self._default_params: - if key in kwargs: - params[key] = kwargs[key] - - params["options"]["stop"] = stop yield from self._client.generate( - model=params["model"], - prompt=prompt, - stream=True, - options=Options(**params["options"]), - keep_alive=params["keep_alive"], - format=params["format"], + **self._generate_params(prompt, stop=stop, **kwargs) ) async def _astream_with_aggregation(