mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 13:27:36 +00:00
community[patch]: fix stop (stop_sequences) param on WatsonxLLM (#15541)
- **Description:** Fix to IBM [watsonx.ai](https://www.ibm.com/products/watsonx-ai) LLM provider (stop (`stop_sequences`) param on watsonxLLM) - **Dependencies:** [ibm-watsonx-ai](https://pypi.org/project/ibm-watsonx-ai/),
This commit is contained in:
parent
7220124368
commit
251afda549
@ -249,6 +249,12 @@ class WatsonxLLM(BaseLLM):
|
|||||||
"input_token_count": input_token_count,
|
"input_token_count": input_token_count,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _get_chat_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||||
|
params: Dict[str, Any] = {**self.params} if self.params else None
|
||||||
|
if stop is not None:
|
||||||
|
params = (params or {}) | {"stop_sequences": stop}
|
||||||
|
return params
|
||||||
|
|
||||||
def _create_llm_result(self, response: List[dict]) -> LLMResult:
|
def _create_llm_result(self, response: List[dict]) -> LLMResult:
|
||||||
"""Create the LLMResult from the choices and prompts."""
|
"""Create the LLMResult from the choices and prompts."""
|
||||||
generations = []
|
generations = []
|
||||||
@ -334,11 +340,7 @@ class WatsonxLLM(BaseLLM):
|
|||||||
|
|
||||||
response = watsonx_llm.generate(["What is a molecule"])
|
response = watsonx_llm.generate(["What is a molecule"])
|
||||||
"""
|
"""
|
||||||
if stop:
|
params = self._get_chat_params(stop=stop)
|
||||||
if self.params:
|
|
||||||
self.params.update({"stop_sequences": stop})
|
|
||||||
else:
|
|
||||||
self.params = {"stop_sequences": stop}
|
|
||||||
should_stream = stream if stream is not None else self.streaming
|
should_stream = stream if stream is not None else self.streaming
|
||||||
if should_stream:
|
if should_stream:
|
||||||
if len(prompts) > 1:
|
if len(prompts) > 1:
|
||||||
@ -360,7 +362,7 @@ class WatsonxLLM(BaseLLM):
|
|||||||
return LLMResult(generations=[[generation]], llm_output=llm_output)
|
return LLMResult(generations=[[generation]], llm_output=llm_output)
|
||||||
return LLMResult(generations=[[generation]])
|
return LLMResult(generations=[[generation]])
|
||||||
else:
|
else:
|
||||||
response = self.watsonx_model.generate(prompt=prompts, params=self.params)
|
response = self.watsonx_model.generate(prompt=prompts, params=params)
|
||||||
return self._create_llm_result(response)
|
return self._create_llm_result(response)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
@ -384,13 +386,9 @@ class WatsonxLLM(BaseLLM):
|
|||||||
for chunk in response:
|
for chunk in response:
|
||||||
print(chunk, end='')
|
print(chunk, end='')
|
||||||
"""
|
"""
|
||||||
if stop:
|
params = self._get_chat_params(stop=stop)
|
||||||
if self.params:
|
|
||||||
self.params.update({"stop_sequences": stop})
|
|
||||||
else:
|
|
||||||
self.params = {"stop_sequences": stop}
|
|
||||||
for stream_resp in self.watsonx_model.generate_text_stream(
|
for stream_resp in self.watsonx_model.generate_text_stream(
|
||||||
prompt=prompt, raw_response=True, params=self.params
|
prompt=prompt, raw_response=True, params=params
|
||||||
):
|
):
|
||||||
chunk = self._stream_response_to_generation_chunk(stream_resp)
|
chunk = self._stream_response_to_generation_chunk(stream_resp)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
Loading…
Reference in New Issue
Block a user