remove stop as named param

This commit is contained in:
Mason Daugherty 2025-08-06 17:45:31 -04:00
parent 1c5d51965e
commit eaba6bf650
No known key found for this signature in database
2 changed files with 14 additions and 51 deletions

View File

@ -710,16 +710,13 @@ class BaseChatModel(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC):
def _get_invocation_params( def _get_invocation_params(
self, self,
stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> dict: ) -> dict:
params = self.dump() params = self.dump()
params["stop"] = stop
return {**params, **kwargs} return {**params, **kwargs}
def _get_ls_params( def _get_ls_params(
self, self,
stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> LangSmithParams: ) -> LangSmithParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
@ -732,8 +729,6 @@ class BaseChatModel(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC):
default_provider = default_provider.lower() default_provider = default_provider.lower()
ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="chat") ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="chat")
if stop:
ls_params["ls_stop"] = stop
# model # model
model = getattr(self, "model", None) or getattr(self, "model_name", None) model = getattr(self, "model", None) or getattr(self, "model_name", None)
@ -752,8 +747,8 @@ class BaseChatModel(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC):
return ls_params return ls_params
def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str: def _get_llm_string(self, **kwargs: Any) -> str:
params = self._get_invocation_params(stop=stop, **kwargs) params = self._get_invocation_params(**kwargs)
params = {**params, **kwargs} params = {**params, **kwargs}
return str(sorted(params.items())) return str(sorted(params.items()))

View File

