From eaba6bf65029d4478433720d5bfe6227ce4ac349 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Wed, 6 Aug 2025 17:45:31 -0400 Subject: [PATCH] remove `stop` as named param --- libs/core/langchain_core/v1/chat_models.py | 9 +-- .../langchain_ollama/v1/chat_models/base.py | 56 ++++--------------- 2 files changed, 14 insertions(+), 51 deletions(-) diff --git a/libs/core/langchain_core/v1/chat_models.py b/libs/core/langchain_core/v1/chat_models.py index 53e130da50a..09a0ab6e1fb 100644 --- a/libs/core/langchain_core/v1/chat_models.py +++ b/libs/core/langchain_core/v1/chat_models.py @@ -710,16 +710,13 @@ class BaseChatModel(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC): def _get_invocation_params( self, - stop: Optional[list[str]] = None, **kwargs: Any, ) -> dict: params = self.dump() - params["stop"] = stop return {**params, **kwargs} def _get_ls_params( self, - stop: Optional[list[str]] = None, **kwargs: Any, ) -> LangSmithParams: """Get standard params for tracing.""" @@ -732,8 +729,6 @@ class BaseChatModel(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC): default_provider = default_provider.lower() ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="chat") - if stop: - ls_params["ls_stop"] = stop # model model = getattr(self, "model", None) or getattr(self, "model_name", None) @@ -752,8 +747,8 @@ class BaseChatModel(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC): return ls_params - def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str: - params = self._get_invocation_params(stop=stop, **kwargs) + def _get_llm_string(self, **kwargs: Any) -> str: + params = self._get_invocation_params(**kwargs) params = {**params, **kwargs} return str(sorted(params.items())) diff --git a/libs/partners/ollama/langchain_ollama/v1/chat_models/base.py b/libs/partners/ollama/langchain_ollama/v1/chat_models/base.py index a68e46e40c6..1383970fd16 100644 --- a/libs/partners/ollama/langchain_ollama/v1/chat_models/base.py +++ b/libs/partners/ollama/langchain_ollama/v1/chat_models/base.py @@ -389,9 +389,6 @@ class ChatOllama(BaseChatModel): """ - stop: Optional[list[str]] = None - """Sets the stop tokens to use.""" - tfs_z: Optional[float] = None """Tail free sampling is used to reduce the impact of less probable tokens from the output. @@ -481,24 +478,20 @@ class ChatOllama(BaseChatModel): validate_model(self._client, self.model) return self - def _get_ls_params( - self, stop: Optional[list[str]] = None, **kwargs: Any - ) -> LangSmithParams: + def _get_ls_params(self, **kwargs: Any) -> LangSmithParams: """Get standard params for tracing.""" - params = self._get_invocation_params(stop=stop, **kwargs) + params = self._get_invocation_params(**kwargs) ls_params = LangSmithParams( ls_provider="ollama", ls_model_name=self.model, ls_model_type="chat", ls_temperature=params.get("temperature", self.temperature), ) - if ls_stop := stop or params.get("stop", None) or self.stop: + if ls_stop := params.get("stop", None): ls_params["ls_stop"] = ls_stop return ls_params - def _get_invocation_params( - self, stop: Optional[list[str]] = None, **kwargs: Any - ) -> dict[str, Any]: + def _get_invocation_params(self, **kwargs: Any) -> dict[str, Any]: """Get parameters for model invocation.""" params = { "model": self.model, @@ -513,7 +506,6 @@ class ChatOllama(BaseChatModel): "repeat_penalty": self.repeat_penalty, "temperature": self.temperature, "seed": self.seed, - "stop": stop or self.stop, "tfs_z": self.tfs_z, "top_k": self.top_k, "top_p": self.top_p, @@ -531,7 +523,6 @@ class ChatOllama(BaseChatModel): def _chat_params( self, messages: list[MessageV1], - stop: Optional[list[str]] = None, *, stream: bool = True, **kwargs: Any, @@ -540,12 +531,6 @@ class ChatOllama(BaseChatModel): # Convert v1 messages to Ollama format ollama_messages = [_convert_from_v1_to_ollama_format(msg) for msg in messages] - if self.stop is not None and stop is not None: - msg = "`stop` found in both the input and default params." - raise ValueError(msg) - if self.stop is not None: - stop = self.stop - options_dict = kwargs.pop( "options", { @@ -560,7 +545,6 @@ class ChatOllama(BaseChatModel): "repeat_penalty": self.repeat_penalty, "temperature": self.temperature, "seed": self.seed, - "stop": self.stop if stop is None else stop, "tfs_z": self.tfs_z, "top_k": self.top_k, "top_p": self.top_p, @@ -586,12 +570,11 @@ class ChatOllama(BaseChatModel): def _generate_stream( self, messages: list[MessageV1], - stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[AIMessageChunk]: """Generate streaming response with native v1 chunks.""" - chat_params = self._chat_params(messages, stop, **kwargs) + chat_params = self._chat_params(messages, **kwargs) if chat_params["stream"]: for part in self._client.chat(**chat_params): @@ -635,12 +618,11 @@ class ChatOllama(BaseChatModel): async def _agenerate_stream( self, messages: list[MessageV1], - stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[AIMessageChunk]: """Generate async streaming response with native v1 chunks.""" - chat_params = self._chat_params(messages, stop, **kwargs) + chat_params = self._chat_params(messages, **kwargs) if chat_params["stream"]: async for part in await self._async_client.chat(**chat_params): @@ -684,7 +666,6 @@ class ChatOllama(BaseChatModel): def _invoke( self, messages: list[MessageV1], - stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AIMessage: @@ -692,7 +673,6 @@ class ChatOllama(BaseChatModel): Args: messages: List of v1 format messages. - stop: List of stop sequences. run_manager: Callback manager for the run. kwargs: Additional parameters. @@ -701,19 +681,16 @@ class ChatOllama(BaseChatModel): """ if self.streaming: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) + stream_iter = self._stream(messages, run_manager=run_manager, **kwargs) return generate_from_stream(stream_iter) - chat_params = self._chat_params(messages, stop, stream=False, **kwargs) + chat_params = self._chat_params(messages, stream=False, **kwargs) response = self._client.chat(**chat_params) return _convert_to_v1_from_ollama_format(response) async def _ainvoke( self, messages: list[MessageV1], - stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AIMessage: @@ -721,7 +698,6 @@ class ChatOllama(BaseChatModel): Args: messages: List of v1 format messages. - stop: List of stop sequences. run_manager: Async callback manager for the run. kwargs: Additional parameters. @@ -730,20 +706,17 @@ class ChatOllama(BaseChatModel): """ if self.streaming: - stream_iter = self._astream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) + stream_iter = self._astream(messages, run_manager=run_manager, **kwargs) return await agenerate_from_stream(stream_iter) # Non-streaming case: direct API call - chat_params = self._chat_params(messages, stop, stream=False, **kwargs) + chat_params = self._chat_params(messages, stream=False, **kwargs) response = await self._async_client.chat(**chat_params) return _convert_to_v1_from_ollama_format(response) def _stream( self, messages: list[MessageV1], - stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[AIMessageChunk]: @@ -751,7 +724,6 @@ class ChatOllama(BaseChatModel): Args: messages: List of v1 format messages. - stop: List of stop sequences. run_manager: Callback manager for the run. kwargs: Additional parameters. @@ -759,14 +731,11 @@ class ChatOllama(BaseChatModel): AI message chunks in v1 format. """ - yield from self._generate_stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) + yield from self._generate_stream(messages, run_manager=run_manager, **kwargs) async def _astream( self, messages: list[MessageV1], - stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[AIMessageChunk]: @@ -774,7 +743,6 @@ class ChatOllama(BaseChatModel): Args: messages: List of v1 format messages. - stop: List of stop sequences. run_manager: Async callback manager for the run. kwargs: Additional parameters. @@ -783,7 +751,7 @@ class ChatOllama(BaseChatModel): """ async for chunk in self._agenerate_stream( - messages, stop=stop, run_manager=run_manager, **kwargs + messages, run_manager=run_manager, **kwargs ): yield chunk