mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-18 21:09:00 +00:00
community: add runtime kwargs to HuggingFacePipeline (#17005)
This PR enables changing the behaviour of huggingface pipeline between different calls. For example, before this PR there's no way of changing maximum generation length between different invocations of the chain. This is desirable in cases, such as when we want to scale the maximum output size depending on a dynamic prompt size. Usage example: ```python from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline model_id = "gpt2" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) hf = HuggingFacePipeline(pipeline=pipe) hf("Say foo:", pipeline_kwargs={"max_new_tokens": 42}) ``` --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
a32798abd7
commit
641efcf41c
@ -195,12 +195,13 @@ class HuggingFacePipeline(BaseLLM):
|
||||
) -> LLMResult:
|
||||
# List to hold all results
|
||||
text_generations: List[str] = []
|
||||
pipeline_kwargs = kwargs.get("pipeline_kwargs", {})
|
||||
|
||||
for i in range(0, len(prompts), self.batch_size):
|
||||
batch_prompts = prompts[i : i + self.batch_size]
|
||||
|
||||
# Process batch of prompts
|
||||
responses = self.pipeline(batch_prompts)
|
||||
responses = self.pipeline(batch_prompts, **pipeline_kwargs)
|
||||
|
||||
# Process each response in the batch
|
||||
for j, response in enumerate(responses):
|
||||
|
@ -69,3 +69,14 @@ def test_init_with_pipeline() -> None:
|
||||
llm = HuggingFacePipeline(pipeline=pipe)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_huggingface_pipeline_runtime_kwargs() -> None:
|
||||
"""Test pipelines specifying the device map parameter."""
|
||||
llm = HuggingFacePipeline.from_model_id(
|
||||
model_id="gpt2",
|
||||
task="text-generation",
|
||||
)
|
||||
prompt = "Say foo:"
|
||||
output = llm(prompt, pipeline_kwargs={"max_new_tokens": 2})
|
||||
assert len(output) < 10
|
||||
|
Loading…
Reference in New Issue
Block a user