From a7c9bd30d47de9e681aff008652cc2ff818102cc Mon Sep 17 00:00:00 2001 From: Massimiliano Pronesti Date: Fri, 1 Sep 2023 22:16:27 +0200 Subject: [PATCH] feat(llms): add missing params to huggingface text-generation (#9724) This small PR aims at supporting the following missing parameters in the `HuggingfaceTextGen` LLM: - `return_full_text` - sometimes useful for completion tasks - `do_sample` - quite handy to control the randomness of the model. - `watermark` @hwchase17 @baskaryan --------- Co-authored-by: Bagatur --- .../langchain/llms/huggingface_text_gen_inference.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/libs/langchain/langchain/llms/huggingface_text_gen_inference.py b/libs/langchain/langchain/llms/huggingface_text_gen_inference.py index 8e23d935f85..284890579b8 100644 --- a/libs/langchain/langchain/llms/huggingface_text_gen_inference.py +++ b/libs/langchain/langchain/llms/huggingface_text_gen_inference.py @@ -80,6 +80,7 @@ class HuggingFaceTextGenInference(LLM): typical_p: Optional[float] = 0.95 temperature: float = 0.8 repetition_penalty: Optional[float] = None + return_full_text: bool = False truncate: Optional[int] = None stop_sequences: List[str] = Field(default_factory=list) seed: Optional[int] = None @@ -87,6 +88,8 @@ class HuggingFaceTextGenInference(LLM): timeout: int = 120 server_kwargs: Dict[str, Any] = Field(default_factory=dict) streaming: bool = False + do_sample: bool = False + watermark: bool = False client: Any async_client: Any @@ -134,9 +137,12 @@ class HuggingFaceTextGenInference(LLM): "typical_p": self.typical_p, "temperature": self.temperature, "repetition_penalty": self.repetition_penalty, + "return_full_text": self.return_full_text, "truncate": self.truncate, "stop_sequences": self.stop_sequences, "seed": self.seed, + "do_sample": self.do_sample, + "watermark": self.watermark, } def _invocation_params(