community[patch]: fix stop (stop_sequences) param on WatsonxLLM (#15541)

- **Description:** Fix to IBM
[watsonx.ai](https://www.ibm.com/products/watsonx-ai) LLM provider (stop
(`stop_sequences`) param on watsonxLLM)
- **Dependencies:**
[ibm-watsonx-ai](https://pypi.org/project/ibm-watsonx-ai/),
This commit is contained in:
Mateusz Szewczyk 2024-01-15 20:44:57 +01:00 committed by GitHub
parent 7220124368
commit 251afda549
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -249,6 +249,12 @@ class WatsonxLLM(BaseLLM):
"input_token_count": input_token_count,
}
def _get_chat_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
params: Dict[str, Any] = {**self.params} if self.params else None
if stop is not None:
params = (params or {}) | {"stop_sequences": stop}
return params
def _create_llm_result(self, response: List[dict]) -> LLMResult:
"""Create the LLMResult from the choices and prompts."""
generations = []
@ -334,11 +340,7 @@ class WatsonxLLM(BaseLLM):
response = watsonx_llm.generate(["What is a molecule"])
"""
if stop:
if self.params:
self.params.update({"stop_sequences": stop})
else:
self.params = {"stop_sequences": stop}
params = self._get_chat_params(stop=stop)
should_stream = stream if stream is not None else self.streaming
if should_stream:
if len(prompts) > 1:
@ -360,7 +362,7 @@ class WatsonxLLM(BaseLLM):
return LLMResult(generations=[[generation]], llm_output=llm_output)
return LLMResult(generations=[[generation]])
else:
response = self.watsonx_model.generate(prompt=prompts, params=self.params)
response = self.watsonx_model.generate(prompt=prompts, params=params)
return self._create_llm_result(response)
def _stream(
@ -384,13 +386,9 @@ class WatsonxLLM(BaseLLM):
for chunk in response:
print(chunk, end='')
"""
if stop:
if self.params:
self.params.update({"stop_sequences": stop})
else:
self.params = {"stop_sequences": stop}
params = self._get_chat_params(stop=stop)
for stream_resp in self.watsonx_model.generate_text_stream(
prompt=prompt, raw_response=True, params=self.params
prompt=prompt, raw_response=True, params=params
):
chunk = self._stream_response_to_generation_chunk(stream_resp)
yield chunk