mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-05 06:33:20 +00:00
Added pipline args to HuggingFacePipeline.from_model_id
(#5268)
The current `HuggingFacePipeline.from_model_id` does not allow passing of pipeline arguments to the transformer pipeline. This PR enables adding important pipeline parameters like setting `max_new_tokens` for example. Previous to this PR it would be necessary to manually create the pipeline through huggingface transformers then handing it to langchain. For example instead of this ```py model_id = "gpt2" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10 ) hf = HuggingFacePipeline(pipeline=pipe) ``` You can write this ```py hf = HuggingFacePipeline.from_model_id( model_id="gpt2", task="text-generation", pipeline_kwargs={"max_new_tokens": 10} ) ``` Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
f01dfe858d
commit
2ef5579eae
@ -28,7 +28,9 @@ class HuggingFacePipeline(LLM):
|
|||||||
|
|
||||||
from langchain.llms import HuggingFacePipeline
|
from langchain.llms import HuggingFacePipeline
|
||||||
hf = HuggingFacePipeline.from_model_id(
|
hf = HuggingFacePipeline.from_model_id(
|
||||||
model_id="gpt2", task="text-generation"
|
model_id="gpt2",
|
||||||
|
task="text-generation",
|
||||||
|
pipeline_kwargs={"max_new_tokens": 10},
|
||||||
)
|
)
|
||||||
Example passing pipeline in directly:
|
Example passing pipeline in directly:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -49,7 +51,9 @@ class HuggingFacePipeline(LLM):
|
|||||||
model_id: str = DEFAULT_MODEL_ID
|
model_id: str = DEFAULT_MODEL_ID
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
model_kwargs: Optional[dict] = None
|
model_kwargs: Optional[dict] = None
|
||||||
"""Key word arguments to pass to the model."""
|
"""Key word arguments passed to the model."""
|
||||||
|
pipeline_kwargs: Optional[dict] = None
|
||||||
|
"""Key word arguments passed to the pipeline."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -63,6 +67,7 @@ class HuggingFacePipeline(LLM):
|
|||||||
task: str,
|
task: str,
|
||||||
device: int = -1,
|
device: int = -1,
|
||||||
model_kwargs: Optional[dict] = None,
|
model_kwargs: Optional[dict] = None,
|
||||||
|
pipeline_kwargs: Optional[dict] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLM:
|
) -> LLM:
|
||||||
"""Construct the pipeline object from model_id and task."""
|
"""Construct the pipeline object from model_id and task."""
|
||||||
@ -119,12 +124,14 @@ class HuggingFacePipeline(LLM):
|
|||||||
_model_kwargs = {
|
_model_kwargs = {
|
||||||
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
|
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
|
||||||
}
|
}
|
||||||
|
_pipeline_kwargs = pipeline_kwargs or {}
|
||||||
pipeline = hf_pipeline(
|
pipeline = hf_pipeline(
|
||||||
task=task,
|
task=task,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
device=device,
|
device=device,
|
||||||
model_kwargs=_model_kwargs,
|
model_kwargs=_model_kwargs,
|
||||||
|
**_pipeline_kwargs,
|
||||||
)
|
)
|
||||||
if pipeline.task not in VALID_TASKS:
|
if pipeline.task not in VALID_TASKS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -135,6 +142,7 @@ class HuggingFacePipeline(LLM):
|
|||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_kwargs=_model_kwargs,
|
model_kwargs=_model_kwargs,
|
||||||
|
pipeline_kwargs=_pipeline_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -142,8 +150,9 @@ class HuggingFacePipeline(LLM):
|
|||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
return {
|
return {
|
||||||
**{"model_id": self.model_id},
|
"model_id": self.model_id,
|
||||||
**{"model_kwargs": self.model_kwargs},
|
"model_kwargs": self.model_kwargs,
|
||||||
|
"pipeline_kwargs": self.pipeline_kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
Loading…
Reference in New Issue
Block a user