diff --git a/langchain/llms/huggingface_text_gen_inference.py b/langchain/llms/huggingface_text_gen_inference.py index 101f1eb78d1..293d463add3 100644 --- a/langchain/llms/huggingface_text_gen_inference.py +++ b/langchain/llms/huggingface_text_gen_inference.py @@ -140,6 +140,13 @@ class HuggingFaceTextGenInference(LLM): "seed": self.seed, } + def _invocation_params( + self, runtime_stop: Optional[List[str]], **kwargs: Any + ) -> Dict[str, Any]: + params = {**self._default_params, **kwargs} + params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or []) + return params + def _call( self, prompt: str, @@ -147,20 +154,11 @@ class HuggingFaceTextGenInference(LLM): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: - if stop is None: - stop = self.stop_sequences - else: - stop += self.stop_sequences - + invocation_params = self._invocation_params(stop, **kwargs) if not self.stream: - res = self.client.generate( - prompt, - **self._default_params, - stop_sequences=stop, - **kwargs, - ) + res = self.client.generate(prompt, **invocation_params) # remove stop sequences from the end of the generated text - for stop_seq in stop: + for stop_seq in invocation_params["stop_sequences"]: if stop_seq in res.generated_text: res.generated_text = res.generated_text[ : res.generated_text.index(stop_seq) @@ -172,16 +170,11 @@ class HuggingFaceTextGenInference(LLM): text_callback = partial( run_manager.on_llm_new_token, verbose=self.verbose ) - params = { - **self._default_params, - "stop_sequences": stop, - **kwargs, - } text = "" - for res in self.client.generate_stream(prompt, **params): + for res in self.client.generate_stream(prompt, **invocation_params): token = res.token is_stop = False - for stop_seq in stop: + for stop_seq in invocation_params["stop_sequences"]: if stop_seq in token.text: is_stop = True break @@ -200,20 +193,14 @@ class HuggingFaceTextGenInference(LLM): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: - if stop is None: - stop = self.stop_sequences - else: - stop += self.stop_sequences - + invocation_params = self._invocation_params(stop, **kwargs) if not self.stream: res = await self.async_client.generate( prompt, - **self._default_params, - stop_sequences=stop, - **kwargs, + **invocation_params, ) # remove stop sequences from the end of the generated text - for stop_seq in stop: + for stop_seq in invocation_params["stop_sequences"]: if stop_seq in res.generated_text: res.generated_text = res.generated_text[ : res.generated_text.index(stop_seq) @@ -225,16 +212,13 @@ class HuggingFaceTextGenInference(LLM): text_callback = partial( run_manager.on_llm_new_token, verbose=self.verbose ) - params = { - **self._default_params, - "stop_sequences": stop, - **kwargs, - } text = "" - async for res in self.async_client.generate_stream(prompt, **params): + async for res in self.async_client.generate_stream( + prompt, **invocation_params + ): token = res.token is_stop = False - for stop_seq in stop: + for stop_seq in invocation_params["stop_sequences"]: if stop_seq in token.text: is_stop = True break diff --git a/tests/integration_tests/llms/test_huggingface_text_gen_inference.py b/tests/integration_tests/llms/test_huggingface_text_gen_inference.py new file mode 100644 index 00000000000..46e63383a4b --- /dev/null +++ b/tests/integration_tests/llms/test_huggingface_text_gen_inference.py @@ -0,0 +1,19 @@ +from langchain import HuggingFaceTextGenInference + + +def test_invocation_params_stop_sequences() -> None: + llm = HuggingFaceTextGenInference() + assert llm._default_params["stop_sequences"] == [] + + runtime_stop = None + assert llm._invocation_params(runtime_stop)["stop_sequences"] == [] + assert llm._default_params["stop_sequences"] == [] + + runtime_stop = ["stop"] + assert llm._invocation_params(runtime_stop)["stop_sequences"] == ["stop"] + assert llm._default_params["stop_sequences"] == [] + + llm = HuggingFaceTextGenInference(stop_sequences=["."]) + runtime_stop = ["stop"] + assert llm._invocation_params(runtime_stop)["stop_sequences"] == [".", "stop"] + assert llm._default_params["stop_sequences"] == ["."]