mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 20:49:17 +00:00
Added pipline args to HuggingFacePipeline.from_model_id
(#5268)
The current `HuggingFacePipeline.from_model_id` does not allow passing of pipeline arguments to the transformer pipeline. This PR enables adding important pipeline parameters like setting `max_new_tokens` for example. Previous to this PR it would be necessary to manually create the pipeline through huggingface transformers then handing it to langchain. For example instead of this ```py model_id = "gpt2" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10 ) hf = HuggingFacePipeline(pipeline=pipe) ``` You can write this ```py hf = HuggingFacePipeline.from_model_id( model_id="gpt2", task="text-generation", pipeline_kwargs={"max_new_tokens": 10} ) ``` Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
f01dfe858d
commit
2ef5579eae
@ -28,7 +28,9 @@ class HuggingFacePipeline(LLM):
|
||||
|
||||
from langchain.llms import HuggingFacePipeline
|
||||
hf = HuggingFacePipeline.from_model_id(
|
||||
model_id="gpt2", task="text-generation"
|
||||
model_id="gpt2",
|
||||
task="text-generation",
|
||||
pipeline_kwargs={"max_new_tokens": 10},
|
||||
)
|
||||
Example passing pipeline in directly:
|
||||
.. code-block:: python
|
||||
@ -49,7 +51,9 @@ class HuggingFacePipeline(LLM):
|
||||
model_id: str = DEFAULT_MODEL_ID
|
||||
"""Model name to use."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
"""Key word arguments passed to the model."""
|
||||
pipeline_kwargs: Optional[dict] = None
|
||||
"""Key word arguments passed to the pipeline."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -63,6 +67,7 @@ class HuggingFacePipeline(LLM):
|
||||
task: str,
|
||||
device: int = -1,
|
||||
model_kwargs: Optional[dict] = None,
|
||||
pipeline_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLM:
|
||||
"""Construct the pipeline object from model_id and task."""
|
||||
@ -119,12 +124,14 @@ class HuggingFacePipeline(LLM):
|
||||
_model_kwargs = {
|
||||
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
|
||||
}
|
||||
_pipeline_kwargs = pipeline_kwargs or {}
|
||||
pipeline = hf_pipeline(
|
||||
task=task,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
model_kwargs=_model_kwargs,
|
||||
**_pipeline_kwargs,
|
||||
)
|
||||
if pipeline.task not in VALID_TASKS:
|
||||
raise ValueError(
|
||||
@ -135,6 +142,7 @@ class HuggingFacePipeline(LLM):
|
||||
pipeline=pipeline,
|
||||
model_id=model_id,
|
||||
model_kwargs=_model_kwargs,
|
||||
pipeline_kwargs=_pipeline_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -142,8 +150,9 @@ class HuggingFacePipeline(LLM):
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
**{"model_id": self.model_id},
|
||||
**{"model_kwargs": self.model_kwargs},
|
||||
"model_id": self.model_id,
|
||||
"model_kwargs": self.model_kwargs,
|
||||
"pipeline_kwargs": self.pipeline_kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
|
Loading…
Reference in New Issue
Block a user