chore(hf-text-gen): extract default params for reusing (#7929)

This PR extract common code (default generation params) for
`HuggingFaceTextGenInference`.

Co-authored-by: Junlin Zhou <jlzhou@zjuici.com>
This commit is contained in:
Junlin Zhou 2023-07-20 21:49:12 +08:00 committed by GitHub
parent 54e02e4392
commit 812a1643db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -36,6 +36,8 @@ class HuggingFaceTextGenInference(LLM):
- _call: Generates text based on a given prompt and stop sequences. - _call: Generates text based on a given prompt and stop sequences.
- _acall: Async generates text based on a given prompt and stop sequences. - _acall: Async generates text based on a given prompt and stop sequences.
- _llm_type: Returns the type of LLM. - _llm_type: Returns the type of LLM.
- _default_params: Returns the default parameters for calling text generation
inference API.
""" """
""" """
@ -123,6 +125,21 @@ class HuggingFaceTextGenInference(LLM):
"""Return type of llm.""" """Return type of llm."""
return "huggingface_textgen_inference" return "huggingface_textgen_inference"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling text generation inference API."""
return {
"max_new_tokens": self.max_new_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"typical_p": self.typical_p,
"temperature": self.temperature,
"repetition_penalty": self.repetition_penalty,
"truncate": self.truncate,
"stop_sequences": self.stop_sequences,
"seed": self.seed,
}
def _call( def _call(
self, self,
prompt: str, prompt: str,
@ -138,15 +155,8 @@ class HuggingFaceTextGenInference(LLM):
if not self.stream: if not self.stream:
res = self.client.generate( res = self.client.generate(
prompt, prompt,
**self._default_params,
stop_sequences=stop, stop_sequences=stop,
max_new_tokens=self.max_new_tokens,
top_k=self.top_k,
top_p=self.top_p,
typical_p=self.typical_p,
temperature=self.temperature,
repetition_penalty=self.repetition_penalty,
truncate=self.truncate,
seed=self.seed,
**kwargs, **kwargs,
) )
# remove stop sequences from the end of the generated text # remove stop sequences from the end of the generated text
@ -163,15 +173,9 @@ class HuggingFaceTextGenInference(LLM):
run_manager.on_llm_new_token, verbose=self.verbose run_manager.on_llm_new_token, verbose=self.verbose
) )
params = { params = {
**self._default_params,
"stop_sequences": stop, "stop_sequences": stop,
"max_new_tokens": self.max_new_tokens, **kwargs,
"top_k": self.top_k,
"top_p": self.top_p,
"typical_p": self.typical_p,
"temperature": self.temperature,
"repetition_penalty": self.repetition_penalty,
"truncate": self.truncate,
"seed": self.seed,
} }
text = "" text = ""
for res in self.client.generate_stream(prompt, **params): for res in self.client.generate_stream(prompt, **params):
@ -204,15 +208,8 @@ class HuggingFaceTextGenInference(LLM):
if not self.stream: if not self.stream:
res = await self.async_client.generate( res = await self.async_client.generate(
prompt, prompt,
**self._default_params,
stop_sequences=stop, stop_sequences=stop,
max_new_tokens=self.max_new_tokens,
top_k=self.top_k,
top_p=self.top_p,
typical_p=self.typical_p,
temperature=self.temperature,
repetition_penalty=self.repetition_penalty,
truncate=self.truncate,
seed=self.seed,
**kwargs, **kwargs,
) )
# remove stop sequences from the end of the generated text # remove stop sequences from the end of the generated text
@ -229,17 +226,8 @@ class HuggingFaceTextGenInference(LLM):
run_manager.on_llm_new_token, verbose=self.verbose run_manager.on_llm_new_token, verbose=self.verbose
) )
params = { params = {
**{ **self._default_params,
"stop_sequences": stop, "stop_sequences": stop,
"max_new_tokens": self.max_new_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"typical_p": self.typical_p,
"temperature": self.temperature,
"repetition_penalty": self.repetition_penalty,
"truncate": self.truncate,
"seed": self.seed,
},
**kwargs, **kwargs,
} }
text = "" text = ""