mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 22:29:51 +00:00
chore(hf-text-gen): extract default params for reusing (#7929)
This PR extract common code (default generation params) for `HuggingFaceTextGenInference`. Co-authored-by: Junlin Zhou <jlzhou@zjuici.com>
This commit is contained in:
parent
54e02e4392
commit
812a1643db
@ -36,6 +36,8 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
- _call: Generates text based on a given prompt and stop sequences.
|
- _call: Generates text based on a given prompt and stop sequences.
|
||||||
- _acall: Async generates text based on a given prompt and stop sequences.
|
- _acall: Async generates text based on a given prompt and stop sequences.
|
||||||
- _llm_type: Returns the type of LLM.
|
- _llm_type: Returns the type of LLM.
|
||||||
|
- _default_params: Returns the default parameters for calling text generation
|
||||||
|
inference API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -123,6 +125,21 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "huggingface_textgen_inference"
|
return "huggingface_textgen_inference"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the default parameters for calling text generation inference API."""
|
||||||
|
return {
|
||||||
|
"max_new_tokens": self.max_new_tokens,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"typical_p": self.typical_p,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"repetition_penalty": self.repetition_penalty,
|
||||||
|
"truncate": self.truncate,
|
||||||
|
"stop_sequences": self.stop_sequences,
|
||||||
|
"seed": self.seed,
|
||||||
|
}
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -138,15 +155,8 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
if not self.stream:
|
if not self.stream:
|
||||||
res = self.client.generate(
|
res = self.client.generate(
|
||||||
prompt,
|
prompt,
|
||||||
|
**self._default_params,
|
||||||
stop_sequences=stop,
|
stop_sequences=stop,
|
||||||
max_new_tokens=self.max_new_tokens,
|
|
||||||
top_k=self.top_k,
|
|
||||||
top_p=self.top_p,
|
|
||||||
typical_p=self.typical_p,
|
|
||||||
temperature=self.temperature,
|
|
||||||
repetition_penalty=self.repetition_penalty,
|
|
||||||
truncate=self.truncate,
|
|
||||||
seed=self.seed,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
# remove stop sequences from the end of the generated text
|
# remove stop sequences from the end of the generated text
|
||||||
@ -163,15 +173,9 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
run_manager.on_llm_new_token, verbose=self.verbose
|
run_manager.on_llm_new_token, verbose=self.verbose
|
||||||
)
|
)
|
||||||
params = {
|
params = {
|
||||||
|
**self._default_params,
|
||||||
"stop_sequences": stop,
|
"stop_sequences": stop,
|
||||||
"max_new_tokens": self.max_new_tokens,
|
**kwargs,
|
||||||
"top_k": self.top_k,
|
|
||||||
"top_p": self.top_p,
|
|
||||||
"typical_p": self.typical_p,
|
|
||||||
"temperature": self.temperature,
|
|
||||||
"repetition_penalty": self.repetition_penalty,
|
|
||||||
"truncate": self.truncate,
|
|
||||||
"seed": self.seed,
|
|
||||||
}
|
}
|
||||||
text = ""
|
text = ""
|
||||||
for res in self.client.generate_stream(prompt, **params):
|
for res in self.client.generate_stream(prompt, **params):
|
||||||
@ -204,15 +208,8 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
if not self.stream:
|
if not self.stream:
|
||||||
res = await self.async_client.generate(
|
res = await self.async_client.generate(
|
||||||
prompt,
|
prompt,
|
||||||
|
**self._default_params,
|
||||||
stop_sequences=stop,
|
stop_sequences=stop,
|
||||||
max_new_tokens=self.max_new_tokens,
|
|
||||||
top_k=self.top_k,
|
|
||||||
top_p=self.top_p,
|
|
||||||
typical_p=self.typical_p,
|
|
||||||
temperature=self.temperature,
|
|
||||||
repetition_penalty=self.repetition_penalty,
|
|
||||||
truncate=self.truncate,
|
|
||||||
seed=self.seed,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
# remove stop sequences from the end of the generated text
|
# remove stop sequences from the end of the generated text
|
||||||
@ -229,17 +226,8 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
run_manager.on_llm_new_token, verbose=self.verbose
|
run_manager.on_llm_new_token, verbose=self.verbose
|
||||||
)
|
)
|
||||||
params = {
|
params = {
|
||||||
**{
|
**self._default_params,
|
||||||
"stop_sequences": stop,
|
"stop_sequences": stop,
|
||||||
"max_new_tokens": self.max_new_tokens,
|
|
||||||
"top_k": self.top_k,
|
|
||||||
"top_p": self.top_p,
|
|
||||||
"typical_p": self.typical_p,
|
|
||||||
"temperature": self.temperature,
|
|
||||||
"repetition_penalty": self.repetition_penalty,
|
|
||||||
"truncate": self.truncate,
|
|
||||||
"seed": self.seed,
|
|
||||||
},
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
text = ""
|
text = ""
|
||||||
|
Loading…
Reference in New Issue
Block a user