From 251afda5494924df343dc52ce244cb28eba8a320 Mon Sep 17 00:00:00 2001 From: Mateusz Szewczyk <139469471+MateuszOssGit@users.noreply.github.com> Date: Mon, 15 Jan 2024 20:44:57 +0100 Subject: [PATCH] community[patch]: fix stop (stop_sequences) param on WatsonxLLM (#15541) - **Description:** Fix to IBM [watsonx.ai](https://www.ibm.com/products/watsonx-ai) LLM provider (stop (`stop_sequences`) param on watsonxLLM) - **Dependencies:** [ibm-watsonx-ai](https://pypi.org/project/ibm-watsonx-ai/), --- .../langchain_community/llms/watsonxllm.py | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/libs/community/langchain_community/llms/watsonxllm.py b/libs/community/langchain_community/llms/watsonxllm.py index 380628f5ef3..d60c7284600 100644 --- a/libs/community/langchain_community/llms/watsonxllm.py +++ b/libs/community/langchain_community/llms/watsonxllm.py @@ -249,6 +249,12 @@ class WatsonxLLM(BaseLLM): "input_token_count": input_token_count, } + def _get_chat_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]: + params: Dict[str, Any] = {**self.params} if self.params else None + if stop is not None: + params = (params or {}) | {"stop_sequences": stop} + return params + def _create_llm_result(self, response: List[dict]) -> LLMResult: """Create the LLMResult from the choices and prompts.""" generations = [] @@ -334,11 +340,7 @@ class WatsonxLLM(BaseLLM): response = watsonx_llm.generate(["What is a molecule"]) """ - if stop: - if self.params: - self.params.update({"stop_sequences": stop}) - else: - self.params = {"stop_sequences": stop} + params = self._get_chat_params(stop=stop) should_stream = stream if stream is not None else self.streaming if should_stream: if len(prompts) > 1: @@ -360,7 +362,7 @@ class WatsonxLLM(BaseLLM): return LLMResult(generations=[[generation]], llm_output=llm_output) return LLMResult(generations=[[generation]]) else: - response = self.watsonx_model.generate(prompt=prompts, params=self.params) + response = self.watsonx_model.generate(prompt=prompts, params=params) return self._create_llm_result(response) def _stream( @@ -384,13 +386,9 @@ class WatsonxLLM(BaseLLM): for chunk in response: print(chunk, end='') """ - if stop: - if self.params: - self.params.update({"stop_sequences": stop}) - else: - self.params = {"stop_sequences": stop} + params = self._get_chat_params(stop=stop) for stream_resp in self.watsonx_model.generate_text_stream( - prompt=prompt, raw_response=True, params=self.params + prompt=prompt, raw_response=True, params=params ): chunk = self._stream_response_to_generation_chunk(stream_resp) yield chunk