mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
community[patch]: Fix generation_config not setting properly for DeepSparse (#15036)
- **Description:** Tiny but important bugfix to use a more stable interface for specifying generation_config parameters for DeepSparse LLM
This commit is contained in:
parent
2460f977c5
commit
501cc8311d
@ -63,7 +63,7 @@ class DeepSparse(LLM):
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import `deepsparse` package. "
|
||||
"Please install it with `pip install deepsparse`"
|
||||
"Please install it with `pip install deepsparse[llm]`"
|
||||
)
|
||||
|
||||
model_config = values["model_config"] or {}
|
||||
@ -103,9 +103,7 @@ class DeepSparse(LLM):
|
||||
text = combined_output
|
||||
else:
|
||||
text = (
|
||||
self.pipeline(
|
||||
sequences=prompt, generation_config=self.generation_config
|
||||
)
|
||||
self.pipeline(sequences=prompt, **self.generation_config)
|
||||
.generations[0]
|
||||
.text
|
||||
)
|
||||
@ -143,9 +141,7 @@ class DeepSparse(LLM):
|
||||
text = combined_output
|
||||
else:
|
||||
text = (
|
||||
self.pipeline(
|
||||
sequences=prompt, generation_config=self.generation_config
|
||||
)
|
||||
self.pipeline(sequences=prompt, **self.generation_config)
|
||||
.generations[0]
|
||||
.text
|
||||
)
|
||||
@ -184,7 +180,7 @@ class DeepSparse(LLM):
|
||||
print(chunk, end='', flush=True)
|
||||
"""
|
||||
inference = self.pipeline(
|
||||
sequences=prompt, generation_config=self.generation_config, streaming=True
|
||||
sequences=prompt, streaming=True, **self.generation_config
|
||||
)
|
||||
for token in inference:
|
||||
chunk = GenerationChunk(text=token.generations[0].text)
|
||||
@ -222,7 +218,7 @@ class DeepSparse(LLM):
|
||||
print(chunk, end='', flush=True)
|
||||
"""
|
||||
inference = self.pipeline(
|
||||
sequences=prompt, generation_config=self.generation_config, streaming=True
|
||||
sequences=prompt, streaming=True, **self.generation_config
|
||||
)
|
||||
for token in inference:
|
||||
chunk = GenerationChunk(text=token.generations[0].text)
|
||||
|
Loading…
Reference in New Issue
Block a user