community[patch]: Fix generation_config not setting properly for DeepSparse (#15036)

- **Description:** Tiny but important bugfix to use a more stable
interface for specifying generation_config parameters for DeepSparse LLM
This commit is contained in:
Michael Goin 2023-12-21 22:39:22 -08:00 committed by GitHub
parent 2460f977c5
commit 501cc8311d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -63,7 +63,7 @@ class DeepSparse(LLM):
except ImportError:
raise ImportError(
"Could not import `deepsparse` package. "
"Please install it with `pip install deepsparse`"
"Please install it with `pip install deepsparse[llm]`"
)
model_config = values["model_config"] or {}
@ -103,9 +103,7 @@ class DeepSparse(LLM):
text = combined_output
else:
text = (
self.pipeline(
sequences=prompt, generation_config=self.generation_config
)
self.pipeline(sequences=prompt, **self.generation_config)
.generations[0]
.text
)
@ -143,9 +141,7 @@ class DeepSparse(LLM):
text = combined_output
else:
text = (
self.pipeline(
sequences=prompt, generation_config=self.generation_config
)
self.pipeline(sequences=prompt, **self.generation_config)
.generations[0]
.text
)
@ -184,7 +180,7 @@ class DeepSparse(LLM):
print(chunk, end='', flush=True)
"""
inference = self.pipeline(
sequences=prompt, generation_config=self.generation_config, streaming=True
sequences=prompt, streaming=True, **self.generation_config
)
for token in inference:
chunk = GenerationChunk(text=token.generations[0].text)
@ -222,7 +218,7 @@ class DeepSparse(LLM):
print(chunk, end='', flush=True)
"""
inference = self.pipeline(
sequences=prompt, generation_config=self.generation_config, streaming=True
sequences=prompt, streaming=True, **self.generation_config
)
for token in inference:
chunk = GenerationChunk(text=token.generations[0].text)