mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-11-04 02:03:32 +00:00 
			
		
		
		
	feat(llms): add missing params to huggingface text-generation (#9724)
This small PR aims at supporting the following missing parameters in the `HuggingfaceTextGen` LLM: - `return_full_text` - sometimes useful for completion tasks - `do_sample` - quite handy to control the randomness of the model. - `watermark` @hwchase17 @baskaryan --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							491089754d
						
					
				
				
					commit
					a7c9bd30d4
				
			@@ -80,6 +80,7 @@ class HuggingFaceTextGenInference(LLM):
 | 
				
			|||||||
    typical_p: Optional[float] = 0.95
 | 
					    typical_p: Optional[float] = 0.95
 | 
				
			||||||
    temperature: float = 0.8
 | 
					    temperature: float = 0.8
 | 
				
			||||||
    repetition_penalty: Optional[float] = None
 | 
					    repetition_penalty: Optional[float] = None
 | 
				
			||||||
 | 
					    return_full_text: bool = False
 | 
				
			||||||
    truncate: Optional[int] = None
 | 
					    truncate: Optional[int] = None
 | 
				
			||||||
    stop_sequences: List[str] = Field(default_factory=list)
 | 
					    stop_sequences: List[str] = Field(default_factory=list)
 | 
				
			||||||
    seed: Optional[int] = None
 | 
					    seed: Optional[int] = None
 | 
				
			||||||
@@ -87,6 +88,8 @@ class HuggingFaceTextGenInference(LLM):
 | 
				
			|||||||
    timeout: int = 120
 | 
					    timeout: int = 120
 | 
				
			||||||
    server_kwargs: Dict[str, Any] = Field(default_factory=dict)
 | 
					    server_kwargs: Dict[str, Any] = Field(default_factory=dict)
 | 
				
			||||||
    streaming: bool = False
 | 
					    streaming: bool = False
 | 
				
			||||||
 | 
					    do_sample: bool = False
 | 
				
			||||||
 | 
					    watermark: bool = False
 | 
				
			||||||
    client: Any
 | 
					    client: Any
 | 
				
			||||||
    async_client: Any
 | 
					    async_client: Any
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -134,9 +137,12 @@ class HuggingFaceTextGenInference(LLM):
 | 
				
			|||||||
            "typical_p": self.typical_p,
 | 
					            "typical_p": self.typical_p,
 | 
				
			||||||
            "temperature": self.temperature,
 | 
					            "temperature": self.temperature,
 | 
				
			||||||
            "repetition_penalty": self.repetition_penalty,
 | 
					            "repetition_penalty": self.repetition_penalty,
 | 
				
			||||||
 | 
					            "return_full_text": self.return_full_text,
 | 
				
			||||||
            "truncate": self.truncate,
 | 
					            "truncate": self.truncate,
 | 
				
			||||||
            "stop_sequences": self.stop_sequences,
 | 
					            "stop_sequences": self.stop_sequences,
 | 
				
			||||||
            "seed": self.seed,
 | 
					            "seed": self.seed,
 | 
				
			||||||
 | 
					            "do_sample": self.do_sample,
 | 
				
			||||||
 | 
					            "watermark": self.watermark,
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _invocation_params(
 | 
					    def _invocation_params(
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user