mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-16 16:11:02 +00:00
remove stop
as named param
This commit is contained in:
parent
1c5d51965e
commit
eaba6bf650
@ -710,16 +710,13 @@ class BaseChatModel(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC):
|
|||||||
|
|
||||||
def _get_invocation_params(
|
def _get_invocation_params(
|
||||||
self,
|
self,
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
params = self.dump()
|
params = self.dump()
|
||||||
params["stop"] = stop
|
|
||||||
return {**params, **kwargs}
|
return {**params, **kwargs}
|
||||||
|
|
||||||
def _get_ls_params(
|
def _get_ls_params(
|
||||||
self,
|
self,
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LangSmithParams:
|
) -> LangSmithParams:
|
||||||
"""Get standard params for tracing."""
|
"""Get standard params for tracing."""
|
||||||
@ -732,8 +729,6 @@ class BaseChatModel(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC):
|
|||||||
default_provider = default_provider.lower()
|
default_provider = default_provider.lower()
|
||||||
|
|
||||||
ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="chat")
|
ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="chat")
|
||||||
if stop:
|
|
||||||
ls_params["ls_stop"] = stop
|
|
||||||
|
|
||||||
# model
|
# model
|
||||||
model = getattr(self, "model", None) or getattr(self, "model_name", None)
|
model = getattr(self, "model", None) or getattr(self, "model_name", None)
|
||||||
@ -752,8 +747,8 @@ class BaseChatModel(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC):
|
|||||||
|
|
||||||
return ls_params
|
return ls_params
|
||||||
|
|
||||||
def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str:
|
def _get_llm_string(self, **kwargs: Any) -> str:
|
||||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
params = self._get_invocation_params(**kwargs)
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
return str(sorted(params.items()))
|
return str(sorted(params.items()))
|
||||||
|
|
||||||
|
@ -389,9 +389,6 @@ class ChatOllama(BaseChatModel):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
stop: Optional[list[str]] = None
|
|
||||||
"""Sets the stop tokens to use."""
|
|
||||||
|
|
||||||
tfs_z: Optional[float] = None
|
tfs_z: Optional[float] = None
|
||||||
"""Tail free sampling is used to reduce the impact of less probable tokens from the output.
|
"""Tail free sampling is used to reduce the impact of less probable tokens from the output.
|
||||||
|
|
||||||
@ -481,24 +478,20 @@ class ChatOllama(BaseChatModel):
|
|||||||
validate_model(self._client, self.model)
|
validate_model(self._client, self.model)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _get_ls_params(
|
def _get_ls_params(self, **kwargs: Any) -> LangSmithParams:
|
||||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
|
||||||
) -> LangSmithParams:
|
|
||||||
"""Get standard params for tracing."""
|
"""Get standard params for tracing."""
|
||||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
params = self._get_invocation_params(**kwargs)
|
||||||
ls_params = LangSmithParams(
|
ls_params = LangSmithParams(
|
||||||
ls_provider="ollama",
|
ls_provider="ollama",
|
||||||
ls_model_name=self.model,
|
ls_model_name=self.model,
|
||||||
ls_model_type="chat",
|
ls_model_type="chat",
|
||||||
ls_temperature=params.get("temperature", self.temperature),
|
ls_temperature=params.get("temperature", self.temperature),
|
||||||
)
|
)
|
||||||
if ls_stop := stop or params.get("stop", None) or self.stop:
|
if ls_stop := params.get("stop", None):
|
||||||
ls_params["ls_stop"] = ls_stop
|
ls_params["ls_stop"] = ls_stop
|
||||||
return ls_params
|
return ls_params
|
||||||
|
|
||||||
def _get_invocation_params(
|
def _get_invocation_params(self, **kwargs: Any) -> dict[str, Any]:
|
||||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Get parameters for model invocation."""
|
"""Get parameters for model invocation."""
|
||||||
params = {
|
params = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
@ -513,7 +506,6 @@ class ChatOllama(BaseChatModel):
|
|||||||
"repeat_penalty": self.repeat_penalty,
|
"repeat_penalty": self.repeat_penalty,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"seed": self.seed,
|
"seed": self.seed,
|
||||||
"stop": stop or self.stop,
|
|
||||||
"tfs_z": self.tfs_z,
|
"tfs_z": self.tfs_z,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
@ -531,7 +523,6 @@ class ChatOllama(BaseChatModel):
|
|||||||
def _chat_params(
|
def _chat_params(
|
||||||
self,
|
self,
|
||||||
messages: list[MessageV1],
|
messages: list[MessageV1],
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
*,
|
*,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -540,12 +531,6 @@ class ChatOllama(BaseChatModel):
|
|||||||
# Convert v1 messages to Ollama format
|
# Convert v1 messages to Ollama format
|
||||||
ollama_messages = [_convert_from_v1_to_ollama_format(msg) for msg in messages]
|
ollama_messages = [_convert_from_v1_to_ollama_format(msg) for msg in messages]
|
||||||
|
|
||||||
if self.stop is not None and stop is not None:
|
|
||||||
msg = "`stop` found in both the input and default params."
|
|
||||||
raise ValueError(msg)
|
|
||||||
if self.stop is not None:
|
|
||||||
stop = self.stop
|
|
||||||
|
|
||||||
options_dict = kwargs.pop(
|
options_dict = kwargs.pop(
|
||||||
"options",
|
"options",
|
||||||
{
|
{
|
||||||
@ -560,7 +545,6 @@ class ChatOllama(BaseChatModel):
|
|||||||
"repeat_penalty": self.repeat_penalty,
|
"repeat_penalty": self.repeat_penalty,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"seed": self.seed,
|
"seed": self.seed,
|
||||||
"stop": self.stop if stop is None else stop,
|
|
||||||
"tfs_z": self.tfs_z,
|
"tfs_z": self.tfs_z,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
@ -586,12 +570,11 @@ class ChatOllama(BaseChatModel):
|
|||||||
def _generate_stream(
|
def _generate_stream(
|
||||||
self,
|
self,
|
||||||
messages: list[MessageV1],
|
messages: list[MessageV1],
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[AIMessageChunk]:
|
) -> Iterator[AIMessageChunk]:
|
||||||
"""Generate streaming response with native v1 chunks."""
|
"""Generate streaming response with native v1 chunks."""
|
||||||
chat_params = self._chat_params(messages, stop, **kwargs)
|
chat_params = self._chat_params(messages, **kwargs)
|
||||||
|
|
||||||
if chat_params["stream"]:
|
if chat_params["stream"]:
|
||||||
for part in self._client.chat(**chat_params):
|
for part in self._client.chat(**chat_params):
|
||||||
@ -635,12 +618,11 @@ class ChatOllama(BaseChatModel):
|
|||||||
async def _agenerate_stream(
|
async def _agenerate_stream(
|
||||||
self,
|
self,
|
||||||
messages: list[MessageV1],
|
messages: list[MessageV1],
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[AIMessageChunk]:
|
) -> AsyncIterator[AIMessageChunk]:
|
||||||
"""Generate async streaming response with native v1 chunks."""
|
"""Generate async streaming response with native v1 chunks."""
|
||||||
chat_params = self._chat_params(messages, stop, **kwargs)
|
chat_params = self._chat_params(messages, **kwargs)
|
||||||
|
|
||||||
if chat_params["stream"]:
|
if chat_params["stream"]:
|
||||||
async for part in await self._async_client.chat(**chat_params):
|
async for part in await self._async_client.chat(**chat_params):
|
||||||
@ -684,7 +666,6 @@ class ChatOllama(BaseChatModel):
|
|||||||
def _invoke(
|
def _invoke(
|
||||||
self,
|
self,
|
||||||
messages: list[MessageV1],
|
messages: list[MessageV1],
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AIMessage:
|
) -> AIMessage:
|
||||||
@ -692,7 +673,6 @@ class ChatOllama(BaseChatModel):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of v1 format messages.
|
messages: List of v1 format messages.
|
||||||
stop: List of stop sequences.
|
|
||||||
run_manager: Callback manager for the run.
|
run_manager: Callback manager for the run.
|
||||||
kwargs: Additional parameters.
|
kwargs: Additional parameters.
|
||||||
|
|
||||||
@ -701,19 +681,16 @@ class ChatOllama(BaseChatModel):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
stream_iter = self._stream(
|
stream_iter = self._stream(messages, run_manager=run_manager, **kwargs)
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
return generate_from_stream(stream_iter)
|
return generate_from_stream(stream_iter)
|
||||||
|
|
||||||
chat_params = self._chat_params(messages, stop, stream=False, **kwargs)
|
chat_params = self._chat_params(messages, stream=False, **kwargs)
|
||||||
response = self._client.chat(**chat_params)
|
response = self._client.chat(**chat_params)
|
||||||
return _convert_to_v1_from_ollama_format(response)
|
return _convert_to_v1_from_ollama_format(response)
|
||||||
|
|
||||||
async def _ainvoke(
|
async def _ainvoke(
|
||||||
self,
|
self,
|
||||||
messages: list[MessageV1],
|
messages: list[MessageV1],
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AIMessage:
|
) -> AIMessage:
|
||||||
@ -721,7 +698,6 @@ class ChatOllama(BaseChatModel):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of v1 format messages.
|
messages: List of v1 format messages.
|
||||||
stop: List of stop sequences.
|
|
||||||
run_manager: Async callback manager for the run.
|
run_manager: Async callback manager for the run.
|
||||||
kwargs: Additional parameters.
|
kwargs: Additional parameters.
|
||||||
|
|
||||||
@ -730,20 +706,17 @@ class ChatOllama(BaseChatModel):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
stream_iter = self._astream(
|
stream_iter = self._astream(messages, run_manager=run_manager, **kwargs)
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
return await agenerate_from_stream(stream_iter)
|
return await agenerate_from_stream(stream_iter)
|
||||||
|
|
||||||
# Non-streaming case: direct API call
|
# Non-streaming case: direct API call
|
||||||
chat_params = self._chat_params(messages, stop, stream=False, **kwargs)
|
chat_params = self._chat_params(messages, stream=False, **kwargs)
|
||||||
response = await self._async_client.chat(**chat_params)
|
response = await self._async_client.chat(**chat_params)
|
||||||
return _convert_to_v1_from_ollama_format(response)
|
return _convert_to_v1_from_ollama_format(response)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
messages: list[MessageV1],
|
messages: list[MessageV1],
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[AIMessageChunk]:
|
) -> Iterator[AIMessageChunk]:
|
||||||
@ -751,7 +724,6 @@ class ChatOllama(BaseChatModel):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of v1 format messages.
|
messages: List of v1 format messages.
|
||||||
stop: List of stop sequences.
|
|
||||||
run_manager: Callback manager for the run.
|
run_manager: Callback manager for the run.
|
||||||
kwargs: Additional parameters.
|
kwargs: Additional parameters.
|
||||||
|
|
||||||
@ -759,14 +731,11 @@ class ChatOllama(BaseChatModel):
|
|||||||
AI message chunks in v1 format.
|
AI message chunks in v1 format.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
yield from self._generate_stream(
|
yield from self._generate_stream(messages, run_manager=run_manager, **kwargs)
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
self,
|
self,
|
||||||
messages: list[MessageV1],
|
messages: list[MessageV1],
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[AIMessageChunk]:
|
) -> AsyncIterator[AIMessageChunk]:
|
||||||
@ -774,7 +743,6 @@ class ChatOllama(BaseChatModel):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of v1 format messages.
|
messages: List of v1 format messages.
|
||||||
stop: List of stop sequences.
|
|
||||||
run_manager: Async callback manager for the run.
|
run_manager: Async callback manager for the run.
|
||||||
kwargs: Additional parameters.
|
kwargs: Additional parameters.
|
||||||
|
|
||||||
@ -783,7 +751,7 @@ class ChatOllama(BaseChatModel):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
async for chunk in self._agenerate_stream(
|
async for chunk in self._agenerate_stream(
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
messages, run_manager=run_manager, **kwargs
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user