@ -389,9 +389,6 @@ class ChatOllama(BaseChatModel):
""" """
stop: Optional[list[str]] = None
"""Sets the stop tokens to use."""
tfs_z: Optional[float] = None tfs_z: Optional[float] = None
"""Tail free sampling is used to reduce the impact of less probable tokens from the output. """Tail free sampling is used to reduce the impact of less probable tokens from the output.
@ -481,24 +478,20 @@ class ChatOllama(BaseChatModel):
validate_model(self._client, self.model) validate_model(self._client, self.model)
return self return self
def _get_ls_params( def _get_ls_params(self, **kwargs: Any) -> LangSmithParams:
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs) params = self._get_invocation_params(**kwargs)
ls_params = LangSmithParams( ls_params = LangSmithParams(
ls_provider="ollama", ls_provider="ollama",
ls_model_name=self.model, ls_model_name=self.model,
ls_model_type="chat", ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature), ls_temperature=params.get("temperature", self.temperature),
) )
if ls_stop := stop or params.get("stop", None) or self.stop: if ls_stop := params.get("stop", None):
ls_params["ls_stop"] = ls_stop ls_params["ls_stop"] = ls_stop
return ls_params return ls_params
def _get_invocation_params( def _get_invocation_params(self, **kwargs: Any) -> dict[str, Any]:
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> dict[str, Any]:
"""Get parameters for model invocation.""" """Get parameters for model invocation."""
params = { params = {
"model": self.model, "model": self.model,
@ -513,7 +506,6 @@ class ChatOllama(BaseChatModel):
"repeat_penalty": self.repeat_penalty, "repeat_penalty": self.repeat_penalty,
"temperature": self.temperature, "temperature": self.temperature,
"seed": self.seed, "seed": self.seed,
"stop": stop or self.stop,
"tfs_z": self.tfs_z, "tfs_z": self.tfs_z,
"top_k": self.top_k, "top_k": self.top_k,
"top_p": self.top_p, "top_p": self.top_p,
@ -531,7 +523,6 @@ class ChatOllama(BaseChatModel):
def _chat_params( def _chat_params(
self, self,
messages: list[MessageV1], messages: list[MessageV1],
stop: Optional[list[str]] = None,
*, *,
stream: bool = True, stream: bool = True,
**kwargs: Any, **kwargs: Any,
@ -540,12 +531,6 @@ class ChatOllama(BaseChatModel):
# Convert v1 messages to Ollama format # Convert v1 messages to Ollama format
ollama_messages = [_convert_from_v1_to_ollama_format(msg) for msg in messages] ollama_messages = [_convert_from_v1_to_ollama_format(msg) for msg in messages]
if self.stop is not None and stop is not None:
msg = "`stop` found in both the input and default params."
raise ValueError(msg)
if self.stop is not None:
stop = self.stop
options_dict = kwargs.pop( options_dict = kwargs.pop(
"options", "options",
{ {
@ -560,7 +545,6 @@ class ChatOllama(BaseChatModel):
"repeat_penalty": self.repeat_penalty, "repeat_penalty": self.repeat_penalty,
"temperature": self.temperature, "temperature": self.temperature,
"seed": self.seed, "seed": self.seed,
"stop": self.stop if stop is None else stop,
"tfs_z": self.tfs_z, "tfs_z": self.tfs_z,
"top_k": self.top_k, "top_k": self.top_k,
"top_p": self.top_p, "top_p": self.top_p,
@ -586,12 +570,11 @@ class ChatOllama(BaseChatModel):
def _generate_stream( def _generate_stream(
self, self,
messages: list[MessageV1], messages: list[MessageV1],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[AIMessageChunk]: ) -> Iterator[AIMessageChunk]:
"""Generate streaming response with native v1 chunks.""" """Generate streaming response with native v1 chunks."""
chat_params = self._chat_params(messages, stop, **kwargs) chat_params = self._chat_params(messages, **kwargs)
if chat_params["stream"]: if chat_params["stream"]:
for part in self._client.chat(**chat_params): for part in self._client.chat(**chat_params):
@ -635,12 +618,11 @@ class ChatOllama(BaseChatModel):
async def _agenerate_stream( async def _agenerate_stream(
self, self,
messages: list[MessageV1], messages: list[MessageV1],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[AIMessageChunk]: ) -> AsyncIterator[AIMessageChunk]:
"""Generate async streaming response with native v1 chunks.""" """Generate async streaming response with native v1 chunks."""
chat_params = self._chat_params(messages, stop, **kwargs) chat_params = self._chat_params(messages, **kwargs)
if chat_params["stream"]: if chat_params["stream"]:
async for part in await self._async_client.chat(**chat_params): async for part in await self._async_client.chat(**chat_params):
@ -684,7 +666,6 @@ class ChatOllama(BaseChatModel):
def _invoke( def _invoke(
self, self,
messages: list[MessageV1], messages: list[MessageV1],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AIMessage: ) -> AIMessage:
@ -692,7 +673,6 @@ class ChatOllama(BaseChatModel):
Args: Args:
messages: List of v1 format messages. messages: List of v1 format messages.
stop: List of stop sequences.
run_manager: Callback manager for the run. run_manager: Callback manager for the run.
kwargs: Additional parameters. kwargs: Additional parameters.
@ -701,19 +681,16 @@ class ChatOllama(BaseChatModel):
""" """
if self.streaming: if self.streaming:
stream_iter = self._stream( stream_iter = self._stream(messages, run_manager=run_manager, **kwargs)
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter) return generate_from_stream(stream_iter)
chat_params = self._chat_params(messages, stop, stream=False, **kwargs) chat_params = self._chat_params(messages, stream=False, **kwargs)
response = self._client.chat(**chat_params) response = self._client.chat(**chat_params)
return _convert_to_v1_from_ollama_format(response) return _convert_to_v1_from_ollama_format(response)
async def _ainvoke( async def _ainvoke(
self, self,
messages: list[MessageV1], messages: list[MessageV1],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AIMessage: ) -> AIMessage:
@ -721,7 +698,6 @@ class ChatOllama(BaseChatModel):
Args: Args:
messages: List of v1 format messages. messages: List of v1 format messages.
stop: List of stop sequences.
run_manager: Async callback manager for the run. run_manager: Async callback manager for the run.
kwargs: Additional parameters. kwargs: Additional parameters.
@ -730,20 +706,17 @@ class ChatOllama(BaseChatModel):
""" """
if self.streaming: if self.streaming:
stream_iter = self._astream( stream_iter = self._astream(messages, run_manager=run_manager, **kwargs)
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter) return await agenerate_from_stream(stream_iter)
# Non-streaming case: direct API call # Non-streaming case: direct API call
chat_params = self._chat_params(messages, stop, stream=False, **kwargs) chat_params = self._chat_params(messages, stream=False, **kwargs)
response = await self._async_client.chat(**chat_params) response = await self._async_client.chat(**chat_params)
return _convert_to_v1_from_ollama_format(response) return _convert_to_v1_from_ollama_format(response)
def _stream( def _stream(
self, self,
messages: list[MessageV1], messages: list[MessageV1],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[AIMessageChunk]: ) -> Iterator[AIMessageChunk]:
@ -751,7 +724,6 @@ class ChatOllama(BaseChatModel):
Args: Args:
messages: List of v1 format messages. messages: List of v1 format messages.
stop: List of stop sequences.
run_manager: Callback manager for the run. run_manager: Callback manager for the run.
kwargs: Additional parameters. kwargs: Additional parameters.
@ -759,14 +731,11 @@ class ChatOllama(BaseChatModel):
AI message chunks in v1 format. AI message chunks in v1 format.
""" """
yield from self._generate_stream( yield from self._generate_stream(messages, run_manager=run_manager, **kwargs)
messages, stop=stop, run_manager=run_manager, **kwargs
)
async def _astream( async def _astream(
self, self,
messages: list[MessageV1], messages: list[MessageV1],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[AIMessageChunk]: ) -> AsyncIterator[AIMessageChunk]:
@ -774,7 +743,6 @@ class ChatOllama(BaseChatModel):
Args: Args:
messages: List of v1 format messages. messages: List of v1 format messages.
stop: List of stop sequences.
run_manager: Async callback manager for the run. run_manager: Async callback manager for the run.
kwargs: Additional parameters. kwargs: Additional parameters.
@ -783,7 +751,7 @@ class ChatOllama(BaseChatModel):
""" """
async for chunk in self._agenerate_stream( async for chunk in self._agenerate_stream(
messages, stop=stop, run_manager=run_manager, **kwargs messages, run_manager=run_manager, **kwargs
): ):
yield chunk yield chunk