mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 07:07:34 +00:00
remove stop
as named param
This commit is contained in:
parent
1c5d51965e
commit
eaba6bf650
@ -710,16 +710,13 @@ class BaseChatModel(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC):
|
||||
|
||||
def _get_invocation_params(
|
||||
self,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
params = self.dump()
|
||||
params["stop"] = stop
|
||||
return {**params, **kwargs}
|
||||
|
||||
def _get_ls_params(
|
||||
self,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
@ -732,8 +729,6 @@ class BaseChatModel(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC):
|
||||
default_provider = default_provider.lower()
|
||||
|
||||
ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="chat")
|
||||
if stop:
|
||||
ls_params["ls_stop"] = stop
|
||||
|
||||
# model
|
||||
model = getattr(self, "model", None) or getattr(self, "model_name", None)
|
||||
@ -752,8 +747,8 @@ class BaseChatModel(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC):
|
||||
|
||||
return ls_params
|
||||
|
||||
def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str:
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
def _get_llm_string(self, **kwargs: Any) -> str:
|
||||
params = self._get_invocation_params(**kwargs)
|
||||
params = {**params, **kwargs}
|
||||
return str(sorted(params.items()))
|
||||
|
||||
|
@ -389,9 +389,6 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
"""
|
||||
|
||||
stop: Optional[list[str]] = None
|
||||
"""Sets the stop tokens to use."""
|
||||
|
||||
tfs_z: Optional[float] = None
|
||||
"""Tail free sampling is used to reduce the impact of less probable tokens from the output.
|
||||
|
||||
@ -481,24 +478,20 @@ class ChatOllama(BaseChatModel):
|
||||
validate_model(self._client, self.model)
|
||||
return self
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
def _get_ls_params(self, **kwargs: Any) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
params = self._get_invocation_params(**kwargs)
|
||||
ls_params = LangSmithParams(
|
||||
ls_provider="ollama",
|
||||
ls_model_name=self.model,
|
||||
ls_model_type="chat",
|
||||
ls_temperature=params.get("temperature", self.temperature),
|
||||
)
|
||||
if ls_stop := stop or params.get("stop", None) or self.stop:
|
||||
if ls_stop := params.get("stop", None):
|
||||
ls_params["ls_stop"] = ls_stop
|
||||
return ls_params
|
||||
|
||||
def _get_invocation_params(
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
def _get_invocation_params(self, **kwargs: Any) -> dict[str, Any]:
|
||||
"""Get parameters for model invocation."""
|
||||
params = {
|
||||
"model": self.model,
|
||||
@ -513,7 +506,6 @@ class ChatOllama(BaseChatModel):
|
||||
"repeat_penalty": self.repeat_penalty,
|
||||
"temperature": self.temperature,
|
||||
"seed": self.seed,
|
||||
"stop": stop or self.stop,
|
||||
"tfs_z": self.tfs_z,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
@ -531,7 +523,6 @@ class ChatOllama(BaseChatModel):
|
||||
def _chat_params(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
stop: Optional[list[str]] = None,
|
||||
*,
|
||||
stream: bool = True,
|
||||
**kwargs: Any,
|
||||
@ -540,12 +531,6 @@ class ChatOllama(BaseChatModel):
|
||||
# Convert v1 messages to Ollama format
|
||||
ollama_messages = [_convert_from_v1_to_ollama_format(msg) for msg in messages]
|
||||
|
||||
if self.stop is not None and stop is not None:
|
||||
msg = "`stop` found in both the input and default params."
|
||||
raise ValueError(msg)
|
||||
if self.stop is not None:
|
||||
stop = self.stop
|
||||
|
||||
options_dict = kwargs.pop(
|
||||
"options",
|
||||
{
|
||||
@ -560,7 +545,6 @@ class ChatOllama(BaseChatModel):
|
||||
"repeat_penalty": self.repeat_penalty,
|
||||
"temperature": self.temperature,
|
||||
"seed": self.seed,
|
||||
"stop": self.stop if stop is None else stop,
|
||||
"tfs_z": self.tfs_z,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
@ -586,12 +570,11 @@ class ChatOllama(BaseChatModel):
|
||||
def _generate_stream(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[AIMessageChunk]:
|
||||
"""Generate streaming response with native v1 chunks."""
|
||||
chat_params = self._chat_params(messages, stop, **kwargs)
|
||||
chat_params = self._chat_params(messages, **kwargs)
|
||||
|
||||
if chat_params["stream"]:
|
||||
for part in self._client.chat(**chat_params):
|
||||
@ -635,12 +618,11 @@ class ChatOllama(BaseChatModel):
|
||||
async def _agenerate_stream(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[AIMessageChunk]:
|
||||
"""Generate async streaming response with native v1 chunks."""
|
||||
chat_params = self._chat_params(messages, stop, **kwargs)
|
||||
chat_params = self._chat_params(messages, **kwargs)
|
||||
|
||||
if chat_params["stream"]:
|
||||
async for part in await self._async_client.chat(**chat_params):
|
||||
@ -684,7 +666,6 @@ class ChatOllama(BaseChatModel):
|
||||
def _invoke(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AIMessage:
|
||||
@ -692,7 +673,6 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
Args:
|
||||
messages: List of v1 format messages.
|
||||
stop: List of stop sequences.
|
||||
run_manager: Callback manager for the run.
|
||||
kwargs: Additional parameters.
|
||||
|
||||
@ -701,19 +681,16 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
"""
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
stream_iter = self._stream(messages, run_manager=run_manager, **kwargs)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
chat_params = self._chat_params(messages, stop, stream=False, **kwargs)
|
||||
chat_params = self._chat_params(messages, stream=False, **kwargs)
|
||||
response = self._client.chat(**chat_params)
|
||||
return _convert_to_v1_from_ollama_format(response)
|
||||
|
||||
async def _ainvoke(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AIMessage:
|
||||
@ -721,7 +698,6 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
Args:
|
||||
messages: List of v1 format messages.
|
||||
stop: List of stop sequences.
|
||||
run_manager: Async callback manager for the run.
|
||||
kwargs: Additional parameters.
|
||||
|
||||
@ -730,20 +706,17 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
"""
|
||||
if self.streaming:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
stream_iter = self._astream(messages, run_manager=run_manager, **kwargs)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
# Non-streaming case: direct API call
|
||||
chat_params = self._chat_params(messages, stop, stream=False, **kwargs)
|
||||
chat_params = self._chat_params(messages, stream=False, **kwargs)
|
||||
response = await self._async_client.chat(**chat_params)
|
||||
return _convert_to_v1_from_ollama_format(response)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[AIMessageChunk]:
|
||||
@ -751,7 +724,6 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
Args:
|
||||
messages: List of v1 format messages.
|
||||
stop: List of stop sequences.
|
||||
run_manager: Callback manager for the run.
|
||||
kwargs: Additional parameters.
|
||||
|
||||
@ -759,14 +731,11 @@ class ChatOllama(BaseChatModel):
|
||||
AI message chunks in v1 format.
|
||||
|
||||
"""
|
||||
yield from self._generate_stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
yield from self._generate_stream(messages, run_manager=run_manager, **kwargs)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[AIMessageChunk]:
|
||||
@ -774,7 +743,6 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
Args:
|
||||
messages: List of v1 format messages.
|
||||
stop: List of stop sequences.
|
||||
run_manager: Async callback manager for the run.
|
||||
kwargs: Additional parameters.
|
||||
|
||||
@ -783,7 +751,7 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
"""
|
||||
async for chunk in self._agenerate_stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
messages, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user