diff --git a/langchain/llms/huggingface_text_gen_inference.py b/langchain/llms/huggingface_text_gen_inference.py index 012dd29a5c7..101f1eb78d1 100644 --- a/langchain/llms/huggingface_text_gen_inference.py +++ b/langchain/llms/huggingface_text_gen_inference.py @@ -36,6 +36,8 @@ class HuggingFaceTextGenInference(LLM): - _call: Generates text based on a given prompt and stop sequences. - _acall: Async generates text based on a given prompt and stop sequences. - _llm_type: Returns the type of LLM. + - _default_params: Returns the default parameters for calling text generation + inference API. """ """ @@ -123,6 +125,21 @@ class HuggingFaceTextGenInference(LLM): """Return type of llm.""" return "huggingface_textgen_inference" + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling text generation inference API.""" + return { + "max_new_tokens": self.max_new_tokens, + "top_k": self.top_k, + "top_p": self.top_p, + "typical_p": self.typical_p, + "temperature": self.temperature, + "repetition_penalty": self.repetition_penalty, + "truncate": self.truncate, + "stop_sequences": self.stop_sequences, + "seed": self.seed, + } + def _call( self, prompt: str, @@ -138,15 +155,8 @@ class HuggingFaceTextGenInference(LLM): if not self.stream: res = self.client.generate( prompt, + **self._default_params, stop_sequences=stop, - max_new_tokens=self.max_new_tokens, - top_k=self.top_k, - top_p=self.top_p, - typical_p=self.typical_p, - temperature=self.temperature, - repetition_penalty=self.repetition_penalty, - truncate=self.truncate, - seed=self.seed, **kwargs, ) # remove stop sequences from the end of the generated text @@ -163,15 +173,9 @@ class HuggingFaceTextGenInference(LLM): run_manager.on_llm_new_token, verbose=self.verbose ) params = { + **self._default_params, "stop_sequences": stop, - "max_new_tokens": self.max_new_tokens, - "top_k": self.top_k, - "top_p": self.top_p, - "typical_p": self.typical_p, - "temperature": self.temperature, - "repetition_penalty": self.repetition_penalty, - "truncate": self.truncate, - "seed": self.seed, + **kwargs, } text = "" for res in self.client.generate_stream(prompt, **params): @@ -204,15 +208,8 @@ class HuggingFaceTextGenInference(LLM): if not self.stream: res = await self.async_client.generate( prompt, + **self._default_params, stop_sequences=stop, - max_new_tokens=self.max_new_tokens, - top_k=self.top_k, - top_p=self.top_p, - typical_p=self.typical_p, - temperature=self.temperature, - repetition_penalty=self.repetition_penalty, - truncate=self.truncate, - seed=self.seed, **kwargs, ) # remove stop sequences from the end of the generated text @@ -229,17 +226,8 @@ class HuggingFaceTextGenInference(LLM): run_manager.on_llm_new_token, verbose=self.verbose ) params = { - **{ - "stop_sequences": stop, - "max_new_tokens": self.max_new_tokens, - "top_k": self.top_k, - "top_p": self.top_p, - "typical_p": self.typical_p, - "temperature": self.temperature, - "repetition_penalty": self.repetition_penalty, - "truncate": self.truncate, - "seed": self.seed, - }, + **self._default_params, + "stop_sequences": stop, **kwargs, } text = ""