remove stop as named param

This commit is contained in:
Mason Daugherty 2025-08-06 17:45:31 -04:00
parent 1c5d51965e
commit eaba6bf650
No known key found for this signature in database
2 changed files with 14 additions and 51 deletions

View File

@ -710,16 +710,13 @@ class BaseChatModel(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC):
def _get_invocation_params(
self,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> dict:
params = self.dump()
params["stop"] = stop
return {**params, **kwargs}
def _get_ls_params(
self,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> LangSmithParams:
"""Get standard params for tracing."""
@ -732,8 +729,6 @@ class BaseChatModel(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC):
default_provider = default_provider.lower()
ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="chat")
if stop:
ls_params["ls_stop"] = stop
# model
model = getattr(self, "model", None) or getattr(self, "model_name", None)
@ -752,8 +747,8 @@ class BaseChatModel(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC):
return ls_params
def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str:
params = self._get_invocation_params(stop=stop, **kwargs)
def _get_llm_string(self, **kwargs: Any) -> str:
params = self._get_invocation_params(**kwargs)
params = {**params, **kwargs}
return str(sorted(params.items()))

View File

@ -389,9 +389,6 @@ class ChatOllama(BaseChatModel):
"""
stop: Optional[list[str]] = None
"""Sets the stop tokens to use."""
tfs_z: Optional[float] = None
"""Tail free sampling is used to reduce the impact of less probable tokens from the output.
@ -481,24 +478,20 @@ class ChatOllama(BaseChatModel):
validate_model(self._client, self.model)
return self
def _get_ls_params(
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams:
def _get_ls_params(self, **kwargs: Any) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
params = self._get_invocation_params(**kwargs)
ls_params = LangSmithParams(
ls_provider="ollama",
ls_model_name=self.model,
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)
if ls_stop := stop or params.get("stop", None) or self.stop:
if ls_stop := params.get("stop", None):
ls_params["ls_stop"] = ls_stop
return ls_params
def _get_invocation_params(
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> dict[str, Any]:
def _get_invocation_params(self, **kwargs: Any) -> dict[str, Any]:
"""Get parameters for model invocation."""
params = {
"model": self.model,
@ -513,7 +506,6 @@ class ChatOllama(BaseChatModel):
"repeat_penalty": self.repeat_penalty,
"temperature": self.temperature,
"seed": self.seed,
"stop": stop or self.stop,
"tfs_z": self.tfs_z,
"top_k": self.top_k,
"top_p": self.top_p,
@ -531,7 +523,6 @@ class ChatOllama(BaseChatModel):
def _chat_params(
self,
messages: list[MessageV1],
stop: Optional[list[str]] = None,
*,
stream: bool = True,
**kwargs: Any,
@ -540,12 +531,6 @@ class ChatOllama(BaseChatModel):
# Convert v1 messages to Ollama format
ollama_messages = [_convert_from_v1_to_ollama_format(msg) for msg in messages]
if self.stop is not None and stop is not None:
msg = "`stop` found in both the input and default params."
raise ValueError(msg)
if self.stop is not None:
stop = self.stop
options_dict = kwargs.pop(
"options",
{
@ -560,7 +545,6 @@ class ChatOllama(BaseChatModel):
"repeat_penalty": self.repeat_penalty,
"temperature": self.temperature,
"seed": self.seed,
"stop": self.stop if stop is None else stop,
"tfs_z": self.tfs_z,
"top_k": self.top_k,
"top_p": self.top_p,
@ -586,12 +570,11 @@ class ChatOllama(BaseChatModel):
def _generate_stream(
self,
messages: list[MessageV1],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[AIMessageChunk]:
"""Generate streaming response with native v1 chunks."""
chat_params = self._chat_params(messages, stop, **kwargs)
chat_params = self._chat_params(messages, **kwargs)
if chat_params["stream"]:
for part in self._client.chat(**chat_params):
@ -635,12 +618,11 @@ class ChatOllama(BaseChatModel):
async def _agenerate_stream(
self,
messages: list[MessageV1],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[AIMessageChunk]:
"""Generate async streaming response with native v1 chunks."""
chat_params = self._chat_params(messages, stop, **kwargs)
chat_params = self._chat_params(messages, **kwargs)
if chat_params["stream"]:
async for part in await self._async_client.chat(**chat_params):
@ -684,7 +666,6 @@ class ChatOllama(BaseChatModel):
def _invoke(
self,
messages: list[MessageV1],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AIMessage:
@ -692,7 +673,6 @@ class ChatOllama(BaseChatModel):
Args:
messages: List of v1 format messages.
stop: List of stop sequences.
run_manager: Callback manager for the run.
kwargs: Additional parameters.
@ -701,19 +681,16 @@ class ChatOllama(BaseChatModel):
"""
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
stream_iter = self._stream(messages, run_manager=run_manager, **kwargs)
return generate_from_stream(stream_iter)
chat_params = self._chat_params(messages, stop, stream=False, **kwargs)
chat_params = self._chat_params(messages, stream=False, **kwargs)
response = self._client.chat(**chat_params)
return _convert_to_v1_from_ollama_format(response)
async def _ainvoke(
self,
messages: list[MessageV1],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AIMessage:
@ -721,7 +698,6 @@ class ChatOllama(BaseChatModel):
Args:
messages: List of v1 format messages.
stop: List of stop sequences.
run_manager: Async callback manager for the run.
kwargs: Additional parameters.
@ -730,20 +706,17 @@ class ChatOllama(BaseChatModel):
"""
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
stream_iter = self._astream(messages, run_manager=run_manager, **kwargs)
return await agenerate_from_stream(stream_iter)
# Non-streaming case: direct API call
chat_params = self._chat_params(messages, stop, stream=False, **kwargs)
chat_params = self._chat_params(messages, stream=False, **kwargs)
response = await self._async_client.chat(**chat_params)
return _convert_to_v1_from_ollama_format(response)
def _stream(
self,
messages: list[MessageV1],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[AIMessageChunk]:
@ -751,7 +724,6 @@ class ChatOllama(BaseChatModel):
Args:
messages: List of v1 format messages.
stop: List of stop sequences.
run_manager: Callback manager for the run.
kwargs: Additional parameters.
@ -759,14 +731,11 @@ class ChatOllama(BaseChatModel):
AI message chunks in v1 format.
"""
yield from self._generate_stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
yield from self._generate_stream(messages, run_manager=run_manager, **kwargs)
async def _astream(
self,
messages: list[MessageV1],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[AIMessageChunk]:
@ -774,7 +743,6 @@ class ChatOllama(BaseChatModel):
Args:
messages: List of v1 format messages.
stop: List of stop sequences.
run_manager: Async callback manager for the run.
kwargs: Additional parameters.
@ -783,7 +751,7 @@ class ChatOllama(BaseChatModel):
"""
async for chunk in self._agenerate_stream(
messages, stop=stop, run_manager=run_manager, **kwargs
messages, run_manager=run_manager, **kwargs
):
yield chunk