Compare commits

...

3 Commits

Author SHA1 Message Date
Bagatur
a86f4942e8 wip 2023-12-19 16:18:02 -05:00
Bagatur
fe1230d0b3 Merge branch 'master' into bagatur/chat_hf 2023-12-19 12:08:21 -05:00
Bagatur
efac57d6ab wip 2023-12-17 16:31:48 -05:00
3 changed files with 500 additions and 0 deletions

View File

@@ -0,0 +1,159 @@
from typing import Any, Dict, List, Optional, Sequence
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_community.llms.utils import enforce_stop_tokens
def _messages_to_dict(messages: Sequence[BaseMessage]) -> Dict[str, Any]:
"""Convert BaseMessage sequence to conversational HF inference API input.
Assumes the last message in the sequence is the current input.
Assumes any message that's not an AIMessage is a past user input.
"""
past_user_inputs = []
generated_responses = []
for msg in messages[:-1]:
if isinstance(msg, AIMessage):
generated_responses.append(msg.content)
else:
past_user_inputs.append(msg.content)
return {
"past_user_inputs": past_user_inputs,
"generated_responses": generated_responses,
"text": messages[-1].content,
}
class ChatHuggingFaceEndpoint(BaseChatModel):
"""HuggingFace Endpoint chat models.
To use, you should have the ``huggingface_hub`` python package installed, and the
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
Only supports `conversational` task for now.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatHuggingFaceEndpoint
from langchain_core.messages import HumanMessage
endpoint_url = (
"https://abcdefghijklmnop.us-east-1.aws.endpoints.huggingface.cloud"
)
chat = ChatHuggingFaceEndpoint(
endpoint_url=endpoint_url,
huggingfacehub_api_token="my-api-key"
)
chat.invoke([HumanMessage(content="Write a plot for a Christmas movie")])
"""
endpoint_url: str
"""Endpoint URL to use."""
model_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
task: str = "conversational"
huggingfacehub_api_token: Optional[str] = None
"""HuggingFace Hub API token.
If not specified will be read from environment variable 'HUGGINGFACEHUB_API_TOKEN'.
"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
values["huggingfacehub_api_token"] = huggingfacehub_api_token
return values
def validate_api_token(cls, huggingfacehub_api_token: str) -> None:
try:
from huggingface_hub.hf_api import HfApi
except ImportError:
raise ImportError(
"Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`."
)
try:
HfApi(
endpoint="https://huggingface.co", # Can be a Private Hub endpoint.
token=huggingfacehub_api_token,
).whoami()
except Exception as e:
raise ValueError(
"Could not authenticate with huggingface_hub. "
"Please check your API token."
) from e
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"endpoint_url": self.endpoint_url,
"task": self.task,
"model_kwargs": self.model_kwargs,
}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Call out to HuggingFace Hub's inference endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The ChatResult containing the chat response generated by the model.
"""
params = {**self.model_kwargs, **kwargs}
inputs = _messages_to_dict(messages)
payload = {"inputs": inputs, "parameters": params}
headers = {
"Authorization": f"Bearer {self.huggingfacehub_api_token}",
"Content-Type": "application/json",
}
response = requests.post(self.endpoint_url, headers=headers, json=payload)
if "error" in response.json():
raise ValueError(
f"Error raised by inference API: {response.json()['error']}"
)
generated_text = response.json()["generated_text"]
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
generated_text = enforce_stop_tokens(generated_text, stop)
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=generated_text))]
)
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "chat_huggingface_endpoint"

View File

@@ -0,0 +1,145 @@
from typing import Any, Dict, List, Mapping, Optional, Sequence
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_community.llms.utils import enforce_stop_tokens
def _messages_to_dict(messages: Sequence[BaseMessage]) -> Dict[str, Any]:
"""Convert BaseMessage sequence to conversational HF inference API input.
Assumes the last message in the sequence is the current input.
Assumes any message that's not an AIMessage is a past user input.
"""
past_user_inputs = []
generated_responses = []
for msg in messages[:-1]:
if isinstance(msg, AIMessage):
generated_responses.append(msg.content)
else:
past_user_inputs.append(msg.content)
return {
"past_user_inputs": past_user_inputs,
"generated_responses": generated_responses,
"text": messages[-1].content,
}
class ChatHuggingFaceHub(BaseChatModel):
"""HuggingFaceHub chat models.
Uses the ``huggingfacehub.inference_api.InferenceAPI`` client.
Install with `pip install huggingface-hub`.
To use, you should have the ``huggingface_hub`` python package installed, and the
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
Only supports `conversational` task for now.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatHuggingFaceHub
chat = ChatHuggingHub(
repo_id="repo/id",
huggingfacehub_api_token="my-api-key"
)
chat.invoke([HumanMessage(content="Write a plot for a Christmas movie")])
"""
client: Any = Field(exclude=True) #: :meta private:
repo_id: str
"""Model name to use."""
task: str = "conversational"
model_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
huggingfacehub_api_token: Optional[str] = None
"""HuggingFace Hub API token.
If not specified will be read from environment variable 'HUGGINGFACEHUB_API_TOKEN'.
"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
try:
from huggingface_hub.inference_api import InferenceApi
except ImportError:
raise ImportError(
"Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`."
)
values["client"] = InferenceApi(
repo_id=values["repo_id"],
token=huggingfacehub_api_token,
task=values.get("task"),
)
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"repo_id": self.repo_id,
"task": self.task,
"model_kwargs": self.model_kwargs,
}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Call out to HuggingFace Hub's inference endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The ChatResult containing the chat response generated by the model.
"""
_model_kwargs = self.model_kwargs or {}
# payload samples
params = {**self.model_kwargs, **kwargs}
inputs = _messages_to_dict(messages)
response = self.client(inputs=inputs, params=params)
if "error" in response:
raise ValueError(f"Error raised by inference API: {response['error']}")
generated_text = response["generated_text"]
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
generated_text = enforce_stop_tokens(generated_text, stop)
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=generated_text))]
)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "chat_huggingface_hub"

