mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-03 03:38:06 +00:00
Harrison/version 0040 (#366)
This commit is contained in:
parent
50257fce59
commit
a7084ad6e4
@ -61,24 +61,36 @@ class HuggingFacePipeline(LLM, BaseModel):
|
||||
) -> LLM:
|
||||
"""Construct the pipeline object from model_id and task."""
|
||||
try:
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from transformers import pipeline as hf_pipeline
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
if task == "text-generation":
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
elif task == "text2text-generation":
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
_model_kwargs = model_kwargs or {}
|
||||
pipeline = hf_pipeline(
|
||||
task=task, model=model, tokenizer=tokenizer, **model_kwargs
|
||||
task=task, model=model, tokenizer=tokenizer, **_model_kwargs
|
||||
)
|
||||
if pipeline.task not in VALID_TASKS:
|
||||
raise ValueError(
|
||||
f"Got invalid task {pipeline.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
|
||||
return cls(
|
||||
pipeline=pipeline,
|
||||
model_id=model_id,
|
||||
model_kwargs=model_kwargs,
|
||||
model_kwargs=_model_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
except ImportError:
|
||||
@ -100,7 +112,7 @@ class HuggingFacePipeline(LLM, BaseModel):
|
||||
return "huggingface_pipeline"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
response = self.pipeline(text_inputs=prompt)
|
||||
response = self.pipeline(prompt)
|
||||
if self.pipeline.task == "text-generation":
|
||||
# Text generation return includes the starter text.
|
||||
text = response[0]["generated_text"][len(prompt) :]
|
||||
|
@ -163,6 +163,9 @@ class OpenAI(LLM, BaseModel):
|
||||
def stream(self, prompt: str) -> Generator:
|
||||
"""Call OpenAI with streaming flag and return the resulting generator.
|
||||
|
||||
BETA: this is a beta feature while we figure out the right abstraction.
|
||||
Once that happens, this interface could change.
|
||||
|
||||
Args:
|
||||
prompt: The prompts to pass into the model.
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain"
|
||||
version = "0.0.39"
|
||||
version = "0.0.40"
|
||||
description = "Building applications with LLMs through composability"
|
||||
authors = []
|
||||
license = "MIT"
|
||||
|
@ -18,6 +18,15 @@ def test_huggingface_pipeline_text_generation() -> None:
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_huggingface_pipeline_text2text_generation() -> None:
|
||||
"""Test valid call to HuggingFace text2text generation model."""
|
||||
llm = HuggingFacePipeline.from_model_id(
|
||||
model_id="google/flan-t5-small", task="text2text-generation"
|
||||
)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an HuggingFaceHub LLM."""
|
||||
llm = HuggingFacePipeline.from_model_id(
|
||||
|
Loading…
Reference in New Issue
Block a user