From 3999761201ee791e489942fc1e582bb80d312886 Mon Sep 17 00:00:00 2001 From: ccurme Date: Thu, 6 Jun 2024 12:11:52 -0400 Subject: [PATCH] multiple: add `stop` attribute (#22573) --- libs/partners/ai21/langchain_ai21/chat_models.py | 4 ++++ .../anthropic/langchain_anthropic/chat_models.py | 6 +++++- .../fireworks/langchain_fireworks/chat_models.py | 5 +++-- libs/partners/groq/langchain_groq/chat_models.py | 5 +++-- .../integration_tests/chat_models.py | 11 +++++++++++ 5 files changed, 26 insertions(+), 5 deletions(-) diff --git a/libs/partners/ai21/langchain_ai21/chat_models.py b/libs/partners/ai21/langchain_ai21/chat_models.py index 8dbf894dcca..cb3a28d27bd 100644 --- a/libs/partners/ai21/langchain_ai21/chat_models.py +++ b/libs/partners/ai21/langchain_ai21/chat_models.py @@ -35,6 +35,8 @@ class ChatAI21(BaseChatModel, AI21Base): You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types""" num_results: int = 1 """The number of responses to generate for a given prompt.""" + stop: Optional[List[str]] = None + """Default stop sequences.""" max_tokens: int = 16 """The maximum number of tokens to generate for each response.""" @@ -97,6 +99,8 @@ class ChatAI21(BaseChatModel, AI21Base): "top_k_return": self.top_k_return, "n": self.n, } + if self.stop: + base_params["stop_sequences"] = self.stop if self.count_penalty is not None: base_params["count_penalty"] = self.count_penalty.to_dict() diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 91a6e31a2f6..65eec5bef8a 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -492,6 +492,9 @@ class ChatAnthropic(BaseChatModel): max_retries: int = 2 """Number of retries allowed for requests sent to the Anthropic Completion API.""" + stop: Optional[List[str]] = Field(None, alias="stop_sequences") + """Default stop sequences.""" + anthropic_api_url: Optional[str] = Field(None, alias="base_url") """Base URL for API requests. Only specify if using a proxy or service emulator. @@ -611,6 +614,7 @@ class ChatAnthropic(BaseChatModel): ) -> Dict: # get system prompt if any system, formatted_messages = _format_messages(messages) + stop_sequences = stop or self.stop rtn = { "model": self.model, "max_tokens": self.max_tokens, @@ -618,7 +622,7 @@ class ChatAnthropic(BaseChatModel): "temperature": self.temperature, "top_k": self.top_k, "top_p": self.top_p, - "stop_sequences": stop, + "stop_sequences": stop_sequences, "system": system, **self.model_kwargs, **kwargs, diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 52dea080894..c3af3933db2 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -300,6 +300,8 @@ class ChatFireworks(BaseChatModel): """Number of chat completions to generate for each prompt.""" max_tokens: Optional[int] = None """Maximum number of tokens to generate.""" + stop: Optional[List[str]] = Field(None, alias="stop_sequences") + """Default stop sequences.""" class Config: """Configuration for this pydantic object.""" @@ -354,6 +356,7 @@ class ChatFireworks(BaseChatModel): "stream": self.streaming, "n": self.n, "temperature": self.temperature, + "stop": self.stop, **self.model_kwargs, } if self.max_tokens is not None: @@ -443,8 +446,6 @@ class ChatFireworks(BaseChatModel): ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: params = self._default_params if stop is not None: - if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") params["stop"] = stop message_dicts = [_convert_message_to_dict(m) for m in messages] return message_dicts, params diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 69234ce93d6..611ed6cb44b 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -123,6 +123,8 @@ class ChatGroq(BaseChatModel): """Number of chat completions to generate for each prompt.""" max_tokens: Optional[int] = None """Maximum number of tokens to generate.""" + stop: Optional[List[str]] = Field(None, alias="stop_sequences") + """Default stop sequences.""" default_headers: Union[Mapping[str, str], None] = None default_query: Union[Mapping[str, object], None] = None # Configure a custom httpx client. See the @@ -428,6 +430,7 @@ class ChatGroq(BaseChatModel): "stream": self.streaming, "n": self.n, "temperature": self.temperature, + "stop": self.stop, **self.model_kwargs, } if self.max_tokens is not None: @@ -461,8 +464,6 @@ class ChatGroq(BaseChatModel): ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: params = self._default_params if stop is not None: - if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") params["stop"] = stop message_dicts = [_convert_message_to_dict(m) for m in messages] return message_dicts, params diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index 5f669efda16..9eabc6239e8 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -144,6 +144,17 @@ class ChatModelIntegrationTests(ABC): assert isinstance(result.usage_metadata["output_tokens"], int) assert isinstance(result.usage_metadata["total_tokens"], int) + def test_stop_sequence( + self, chat_model_class: Type[BaseChatModel], chat_model_params: dict + ) -> None: + model = chat_model_class(**chat_model_params) + result = model.invoke("hi", stop=["you"]) + assert isinstance(result, AIMessage) + + model = chat_model_class(**chat_model_params, stop=["you"]) + result = model.invoke("hi") + assert isinstance(result, AIMessage) + def test_tool_message_histories_string_content( self, chat_model_class: Type[BaseChatModel],