community[patch]: fix stop (stop_sequences) param on WatsonxLLM (#15541)

- **Description:** Fix to IBM
[watsonx.ai](https://www.ibm.com/products/watsonx-ai) LLM provider (stop
(`stop_sequences`) param on watsonxLLM)
- **Dependencies:**
[ibm-watsonx-ai](https://pypi.org/project/ibm-watsonx-ai/),
This commit is contained in:
Mateusz Szewczyk 2024-01-15 20:44:57 +01:00 committed by GitHub
parent 7220124368
commit 251afda549
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -249,6 +249,12 @@ class WatsonxLLM(BaseLLM):
"input_token_count": input_token_count, "input_token_count": input_token_count,
} }
def _get_chat_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
params: Dict[str, Any] = {**self.params} if self.params else None
if stop is not None:
params = (params or {}) | {"stop_sequences": stop}
return params
def _create_llm_result(self, response: List[dict]) -> LLMResult: def _create_llm_result(self, response: List[dict]) -> LLMResult:
"""Create the LLMResult from the choices and prompts.""" """Create the LLMResult from the choices and prompts."""
generations = [] generations = []
@ -334,11 +340,7 @@ class WatsonxLLM(BaseLLM):
response = watsonx_llm.generate(["What is a molecule"]) response = watsonx_llm.generate(["What is a molecule"])
""" """
if stop: params = self._get_chat_params(stop=stop)
if self.params:
self.params.update({"stop_sequences": stop})
else:
self.params = {"stop_sequences": stop}
should_stream = stream if stream is not None else self.streaming should_stream = stream if stream is not None else self.streaming
if should_stream: if should_stream:
if len(prompts) > 1: if len(prompts) > 1:
@ -360,7 +362,7 @@ class WatsonxLLM(BaseLLM):
return LLMResult(generations=[[generation]], llm_output=llm_output) return LLMResult(generations=[[generation]], llm_output=llm_output)
return LLMResult(generations=[[generation]]) return LLMResult(generations=[[generation]])
else: else:
response = self.watsonx_model.generate(prompt=prompts, params=self.params) response = self.watsonx_model.generate(prompt=prompts, params=params)
return self._create_llm_result(response) return self._create_llm_result(response)
def _stream( def _stream(
@ -384,13 +386,9 @@ class WatsonxLLM(BaseLLM):
for chunk in response: for chunk in response:
print(chunk, end='') print(chunk, end='')
""" """
if stop: params = self._get_chat_params(stop=stop)
if self.params:
self.params.update({"stop_sequences": stop})
else:
self.params = {"stop_sequences": stop}
for stream_resp in self.watsonx_model.generate_text_stream( for stream_resp in self.watsonx_model.generate_text_stream(
prompt=prompt, raw_response=True, params=self.params prompt=prompt, raw_response=True, params=params
): ):
chunk = self._stream_response_to_generation_chunk(stream_resp) chunk = self._stream_response_to_generation_chunk(stream_resp)
yield chunk yield chunk