View File

@@ -0,0 +1,196 @@
from __future__ import annotations
import importlib.util
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import Extra, Field
from langchain_community.adapters.openai import convert_message_to_dict
from langchain_community.llms.utils import enforce_stop_tokens
if TYPE_CHECKING:
from transformers import Conversation
logger = logging.getLogger(__name__)
def _messages_to_conversation(messages: Sequence[BaseMessage]) -> Conversation:
"""Convert messages to transformers Conversation.
Uses OpenAI message role conventions: AIMessage has role "assistant",
HumanMessage has role "user".
"""
from transformers import Conversation
return Conversation(messages=[convert_message_to_dict(msg) for msg in messages])
class ChatHuggingFacePipeline(BaseChatModel):
"""HuggingFace Pipeline chat model API.
To use, you should have the ``transformers`` python package installed.
Only supports `conversational` task for now.
Example using from_model_id:
.. code-block:: python
from langchain_community.chat_models import ChatHuggingFacePipeline
chat = ChatHuggingFacePipeline.from_model_id(
model_id="",
pipeline_kwargs={"max_new_tokens": 10},
)
Example passing pipeline in directly:
.. code-block:: python
from langchain_community.chat_models import ChatHuggingFacePipeline
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
model_id = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
pipe = pipeline(
"conversational", model=model, tokenizer=tokenizer, max_new_tokens=10
)
chat = HuggingFacePipeline(pipeline=pipe)
"""
pipeline: Any = Field(exclude=True) #: :meta private:
model_id: str
"""Model name to use."""
model_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments passed to the model."""
pipeline_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments passed to the pipeline."""
messages_to_conversation: Callable = _messages_to_conversation
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@classmethod
def from_model_id(
cls,
model_id: str,
*,
task: str = "conversational",
device: Optional[int] = -1,
device_map: Optional[str] = None,
model_kwargs: Optional[dict] = None,
pipeline_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> ChatHuggingFacePipeline:
"""Construct the pipeline object from model_id and task."""
try:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import pipeline as hf_pipeline
except ImportError as e:
raise ImportError(
"Could not import transformers python package. "
"Please install it with `pip install transformers`."
) from e
_model_kwargs = model_kwargs or {}
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
try:
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
except ImportError as e:
raise ImportError(
f"Could not load the {task} model due to missing dependencies."
) from e
if tokenizer.pad_token is None:
tokenizer.pad_token_id = model.config.eos_token_id
if (
getattr(model, "is_loaded_in_4bit", False)
or getattr(model, "is_loaded_in_8bit", False)
) and device is not None:
logger.warning(
f"Setting the `device` argument to None from {device} to avoid "
"the error caused by attempting to move the model that was already "
"loaded on the GPU using the Accelerate module to the same or "
"another device."
)
device = None
if device is not None and importlib.util.find_spec("torch") is not None:
import torch
cuda_device_count = torch.cuda.device_count()
if device < -1 or (device >= cuda_device_count):
raise ValueError(
f"Got device=={device}, "
f"device is required to be within [-1, {cuda_device_count})"
)
if device_map is not None and device < 0:
device = None
if device is not None and device < 0 and cuda_device_count > 0:
logger.warning(
"Device has %d GPUs available. "
"Provide device={deviceId} to `from_model_id` to use available"
"GPUs for execution. deviceId is -1 (default) for CPU and "
"can be a positive integer associated with CUDA device id.",
cuda_device_count,
)
if "trust_remote_code" in _model_kwargs:
_model_kwargs = {
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
}
_pipeline_kwargs = pipeline_kwargs or {}
pipeline = hf_pipeline(
task=task,
model=model,
tokenizer=tokenizer,
device=device,
device_map=device_map,
model_kwargs=_model_kwargs,
**_pipeline_kwargs,
)
return cls(
pipeline=pipeline,
model_id=model_id,
model_kwargs=_model_kwargs,
pipeline_kwargs=_pipeline_kwargs,
**kwargs,
)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"model_id": self.model_id,
"model_kwargs": self.model_kwargs,
"pipeline_kwargs": self.pipeline_kwargs,
}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
conversation = self.messages_to_conversation(messages)
conversation = self.pipeline(conversation)
generated_text = conversation.messages[-1]["content"]
if stop:
# Enforce stop tokens
generated_text = enforce_stop_tokens(generated_text, stop)
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=generated_text))]
)
@property
def _llm_type(self) -> str:
return "chat_huggingface_pipeline"