mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-22 15:38:06 +00:00
feat(llms): add missing params to huggingface text-generation (#9724)
This small PR aims at supporting the following missing parameters in the `HuggingfaceTextGen` LLM: - `return_full_text` - sometimes useful for completion tasks - `do_sample` - quite handy to control the randomness of the model. - `watermark` @hwchase17 @baskaryan --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
491089754d
commit
a7c9bd30d4
@ -80,6 +80,7 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
typical_p: Optional[float] = 0.95
|
typical_p: Optional[float] = 0.95
|
||||||
temperature: float = 0.8
|
temperature: float = 0.8
|
||||||
repetition_penalty: Optional[float] = None
|
repetition_penalty: Optional[float] = None
|
||||||
|
return_full_text: bool = False
|
||||||
truncate: Optional[int] = None
|
truncate: Optional[int] = None
|
||||||
stop_sequences: List[str] = Field(default_factory=list)
|
stop_sequences: List[str] = Field(default_factory=list)
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
@ -87,6 +88,8 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
timeout: int = 120
|
timeout: int = 120
|
||||||
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
|
do_sample: bool = False
|
||||||
|
watermark: bool = False
|
||||||
client: Any
|
client: Any
|
||||||
async_client: Any
|
async_client: Any
|
||||||
|
|
||||||
@ -134,9 +137,12 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
"typical_p": self.typical_p,
|
"typical_p": self.typical_p,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"repetition_penalty": self.repetition_penalty,
|
"repetition_penalty": self.repetition_penalty,
|
||||||
|
"return_full_text": self.return_full_text,
|
||||||
"truncate": self.truncate,
|
"truncate": self.truncate,
|
||||||
"stop_sequences": self.stop_sequences,
|
"stop_sequences": self.stop_sequences,
|
||||||
"seed": self.seed,
|
"seed": self.seed,
|
||||||
|
"do_sample": self.do_sample,
|
||||||
|
"watermark": self.watermark,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _invocation_params(
|
def _invocation_params(
|
||||||
|
Loading…
Reference in New Issue
Block a user