mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 09:28:48 +00:00
HuggingFaceTextGenInference bug fix: Multiple values for keyword argument (#8044)
Fixed the bug causing: `TypeError: generate() got multiple values for keyword argument 'stop_sequences'` ```python res = await self.async_client.generate( prompt, **self._default_params, stop_sequences=stop, **kwargs, ) ``` The above throws an error because stop_sequences is in also in the self._default_params. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
ed6a5532ac
commit
ebc5ff2948
@ -140,6 +140,13 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
"seed": self.seed,
|
"seed": self.seed,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _invocation_params(
|
||||||
|
self, runtime_stop: Optional[List[str]], **kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
params = {**self._default_params, **kwargs}
|
||||||
|
params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or [])
|
||||||
|
return params
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -147,20 +154,11 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
if stop is None:
|
invocation_params = self._invocation_params(stop, **kwargs)
|
||||||
stop = self.stop_sequences
|
|
||||||
else:
|
|
||||||
stop += self.stop_sequences
|
|
||||||
|
|
||||||
if not self.stream:
|
if not self.stream:
|
||||||
res = self.client.generate(
|
res = self.client.generate(prompt, **invocation_params)
|
||||||
prompt,
|
|
||||||
**self._default_params,
|
|
||||||
stop_sequences=stop,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
# remove stop sequences from the end of the generated text
|
# remove stop sequences from the end of the generated text
|
||||||
for stop_seq in stop:
|
for stop_seq in invocation_params["stop_sequences"]:
|
||||||
if stop_seq in res.generated_text:
|
if stop_seq in res.generated_text:
|
||||||
res.generated_text = res.generated_text[
|
res.generated_text = res.generated_text[
|
||||||
: res.generated_text.index(stop_seq)
|
: res.generated_text.index(stop_seq)
|
||||||
@ -172,16 +170,11 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
text_callback = partial(
|
text_callback = partial(
|
||||||
run_manager.on_llm_new_token, verbose=self.verbose
|
run_manager.on_llm_new_token, verbose=self.verbose
|
||||||
)
|
)
|
||||||
params = {
|
|
||||||
**self._default_params,
|
|
||||||
"stop_sequences": stop,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
text = ""
|
text = ""
|
||||||
for res in self.client.generate_stream(prompt, **params):
|
for res in self.client.generate_stream(prompt, **invocation_params):
|
||||||
token = res.token
|
token = res.token
|
||||||
is_stop = False
|
is_stop = False
|
||||||
for stop_seq in stop:
|
for stop_seq in invocation_params["stop_sequences"]:
|
||||||
if stop_seq in token.text:
|
if stop_seq in token.text:
|
||||||
is_stop = True
|
is_stop = True
|
||||||
break
|
break
|
||||||
@ -200,20 +193,14 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
if stop is None:
|
invocation_params = self._invocation_params(stop, **kwargs)
|
||||||
stop = self.stop_sequences
|
|
||||||
else:
|
|
||||||
stop += self.stop_sequences
|
|
||||||
|
|
||||||
if not self.stream:
|
if not self.stream:
|
||||||
res = await self.async_client.generate(
|
res = await self.async_client.generate(
|
||||||
prompt,
|
prompt,
|
||||||
**self._default_params,
|
**invocation_params,
|
||||||
stop_sequences=stop,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
# remove stop sequences from the end of the generated text
|
# remove stop sequences from the end of the generated text
|
||||||
for stop_seq in stop:
|
for stop_seq in invocation_params["stop_sequences"]:
|
||||||
if stop_seq in res.generated_text:
|
if stop_seq in res.generated_text:
|
||||||
res.generated_text = res.generated_text[
|
res.generated_text = res.generated_text[
|
||||||
: res.generated_text.index(stop_seq)
|
: res.generated_text.index(stop_seq)
|
||||||
@ -225,16 +212,13 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
text_callback = partial(
|
text_callback = partial(
|
||||||
run_manager.on_llm_new_token, verbose=self.verbose
|
run_manager.on_llm_new_token, verbose=self.verbose
|
||||||
)
|
)
|
||||||
params = {
|
|
||||||
**self._default_params,
|
|
||||||
"stop_sequences": stop,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
text = ""
|
text = ""
|
||||||
async for res in self.async_client.generate_stream(prompt, **params):
|
async for res in self.async_client.generate_stream(
|
||||||
|
prompt, **invocation_params
|
||||||
|
):
|
||||||
token = res.token
|
token = res.token
|
||||||
is_stop = False
|
is_stop = False
|
||||||
for stop_seq in stop:
|
for stop_seq in invocation_params["stop_sequences"]:
|
||||||
if stop_seq in token.text:
|
if stop_seq in token.text:
|
||||||
is_stop = True
|
is_stop = True
|
||||||
break
|
break
|
||||||
|
@ -0,0 +1,19 @@
|
|||||||
|
from langchain import HuggingFaceTextGenInference
|
||||||
|
|
||||||
|
|
||||||
|
def test_invocation_params_stop_sequences() -> None:
|
||||||
|
llm = HuggingFaceTextGenInference()
|
||||||
|
assert llm._default_params["stop_sequences"] == []
|
||||||
|
|
||||||
|
runtime_stop = None
|
||||||
|
assert llm._invocation_params(runtime_stop)["stop_sequences"] == []
|
||||||
|
assert llm._default_params["stop_sequences"] == []
|
||||||
|
|
||||||
|
runtime_stop = ["stop"]
|
||||||
|
assert llm._invocation_params(runtime_stop)["stop_sequences"] == ["stop"]
|
||||||
|
assert llm._default_params["stop_sequences"] == []
|
||||||
|
|
||||||
|
llm = HuggingFaceTextGenInference(stop_sequences=["."])
|
||||||
|
runtime_stop = ["stop"]
|
||||||
|
assert llm._invocation_params(runtime_stop)["stop_sequences"] == [".", "stop"]
|
||||||
|
assert llm._default_params["stop_sequences"] == ["."]
|
Loading…
Reference in New Issue
Block a user