mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-04 00:00:34 +00:00
Compare commits
3 Commits
bagatur/fe
...
bagatur/ch
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a86f4942e8 | ||
|
|
fe1230d0b3 | ||
|
|
efac57d6ab |
@@ -0,0 +1,159 @@
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
|
||||
def _messages_to_dict(messages: Sequence[BaseMessage]) -> Dict[str, Any]:
|
||||
"""Convert BaseMessage sequence to conversational HF inference API input.
|
||||
|
||||
Assumes the last message in the sequence is the current input.
|
||||
Assumes any message that's not an AIMessage is a past user input.
|
||||
"""
|
||||
past_user_inputs = []
|
||||
generated_responses = []
|
||||
for msg in messages[:-1]:
|
||||
if isinstance(msg, AIMessage):
|
||||
generated_responses.append(msg.content)
|
||||
else:
|
||||
past_user_inputs.append(msg.content)
|
||||
return {
|
||||
"past_user_inputs": past_user_inputs,
|
||||
"generated_responses": generated_responses,
|
||||
"text": messages[-1].content,
|
||||
}
|
||||
|
||||
|
||||
class ChatHuggingFaceEndpoint(BaseChatModel):
|
||||
"""HuggingFace Endpoint chat models.
|
||||
|
||||
To use, you should have the ``huggingface_hub`` python package installed, and the
|
||||
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Only supports `conversational` task for now.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatHuggingFaceEndpoint
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
endpoint_url = (
|
||||
"https://abcdefghijklmnop.us-east-1.aws.endpoints.huggingface.cloud"
|
||||
)
|
||||
chat = ChatHuggingFaceEndpoint(
|
||||
endpoint_url=endpoint_url,
|
||||
huggingfacehub_api_token="my-api-key"
|
||||
)
|
||||
chat.invoke([HumanMessage(content="Write a plot for a Christmas movie")])
|
||||
"""
|
||||
|
||||
endpoint_url: str
|
||||
"""Endpoint URL to use."""
|
||||
|
||||
model_kwargs: dict = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
task: str = "conversational"
|
||||
|
||||
huggingfacehub_api_token: Optional[str] = None
|
||||
"""HuggingFace Hub API token.
|
||||
|
||||
If not specified will be read from environment variable 'HUGGINGFACEHUB_API_TOKEN'.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
huggingfacehub_api_token = get_from_dict_or_env(
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
values["huggingfacehub_api_token"] = huggingfacehub_api_token
|
||||
return values
|
||||
|
||||
def validate_api_token(cls, huggingfacehub_api_token: str) -> None:
|
||||
try:
|
||||
from huggingface_hub.hf_api import HfApi
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import huggingface_hub python package. "
|
||||
"Please install it with `pip install huggingface_hub`."
|
||||
)
|
||||
try:
|
||||
HfApi(
|
||||
endpoint="https://huggingface.co", # Can be a Private Hub endpoint.
|
||||
token=huggingfacehub_api_token,
|
||||
).whoami()
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not authenticate with huggingface_hub. "
|
||||
"Please check your API token."
|
||||
) from e
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"endpoint_url": self.endpoint_url,
|
||||
"task": self.task,
|
||||
"model_kwargs": self.model_kwargs,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Call out to HuggingFace Hub's inference endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The ChatResult containing the chat response generated by the model.
|
||||
|
||||
"""
|
||||
params = {**self.model_kwargs, **kwargs}
|
||||
inputs = _messages_to_dict(messages)
|
||||
payload = {"inputs": inputs, "parameters": params}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.huggingfacehub_api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.post(self.endpoint_url, headers=headers, json=payload)
|
||||
if "error" in response.json():
|
||||
raise ValueError(
|
||||
f"Error raised by inference API: {response.json()['error']}"
|
||||
)
|
||||
generated_text = response.json()["generated_text"]
|
||||
|
||||
if stop is not None:
|
||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||
# stop tokens when making calls to huggingface_hub.
|
||||
generated_text = enforce_stop_tokens(generated_text, stop)
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=generated_text))]
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "chat_huggingface_endpoint"
|
||||
@@ -0,0 +1,145 @@
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
|
||||
def _messages_to_dict(messages: Sequence[BaseMessage]) -> Dict[str, Any]:
|
||||
"""Convert BaseMessage sequence to conversational HF inference API input.
|
||||
|
||||
Assumes the last message in the sequence is the current input.
|
||||
Assumes any message that's not an AIMessage is a past user input.
|
||||
"""
|
||||
past_user_inputs = []
|
||||
generated_responses = []
|
||||
for msg in messages[:-1]:
|
||||
if isinstance(msg, AIMessage):
|
||||
generated_responses.append(msg.content)
|
||||
else:
|
||||
past_user_inputs.append(msg.content)
|
||||
return {
|
||||
"past_user_inputs": past_user_inputs,
|
||||
"generated_responses": generated_responses,
|
||||
"text": messages[-1].content,
|
||||
}
|
||||
|
||||
|
||||
class ChatHuggingFaceHub(BaseChatModel):
|
||||
"""HuggingFaceHub chat models.
|
||||
|
||||
Uses the ``huggingfacehub.inference_api.InferenceAPI`` client.
|
||||
Install with `pip install huggingface-hub`.
|
||||
|
||||
To use, you should have the ``huggingface_hub`` python package installed, and the
|
||||
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Only supports `conversational` task for now.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatHuggingFaceHub
|
||||
|
||||
chat = ChatHuggingHub(
|
||||
repo_id="repo/id",
|
||||
huggingfacehub_api_token="my-api-key"
|
||||
)
|
||||
chat.invoke([HumanMessage(content="Write a plot for a Christmas movie")])
|
||||
"""
|
||||
|
||||
client: Any = Field(exclude=True) #: :meta private:
|
||||
|
||||
repo_id: str
|
||||
"""Model name to use."""
|
||||
|
||||
task: str = "conversational"
|
||||
|
||||
model_kwargs: dict = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
huggingfacehub_api_token: Optional[str] = None
|
||||
"""HuggingFace Hub API token.
|
||||
|
||||
If not specified will be read from environment variable 'HUGGINGFACEHUB_API_TOKEN'.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
huggingfacehub_api_token = get_from_dict_or_env(
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
try:
|
||||
from huggingface_hub.inference_api import InferenceApi
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import huggingface_hub python package. "
|
||||
"Please install it with `pip install huggingface_hub`."
|
||||
)
|
||||
values["client"] = InferenceApi(
|
||||
repo_id=values["repo_id"],
|
||||
token=huggingfacehub_api_token,
|
||||
task=values.get("task"),
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"repo_id": self.repo_id,
|
||||
"task": self.task,
|
||||
"model_kwargs": self.model_kwargs,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Call out to HuggingFace Hub's inference endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The ChatResult containing the chat response generated by the model.
|
||||
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
# payload samples
|
||||
params = {**self.model_kwargs, **kwargs}
|
||||
inputs = _messages_to_dict(messages)
|
||||
response = self.client(inputs=inputs, params=params)
|
||||
if "error" in response:
|
||||
raise ValueError(f"Error raised by inference API: {response['error']}")
|
||||
generated_text = response["generated_text"]
|
||||
|
||||
if stop is not None:
|
||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||
# stop tokens when making calls to huggingface_hub.
|
||||
generated_text = enforce_stop_tokens(generated_text, stop)
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=generated_text))]
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "chat_huggingface_hub"
|
||||
@@ -0,0 +1,196 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.pydantic_v1 import Extra, Field
|
||||
|
||||
from langchain_community.adapters.openai import convert_message_to_dict
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Conversation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _messages_to_conversation(messages: Sequence[BaseMessage]) -> Conversation:
|
||||
"""Convert messages to transformers Conversation.
|
||||
|
||||
Uses OpenAI message role conventions: AIMessage has role "assistant",
|
||||
HumanMessage has role "user".
|
||||
"""
|
||||
from transformers import Conversation
|
||||
|
||||
return Conversation(messages=[convert_message_to_dict(msg) for msg in messages])
|
||||
|
||||
|
||||
class ChatHuggingFacePipeline(BaseChatModel):
|
||||
"""HuggingFace Pipeline chat model API.
|
||||
|
||||
To use, you should have the ``transformers`` python package installed.
|
||||
|
||||
Only supports `conversational` task for now.
|
||||
|
||||
Example using from_model_id:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatHuggingFacePipeline
|
||||
|
||||
chat = ChatHuggingFacePipeline.from_model_id(
|
||||
model_id="",
|
||||
pipeline_kwargs={"max_new_tokens": 10},
|
||||
)
|
||||
|
||||
Example passing pipeline in directly:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatHuggingFacePipeline
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
|
||||
|
||||
model_id = "gpt2"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||
pipe = pipeline(
|
||||
"conversational", model=model, tokenizer=tokenizer, max_new_tokens=10
|
||||
)
|
||||
chat = HuggingFacePipeline(pipeline=pipe)
|
||||
"""
|
||||
|
||||
pipeline: Any = Field(exclude=True) #: :meta private:
|
||||
model_id: str
|
||||
"""Model name to use."""
|
||||
model_kwargs: dict = Field(default_factory=dict)
|
||||
"""Keyword arguments passed to the model."""
|
||||
pipeline_kwargs: dict = Field(default_factory=dict)
|
||||
"""Keyword arguments passed to the pipeline."""
|
||||
messages_to_conversation: Callable = _messages_to_conversation
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@classmethod
|
||||
def from_model_id(
|
||||
cls,
|
||||
model_id: str,
|
||||
*,
|
||||
task: str = "conversational",
|
||||
device: Optional[int] = -1,
|
||||
device_map: Optional[str] = None,
|
||||
model_kwargs: Optional[dict] = None,
|
||||
pipeline_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatHuggingFacePipeline:
|
||||
"""Construct the pipeline object from model_id and task."""
|
||||
try:
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from transformers import pipeline as hf_pipeline
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import transformers python package. "
|
||||
"Please install it with `pip install transformers`."
|
||||
) from e
|
||||
|
||||
_model_kwargs = model_kwargs or {}
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||
|
||||
try:
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
f"Could not load the {task} model due to missing dependencies."
|
||||
) from e
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token_id = model.config.eos_token_id
|
||||
|
||||
if (
|
||||
getattr(model, "is_loaded_in_4bit", False)
|
||||
or getattr(model, "is_loaded_in_8bit", False)
|
||||
) and device is not None:
|
||||
logger.warning(
|
||||
f"Setting the `device` argument to None from {device} to avoid "
|
||||
"the error caused by attempting to move the model that was already "
|
||||
"loaded on the GPU using the Accelerate module to the same or "
|
||||
"another device."
|
||||
)
|
||||
device = None
|
||||
|
||||
if device is not None and importlib.util.find_spec("torch") is not None:
|
||||
import torch
|
||||
|
||||
cuda_device_count = torch.cuda.device_count()
|
||||
if device < -1 or (device >= cuda_device_count):
|
||||
raise ValueError(
|
||||
f"Got device=={device}, "
|
||||
f"device is required to be within [-1, {cuda_device_count})"
|
||||
)
|
||||
if device_map is not None and device < 0:
|
||||
device = None
|
||||
if device is not None and device < 0 and cuda_device_count > 0:
|
||||
logger.warning(
|
||||
"Device has %d GPUs available. "
|
||||
"Provide device={deviceId} to `from_model_id` to use available"
|
||||
"GPUs for execution. deviceId is -1 (default) for CPU and "
|
||||
"can be a positive integer associated with CUDA device id.",
|
||||
cuda_device_count,
|
||||
)
|
||||
if "trust_remote_code" in _model_kwargs:
|
||||
_model_kwargs = {
|
||||
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
|
||||
}
|
||||
_pipeline_kwargs = pipeline_kwargs or {}
|
||||
pipeline = hf_pipeline(
|
||||
task=task,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
device_map=device_map,
|
||||
model_kwargs=_model_kwargs,
|
||||
**_pipeline_kwargs,
|
||||
)
|
||||
return cls(
|
||||
pipeline=pipeline,
|
||||
model_id=model_id,
|
||||
model_kwargs=_model_kwargs,
|
||||
pipeline_kwargs=_pipeline_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model_id": self.model_id,
|
||||
"model_kwargs": self.model_kwargs,
|
||||
"pipeline_kwargs": self.pipeline_kwargs,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
conversation = self.messages_to_conversation(messages)
|
||||
conversation = self.pipeline(conversation)
|
||||
generated_text = conversation.messages[-1]["content"]
|
||||
if stop:
|
||||
# Enforce stop tokens
|
||||
generated_text = enforce_stop_tokens(generated_text, stop)
|
||||
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=generated_text))]
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "chat_huggingface_pipeline"
|
||||
Reference in New Issue
Block a user