mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-12 06:13:36 +00:00
huggingface: init package (#21097)
First Pr for the langchain_huggingface partner Package - Moved some of the hugging face related class from `community` to the new `partner package` Still needed : - Documentation - Tests - Support for the new apply_chat_template in `ChatHuggingFace` - Confirm choice of class to support for embeddings witht he sentence-transformer team. cc : @efriis --------- Co-authored-by: Cyril Kondratenko <kkn1993@gmail.com> Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
9fce03e7db
commit
afd85b60fc
@ -9,9 +9,10 @@
|
||||
"This notebook shows how to get started using `Hugging Face` LLM's as chat models.\n",
|
||||
"\n",
|
||||
"In particular, we will:\n",
|
||||
"1. Utilize the [HuggingFaceTextGenInference](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/llms/huggingface_text_gen_inference.py), [HuggingFaceEndpoint](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/llms/huggingface_endpoint.py), or [HuggingFaceHub](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/llms/huggingface_hub.py) integrations to instantiate an `LLM`.\n",
|
||||
"2. Utilize the `ChatHuggingFace` class to enable any of these LLMs to interface with LangChain's [Chat Messages](/docs/concepts#chat-models) abstraction.\n",
|
||||
"3. Demonstrate how to use an open-source LLM to power an `ChatAgent` pipeline\n",
|
||||
"1. Utilize the [HuggingFaceEndpoint](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/llms/huggingface_endpoint.py) integrations to instantiate an `LLM`.\n",
|
||||
"2. Utilize the `ChatHuggingFace` class to enable any of these LLMs to interface with LangChain's [Chat Messages](/docs/concepts/#message-types) abstraction.\n",
|
||||
"3. Explore tool calling with the `ChatHuggingFace`.\n",
|
||||
"4. Demonstrate how to use an open-source LLM to power an `ChatAgent` pipeline\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"> Note: To get started, you'll need to have a [Hugging Face Access Token](https://huggingface.co/docs/hub/security-tokens) saved as an environment variable: `HUGGINGFACEHUB_API_TOKEN`."
|
||||
@ -21,15 +22,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install --upgrade --quiet text-generation transformers google-search-results numexpr langchainhub sentencepiece jinja2"
|
||||
]
|
||||
@ -38,44 +31,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Instantiate an LLM\n",
|
||||
"\n",
|
||||
"There are three LLM options to choose from."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### `HuggingFaceTextGenInference`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from langchain_community.llms import HuggingFaceTextGenInference\n",
|
||||
"\n",
|
||||
"ENDPOINT_URL = \"<YOUR_ENDPOINT_URL_HERE>\"\n",
|
||||
"HF_TOKEN = os.getenv(\"HUGGINGFACEHUB_API_TOKEN\")\n",
|
||||
"\n",
|
||||
"llm = HuggingFaceTextGenInference(\n",
|
||||
" inference_server_url=ENDPOINT_URL,\n",
|
||||
" max_new_tokens=512,\n",
|
||||
" top_k=50,\n",
|
||||
" temperature=0.1,\n",
|
||||
" repetition_penalty=1.03,\n",
|
||||
" server_kwargs={\n",
|
||||
" \"headers\": {\n",
|
||||
" \"Authorization\": f\"Bearer {HF_TOKEN}\",\n",
|
||||
" \"Content-Type\": \"application/json\",\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
")"
|
||||
"## 1. Instantiate an LLM"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -87,58 +43,18 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms import HuggingFaceEndpoint\n",
|
||||
"from langchain_huggingface.llms import HuggingFaceEndpoint\n",
|
||||
"\n",
|
||||
"ENDPOINT_URL = \"<YOUR_ENDPOINT_URL_HERE>\"\n",
|
||||
"llm = HuggingFaceEndpoint(\n",
|
||||
" endpoint_url=ENDPOINT_URL,\n",
|
||||
" repo_id=\"meta-llama/Meta-Llama-3-70B-Instruct\",\n",
|
||||
" task=\"text-generation\",\n",
|
||||
" model_kwargs={\n",
|
||||
" \"max_new_tokens\": 512,\n",
|
||||
" \"top_k\": 50,\n",
|
||||
" \"temperature\": 0.1,\n",
|
||||
" \"repetition_penalty\": 1.03,\n",
|
||||
" },\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### `HuggingFaceHub`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/jacoblee/langchain/langchain/libs/langchain/.venv/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:127: FutureWarning: '__init__' (from 'huggingface_hub.inference_api') is deprecated and will be removed from version '1.0'. `InferenceApi` client is deprecated in favor of the more feature-complete `InferenceClient`. Check out this guide to learn how to convert your script to use it: https://huggingface.co/docs/huggingface_hub/guides/inference#legacy-inferenceapi-client.\n",
|
||||
" warnings.warn(warning_message, FutureWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_community.llms import HuggingFaceHub\n",
|
||||
"\n",
|
||||
"llm = HuggingFaceHub(\n",
|
||||
" repo_id=\"HuggingFaceH4/zephyr-7b-beta\",\n",
|
||||
" task=\"text-generation\",\n",
|
||||
" model_kwargs={\n",
|
||||
" \"max_new_tokens\": 512,\n",
|
||||
" \"top_k\": 30,\n",
|
||||
" \"temperature\": 0.1,\n",
|
||||
" \"repetition_penalty\": 1.03,\n",
|
||||
" },\n",
|
||||
" max_new_tokens=512,\n",
|
||||
" do_sample=False,\n",
|
||||
" repetition_penalty=1.03,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -153,37 +69,30 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Instantiate the chat model and some messages to pass."
|
||||
"Instantiate the chat model and some messages to pass. \n",
|
||||
"\n",
|
||||
"**Note**: you need to pass the `model_id` explicitly if you are using self-hosted `text-generation-inference`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"WARNING! repo_id is not default parameter.\n",
|
||||
" repo_id was transferred to model_kwargs.\n",
|
||||
" Please confirm that repo_id is what you intended.\n",
|
||||
"WARNING! task is not default parameter.\n",
|
||||
" task was transferred to model_kwargs.\n",
|
||||
" Please confirm that task is what you intended.\n",
|
||||
"WARNING! huggingfacehub_api_token is not default parameter.\n",
|
||||
" huggingfacehub_api_token was transferred to model_kwargs.\n",
|
||||
" Please confirm that huggingfacehub_api_token is what you intended.\n",
|
||||
"None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n"
|
||||
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.schema import (\n",
|
||||
"from langchain_core.messages import (\n",
|
||||
" HumanMessage,\n",
|
||||
" SystemMessage,\n",
|
||||
")\n",
|
||||
"from langchain_community.chat_models.huggingface import ChatHuggingFace\n",
|
||||
"from langchain_huggingface.chat_models import ChatHuggingFace\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
" SystemMessage(content=\"You're a helpful assistant\"),\n",
|
||||
@ -199,21 +108,21 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Inspect which model and corresponding chat template is being used."
|
||||
"Check the `model_id`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'HuggingFaceH4/zephyr-7b-beta'"
|
||||
"'meta-llama/Meta-Llama-3-70B-Instruct'"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -231,16 +140,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"<|system|>\\nYou're a helpful assistant</s>\\n<|user|>\\nWhat happens when an unstoppable force meets an immovable object?</s>\\n<|assistant|>\\n\""
|
||||
"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou're a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nWhat happens when an unstoppable force meets an immovable object?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\""
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -258,14 +167,20 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"According to a popular philosophical paradox, when an unstoppable force meets an immovable object, it is impossible to determine which one will prevail because both are defined as being completely unyielding and unmovable. The paradox suggests that the very concepts of \"unstoppable force\" and \"immovable object\" are inherently contradictory, and therefore, it is illogical to imagine a scenario where they would meet and interact. However, in practical terms, it is highly unlikely for such a scenario to occur in the real world, as the concepts of \"unstoppable force\" and \"immovable object\" are often used metaphorically to describe hypothetical situations or abstract concepts, rather than physical objects or forces.\n"
|
||||
"One of the classic thought experiments in physics!\n",
|
||||
"\n",
|
||||
"The concept of an unstoppable force meeting an immovable object is a paradox that has puzzled philosophers and physicists for centuries. It's a mind-bending scenario that challenges our understanding of the fundamental laws of physics.\n",
|
||||
"\n",
|
||||
"In essence, an unstoppable force is something that cannot be halted or slowed down, while an immovable object is something that cannot be moved or displaced. If we assume that both entities exist in the same universe, we run into a logical contradiction.\n",
|
||||
"\n",
|
||||
"Here\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -278,7 +193,71 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. Take it for a spin as an agent!\n",
|
||||
"## 3. Explore the tool calling with `ChatHuggingFace`\n",
|
||||
"\n",
|
||||
"`text-generation-inference` supports tool with open source LLMs starting from v2.0.1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Create a basic tool (`Calculator`):"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Calculator(BaseModel):\n",
|
||||
" \"\"\"Multiply two integers together.\"\"\"\n",
|
||||
"\n",
|
||||
" a: int = Field(..., description=\"First integer\")\n",
|
||||
" b: int = Field(..., description=\"Second integer\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Bind the tool to the `chat_model` and give it a try:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Calculator(a=3, b=12)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_core.output_parsers.openai_tools import PydanticToolsParser\n",
|
||||
"\n",
|
||||
"llm_with_multiply = chat_model.bind_tools([Calculator], tool_choice=\"auto\")\n",
|
||||
"parser = PydanticToolsParser(tools=[Calculator])\n",
|
||||
"tool_chain = llm_with_multiply | parser\n",
|
||||
"tool_chain.invoke(\"How much is 3 multiplied by 12?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. Take it for a spin as an agent!\n",
|
||||
"\n",
|
||||
"Here we'll test out `Zephyr-7B-beta` as a zero-shot `ReAct` Agent. The example below is taken from [here](https://python.langchain.com/v0.1/docs/modules/agents/agent_types/react/#using-chat-models).\n",
|
||||
"\n",
|
||||
@ -287,7 +266,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -310,7 +289,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -342,7 +321,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -20,7 +20,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms import HuggingFaceEndpoint"
|
||||
"from langchain_huggingface.llms import HuggingFaceEndpoint"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -83,7 +83,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms import HuggingFaceEndpoint"
|
||||
"from langchain_huggingface.llms import HuggingFaceEndpoint"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -193,7 +193,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||
"from langchain_community.llms import HuggingFaceEndpoint\n",
|
||||
"from langchain_huggingface.llms import HuggingFaceEndpoint\n",
|
||||
"\n",
|
||||
"llm = HuggingFaceEndpoint(\n",
|
||||
" endpoint_url=f\"{your_endpoint_url}\",\n",
|
||||
|
@ -55,7 +55,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline\n",
|
||||
"from langchain_huggingface.llms import HuggingFacePipeline\n",
|
||||
"\n",
|
||||
"hf = HuggingFacePipeline.from_model_id(\n",
|
||||
" model_id=\"gpt2\",\n",
|
||||
@ -79,7 +79,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline\n",
|
||||
"from langchain_huggingface.llms import HuggingFacePipeline\n",
|
||||
"from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n",
|
||||
"\n",
|
||||
"model_id = \"gpt2\"\n",
|
||||
|
@ -26,7 +26,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.embeddings import HuggingFaceEmbeddings"
|
||||
"from langchain_huggingface.embeddings import HuggingFaceEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -175,7 +175,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.embeddings import HuggingFaceHubEmbeddings"
|
||||
"from langchain_huggingface.embeddings import HuggingFaceEndpointEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -185,7 +185,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embeddings = HuggingFaceHubEmbeddings()"
|
||||
"embeddings = HuggingFaceEndpointEmbeddings()"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -59,7 +59,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.embeddings import HuggingFaceHubEmbeddings"
|
||||
"from langchain_huggingface.embeddings import HuggingFaceEndpointEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -71,7 +71,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embeddings = HuggingFaceHubEmbeddings(model=\"http://localhost:8080\")"
|
||||
"embeddings = HuggingFaceEndpointEmbeddings(model=\"http://localhost:8080\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""Hugging Face Chat Wrapper."""
|
||||
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@ -34,6 +36,13 @@ from langchain_community.llms.huggingface_text_gen_inference import (
|
||||
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.37",
|
||||
removal="0.3",
|
||||
alternative_import=(
|
||||
"from langchain_huggingface.chat_models.huggingface import ChatHuggingFace"
|
||||
),
|
||||
)
|
||||
class ChatHuggingFace(BaseChatModel):
|
||||
"""
|
||||
Wrapper for using Hugging Face LLM's as ChatModels.
|
||||
|
@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@ -21,6 +22,11 @@ VALID_TASKS = (
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.37",
|
||||
removal="0.3",
|
||||
alternative_import="from langchain_huggingface.llms import HuggingFaceEndpoint",
|
||||
)
|
||||
class HuggingFaceEndpoint(LLM):
|
||||
"""
|
||||
HuggingFace Endpoint.
|
||||
|
@ -4,6 +4,7 @@ import importlib.util
|
||||
import logging
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import BaseLLM
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
@ -22,6 +23,11 @@ DEFAULT_BATCH_SIZE = 4
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.37",
|
||||
removal="0.3",
|
||||
alternative_import="from rom langchain_huggingface.llms import HuggingFacePipeline",
|
||||
)
|
||||
class HuggingFacePipeline(BaseLLM):
|
||||
"""HuggingFace Pipeline API.
|
||||
|
||||
|
1
libs/partners/huggingface/.gitignore
vendored
Normal file
1
libs/partners/huggingface/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
__pycache__
|
21
libs/partners/huggingface/LICENSE
Normal file
21
libs/partners/huggingface/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
58
libs/partners/huggingface/Makefile
Normal file
58
libs/partners/huggingface/Makefile
Normal file
@ -0,0 +1,58 @@
|
||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
integration_test integration_tests: TEST_FILE=tests/integration_tests/
|
||||
|
||||
test tests integration_test integration_tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/huggingface --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_huggingface
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
poetry run ruff .
|
||||
poetry run ruff format $(PYTHON_FILES) --diff
|
||||
poetry run ruff --select I $(PYTHON_FILES)
|
||||
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_huggingface -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
10
libs/partners/huggingface/README.md
Normal file
10
libs/partners/huggingface/README.md
Normal file
@ -0,0 +1,10 @@
|
||||
# langchain-huggingface
|
||||
|
||||
This package contains the LangChain integrations for huggingface related classes.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
- Install the LangChain partner package
|
||||
```bash
|
||||
pip install langchain-huggingface
|
||||
```
|
17
libs/partners/huggingface/langchain_huggingface/__init__.py
Normal file
17
libs/partners/huggingface/langchain_huggingface/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
from langchain_huggingface.chat_models import ChatHuggingFace
|
||||
from langchain_huggingface.embeddings import (
|
||||
HuggingFaceEmbeddings,
|
||||
HuggingFaceEndpointEmbeddings,
|
||||
)
|
||||
from langchain_huggingface.llms import (
|
||||
HuggingFaceEndpoint,
|
||||
HuggingFacePipeline,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ChatHuggingFace",
|
||||
"HuggingFaceEndpointEmbeddings",
|
||||
"HuggingFaceEmbeddings",
|
||||
"HuggingFaceEndpoint",
|
||||
"HuggingFacePipeline",
|
||||
]
|
@ -0,0 +1,15 @@
|
||||
from langchain_huggingface.chat_models.huggingface import (
|
||||
TGI_MESSAGE,
|
||||
TGI_RESPONSE,
|
||||
ChatHuggingFace,
|
||||
_convert_message_to_chat_message,
|
||||
_convert_TGI_message_to_LC_message,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ChatHuggingFace",
|
||||
"_convert_message_to_chat_message",
|
||||
"_convert_TGI_message_to_LC_message",
|
||||
"TGI_MESSAGE",
|
||||
"TGI_RESPONSE",
|
||||
]
|
@ -0,0 +1,350 @@
|
||||
"""Hugging Face Chat Wrapper."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_community.llms.huggingface_hub import HuggingFaceHub
|
||||
from langchain_community.llms.huggingface_text_gen_inference import (
|
||||
HuggingFaceTextGenInference,
|
||||
)
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TGI_RESPONSE:
|
||||
choices: List[Any]
|
||||
usage: Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class TGI_MESSAGE:
|
||||
role: str
|
||||
content: str
|
||||
tool_calls: List[Dict]
|
||||
|
||||
|
||||
def _convert_message_to_chat_message(
|
||||
message: BaseMessage,
|
||||
) -> Dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
return dict(role=message.role, content=message.content)
|
||||
elif isinstance(message, HumanMessage):
|
||||
return dict(role="user", content=message.content)
|
||||
elif isinstance(message, AIMessage):
|
||||
if "tool_calls" in message.additional_kwargs:
|
||||
tool_calls = [
|
||||
{
|
||||
"function": {
|
||||
"name": tc["function"]["name"],
|
||||
"arguments": tc["function"]["arguments"],
|
||||
}
|
||||
}
|
||||
for tc in message.additional_kwargs["tool_calls"]
|
||||
]
|
||||
else:
|
||||
tool_calls = None
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
elif isinstance(message, SystemMessage):
|
||||
return dict(role="system", content=message.content)
|
||||
elif isinstance(message, ToolMessage):
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
||||
def _convert_TGI_message_to_LC_message(
|
||||
_message: TGI_MESSAGE,
|
||||
) -> BaseMessage:
|
||||
role = _message.role
|
||||
assert role == "assistant", f"Expected role to be 'assistant', got {role}"
|
||||
content = cast(str, _message.content)
|
||||
if content is None:
|
||||
content = ""
|
||||
additional_kwargs: Dict = {}
|
||||
if tool_calls := _message.tool_calls:
|
||||
if "arguments" in tool_calls[0]["function"]:
|
||||
functions_string = str(tool_calls[0]["function"].pop("arguments"))
|
||||
corrected_functions = functions_string.replace("'", '"')
|
||||
tool_calls[0]["function"]["arguments"] = corrected_functions
|
||||
additional_kwargs["tool_calls"] = tool_calls
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
|
||||
|
||||
class ChatHuggingFace(BaseChatModel):
|
||||
"""
|
||||
Wrapper for using Hugging Face LLM's as ChatModels.
|
||||
|
||||
Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`,
|
||||
and `HuggingFaceHub` LLMs.
|
||||
|
||||
Upon instantiating this class, the model_id is resolved from the url
|
||||
provided to the LLM, and the appropriate tokenizer is loaded from
|
||||
the HuggingFace Hub.
|
||||
|
||||
Adapted from: https://python.langchain.com/docs/integrations/chat/llama2_chat
|
||||
"""
|
||||
|
||||
llm: Any
|
||||
"""LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint, or
|
||||
HuggingFaceHub."""
|
||||
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
|
||||
tokenizer: Any = None
|
||||
model_id: Optional[str] = None
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
from transformers import AutoTokenizer # type: ignore[import]
|
||||
|
||||
self._resolve_model_id()
|
||||
|
||||
self.tokenizer = (
|
||||
AutoTokenizer.from_pretrained(self.model_id)
|
||||
if self.tokenizer is None
|
||||
else self.tokenizer
|
||||
)
|
||||
|
||||
@root_validator()
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
if not isinstance(
|
||||
values["llm"],
|
||||
(HuggingFaceHub, HuggingFaceTextGenInference, HuggingFaceEndpoint),
|
||||
):
|
||||
raise TypeError(
|
||||
"Expected llm to be one of HuggingFaceTextGenInference, "
|
||||
f"HuggingFaceEndpoint, HuggingFaceHub, received {type(values['llm'])}"
|
||||
)
|
||||
return values
|
||||
|
||||
def _create_chat_result(self, response: TGI_RESPONSE) -> ChatResult:
|
||||
generations = []
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
gen = ChatGeneration(
|
||||
message=_convert_TGI_message_to_LC_message(response.choices[0].message),
|
||||
generation_info={"finish_reason": finish_reason},
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = response.usage
|
||||
model_object = self.llm.inference_server_url
|
||||
llm_output = {"token_usage": token_usage, "model": model_object}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if isinstance(self.llm, HuggingFaceTextGenInference):
|
||||
message_dicts = self._create_message_dicts(messages, stop)
|
||||
answer = self.llm.client.chat(messages=message_dicts, **kwargs)
|
||||
return self._create_chat_result(answer)
|
||||
elif isinstance(self.llm, HuggingFaceEndpoint):
|
||||
message_dicts = self._create_message_dicts(messages, stop)
|
||||
answer = self.llm.client.chat_completion(messages=message_dicts, **kwargs)
|
||||
return self._create_chat_result(answer)
|
||||
else:
|
||||
llm_input = self._to_chat_prompt(messages)
|
||||
llm_result = self.llm._generate(
|
||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return self._to_chat_result(llm_result)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if isinstance(self.llm, HuggingFaceTextGenInference):
|
||||
message_dicts = self._create_message_dicts(messages, stop)
|
||||
answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs)
|
||||
return self._create_chat_result(answer)
|
||||
else:
|
||||
llm_input = self._to_chat_prompt(messages)
|
||||
llm_result = await self.llm._agenerate(
|
||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return self._to_chat_result(llm_result)
|
||||
|
||||
def _to_chat_prompt(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
) -> str:
|
||||
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
|
||||
if not messages:
|
||||
raise ValueError("At least one HumanMessage must be provided!")
|
||||
|
||||
if not isinstance(messages[-1], HumanMessage):
|
||||
raise ValueError("Last message must be a HumanMessage!")
|
||||
|
||||
messages_dicts = [self._to_chatml_format(m) for m in messages]
|
||||
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages_dicts, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
def _to_chatml_format(self, message: BaseMessage) -> dict:
|
||||
"""Convert LangChain message to ChatML format."""
|
||||
|
||||
if isinstance(message, SystemMessage):
|
||||
role = "system"
|
||||
elif isinstance(message, AIMessage):
|
||||
role = "assistant"
|
||||
elif isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
else:
|
||||
raise ValueError(f"Unknown message type: {type(message)}")
|
||||
|
||||
return {"role": role, "content": message.content}
|
||||
|
||||
@staticmethod
|
||||
def _to_chat_result(llm_result: LLMResult) -> ChatResult:
|
||||
chat_generations = []
|
||||
|
||||
for g in llm_result.generations[0]:
|
||||
chat_generation = ChatGeneration(
|
||||
message=AIMessage(content=g.text), generation_info=g.generation_info
|
||||
)
|
||||
chat_generations.append(chat_generation)
|
||||
|
||||
return ChatResult(
|
||||
generations=chat_generations, llm_output=llm_result.llm_output
|
||||
)
|
||||
|
||||
def _resolve_model_id(self) -> None:
|
||||
"""Resolve the model_id from the LLM's inference_server_url"""
|
||||
|
||||
from huggingface_hub import list_inference_endpoints # type: ignore[import]
|
||||
|
||||
available_endpoints = list_inference_endpoints("*")
|
||||
if isinstance(self.llm, HuggingFaceHub) or (
|
||||
hasattr(self.llm, "repo_id") and self.llm.repo_id
|
||||
):
|
||||
self.model_id = self.llm.repo_id
|
||||
return
|
||||
elif isinstance(self.llm, HuggingFaceTextGenInference):
|
||||
endpoint_url: Optional[str] = self.llm.inference_server_url
|
||||
else:
|
||||
endpoint_url = self.llm.endpoint_url
|
||||
|
||||
for endpoint in available_endpoints:
|
||||
if endpoint.url == endpoint_url:
|
||||
self.model_id = endpoint.repository
|
||||
|
||||
if not self.model_id:
|
||||
raise ValueError(
|
||||
"Failed to resolve model_id:"
|
||||
f"Could not find model id for inference server: {endpoint_url}"
|
||||
"Make sure that your Hugging Face token has access to the endpoint."
|
||||
)
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
Assumes model is compatible with OpenAI tool-calling API.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
tool_choice: Which tool to require the model to call.
|
||||
Must be the name of the single provided function or
|
||||
"auto" to automatically determine which function to call
|
||||
(if any), or a dict of the form:
|
||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
if tool_choice is not None and tool_choice:
|
||||
if len(formatted_tools) != 1:
|
||||
raise ValueError(
|
||||
"When specifying `tool_choice`, you must provide exactly one "
|
||||
f"tool. Received {len(formatted_tools)} tools."
|
||||
)
|
||||
if isinstance(tool_choice, str):
|
||||
if tool_choice not in ("auto", "none"):
|
||||
tool_choice = {
|
||||
"type": "function",
|
||||
"function": {"name": tool_choice},
|
||||
}
|
||||
elif isinstance(tool_choice, bool):
|
||||
tool_choice = formatted_tools[0]
|
||||
elif isinstance(tool_choice, dict):
|
||||
if (
|
||||
formatted_tools[0]["function"]["name"]
|
||||
!= tool_choice["function"]["name"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Tool choice {tool_choice} was specified, but the only "
|
||||
f"provided tool was {formatted_tools[0]['function']['name']}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized tool_choice type. Expected str, bool or dict. "
|
||||
f"Received: {tool_choice}"
|
||||
)
|
||||
kwargs["tool_choice"] = tool_choice
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> List[Dict[Any, Any]]:
|
||||
message_dicts = [_convert_message_to_chat_message(m) for m in messages]
|
||||
return message_dicts
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "huggingface-chat-wrapper"
|
@ -0,0 +1,9 @@
|
||||
from langchain_huggingface.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from langchain_huggingface.embeddings.huggingface_endpoint import (
|
||||
HuggingFaceEndpointEmbeddings,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"HuggingFaceEmbeddings",
|
||||
"HuggingFaceEndpointEmbeddings",
|
||||
]
|
@ -0,0 +1,102 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field
|
||||
|
||||
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
||||
|
||||
|
||||
class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
"""HuggingFace sentence_transformers embedding models.
|
||||
|
||||
To use, you should have the ``sentence_transformers`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||
model_kwargs = {'device': 'cpu'}
|
||||
encode_kwargs = {'normalize_embeddings': False}
|
||||
hf = HuggingFaceEmbeddings(
|
||||
model_name=model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
encode_kwargs=encode_kwargs
|
||||
)
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = DEFAULT_MODEL_NAME
|
||||
"""Model name to use."""
|
||||
cache_folder: Optional[str] = None
|
||||
"""Path to store models.
|
||||
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the Sentence Transformer model, such as `device`,
|
||||
`prompts`, `default_prompt_name`, `revision`, `trust_remote_code`, or `token`.
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer"""
|
||||
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method of the Sentence
|
||||
Transformer model, such as `prompt_name`, `prompt`, `batch_size`, `precision`,
|
||||
`normalize_embeddings`, and more.
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
|
||||
multi_process: bool = False
|
||||
"""Run encode() on multiple GPUs."""
|
||||
show_progress: bool = False
|
||||
"""Whether to show a progress bar."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import sentence_transformers # type: ignore[import]
|
||||
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Could not import sentence_transformers python package. "
|
||||
"Please install it with `pip install sentence-transformers`."
|
||||
) from exc
|
||||
|
||||
self.client = sentence_transformers.SentenceTransformer(
|
||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||
)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
import sentence_transformers # type: ignore[import]
|
||||
|
||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
if self.multi_process:
|
||||
pool = self.client.start_multi_process_pool()
|
||||
embeddings = self.client.encode_multi_process(texts, pool)
|
||||
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
|
||||
else:
|
||||
embeddings = self.client.encode(
|
||||
texts, show_progress_bar=self.show_progress, **self.encode_kwargs
|
||||
)
|
||||
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
@ -0,0 +1,151 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
|
||||
DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
||||
VALID_TASKS = ("feature-extraction",)
|
||||
|
||||
|
||||
class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
"""HuggingFaceHub embedding models.
|
||||
|
||||
To use, you should have the ``huggingface_hub`` python package installed, and the
|
||||
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings import HuggingFaceEndpointEmbeddings
|
||||
model = "sentence-transformers/all-mpnet-base-v2"
|
||||
hf = HuggingFaceEndpointEmbeddings(
|
||||
model=model,
|
||||
task="feature-extraction",
|
||||
huggingfacehub_api_token="my-api-key",
|
||||
)
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
async_client: Any #: :meta private:
|
||||
model: Optional[str] = None
|
||||
"""Model name to use."""
|
||||
repo_id: Optional[str] = None
|
||||
"""Huggingfacehub repository id, for backward compatibility."""
|
||||
task: Optional[str] = "feature-extraction"
|
||||
"""Task to call the model with."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
huggingfacehub_api_token: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
huggingfacehub_api_token = values["huggingfacehub_api_token"] or os.getenv(
|
||||
"HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
|
||||
try:
|
||||
from huggingface_hub import ( # type: ignore[import]
|
||||
AsyncInferenceClient,
|
||||
InferenceClient,
|
||||
)
|
||||
|
||||
if values["model"]:
|
||||
values["repo_id"] = values["model"]
|
||||
elif values["repo_id"]:
|
||||
values["model"] = values["repo_id"]
|
||||
else:
|
||||
values["model"] = DEFAULT_MODEL
|
||||
values["repo_id"] = DEFAULT_MODEL
|
||||
|
||||
client = InferenceClient(
|
||||
model=values["model"],
|
||||
token=huggingfacehub_api_token,
|
||||
)
|
||||
|
||||
async_client = AsyncInferenceClient(
|
||||
model=values["model"],
|
||||
token=huggingfacehub_api_token,
|
||||
)
|
||||
|
||||
if values["task"] not in VALID_TASKS:
|
||||
raise ValueError(
|
||||
f"Got invalid task {values['task']}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
values["client"] = client
|
||||
values["async_client"] = async_client
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import huggingface_hub python package. "
|
||||
"Please install it with `pip install huggingface_hub`."
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to HuggingFaceHub's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
# replace newlines, which can negatively affect performance.
|
||||
texts = [text.replace("\n", " ") for text in texts]
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
responses = self.client.post(
|
||||
json={"inputs": texts, "parameters": _model_kwargs}, task=self.task
|
||||
)
|
||||
return json.loads(responses.decode())
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Async Call to HuggingFaceHub's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
# replace newlines, which can negatively affect performance.
|
||||
texts = [text.replace("\n", " ") for text in texts]
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
responses = await self.async_client.post(
|
||||
json={"inputs": texts, "parameters": _model_kwargs}, task=self.task
|
||||
)
|
||||
return json.loads(responses.decode())
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to HuggingFaceHub's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
response = self.embed_documents([text])[0]
|
||||
return response
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Async Call to HuggingFaceHub's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
response = (await self.aembed_documents([text]))[0]
|
||||
return response
|
@ -0,0 +1,7 @@
|
||||
from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
|
||||
__all__ = [
|
||||
"HuggingFaceEndpoint",
|
||||
"HuggingFacePipeline",
|
||||
]
|
@ -0,0 +1,372 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VALID_TASKS = (
|
||||
"text2text-generation",
|
||||
"text-generation",
|
||||
"summarization",
|
||||
"conversational",
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceEndpoint(LLM):
|
||||
"""
|
||||
HuggingFace Endpoint.
|
||||
|
||||
To use this class, you should have installed the ``huggingface_hub`` package, and
|
||||
the environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token,
|
||||
or given as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# Basic Example (no streaming)
|
||||
llm = HuggingFaceEndpoint(
|
||||
endpoint_url="http://localhost:8010/",
|
||||
max_new_tokens=512,
|
||||
top_k=10,
|
||||
top_p=0.95,
|
||||
typical_p=0.95,
|
||||
temperature=0.01,
|
||||
repetition_penalty=1.03,
|
||||
huggingfacehub_api_token="my-api-key"
|
||||
)
|
||||
print(llm.invoke("What is Deep Learning?"))
|
||||
|
||||
# Streaming response example
|
||||
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
callbacks = [StreamingStdOutCallbackHandler()]
|
||||
llm = HuggingFaceEndpoint(
|
||||
endpoint_url="http://localhost:8010/",
|
||||
max_new_tokens=512,
|
||||
top_k=10,
|
||||
top_p=0.95,
|
||||
typical_p=0.95,
|
||||
temperature=0.01,
|
||||
repetition_penalty=1.03,
|
||||
callbacks=callbacks,
|
||||
streaming=True,
|
||||
huggingfacehub_api_token="my-api-key"
|
||||
)
|
||||
print(llm.invoke("What is Deep Learning?"))
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
endpoint_url: Optional[str] = None
|
||||
"""Endpoint URL to use."""
|
||||
repo_id: Optional[str] = None
|
||||
"""Repo to use."""
|
||||
huggingfacehub_api_token: Optional[str] = None
|
||||
max_new_tokens: int = 512
|
||||
"""Maximum number of generated tokens"""
|
||||
top_k: Optional[int] = None
|
||||
"""The number of highest probability vocabulary tokens to keep for
|
||||
top-k-filtering."""
|
||||
top_p: Optional[float] = 0.95
|
||||
"""If set to < 1, only the smallest set of most probable tokens with probabilities
|
||||
that add up to `top_p` or higher are kept for generation."""
|
||||
typical_p: Optional[float] = 0.95
|
||||
"""Typical Decoding mass. See [Typical Decoding for Natural Language
|
||||
Generation](https://arxiv.org/abs/2202.00666) for more information."""
|
||||
temperature: Optional[float] = 0.8
|
||||
"""The value used to module the logits distribution."""
|
||||
repetition_penalty: Optional[float] = None
|
||||
"""The parameter for repetition penalty. 1.0 means no penalty.
|
||||
See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details."""
|
||||
return_full_text: bool = False
|
||||
"""Whether to prepend the prompt to the generated text"""
|
||||
truncate: Optional[int] = None
|
||||
"""Truncate inputs tokens to the given size"""
|
||||
stop_sequences: List[str] = Field(default_factory=list)
|
||||
"""Stop generating tokens if a member of `stop_sequences` is generated"""
|
||||
seed: Optional[int] = None
|
||||
"""Random sampling seed"""
|
||||
inference_server_url: str = ""
|
||||
"""text-generation-inference instance base url"""
|
||||
timeout: int = 120
|
||||
"""Timeout in seconds"""
|
||||
streaming: bool = False
|
||||
"""Whether to generate a stream of tokens asynchronously"""
|
||||
do_sample: bool = False
|
||||
"""Activate logits sampling"""
|
||||
watermark: bool = False
|
||||
"""Watermarking with [A Watermark for Large Language Models]
|
||||
(https://arxiv.org/abs/2301.10226)"""
|
||||
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any text-generation-inference server parameters not explicitly specified"""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `call` not explicitly specified"""
|
||||
model: str
|
||||
client: Any
|
||||
async_client: Any
|
||||
task: Optional[str] = None
|
||||
"""Task to call the model with.
|
||||
Should be a task that returns `generated_text` or `summary_text`."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
logger.warning(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please make sure that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
if "endpoint_url" not in values and "repo_id" not in values:
|
||||
raise ValueError(
|
||||
"Please specify an `endpoint_url` or `repo_id` for the model."
|
||||
)
|
||||
if "endpoint_url" in values and "repo_id" in values:
|
||||
raise ValueError(
|
||||
"Please specify either an `endpoint_url` OR a `repo_id`, not both."
|
||||
)
|
||||
values["model"] = values.get("endpoint_url") or values.get("repo_id")
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that package is installed and that the API token is valid."""
|
||||
try:
|
||||
from huggingface_hub import login # type: ignore[import]
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import huggingface_hub python package. "
|
||||
"Please install it with `pip install huggingface_hub`."
|
||||
)
|
||||
try:
|
||||
huggingfacehub_api_token = get_from_dict_or_env(
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
login(token=huggingfacehub_api_token)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not authenticate with huggingface_hub. "
|
||||
"Please check your API token."
|
||||
) from e
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, InferenceClient
|
||||
|
||||
values["client"] = InferenceClient(
|
||||
model=values["model"],
|
||||
timeout=values["timeout"],
|
||||
token=huggingfacehub_api_token,
|
||||
**values["server_kwargs"],
|
||||
)
|
||||
values["async_client"] = AsyncInferenceClient(
|
||||
model=values["model"],
|
||||
timeout=values["timeout"],
|
||||
token=huggingfacehub_api_token,
|
||||
**values["server_kwargs"],
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling text generation inference API."""
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"typical_p": self.typical_p,
|
||||
"temperature": self.temperature,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"return_full_text": self.return_full_text,
|
||||
"truncate": self.truncate,
|
||||
"stop_sequences": self.stop_sequences,
|
||||
"seed": self.seed,
|
||||
"do_sample": self.do_sample,
|
||||
"watermark": self.watermark,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
**{"endpoint_url": self.endpoint_url, "task": self.task},
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "huggingface_endpoint"
|
||||
|
||||
def _invocation_params(
|
||||
self, runtime_stop: Optional[List[str]], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
params = {**self._default_params, **kwargs}
|
||||
params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or [])
|
||||
return params
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to HuggingFace Hub's inference endpoint."""
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(prompt, stop, run_manager, **invocation_params):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
else:
|
||||
invocation_params["stop"] = invocation_params[
|
||||
"stop_sequences"
|
||||
] # porting 'stop_sequences' into the 'stop' argument
|
||||
response = self.client.post(
|
||||
json={"inputs": prompt, "parameters": invocation_params},
|
||||
stream=False,
|
||||
task=self.task,
|
||||
)
|
||||
response_text = json.loads(response.decode())[0]["generated_text"]
|
||||
|
||||
# Maybe the generation has stopped at one of the stop sequences:
|
||||
# then we remove this stop sequence from the end of the generated text
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if response_text[-len(stop_seq) :] == stop_seq:
|
||||
response_text = response_text[: -len(stop_seq)]
|
||||
return response_text
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
async for chunk in self._astream(
|
||||
prompt, stop, run_manager, **invocation_params
|
||||
):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
else:
|
||||
invocation_params["stop"] = invocation_params["stop_sequences"]
|
||||
response = await self.async_client.post(
|
||||
json={"inputs": prompt, "parameters": invocation_params},
|
||||
stream=False,
|
||||
task=self.task,
|
||||
)
|
||||
response_text = json.loads(response.decode())[0]["generated_text"]
|
||||
|
||||
# Maybe the generation has stopped at one of the stop sequences:
|
||||
# then remove this stop sequence from the end of the generated text
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if response_text[-len(stop_seq) :] == stop_seq:
|
||||
response_text = response_text[: -len(stop_seq)]
|
||||
return response_text
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
|
||||
for response in self.client.text_generation(
|
||||
prompt, **invocation_params, stream=True
|
||||
):
|
||||
# identify stop sequence in generated text, if any
|
||||
stop_seq_found: Optional[str] = None
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if stop_seq in response:
|
||||
stop_seq_found = stop_seq
|
||||
|
||||
# identify text to yield
|
||||
text: Optional[str] = None
|
||||
if stop_seq_found:
|
||||
text = response[: response.index(stop_seq_found)]
|
||||
else:
|
||||
text = response
|
||||
|
||||
# yield text, if any
|
||||
if text:
|
||||
chunk = GenerationChunk(text=text)
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text)
|
||||
yield chunk
|
||||
|
||||
# break if stop sequence found
|
||||
if stop_seq_found:
|
||||
break
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
async for response in await self.async_client.text_generation(
|
||||
prompt, **invocation_params, stream=True
|
||||
):
|
||||
# identify stop sequence in generated text, if any
|
||||
stop_seq_found: Optional[str] = None
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if stop_seq in response:
|
||||
stop_seq_found = stop_seq
|
||||
|
||||
# identify text to yield
|
||||
text: Optional[str] = None
|
||||
if stop_seq_found:
|
||||
text = response[: response.index(stop_seq_found)]
|
||||
else:
|
||||
text = response
|
||||
|
||||
# yield text, if any
|
||||
if text:
|
||||
chunk = GenerationChunk(text=text)
|
||||
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.text)
|
||||
yield chunk
|
||||
|
||||
# break if stop sequence found
|
||||
if stop_seq_found:
|
||||
break
|
@ -0,0 +1,299 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import BaseLLM
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
from langchain_core.pydantic_v1 import Extra
|
||||
|
||||
DEFAULT_MODEL_ID = "gpt2"
|
||||
DEFAULT_TASK = "text-generation"
|
||||
VALID_TASKS = (
|
||||
"text2text-generation",
|
||||
"text-generation",
|
||||
"summarization",
|
||||
"translation",
|
||||
)
|
||||
DEFAULT_BATCH_SIZE = 4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingFacePipeline(BaseLLM):
|
||||
"""HuggingFace Pipeline API.
|
||||
|
||||
To use, you should have the ``transformers`` python package installed.
|
||||
|
||||
Only supports `text-generation`, `text2text-generation`, `summarization` and
|
||||
`translation` for now.
|
||||
|
||||
Example using from_model_id:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import HuggingFacePipeline
|
||||
hf = HuggingFacePipeline.from_model_id(
|
||||
model_id="gpt2",
|
||||
task="text-generation",
|
||||
pipeline_kwargs={"max_new_tokens": 10},
|
||||
)
|
||||
Example passing pipeline in directly:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import HuggingFacePipeline
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||
|
||||
model_id = "gpt2"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
pipe = pipeline(
|
||||
"text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10
|
||||
)
|
||||
hf = HuggingFacePipeline(pipeline=pipe)
|
||||
"""
|
||||
|
||||
pipeline: Any #: :meta private:
|
||||
model_id: str = DEFAULT_MODEL_ID
|
||||
"""Model name to use."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Keyword arguments passed to the model."""
|
||||
pipeline_kwargs: Optional[dict] = None
|
||||
"""Keyword arguments passed to the pipeline."""
|
||||
batch_size: int = DEFAULT_BATCH_SIZE
|
||||
"""Batch size to use when passing multiple documents to generate."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@classmethod
|
||||
def from_model_id(
|
||||
cls,
|
||||
model_id: str,
|
||||
task: str,
|
||||
backend: str = "default",
|
||||
device: Optional[int] = -1,
|
||||
device_map: Optional[str] = None,
|
||||
model_kwargs: Optional[dict] = None,
|
||||
pipeline_kwargs: Optional[dict] = None,
|
||||
batch_size: int = DEFAULT_BATCH_SIZE,
|
||||
**kwargs: Any,
|
||||
) -> HuggingFacePipeline:
|
||||
"""Construct the pipeline object from model_id and task."""
|
||||
try:
|
||||
from transformers import ( # type: ignore[import]
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from transformers import pipeline as hf_pipeline # type: ignore[import]
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
|
||||
_model_kwargs = model_kwargs or {}
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||
|
||||
try:
|
||||
if task == "text-generation":
|
||||
if backend == "openvino":
|
||||
try:
|
||||
from optimum.intel.openvino import ( # type: ignore[import]
|
||||
OVModelForCausalLM,
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import optimum-intel python package. "
|
||||
"Please install it with: "
|
||||
"pip install 'optimum[openvino,nncf]' "
|
||||
)
|
||||
try:
|
||||
# use local model
|
||||
model = OVModelForCausalLM.from_pretrained(
|
||||
model_id, **_model_kwargs
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# use remote model
|
||||
model = OVModelForCausalLM.from_pretrained(
|
||||
model_id, export=True, **_model_kwargs
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, **_model_kwargs
|
||||
)
|
||||
elif task in ("text2text-generation", "summarization", "translation"):
|
||||
if backend == "openvino":
|
||||
try:
|
||||
from optimum.intel.openvino import OVModelForSeq2SeqLM
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import optimum-intel python package. "
|
||||
"Please install it with: "
|
||||
"pip install 'optimum[openvino,nncf]' "
|
||||
)
|
||||
try:
|
||||
# use local model
|
||||
model = OVModelForSeq2SeqLM.from_pretrained(
|
||||
model_id, **_model_kwargs
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# use remote model
|
||||
model = OVModelForSeq2SeqLM.from_pretrained(
|
||||
model_id, export=True, **_model_kwargs
|
||||
)
|
||||
else:
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
model_id, **_model_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
f"Could not load the {task} model due to missing dependencies."
|
||||
) from e
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token_id = model.config.eos_token_id
|
||||
|
||||
if (
|
||||
(
|
||||
getattr(model, "is_loaded_in_4bit", False)
|
||||
or getattr(model, "is_loaded_in_8bit", False)
|
||||
)
|
||||
and device is not None
|
||||
and backend == "default"
|
||||
):
|
||||
logger.warning(
|
||||
f"Setting the `device` argument to None from {device} to avoid "
|
||||
"the error caused by attempting to move the model that was already "
|
||||
"loaded on the GPU using the Accelerate module to the same or "
|
||||
"another device."
|
||||
)
|
||||
device = None
|
||||
|
||||
if (
|
||||
device is not None
|
||||
and importlib.util.find_spec("torch") is not None
|
||||
and backend == "default"
|
||||
):
|
||||
import torch
|
||||
|
||||
cuda_device_count = torch.cuda.device_count()
|
||||
if device < -1 or (device >= cuda_device_count):
|
||||
raise ValueError(
|
||||
f"Got device=={device}, "
|
||||
f"device is required to be within [-1, {cuda_device_count})"
|
||||
)
|
||||
if device_map is not None and device < 0:
|
||||
device = None
|
||||
if device is not None and device < 0 and cuda_device_count > 0:
|
||||
logger.warning(
|
||||
"Device has %d GPUs available. "
|
||||
"Provide device={deviceId} to `from_model_id` to use available"
|
||||
"GPUs for execution. deviceId is -1 (default) for CPU and "
|
||||
"can be a positive integer associated with CUDA device id.",
|
||||
cuda_device_count,
|
||||
)
|
||||
if device is not None and device_map is not None and backend == "openvino":
|
||||
logger.warning("Please set device for OpenVINO through: " "'model_kwargs'")
|
||||
if "trust_remote_code" in _model_kwargs:
|
||||
_model_kwargs = {
|
||||
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
|
||||
}
|
||||
_pipeline_kwargs = pipeline_kwargs or {}
|
||||
pipeline = hf_pipeline(
|
||||
task=task,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
device_map=device_map,
|
||||
batch_size=batch_size,
|
||||
model_kwargs=_model_kwargs,
|
||||
**_pipeline_kwargs,
|
||||
)
|
||||
if pipeline.task not in VALID_TASKS:
|
||||
raise ValueError(
|
||||
f"Got invalid task {pipeline.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
return cls(
|
||||
pipeline=pipeline,
|
||||
model_id=model_id,
|
||||
model_kwargs=_model_kwargs,
|
||||
pipeline_kwargs=_pipeline_kwargs,
|
||||
batch_size=batch_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model_id": self.model_id,
|
||||
"model_kwargs": self.model_kwargs,
|
||||
"pipeline_kwargs": self.pipeline_kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "huggingface_pipeline"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
# List to hold all results
|
||||
text_generations: List[str] = []
|
||||
pipeline_kwargs = kwargs.get("pipeline_kwargs", {})
|
||||
|
||||
for i in range(0, len(prompts), self.batch_size):
|
||||
batch_prompts = prompts[i : i + self.batch_size]
|
||||
|
||||
# Process batch of prompts
|
||||
responses = self.pipeline(
|
||||
batch_prompts,
|
||||
**pipeline_kwargs,
|
||||
)
|
||||
|
||||
# Process each response in the batch
|
||||
for j, response in enumerate(responses):
|
||||
if isinstance(response, list):
|
||||
# if model returns multiple generations, pick the top one
|
||||
response = response[0]
|
||||
|
||||
if self.pipeline.task == "text-generation":
|
||||
text = response["generated_text"]
|
||||
elif self.pipeline.task == "text2text-generation":
|
||||
text = response["generated_text"]
|
||||
elif self.pipeline.task == "summarization":
|
||||
text = response["summary_text"]
|
||||
elif self.pipeline.task in "translation":
|
||||
text = response["translation_text"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {self.pipeline.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
|
||||
# Append the processed text to results
|
||||
text_generations.append(text)
|
||||
|
||||
return LLMResult(
|
||||
generations=[[Generation(text=text)] for text in text_generations]
|
||||
)
|
3346
libs/partners/huggingface/poetry.lock
generated
Normal file
3346
libs/partners/huggingface/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
97
libs/partners/huggingface/pyproject.toml
Normal file
97
libs/partners/huggingface/pyproject.toml
Normal file
@ -0,0 +1,97 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-huggingface"
|
||||
version = "0.0.1"
|
||||
description = "An integration package connecting Hugging Face and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
license = "MIT"
|
||||
|
||||
[tool.poetry.urls]
|
||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/huggingface"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = ">=0.1.52,<0.3"
|
||||
tokenizers = ">=0.19.1"
|
||||
transformers = ">=4.39.0"
|
||||
sentence-transformers = ">=2.6.0"
|
||||
text-generation = "^0.7.0"
|
||||
huggingface-hub = ">=0.23.0"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
langchain-standard-tests = { path = "../../standard-tests", develop = true }
|
||||
langchain-community = { path = "../../community", develop = true }
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^1"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
langchain-community = { path = "../../community", develop = true }
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
langchain-community = { path = "../../community", develop = true }
|
||||
ipykernel = "^6.29.2"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"T201", # print
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
addopts = "--strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
17
libs/partners/huggingface/scripts/check_imports.py
Normal file
17
libs/partners/huggingface/scripts/check_imports.py
Normal file
@ -0,0 +1,17 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_faillure = True
|
||||
print(file) # noqa: T201
|
||||
traceback.print_exc()
|
||||
print() # noqa: T201
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
27
libs/partners/huggingface/scripts/check_pydantic.sh
Executable file
27
libs/partners/huggingface/scripts/check_pydantic.sh
Executable file
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
18
libs/partners/huggingface/scripts/lint_imports.sh
Executable file
18
libs/partners/huggingface/scripts/lint_imports.sh
Executable file
@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
# git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
242
libs/partners/huggingface/tests/unit_tests/test_chat_models.py
Normal file
242
libs/partners/huggingface/tests/unit_tests/test_chat_models.py
Normal file
@ -0,0 +1,242 @@
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatResult
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_huggingface.chat_models import ( # type: ignore[import]
|
||||
TGI_MESSAGE,
|
||||
ChatHuggingFace,
|
||||
_convert_message_to_chat_message,
|
||||
_convert_TGI_message_to_LC_message,
|
||||
)
|
||||
from langchain_huggingface.llms.huggingface_endpoint import (
|
||||
HuggingFaceEndpoint,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("message", "expected"),
|
||||
[
|
||||
(
|
||||
SystemMessage(content="Hello"),
|
||||
dict(role="system", content="Hello"),
|
||||
),
|
||||
(
|
||||
HumanMessage(content="Hello"),
|
||||
dict(role="user", content="Hello"),
|
||||
),
|
||||
(
|
||||
AIMessage(content="Hello"),
|
||||
dict(role="assistant", content="Hello", tool_calls=None),
|
||||
),
|
||||
(
|
||||
ChatMessage(role="assistant", content="Hello"),
|
||||
dict(role="assistant", content="Hello"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_message_to_chat_message(
|
||||
message: BaseMessage, expected: Dict[str, str]
|
||||
) -> None:
|
||||
result = _convert_message_to_chat_message(message)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tgi_message", "expected"),
|
||||
[
|
||||
(
|
||||
TGI_MESSAGE(role="assistant", content="Hello", tool_calls=[]),
|
||||
AIMessage(content="Hello"),
|
||||
),
|
||||
(
|
||||
TGI_MESSAGE(role="assistant", content="", tool_calls=[]),
|
||||
AIMessage(content=""),
|
||||
),
|
||||
(
|
||||
TGI_MESSAGE(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[{"function": {"arguments": "'function string'"}}],
|
||||
),
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [{"function": {"arguments": '"function string"'}}]
|
||||
},
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_TGI_message_to_LC_message(
|
||||
tgi_message: TGI_MESSAGE, expected: BaseMessage
|
||||
) -> None:
|
||||
result = _convert_TGI_message_to_LC_message(tgi_message)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm() -> Mock:
|
||||
llm = Mock(spec=HuggingFaceEndpoint)
|
||||
llm.inference_server_url = "test endpoint url"
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@patch(
|
||||
"langchain_huggingface.chat_models.huggingface.ChatHuggingFace._resolve_model_id"
|
||||
)
|
||||
def chat_hugging_face(mock_resolve_id: Any, mock_llm: Any) -> ChatHuggingFace:
|
||||
chat_hf = ChatHuggingFace(llm=mock_llm, tokenizer=MagicMock())
|
||||
return chat_hf
|
||||
|
||||
|
||||
def test_create_chat_result(chat_hugging_face: Any) -> None:
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [
|
||||
MagicMock(
|
||||
message=TGI_MESSAGE(
|
||||
role="assistant", content="test message", tool_calls=[]
|
||||
),
|
||||
finish_reason="test finish reason",
|
||||
)
|
||||
]
|
||||
mock_response.usage = {"tokens": 420}
|
||||
|
||||
result = chat_hugging_face._create_chat_result(mock_response)
|
||||
assert isinstance(result, ChatResult)
|
||||
assert result.generations[0].message.content == "test message"
|
||||
assert (
|
||||
result.generations[0].generation_info["finish_reason"] == "test finish reason" # type: ignore[index]
|
||||
)
|
||||
assert result.llm_output["token_usage"]["tokens"] == 420 # type: ignore[index]
|
||||
assert result.llm_output["model"] == chat_hugging_face.llm.inference_server_url # type: ignore[index]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"messages, expected_error",
|
||||
[
|
||||
([], "At least one HumanMessage must be provided!"),
|
||||
(
|
||||
[HumanMessage(content="Hi"), AIMessage(content="Hello")],
|
||||
"Last message must be a HumanMessage!",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_to_chat_prompt_errors(
|
||||
chat_hugging_face: Any, messages: List[BaseMessage], expected_error: str
|
||||
) -> None:
|
||||
with pytest.raises(ValueError) as e:
|
||||
chat_hugging_face._to_chat_prompt(messages)
|
||||
assert expected_error in str(e.value)
|
||||
|
||||
|
||||
def test_to_chat_prompt_valid_messages(chat_hugging_face: Any) -> None:
|
||||
messages = [AIMessage(content="Hello"), HumanMessage(content="How are you?")]
|
||||
expected_prompt = "Generated chat prompt"
|
||||
|
||||
chat_hugging_face.tokenizer.apply_chat_template.return_value = expected_prompt
|
||||
|
||||
result = chat_hugging_face._to_chat_prompt(messages)
|
||||
|
||||
assert result == expected_prompt
|
||||
chat_hugging_face.tokenizer.apply_chat_template.assert_called_once_with(
|
||||
[
|
||||
{"role": "assistant", "content": "Hello"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("message", "expected"),
|
||||
[
|
||||
(
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
),
|
||||
(
|
||||
AIMessage(content="How can I help you?"),
|
||||
{"role": "assistant", "content": "How can I help you?"},
|
||||
),
|
||||
(
|
||||
HumanMessage(content="Hello"),
|
||||
{"role": "user", "content": "Hello"},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_to_chatml_format(
|
||||
chat_hugging_face: Any, message: BaseMessage, expected: Dict[str, str]
|
||||
) -> None:
|
||||
result = chat_hugging_face._to_chatml_format(message)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_to_chatml_format_with_invalid_type(chat_hugging_face: Any) -> None:
|
||||
message = "Invalid message type"
|
||||
with pytest.raises(ValueError) as e:
|
||||
chat_hugging_face._to_chatml_format(message)
|
||||
assert "Unknown message type:" in str(e.value)
|
||||
|
||||
|
||||
def tool_mock() -> Dict:
|
||||
return {"function": {"name": "test_tool"}}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tools, tool_choice, expected_exception, expected_message",
|
||||
[
|
||||
([tool_mock()], ["invalid type"], ValueError, "Unrecognized tool_choice type."),
|
||||
(
|
||||
[tool_mock(), tool_mock()],
|
||||
"test_tool",
|
||||
ValueError,
|
||||
"must provide exactly one tool.",
|
||||
),
|
||||
(
|
||||
[tool_mock()],
|
||||
{"type": "function", "function": {"name": "other_tool"}},
|
||||
ValueError,
|
||||
"Tool choice {'type': 'function', 'function': {'name': 'other_tool'}} "
|
||||
"was specified, but the only provided tool was test_tool.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_bind_tools_errors(
|
||||
chat_hugging_face: Any,
|
||||
tools: Dict[str, str],
|
||||
tool_choice: Any,
|
||||
expected_exception: Any,
|
||||
expected_message: str,
|
||||
) -> None:
|
||||
with patch(
|
||||
"langchain_huggingface.chat_models.huggingface.convert_to_openai_tool",
|
||||
side_effect=lambda x: x,
|
||||
):
|
||||
with pytest.raises(expected_exception) as excinfo:
|
||||
chat_hugging_face.bind_tools(tools, tool_choice=tool_choice)
|
||||
assert expected_message in str(excinfo.value)
|
||||
|
||||
|
||||
def test_bind_tools(chat_hugging_face: Any) -> None:
|
||||
tools = [MagicMock(spec=BaseTool)]
|
||||
with patch(
|
||||
"langchain_huggingface.chat_models.huggingface.convert_to_openai_tool",
|
||||
side_effect=lambda x: x,
|
||||
), patch("langchain_core.runnables.base.Runnable.bind") as mock_super_bind:
|
||||
chat_hugging_face.bind_tools(tools, tool_choice="auto")
|
||||
mock_super_bind.assert_called_once()
|
||||
_, kwargs = mock_super_bind.call_args
|
||||
assert kwargs["tools"] == tools
|
||||
assert kwargs["tool_choice"] == "auto"
|
Loading…
Reference in New Issue
Block a user