Compare commits

..

1 Commits

Author SHA1 Message Date
Bagatur
1812c1527f docs: version section sidebar 2024-07-31 02:49:26 -07:00
111 changed files with 1351 additions and 3446 deletions

View File

@@ -43,7 +43,7 @@
"\n",
"This is the easiest and most reliable way to get structured outputs. `with_structured_output()` is implemented for models that provide native APIs for structuring outputs, like tool/function calling or JSON mode, and makes use of these capabilities under the hood.\n",
"\n",
"This method takes a schema as input which specifies the names, types, and descriptions of the desired output attributes. The method returns a model-like Runnable, except that instead of outputting strings or Messages it outputs objects corresponding to the given schema. The schema can be specified as a TypedDict class, [JSON Schema](https://json-schema.org/) or a Pydantic class. If TypedDict or JSON Schema are used then a dictionary will be returned by the Runnable, and if a Pydantic class is used then a Pydantic object will be returned.\n",
"This method takes a schema as input which specifies the names, types, and descriptions of the desired output attributes. The method returns a model-like Runnable, except that instead of outputting strings or Messages it outputs objects corresponding to the given schema. The schema can be specified as a [JSON Schema](https://json-schema.org/) or a Pydantic class. If JSON Schema is used then a dictionary will be returned by the Runnable, and if a Pydantic class is used then Pydantic objects will be returned.\n",
"\n",
"As an example, let's get a model to generate a joke and separate the setup from the punchline:\n",
"\n",
@@ -58,7 +58,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "6d55008f",
"metadata": {},
"outputs": [],
@@ -68,7 +68,7 @@
"\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-4o\", temperature=0)"
"llm = ChatOpenAI(model=\"gpt-4-0125-preview\", temperature=0)"
]
},
{
@@ -76,24 +76,22 @@
"id": "a808a401-be1f-49f9-ad13-58dd68f7db5f",
"metadata": {},
"source": [
"### Pydantic class\n",
"\n",
"If we want the model to return a Pydantic object, we just need to pass in the desired Pydantic class. The key advantage of using Pydantic is that the model-generated output will be validated. Pydantic will raise an error if any required fields are missing or if any fields are of the wrong type."
"If we want the model to return a Pydantic object, we just need to pass in the desired Pydantic class:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "070bf702",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Joke(setup='Why was the cat sitting on the computer?', punchline='Because it wanted to keep an eye on the mouse!', rating=7)"
"Joke(setup='Why was the cat sitting on the computer?', punchline='Because it wanted to keep an eye on the mouse!', rating=8)"
]
},
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@@ -104,15 +102,12 @@
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"\n",
"\n",
"# Pydantic\n",
"class Joke(BaseModel):\n",
" \"\"\"Joke to tell user.\"\"\"\n",
"\n",
" setup: str = Field(description=\"The setup of the joke\")\n",
" punchline: str = Field(description=\"The punchline to the joke\")\n",
" rating: Optional[int] = Field(\n",
" default=None, description=\"How funny the joke is, from 1 to 10\"\n",
" )\n",
" rating: Optional[int] = Field(description=\"How funny the joke is, from 1 to 10\")\n",
"\n",
"\n",
"structured_llm = llm.with_structured_output(Joke)\n",
@@ -135,73 +130,12 @@
"id": "deddb6d3",
"metadata": {},
"source": [
"### TypedDict or JSON Schema\n",
"\n",
"If you don't want to use Pydantic, explicitly don't want validation of the arguments, or want to be able to stream the model outputs, you can define your schema using a TypedDict class. We can optionally use a special `Annotated` syntax supported by LangChain that allows you to specify the default value and description of a field. Note, the default value is *not* filled in automatically if the model doesn't generate it, it is only used in defining the schema that is passed to the model.\n",
"\n",
":::info Requirements\n",
"\n",
"- Core: `langchain-core>=0.2.26`\n",
"- Typing extensions: It is highly recommended to import `Annotated` and `TypedDict` from `typing_extensions` instead of `typing` to ensure consistent behavior across Python versions.\n",
"\n",
":::"
"We can also pass in a [JSON Schema](https://json-schema.org/) dict if you prefer not to use Pydantic. In this case, the response is also a dict:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "70d82891-42e8-424a-919e-07d83bcfec61",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'setup': 'Why was the cat sitting on the computer?',\n",
" 'punchline': 'Because it wanted to keep an eye on the mouse!',\n",
" 'rating': 7}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from typing_extensions import Annotated, TypedDict\n",
"\n",
"\n",
"# TypedDict\n",
"class Joke(TypedDict):\n",
" \"\"\"Joke to tell user.\"\"\"\n",
"\n",
" setup: Annotated[str, ..., \"The setup of the joke\"]\n",
"\n",
" # Alternatively, we could have specified setup as:\n",
"\n",
" # setup: str # no default, no description\n",
" # setup: Annotated[str, ...] # no default, no description\n",
" # setup: Annotated[str, \"foo\"] # default, no description\n",
"\n",
" punchline: Annotated[str, ..., \"The punchline of the joke\"]\n",
" rating: Annotated[Optional[int], None, \"How funny the joke is, from 1 to 10\"]\n",
"\n",
"\n",
"structured_llm = llm.with_structured_output(Joke)\n",
"\n",
"structured_llm.invoke(\"Tell me a joke about cats\")"
]
},
{
"cell_type": "markdown",
"id": "e4d7b4dc-f617-4ea8-aa58-847c228791b4",
"metadata": {},
"source": [
"Equivalently, we can pass in a [JSON Schema](https://json-schema.org/) dict. This requires no imports or classes and makes it very clear exactly how each parameter is documented, at the cost of being a bit more verbose."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6700994a",
"metadata": {},
"outputs": [
@@ -210,10 +144,10 @@
"text/plain": [
"{'setup': 'Why was the cat sitting on the computer?',\n",
" 'punchline': 'Because it wanted to keep an eye on the mouse!',\n",
" 'rating': 7}"
" 'rating': 8}"
]
},
"execution_count": 6,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -235,7 +169,6 @@
" \"rating\": {\n",
" \"type\": \"integer\",\n",
" \"description\": \"How funny the joke is, from 1 to 10\",\n",
" \"default\": None,\n",
" },\n",
" },\n",
" \"required\": [\"setup\", \"punchline\"],\n",
@@ -252,7 +185,7 @@
"source": [
"### Choosing between multiple schemas\n",
"\n",
"The simplest way to let the model choose from multiple schemas is to create a parent schema that has a Union-typed attribute:"
"The simplest way to let the model choose from multiple schemas is to create a parent Pydantic class that has a Union-typed attribute:"
]
},
{
@@ -276,17 +209,6 @@
"from typing import Union\n",
"\n",
"\n",
"# Pydantic\n",
"class Joke(BaseModel):\n",
" \"\"\"Joke to tell user.\"\"\"\n",
"\n",
" setup: str = Field(description=\"The setup of the joke\")\n",
" punchline: str = Field(description=\"The punchline to the joke\")\n",
" rating: Optional[int] = Field(\n",
" default=None, description=\"How funny the joke is, from 1 to 10\"\n",
" )\n",
"\n",
"\n",
"class ConversationalResponse(BaseModel):\n",
" \"\"\"Respond in a conversational manner. Be kind and helpful.\"\"\"\n",
"\n",
@@ -338,7 +260,7 @@
"source": [
"### Streaming\n",
"\n",
"We can stream outputs from our structured model when the output type is a dict (i.e., when the schema is specified as a TypedDict class or JSON Schema dict). \n",
"We can stream outputs from our structured model when the output type is a dict (i.e., when the schema is specified as a JSON Schema dict). \n",
"\n",
":::info\n",
"\n",
@@ -349,7 +271,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 43,
"id": "aff89877-28a3-472f-a1aa-eff893fe7736",
"metadata": {},
"outputs": [
@@ -380,24 +302,12 @@
"{'setup': 'Why was the cat sitting on the computer?', 'punchline': 'Because it wanted to keep an eye on the'}\n",
"{'setup': 'Why was the cat sitting on the computer?', 'punchline': 'Because it wanted to keep an eye on the mouse'}\n",
"{'setup': 'Why was the cat sitting on the computer?', 'punchline': 'Because it wanted to keep an eye on the mouse!'}\n",
"{'setup': 'Why was the cat sitting on the computer?', 'punchline': 'Because it wanted to keep an eye on the mouse!', 'rating': 7}\n"
"{'setup': 'Why was the cat sitting on the computer?', 'punchline': 'Because it wanted to keep an eye on the mouse!', 'rating': 8}\n"
]
}
],
"source": [
"from typing_extensions import Annotated, TypedDict\n",
"\n",
"\n",
"# TypedDict\n",
"class Joke(TypedDict):\n",
" \"\"\"Joke to tell user.\"\"\"\n",
"\n",
" setup: Annotated[str, ..., \"The setup of the joke\"]\n",
" punchline: Annotated[str, ..., \"The punchline of the joke\"]\n",
" rating: Annotated[Optional[int], None, \"How funny the joke is, from 1 to 10\"]\n",
"\n",
"\n",
"structured_llm = llm.with_structured_output(Joke)\n",
"structured_llm = llm.with_structured_output(json_schema)\n",
"\n",
"for chunk in structured_llm.stream(\"Tell me a joke about cats\"):\n",
" print(chunk)"
@@ -417,7 +327,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 47,
"id": "283ba784-2072-47ee-9b2c-1119e3c69e8e",
"metadata": {},
"outputs": [
@@ -425,11 +335,11 @@
"data": {
"text/plain": [
"{'setup': 'Woodpecker',\n",
" 'punchline': \"Woodpecker who? Woodpecker who can't find a tree is just a bird with a headache!\",\n",
" 'rating': 7}"
" 'punchline': \"Woodpecker goes 'knock knock', but don't worry, they never expect you to answer the door!\",\n",
" 'rating': 8}"
]
},
"execution_count": 11,
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
@@ -467,7 +377,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 46,
"id": "d7381cb0-b2c3-4302-a319-ed72d0b9e43f",
"metadata": {},
"outputs": [
@@ -475,11 +385,11 @@
"data": {
"text/plain": [
"{'setup': 'Crocodile',\n",
" 'punchline': 'Crocodile be seeing you later, alligator!',\n",
" 'punchline': \"Crocodile 'see you later', but in a while, it becomes an alligator!\",\n",
" 'rating': 7}"
]
},
"execution_count": 12,
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
@@ -581,24 +491,23 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 6,
"id": "df0370e3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'setup': 'Why was the cat sitting on the computer?',\n",
" 'punchline': 'Because it wanted to keep an eye on the mouse!'}"
"Joke(setup='Why was the cat sitting on the computer?', punchline='Because it wanted to keep an eye on the mouse!', rating=None)"
]
},
"execution_count": 15,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"structured_llm = llm.with_structured_output(None, method=\"json_mode\")\n",
"structured_llm = llm.with_structured_output(Joke, method=\"json_mode\")\n",
"\n",
"structured_llm.invoke(\n",
" \"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys\"\n",
@@ -617,21 +526,19 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 5,
"id": "10ed2842",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_f25ZRmh8u5vHlOWfTUw8sJFZ', 'function': {'arguments': '{\"setup\":\"Why was the cat sitting on the computer?\",\"punchline\":\"Because it wanted to keep an eye on the mouse!\",\"rating\":7}', 'name': 'Joke'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 33, 'prompt_tokens': 93, 'total_tokens': 126}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_4e2b2da518', 'finish_reason': 'stop', 'logprobs': None}, id='run-d880d7e2-df08-4e9e-ad92-dfc29f2fd52f-0', tool_calls=[{'name': 'Joke', 'args': {'setup': 'Why was the cat sitting on the computer?', 'punchline': 'Because it wanted to keep an eye on the mouse!', 'rating': 7}, 'id': 'call_f25ZRmh8u5vHlOWfTUw8sJFZ', 'type': 'tool_call'}], usage_metadata={'input_tokens': 93, 'output_tokens': 33, 'total_tokens': 126}),\n",
" 'parsed': {'setup': 'Why was the cat sitting on the computer?',\n",
" 'punchline': 'Because it wanted to keep an eye on the mouse!',\n",
" 'rating': 7},\n",
"{'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_ASK4EmZeZ69Fi3p554Mb4rWy', 'function': {'arguments': '{\"setup\":\"Why was the cat sitting on the computer?\",\"punchline\":\"Because it wanted to keep an eye on the mouse!\"}', 'name': 'Joke'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 36, 'prompt_tokens': 107, 'total_tokens': 143}, 'model_name': 'gpt-4-0125-preview', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-6491d35b-9164-4656-b75c-d7882cfb76cb-0', tool_calls=[{'name': 'Joke', 'args': {'setup': 'Why was the cat sitting on the computer?', 'punchline': 'Because it wanted to keep an eye on the mouse!'}, 'id': 'call_ASK4EmZeZ69Fi3p554Mb4rWy'}], usage_metadata={'input_tokens': 107, 'output_tokens': 36, 'total_tokens': 143}),\n",
" 'parsed': Joke(setup='Why was the cat sitting on the computer?', punchline='Because it wanted to keep an eye on the mouse!', rating=None),\n",
" 'parsing_error': None}"
]
},
"execution_count": 17,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -639,7 +546,9 @@
"source": [
"structured_llm = llm.with_structured_output(Joke, include_raw=True)\n",
"\n",
"structured_llm.invoke(\"Tell me a joke about cats\")"
"structured_llm.invoke(\n",
" \"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys\"\n",
")"
]
},
{
@@ -915,7 +824,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
@@ -929,7 +838,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.10.5"
}
},
"nbformat": 4,

View File

@@ -24,9 +24,10 @@
"This guide assumes familiarity with the following concepts:\n",
"\n",
"- [Chat models](/docs/concepts/#chat-models)\n",
"- [LangChain Tools](/docs/concepts/#tools)\n",
"- [Tool calling](/docs/concepts/#functiontool-calling)\n",
"- [Tools](/docs/concepts/#tools)\n",
"- [Output parsers](/docs/concepts/#output-parsers)\n",
"\n",
":::\n",
"\n",
"[Tool calling](/docs/concepts/#functiontool-calling) allows a chat model to respond to a given prompt by \"calling a tool\".\n",
@@ -37,11 +38,15 @@
"\n",
"![Diagram of calling a tool](/img/tool_call.png)\n",
"\n",
"If you want to see how to use the model-generated tool call to actually run a tool [check out this guide](/docs/how_to/tool_results_pass_to_model/).\n",
"If you want to see how to use the model-generated tool call to actually run a tool function [check out this guide](/docs/how_to/tool_results_pass_to_model/).\n",
"\n",
":::note Supported models\n",
"\n",
"Tool calling is not universal, but is supported by many popular LLM providers. You can find a [list of all models that support tool calling here](/docs/integrations/chat/).\n",
"Tool calling is not universal, but is supported by many popular LLM providers, including [Anthropic](/docs/integrations/chat/anthropic/), \n",
"[Cohere](/docs/integrations/chat/cohere/), [Google](/docs/integrations/chat/google_vertex_ai_palm/), \n",
"[Mistral](/docs/integrations/chat/mistralai/), [OpenAI](/docs/integrations/chat/openai/), and even for locally-running models via [Ollama](/docs/integrations/chat/ollama/).\n",
"\n",
"You can find a [list of all models that support tool calling here](/docs/integrations/chat/).\n",
"\n",
":::\n",
"\n",
@@ -53,12 +58,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Defining tool schemas\n",
"## Passing tools to chat models\n",
"\n",
"For a model to be able to call tools, we need to pass in tool schemas that describe what the tool does and what it's arguments are. Chat models that support tool calling features implement a `.bind_tools()` method for passing tool schemas to the model. Tool schemas can be passed in as Python functions (with typehints and docstrings), Pydantic models, TypedDict classes, or LangChain [Tool objects](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.BaseTool.html#langchain_core.tools.BaseTool). Subsequent invocations of the model will pass in these tool schemas along with the prompt.\n",
"Chat models that support tool calling features implement a `.bind_tools` method, which \n",
"receives a list of functions, Pydantic models, or LangChain [tool objects](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.BaseTool.html#langchain_core.tools.BaseTool) \n",
"and binds them to the chat model in its expected format. Subsequent invocations of the \n",
"chat model will include tool schemas in its calls to the LLM.\n",
"\n",
"### Python functions\n",
"Our tool schemas can be Python functions:"
"For example, below we implement simple tools for arithmetic:"
]
},
{
@@ -67,41 +74,26 @@
"metadata": {},
"outputs": [],
"source": [
"# The function name, type hints, and docstring are all part of the tool\n",
"# schema that's passed to the model. Defining good, descriptive schemas\n",
"# is an extension of prompt engineering and is an important part of\n",
"# getting models to perform well.\n",
"def add(a: int, b: int) -> int:\n",
" \"\"\"Add two integers.\n",
"\n",
" Args:\n",
" a: First integer\n",
" b: Second integer\n",
" \"\"\"\n",
" \"\"\"Adds a and b.\"\"\"\n",
" return a + b\n",
"\n",
"\n",
"def multiply(a: int, b: int) -> int:\n",
" \"\"\"Multiply two integers.\n",
" \"\"\"Multiplies a and b.\"\"\"\n",
" return a * b\n",
"\n",
" Args:\n",
" a: First integer\n",
" b: Second integer\n",
" \"\"\"\n",
" return a * b"
"\n",
"tools = [add, multiply]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### LangChain Tool\n",
"\n",
"LangChain also implements a `@tool` decorator that allows for further control of the tool schema, such as tool names and argument descriptions. See the how-to guide [here](/docs/how_to/custom_tools/#creating-tools-from-functions) for details.\n",
"\n",
"### Pydantic class\n",
"\n",
"You can equivalently define the schemas without the accompanying functions using [Pydantic](https://docs.pydantic.dev):"
"We can also define the schemas without the accompanying functions using [Pydantic](https://docs.pydantic.dev):"
]
},
{
@@ -113,57 +105,23 @@
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"\n",
"\n",
"class add(BaseModel):\n",
" \"\"\"Add two integers.\"\"\"\n",
"# Note that the docstrings here are crucial, as they will be passed along\n",
"# to the model along with the class name.\n",
"class Add(BaseModel):\n",
" \"\"\"Add two integers together.\"\"\"\n",
"\n",
" a: int = Field(..., description=\"First integer\")\n",
" b: int = Field(..., description=\"Second integer\")\n",
"\n",
"\n",
"class multiply(BaseModel):\n",
" \"\"\"Multiply two integers.\"\"\"\n",
"class Multiply(BaseModel):\n",
" \"\"\"Multiply two integers together.\"\"\"\n",
"\n",
" a: int = Field(..., description=\"First integer\")\n",
" b: int = Field(..., description=\"Second integer\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### TypedDict class\n",
"\n",
":::info Requires `langchain-core>=0.2.25`\n",
":::\n",
"\n",
"Or using TypedDicts and annotations:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from typing_extensions import Annotated, TypedDict\n",
" b: int = Field(..., description=\"Second integer\")\n",
"\n",
"\n",
"class add(TypedDict):\n",
" \"\"\"Add two integers.\"\"\"\n",
"\n",
" # Annotations must have the type and can optionally include a default value and description (in that order).\n",
" a: Annotated[int, ..., \"First integer\"]\n",
" b: Annotated[int, ..., \"Second integer\"]\n",
"\n",
"\n",
"class multiply(BaseModel):\n",
" \"\"\"Multiply two integers.\"\"\"\n",
"\n",
" a: Annotated[int, ..., \"First integer\"]\n",
" b: Annotated[int, ..., \"Second integer\"]\n",
"\n",
"\n",
"tools = [add, multiply]"
"tools = [Add, Multiply]"
]
},
{
@@ -171,7 +129,7 @@
"metadata": {},
"source": [
"To actually bind those schemas to a chat model, we'll use the `.bind_tools()` method. This handles converting\n",
"the `add` and `multiply` schemas to the proper format for the model. The tool schema will then be passed it in each time the model is invoked.\n",
"the `Add` and `Multiply` schemas to the proper format for the model. The tool schema will then be passed it in each time the model is invoked.\n",
"\n",
"```{=mdx}\n",
"import ChatModelTabs from \"@theme/ChatModelTabs\";\n",
@@ -206,16 +164,16 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_BwYJ4UgU5pRVCBOUmiu7NhF9', 'function': {'arguments': '{\"a\":3,\"b\":12}', 'name': 'multiply'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 17, 'prompt_tokens': 80, 'total_tokens': 97}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_ba606877f9', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-7f05e19e-4561-40e2-a2d0-8f4e28e9a00f-0', tool_calls=[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_BwYJ4UgU5pRVCBOUmiu7NhF9', 'type': 'tool_call'}], usage_metadata={'input_tokens': 80, 'output_tokens': 17, 'total_tokens': 97})"
"AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_wLTBasMppAwpdiA5CD92l9x7', 'function': {'arguments': '{\"a\":3,\"b\":12}', 'name': 'Multiply'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 18, 'prompt_tokens': 89, 'total_tokens': 107}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_0f03d4f0ee', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-d3f36cca-f225-416f-ac16-0217046f0b38-0', tool_calls=[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_wLTBasMppAwpdiA5CD92l9x7', 'type': 'tool_call'}], usage_metadata={'input_tokens': 89, 'output_tokens': 18, 'total_tokens': 107})"
]
},
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -256,23 +214,23 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'name': 'multiply',\n",
"[{'name': 'Multiply',\n",
" 'args': {'a': 3, 'b': 12},\n",
" 'id': 'call_rcdMie7E89Xx06lEKKxJyB5N',\n",
" 'id': 'call_uqJsNrDJ8ZZnFa1BHHYAllEv',\n",
" 'type': 'tool_call'},\n",
" {'name': 'add',\n",
" {'name': 'Add',\n",
" 'args': {'a': 11, 'b': 49},\n",
" 'id': 'call_nheGN8yfvSJsnIuGZaXihou3',\n",
" 'id': 'call_ud1uHAaYsdpWuxugwoJ63BDs',\n",
" 'type': 'tool_call'}]"
]
},
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -294,49 +252,31 @@
"are populated in the `.invalid_tool_calls` attribute. An `InvalidToolCall` can have \n",
"a name, string arguments, identifier, and error message.\n",
"\n",
"\n",
"## Parsing\n",
"\n",
"If desired, [output parsers](/docs/how_to#output-parsers) can further process the output. For example, we can convert existing values populated on the `.tool_calls` to Pydantic objects using the\n",
"If desired, [output parsers](/docs/how_to#output-parsers) can further \n",
"process the output. For example, we can convert existing values populated on the `.tool_calls` attribute back to the original Pydantic class using the\n",
"[PydanticToolsParser](https://api.python.langchain.com/en/latest/output_parsers/langchain_core.output_parsers.openai_tools.PydanticToolsParser.html):"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[multiply(a=3, b=12), add(a=11, b=49)]"
"[Multiply(a=3, b=12), Add(a=11, b=49)]"
]
},
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.output_parsers import PydanticToolsParser\n",
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"\n",
"\n",
"class add(BaseModel):\n",
" \"\"\"Add two integers.\"\"\"\n",
"\n",
" a: int = Field(..., description=\"First integer\")\n",
" b: int = Field(..., description=\"Second integer\")\n",
"\n",
"\n",
"class multiply(BaseModel):\n",
" \"\"\"Multiply two integers.\"\"\"\n",
"\n",
" a: int = Field(..., description=\"First integer\")\n",
" b: int = Field(..., description=\"Second integer\")\n",
"\n",
"\n",
"chain = llm_with_tools | PydanticToolsParser(tools=[add, multiply])\n",
"chain = llm_with_tools | PydanticToolsParser(tools=[Multiply, Add])\n",
"chain.invoke(query)"
]
},
@@ -354,18 +294,18 @@
"\n",
"You can also check out some more specific uses of tool calling:\n",
"\n",
"- Getting [structured outputs](/docs/how_to/structured_output/) from models\n",
"- Few shot prompting [with tools](/docs/how_to/tools_few_shot/)\n",
"- Stream [tool calls](/docs/how_to/tool_streaming/)\n",
"- Pass [runtime values to tools](/docs/how_to/tool_runtime)"
"- Pass [runtime values to tools](/docs/how_to/tool_runtime)\n",
"- Getting [structured outputs](/docs/how_to/structured_output/) from models"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv-311",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "poetry-venv-311"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
@@ -377,7 +317,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.10.5"
}
},
"nbformat": 4,

View File

@@ -17,25 +17,26 @@
"source": [
"# ChatAI21\n",
"\n",
"## Overview\n",
"\n",
"This notebook covers how to get started with AI21 chat models.\n",
"Note that different chat models support different parameters. See the [AI21 documentation](https://docs.ai21.com/reference) to learn more about the parameters in your chosen model.\n",
"Note that different chat models support different parameters. See the ",
"[AI21 documentation](https://docs.ai21.com/reference) to learn more about the parameters in your chosen model.\n",
"[See all AI21's LangChain components.](https://pypi.org/project/langchain-ai21/) \n",
"\n",
"### Integration details\n",
"\n",
"| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/v0.2/docs/integrations/chat/__package_name_short_snake__) | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
"| [ChatAI21](https://api.python.langchain.com/en/latest/chat_models/langchain_ai21.chat_models.ChatAI21.html#langchain_ai21.chat_models.ChatAI21) | [langchain-ai21](https://api.python.langchain.com/en/latest/ai21_api_reference.html) | ❌ | beta | ✅ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain-ai21?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain-ai21?style=flat-square&label=%20) |\n",
"\n",
"### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n",
"\n",
"\n",
"## Setup"
"## Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4c3bef91",
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-15T06:50:44.929635Z",
"start_time": "2024-02-15T06:50:41.209704Z"
}
},
"outputs": [],
"source": [
"!pip install -qU langchain-ai21"
]
},
{
@@ -43,9 +44,10 @@
"id": "2b4f3e15",
"metadata": {},
"source": [
"### Credentials\n",
"## Environment Setup\n",
"\n",
"We'll need to get an [AI21 API key](https://docs.ai21.com/) and set the `AI21_API_KEY` environment variable:\n"
"We'll need to get an [AI21 API key](https://docs.ai21.com/) and set the ",
"`AI21_API_KEY` environment variable:\n"
]
},
{
@@ -63,168 +65,50 @@
"os.environ[\"AI21_API_KEY\"] = getpass()"
]
},
{
"cell_type": "markdown",
"id": "f6844fff-3702-4489-ab74-732f69f3b9d7",
"metadata": {},
"source": [
"If you want to get automated tracing of your model calls you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c2e19d3-7c58-4470-9e1a-718b27a32056",
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")"
]
},
{
"cell_type": "markdown",
"id": "98e22f31-8acc-42d6-916d-415d1263c56e",
"metadata": {},
"source": [
"### Installation"
]
},
{
"cell_type": "markdown",
"id": "f9699cd9-58f2-450e-aa64-799e66906c0f",
"metadata": {},
"source": [
"!pip install -qU langchain-ai21"
]
},
{
"cell_type": "markdown",
"id": "4828829d3da430ce",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
"collapsed": false
},
"source": [
"## Instantiation\n",
"\n",
"Now we can instantiate our model object and generate chat completions:"
"## Usage"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c40756fb-cbf8-4d44-a293-3989d707237e",
"metadata": {},
"outputs": [],
"source": [
"from langchain_ai21 import ChatAI21\n",
"\n",
"llm = ChatAI21(model=\"jamba-instruct\", temperature=0)"
]
},
{
"cell_type": "markdown",
"id": "2bdc5d68-2a19-495e-8c04-d11adc86d3ae",
"metadata": {},
"source": [
"## Invocation"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "46b982dc-5d8a-46da-a711-81c03ccd6adc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"J'adore programmer.\", id='run-2e8d16d6-a06e-45cb-8d0c-1c8208645033-0')"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n",
" ),\n",
" (\"human\", \"I love programming.\"),\n",
"]\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg"
]
},
{
"cell_type": "markdown",
"id": "10a30f84-b531-4fd5-8b5b-91512fbdc75b",
"metadata": {},
"source": [
"## Chaining\n",
"\n",
"We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 1,
"id": "39353473fce5dd2e",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Ich liebe das Programmieren.', id='run-e1bd82dc-1a7e-4b2e-bde9-ac995929ac0f-0')"
"AIMessage(content='Bonjour, comment vas-tu?')"
]
},
"execution_count": 4,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_ai21 import ChatAI21\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate(\n",
"chat = ChatAI21(model=\"jamba-instruct\")\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" (\"system\", \"You are a helpful assistant that translates English to French.\"),\n",
" (\"human\", \"Translate this sentence from English to French. {english_text}.\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"id": "e79de691-9dd6-4697-b57e-59a4a3cc073a",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all ChatAI21 features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_ai21.chat_models.ChatAI21.html"
"chain = prompt | chat\n",
"chain.invoke({\"english_text\": \"Hello, how are you?\"})"
]
}
],
@@ -244,7 +128,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
"version": "3.11.4"
}
},
"nbformat": 4,

View File

@@ -115,7 +115,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "cb09c344-1836-4e0c-acf8-11d13ac1dbae",
"metadata": {},
"outputs": [],
@@ -123,8 +123,8 @@
"from langchain_openai import AzureChatOpenAI\n",
"\n",
"llm = AzureChatOpenAI(\n",
" azure_deployment=\"gpt-35-turbo\", # or your deployment\n",
" api_version=\"2023-06-01-preview\", # or your api version\n",
" azure_deployment=\"YOUR-DEPLOYMENT\",\n",
" api_version=\"2024-05-01-preview\",\n",
" temperature=0,\n",
" max_tokens=None,\n",
" timeout=None,\n",
@@ -143,7 +143,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "62e0dbc3",
"metadata": {
"tags": []
@@ -152,10 +152,10 @@
{
"data": {
"text/plain": [
"AIMessage(content=\"J'adore la programmation.\", response_metadata={'token_usage': {'completion_tokens': 8, 'prompt_tokens': 31, 'total_tokens': 39}, 'model_name': 'gpt-35-turbo', 'system_fingerprint': None, 'prompt_filter_results': [{'prompt_index': 0, 'content_filter_results': {'hate': {'filtered': False, 'severity': 'safe'}, 'self_harm': {'filtered': False, 'severity': 'safe'}, 'sexual': {'filtered': False, 'severity': 'safe'}, 'violence': {'filtered': False, 'severity': 'safe'}}}], 'finish_reason': 'stop', 'logprobs': None, 'content_filter_results': {'hate': {'filtered': False, 'severity': 'safe'}, 'self_harm': {'filtered': False, 'severity': 'safe'}, 'sexual': {'filtered': False, 'severity': 'safe'}, 'violence': {'filtered': False, 'severity': 'safe'}}}, id='run-bea4b46c-e3e1-4495-9d3a-698370ad963d-0', usage_metadata={'input_tokens': 31, 'output_tokens': 8, 'total_tokens': 39})"
"AIMessage(content=\"J'adore la programmation.\", response_metadata={'token_usage': {'completion_tokens': 8, 'prompt_tokens': 31, 'total_tokens': 39}, 'model_name': 'gpt-35-turbo', 'system_fingerprint': None, 'prompt_filter_results': [{'prompt_index': 0, 'content_filter_results': {'hate': {'filtered': False, 'severity': 'safe'}, 'self_harm': {'filtered': False, 'severity': 'safe'}, 'sexual': {'filtered': False, 'severity': 'safe'}, 'violence': {'filtered': False, 'severity': 'safe'}}}], 'finish_reason': 'stop', 'logprobs': None, 'content_filter_results': {'hate': {'filtered': False, 'severity': 'safe'}, 'self_harm': {'filtered': False, 'severity': 'safe'}, 'sexual': {'filtered': False, 'severity': 'safe'}, 'violence': {'filtered': False, 'severity': 'safe'}}}, id='run-a6a732c2-cb02-4e50-9a9c-ab30eab034fc-0', usage_metadata={'input_tokens': 31, 'output_tokens': 8, 'total_tokens': 39})"
]
},
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -174,7 +174,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 11,
"id": "d86145b3-bfef-46e8-b227-4dda5c9c2705",
"metadata": {},
"outputs": [
@@ -202,17 +202,17 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 12,
"id": "e197d1d7-a070-4c96-9f8a-a0e86d046e0b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Ich liebe das Programmieren.', response_metadata={'token_usage': {'completion_tokens': 6, 'prompt_tokens': 26, 'total_tokens': 32}, 'model_name': 'gpt-35-turbo', 'system_fingerprint': None, 'prompt_filter_results': [{'prompt_index': 0, 'content_filter_results': {'hate': {'filtered': False, 'severity': 'safe'}, 'self_harm': {'filtered': False, 'severity': 'safe'}, 'sexual': {'filtered': False, 'severity': 'safe'}, 'violence': {'filtered': False, 'severity': 'safe'}}}], 'finish_reason': 'stop', 'logprobs': None, 'content_filter_results': {'hate': {'filtered': False, 'severity': 'safe'}, 'self_harm': {'filtered': False, 'severity': 'safe'}, 'sexual': {'filtered': False, 'severity': 'safe'}, 'violence': {'filtered': False, 'severity': 'safe'}}}, id='run-cbc44038-09d3-40d4-9da2-c5910ee636ca-0', usage_metadata={'input_tokens': 26, 'output_tokens': 6, 'total_tokens': 32})"
"AIMessage(content='Ich liebe das Programmieren.', response_metadata={'token_usage': {'completion_tokens': 6, 'prompt_tokens': 26, 'total_tokens': 32}, 'model_name': 'gpt-35-turbo', 'system_fingerprint': None, 'prompt_filter_results': [{'prompt_index': 0, 'content_filter_results': {'hate': {'filtered': False, 'severity': 'safe'}, 'self_harm': {'filtered': False, 'severity': 'safe'}, 'sexual': {'filtered': False, 'severity': 'safe'}, 'violence': {'filtered': False, 'severity': 'safe'}}}], 'finish_reason': 'stop', 'logprobs': None, 'content_filter_results': {'hate': {'filtered': False, 'severity': 'safe'}, 'self_harm': {'filtered': False, 'severity': 'safe'}, 'sexual': {'filtered': False, 'severity': 'safe'}, 'violence': {'filtered': False, 'severity': 'safe'}}}, id='run-084967d7-06f2-441f-b5c1-477e2a9e9d03-0', usage_metadata={'input_tokens': 26, 'output_tokens': 6, 'total_tokens': 32})"
]
},
"execution_count": 5,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@@ -264,8 +264,8 @@
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2ca02d23-60d0-43eb-8d04-070f61f8fefd",
"execution_count": 5,
"id": "84c411b0-1790-4798-8bb7-47d8ece4c2dc",
"metadata": {},
"outputs": [
{
@@ -288,22 +288,22 @@
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e1b07ae2-3de7-44bd-bfdc-b76f4ba45a35",
"execution_count": 6,
"id": "21234693-d92b-4d69-8a7f-55aa062084bf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total Cost (USD): $0.000074\n"
"Total Cost (USD): $0.000078\n"
]
}
],
"source": [
"llm_0301 = AzureChatOpenAI(\n",
" azure_deployment=\"gpt-35-turbo\", # or your deployment\n",
" api_version=\"2023-06-01-preview\", # or your api version\n",
" azure_deployment=\"YOUR-DEPLOYMENT\",\n",
" api_version=\"2024-05-01-preview\",\n",
" model_version=\"0301\",\n",
")\n",
"with get_openai_callback() as cb:\n",
@@ -338,7 +338,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
"version": "3.11.9"
}
},
"nbformat": 4,

View File

@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "raw",
"id": "53fbf15f",
"id": "afaf8039",
"metadata": {},
"source": [
"---\n",
@@ -12,103 +12,129 @@
},
{
"cell_type": "markdown",
"id": "bf733a38-db84-4363-89e2-de6735c37230",
"id": "e49f1e0d",
"metadata": {},
"source": [
"# Cohere\n",
"# ChatCohere\n",
"\n",
"This notebook covers how to get started with [Cohere chat models](https://cohere.com/chat).\n",
"This doc will help you get started with Cohere [chat models](/docs/concepts/#chat-models). For detailed documentation of all ChatCohere features and configurations head to the [API reference](https://api.python.langchain.com/en/latest/chat_models/langchain_cohere.chat_models.ChatCohere.html).\n",
"\n",
"For an overview of all Cohere models head to the [Cohere docs](https://docs.cohere.com/docs/models).\n",
"\n",
"## Overview\n",
"### Integration details\n",
"\n",
"| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/v0.2/docs/integrations/chat/cohere) | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
"| [ChatCohere](https://api.python.langchain.com/en/latest/chat_models/langchain_cohere.chat_models.ChatCohere.html) | [langchain-cohere](https://api.python.langchain.com/en/latest/cohere_api_reference.html) | ❌ | beta | ✅ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain-cohere?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain-cohere?style=flat-square&label=%20) |\n",
"\n",
"### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | \n",
"\n",
"Head to the [API reference](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.cohere.ChatCohere.html) for detailed documentation of all attributes and methods."
]
},
{
"cell_type": "markdown",
"id": "3607d67e-e56c-4102-bbba-df2edc0e109e",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"The integration lives in the `langchain-cohere` package. We can install these with:\n",
"To access Cohere models you'll need to create a Cohere account, get an API key, and install the `langchain-cohere` integration package.\n",
"\n",
"```bash\n",
"pip install -U langchain-cohere\n",
"```\n",
"### Credentials\n",
"\n",
"We'll also need to get a [Cohere API key](https://cohere.com/) and set the `COHERE_API_KEY` environment variable:"
"Head to https://dashboard.cohere.com/welcome/login to sign up to Cohere and generate an API key. Once you've done this set the COHERE_API_KEY environment variable:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "2108b517-1e8d-473d-92fa-4f930e8072a7",
"execution_count": null,
"id": "433e8d2b-9519-4b49-b2c4-7ab65b046c94",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"COHERE_API_KEY\"] = getpass.getpass()"
"os.environ[\"COHERE_API_KEY\"] = getpass.getpass(\"Enter your Cohere API key: \")"
]
},
{
"cell_type": "markdown",
"id": "cf690fbb",
"id": "72ee0c4b-9764-423a-9dbf-95129e185210",
"metadata": {},
"source": [
"It's also helpful (but not needed) to set up [LangSmith](https://smith.langchain.com/) for best-in-class observability"
"If you want to get automated tracing of your model calls you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7f11de02",
"execution_count": null,
"id": "a15d341e-3e26-4ca3-830b-5aab30ed66de",
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()"
"# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n",
"# os.environ[\"LANGSMITH_TRACING\"] = \"true\""
]
},
{
"cell_type": "markdown",
"id": "4c26754b-b3c9-4d93-8f36-43049bd943bf",
"id": "0730d6a1-c893-4840-9817-5e5251676d5d",
"metadata": {},
"source": [
"## Usage\n",
"### Installation\n",
"\n",
"ChatCohere supports all [ChatModel](/docs/how_to#chat-models) functionality:"
"The LangChain Cohere integration lives in the `langchain-cohere` package:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
"metadata": {
"tags": []
},
"execution_count": null,
"id": "652d6238-1f87-422a-b135-f5abbb8652fc",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain-cohere"
]
},
{
"cell_type": "markdown",
"id": "a38cde65-254d-4219-a441-068766c0d4b5",
"metadata": {},
"source": [
"## Instantiation\n",
"\n",
"Now we can instantiate our model object and generate chat completions:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "cb09c344-1836-4e0c-acf8-11d13ac1dbae",
"metadata": {},
"outputs": [],
"source": [
"from langchain_cohere import ChatCohere\n",
"from langchain_core.messages import HumanMessage"
"\n",
"llm = ChatCohere(\n",
" model=\"command-r-plus\",\n",
" temperature=0,\n",
" max_tokens=None,\n",
" timeout=None,\n",
" max_retries=2,\n",
" # other params...\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
"metadata": {
"tags": []
},
"outputs": [],
"cell_type": "markdown",
"id": "2b4f3e15",
"metadata": {},
"source": [
"chat = ChatCohere()"
"## Invocation"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
"execution_count": 2,
"id": "62e0dbc3",
"metadata": {
"tags": []
},
@@ -116,223 +142,110 @@
{
"data": {
"text/plain": [
"AIMessage(content='4 && 5 \\n6 || 7 \\n\\nWould you like to play a game of odds and evens?', additional_kwargs={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'is_search_required': None, 'generation_id': '2076b614-52b3-4082-a259-cc92cd3d9fea', 'token_count': {'prompt_tokens': 68, 'response_tokens': 23, 'total_tokens': 91, 'billed_tokens': 77}}, response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'is_search_required': None, 'generation_id': '2076b614-52b3-4082-a259-cc92cd3d9fea', 'token_count': {'prompt_tokens': 68, 'response_tokens': 23, 'total_tokens': 91, 'billed_tokens': 77}}, id='run-3475e0c8-c89b-4937-9300-e07d652455e1-0')"
"AIMessage(content=\"J'adore programmer.\", additional_kwargs={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'is_search_required': None, 'generation_id': 'd84f80f3-4611-46e6-aed0-9d8665a20a11', 'token_count': {'input_tokens': 89, 'output_tokens': 5}}, response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'is_search_required': None, 'generation_id': 'd84f80f3-4611-46e6-aed0-9d8665a20a11', 'token_count': {'input_tokens': 89, 'output_tokens': 5}}, id='run-514ab516-ed7e-48ac-b132-2598fb80ebef-0')"
]
},
"execution_count": 15,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [HumanMessage(content=\"1\"), HumanMessage(content=\"2 3\")]\n",
"chat.invoke(messages)"
"messages = [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n",
" ),\n",
" (\"human\", \"I love programming.\"),\n",
"]\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='4 && 5', additional_kwargs={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'is_search_required': None, 'generation_id': 'f0708a92-f874-46ee-9b93-334d616ad92e', 'token_count': {'prompt_tokens': 68, 'response_tokens': 3, 'total_tokens': 71, 'billed_tokens': 57}}, response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'is_search_required': None, 'generation_id': 'f0708a92-f874-46ee-9b93-334d616ad92e', 'token_count': {'prompt_tokens': 68, 'response_tokens': 3, 'total_tokens': 71, 'billed_tokens': 57}}, id='run-1635e63e-2994-4e7f-986e-152ddfc95777-0')"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await chat.ainvoke(messages)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
"metadata": {
"tags": []
},
"execution_count": 3,
"id": "d86145b3-bfef-46e8-b227-4dda5c9c2705",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4 && 5"
"J'adore programmer.\n"
]
}
],
"source": [
"for chunk in chat.stream(messages):\n",
" print(chunk.content, end=\"\", flush=True)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "064288e4-f184-4496-9427-bcf148fa055e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[AIMessage(content='4 && 5', additional_kwargs={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'is_search_required': None, 'generation_id': '6770ca86-f6c3-4ba3-a285-c4772160612f', 'token_count': {'prompt_tokens': 68, 'response_tokens': 3, 'total_tokens': 71, 'billed_tokens': 57}}, response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'is_search_required': None, 'generation_id': '6770ca86-f6c3-4ba3-a285-c4772160612f', 'token_count': {'prompt_tokens': 68, 'response_tokens': 3, 'total_tokens': 71, 'billed_tokens': 57}}, id='run-8d6fade2-1b39-4e31-ab23-4be622dd0027-0')]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat.batch([messages])"
"print(ai_msg.content)"
]
},
{
"cell_type": "markdown",
"id": "f1c56460",
"id": "18e2bfc0-7e78-4528-a73f-499ac150dca8",
"metadata": {},
"source": [
"## Chaining\n",
"\n",
"You can also easily combine with a prompt template for easy structuring of user input. We can do this using [LCEL](/docs/concepts#langchain-expression-language-lcel)"
"We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "0851b103",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_template(\"Tell me a joke about {topic}\")\n",
"chain = prompt | chat"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "ae950c0f-1691-47f1-b609-273033cae707",
"execution_count": 4,
"id": "e197d1d7-a070-4c96-9f8a-a0e86d046e0b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='What color socks do bears wear?\\n\\nThey dont wear socks, they have bear feet. \\n\\nHope you laughed! If not, maybe this will help: laughter is the best medicine, and a good sense of humor is infectious!', additional_kwargs={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'is_search_required': None, 'generation_id': '6edccf44-9bc8-4139-b30e-13b368f3563c', 'token_count': {'prompt_tokens': 68, 'response_tokens': 51, 'total_tokens': 119, 'billed_tokens': 108}}, response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'is_search_required': None, 'generation_id': '6edccf44-9bc8-4139-b30e-13b368f3563c', 'token_count': {'prompt_tokens': 68, 'response_tokens': 51, 'total_tokens': 119, 'billed_tokens': 108}}, id='run-ef7f9789-0d4d-43bf-a4f7-f2a0e27a5320-0')"
"AIMessage(content='Ich liebe Programmierung.', additional_kwargs={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'is_search_required': None, 'generation_id': '053bebde-4e1d-4d06-8ee6-3446e7afa25e', 'token_count': {'input_tokens': 84, 'output_tokens': 6}}, response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'is_search_required': None, 'generation_id': '053bebde-4e1d-4d06-8ee6-3446e7afa25e', 'token_count': {'input_tokens': 84, 'output_tokens': 6}}, id='run-53700708-b7fb-417b-af36-1a6fcde38e7d-0')"
]
},
"execution_count": 20,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"topic\": \"bears\"})"
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"id": "12db8d69",
"id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3",
"metadata": {},
"source": [
"## Tool calling\n",
"## API reference\n",
"\n",
"Cohere supports tool calling functionalities!"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "337e24af",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.messages import (\n",
" HumanMessage,\n",
" ToolMessage,\n",
")\n",
"from langchain_core.tools import tool"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "74d292e7",
"metadata": {},
"outputs": [],
"source": [
"@tool\n",
"def magic_function(number: int) -> int:\n",
" \"\"\"Applies a magic operation to an integer\n",
" Args:\n",
" number: Number to have magic operation performed on\n",
" \"\"\"\n",
" return number + 10\n",
"\n",
"\n",
"def invoke_tools(tool_calls, messages):\n",
" for tool_call in tool_calls:\n",
" selected_tool = {\"magic_function\": magic_function}[tool_call[\"name\"].lower()]\n",
" tool_output = selected_tool.invoke(tool_call[\"args\"])\n",
" messages.append(ToolMessage(tool_output, tool_call_id=tool_call[\"id\"]))\n",
" return messages\n",
"\n",
"\n",
"tools = [magic_function]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "ecafcbc6",
"metadata": {},
"outputs": [],
"source": [
"llm_with_tools = chat.bind_tools(tools=tools)\n",
"messages = [HumanMessage(content=\"What is the value of magic_function(2)?\")]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "aa34fc39",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='The value of magic_function(2) is 12.', additional_kwargs={'documents': [{'id': 'magic_function:0:2:0', 'output': '12', 'tool_name': 'magic_function'}], 'citations': [ChatCitation(start=34, end=36, text='12', document_ids=['magic_function:0:2:0'])], 'search_results': None, 'search_queries': None, 'is_search_required': None, 'generation_id': '96a55791-0c58-4e2e-bc2a-8550e137c46d', 'token_count': {'input_tokens': 998, 'output_tokens': 59}}, response_metadata={'documents': [{'id': 'magic_function:0:2:0', 'output': '12', 'tool_name': 'magic_function'}], 'citations': [ChatCitation(start=34, end=36, text='12', document_ids=['magic_function:0:2:0'])], 'search_results': None, 'search_queries': None, 'is_search_required': None, 'generation_id': '96a55791-0c58-4e2e-bc2a-8550e137c46d', 'token_count': {'input_tokens': 998, 'output_tokens': 59}}, id='run-f318a9cf-55c8-44f4-91d1-27cf46c6a465-0')"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res = llm_with_tools.invoke(messages)\n",
"while res.tool_calls:\n",
" messages.append(res)\n",
" messages = invoke_tools(res.tool_calls, messages)\n",
" res = llm_with_tools.invoke(messages)\n",
"\n",
"res"
"For detailed documentation of all ChatCohere features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_cohere.chat_models.ChatCohere.html"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "poetry-venv-2",
"language": "python",
"name": "python3"
"name": "poetry-venv-2"
},
"language_info": {
"codemirror_mode": {
@@ -344,7 +257,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
"version": "3.11.9"
}
},
"nbformat": 4,

View File

@@ -4,67 +4,18 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"sidebar_label: Hugging Face\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ChatHuggingFace\n",
"# Hugging Face\n",
"\n",
"## Overview\n",
"\n",
"This notebook shows how to get started using Hugging Face LLMs as chat models.\n",
"This notebook shows how to get started using `Hugging Face` LLM's as chat models.\n",
"\n",
"In particular, we will:\n",
"1. Utilize the [HuggingFaceEndpoint](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/llms/huggingface_endpoint.py) integrations to instantiate an LLM.\n",
"1. Utilize the [HuggingFaceEndpoint](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/llms/huggingface_endpoint.py) integrations to instantiate an `LLM`.\n",
"2. Utilize the `ChatHuggingFace` class to enable any of these LLMs to interface with LangChain's [Chat Messages](/docs/concepts/#message-types) abstraction.\n",
"3. Explore tool calling with the `ChatHuggingFace`.\n",
"4. Demonstrate how to use an open-source LLM to power an `ChatAgent` pipeline\n",
"\n",
"### Integration details\n",
"\n",
"| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
"| [ChatHuggingFace](https://api.python.langchain.com/en/latest/chat_models/langchain_huggingface.chat_models.huggingface.ChatHuggingFace.html) | [langchain-huggingface](https://api.python.langchain.com/en/latest/huggingface_api_reference.html) | ✅ | beta | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain_huggingface?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain_huggingface?style=flat-square&label=%20) |\n",
"\n",
"### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | \n",
"\n",
"## Setup\n",
"\n",
"To access Hugging Face models you'll need to create a Hugging Face account, get an API key, and install the `langchain-huggingface` integration package.\n",
"\n",
"### Credentials\n",
"\n",
"Generate a [Hugging Face Access Token](https://huggingface.co/docs/hub/security-tokens) and store it as an environment variable: `HUGGINGFACEHUB_API_TOKEN`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"if not os.getenv(\"HUGGINGFACEHUB_API_TOKEN\"):\n",
" os.environ[\"HUGGINGFACEHUB_API_TOKEN\"] = getpass.getpass(\"Enter your token: \")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Installation\n",
"\n",
"Below we install additional packages as well for demonstration purposes:"
"> Note: To get started, you'll need to have a [Hugging Face Access Token](https://huggingface.co/docs/hub/security-tokens) saved as an environment variable: `HUGGINGFACEHUB_API_TOKEN`."
]
},
{
@@ -80,7 +31,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Instantiation"
"## 1. Instantiate an LLM"
]
},
{
@@ -167,7 +118,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Invocation"
"## 2. Instantiate the `ChatHuggingFace` to apply chat templates"
]
},
{
@@ -298,44 +249,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chaining\n",
"\n",
"We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool calling with `ChatHuggingFace`\n",
"## 3. Explore the tool calling with `ChatHuggingFace`\n",
"\n",
"`text-generation-inference` supports tool with open source LLMs starting from v2.0.1"
]
@@ -399,7 +313,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use with agents\n",
"## 4. Take it for a spin as an agent!\n",
"\n",
"Here we'll test out `Zephyr-7B-beta` as a zero-shot `ReAct` Agent. \n",
"\n",
@@ -544,15 +458,6 @@
"\n",
"It's exciting to see how far open-source LLM's can go as general purpose reasoning agents. Give it a try yourself!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all ChatHuggingFace features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_huggingface.chat_models.huggingface.ChatHuggingFace.html"
]
}
],
"metadata": {
@@ -571,7 +476,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -12,87 +12,43 @@
},
{
"cell_type": "markdown",
"id": "a14c83bf-af26-4f22-8c1a-d632c5795ecf",
"id": "bf733a38-db84-4363-89e2-de6735c37230",
"metadata": {},
"source": [
"# MistralAI\n",
"\n",
"This will help you getting started with Mistral [chat models](/docs/concepts/#chat-models), accessed via their [API](https://docs.mistral.ai/api/). For detailed documentation of all ChatMistralAI features and configurations head to the [API reference](https://api.python.langchain.com/en/latest/chat_models/langchain_mistralai.chat_models.ChatMistralAI.html).\n",
"This notebook covers how to get started with MistralAI chat models, via their [API](https://docs.mistral.ai/api/).\n",
"\n",
"## Overview\n",
"### Integration details\n",
"\n",
"| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/v0.2/docs/integrations/chat/mistral) | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
"| [ChatMistralAI](https://api.python.langchain.com/en/latest/chat_models/langchain_mistralai.chat_models.ChatMistralAI.html) | [langchain_mistralai](https://api.python.langchain.com/en/latest/mistralai_api_reference.html) | ❌ | beta | ✅ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain_mistralai?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain_mistralai?style=flat-square&label=%20) |\n",
"\n",
"### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n",
"A valid [API key](https://console.mistral.ai/users/api-keys/) is needed to communicate with the API.\n",
"\n",
"Head to the [API reference](https://api.python.langchain.com/en/latest/chat_models/langchain_mistralai.chat_models.ChatMistralAI.html) for detailed documentation of all attributes and methods."
]
},
{
"cell_type": "markdown",
"id": "cc686b8f",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"To access Mistral models you'll need to create a Mistral account, get an API key, and install the `langchain-mistralai` integration package.\n",
"You will need the `langchain-core` and `langchain-mistralai` package to use the API. You can install these with:\n",
"\n",
"### Credentials\n",
"```bash\n",
"pip install -U langchain-core langchain-mistralai\n",
"\n",
"A valid [API key](https://console.mistral.ai/users/api-keys/) is needed to communicate with the API. Once you've obtained an API key, store it in the `MISTRAL_API_KEY` environment variable:"
"We'll also need to get a [Mistral API key](https://console.mistral.ai/users/api-keys/)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9acd8340-09d4-4ece-871a-a35b0732c7d8",
"execution_count": 7,
"id": "c3fd4184",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"if not os.getenv(\"__MODULE_NAME___API_KEY\"):\n",
" os.environ[\"__MODULE_NAME___API_KEY\"] = getpass.getpass(\n",
" \"Enter your __ModuleName__ API key: \"\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "42c979b1-df49-4f6c-9fe6-d9dbf3ea8c2a",
"metadata": {},
"source": [
"If you want to get automated tracing of your model calls you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cc4f11ec-5cb3-4caf-b3cd-7a20c41b0cfe",
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")"
]
},
{
"cell_type": "markdown",
"id": "0fc42221-97b2-466b-95db-10368e17ca56",
"metadata": {},
"source": [
"### Installation\n",
"\n",
"The LangChain MistralAI integration lives in the `langchain-mistralai` package:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "85cb1ab8-9f2c-4b93-8415-ad65819dcb38",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain-mistralai"
"api_key = getpass.getpass()"
]
},
{
@@ -100,76 +56,57 @@
"id": "502127fd",
"metadata": {},
"source": [
"## Instantiation\n",
"\n",
"Now we can instantiate our model object and generate chat completions:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2dfa801a-d040-4c09-9634-58604e8eaf16",
"metadata": {},
"outputs": [],
"source": [
"from langchain_mistralai.chat_models import ChatMistralAI\n",
"\n",
"llm = ChatMistralAI(model=\"mistral-large-latest\")"
]
},
{
"cell_type": "markdown",
"id": "f668acff-eb14-4b3a-959a-df5bfc02968b",
"metadata": {},
"source": [
"## Invocation"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "86e3f9e6-67ec-4fbf-8ff1-85331200f412",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"J'adore la programmation.\", response_metadata={'token_usage': {'prompt_tokens': 27, 'total_tokens': 36, 'completion_tokens': 9}, 'model': 'mistral-large-latest', 'finish_reason': 'stop'}, id='run-d6196c33-9410-413b-b454-4ed0bec1f0c7-0', usage_metadata={'input_tokens': 27, 'output_tokens': 9, 'total_tokens': 36})"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n",
" ),\n",
" (\"human\", \"I love programming.\"),\n",
"]\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg"
"## Usage"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8f8a24bc-b7f0-4d3a-b310-8a4e0ba125dd",
"metadata": {},
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain_core.messages import HumanMessage\n",
"from langchain_mistralai.chat_models import ChatMistralAI"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.\n",
"chat = ChatMistralAI(api_key=api_key)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"J'adore la programmation.\n"
]
"data": {
"text/plain": [
"AIMessage(content=\"Who's there? I was just about to ask the same thing! How can I assist you today?\")"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(ai_msg.content)"
"messages = [HumanMessage(content=\"knock knock\")]\n",
"chat.invoke(messages)"
]
},
{
@@ -182,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 10,
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
"metadata": {
"tags": []
@@ -191,16 +128,16 @@
{
"data": {
"text/plain": [
"AIMessage(content=\"J'aime programmer.\", response_metadata={'token_usage': {'prompt_tokens': 27, 'total_tokens': 34, 'completion_tokens': 7}, 'model': 'mistral-large-latest', 'finish_reason': 'stop'}, id='run-1873888a-186f-49a8-ab81-24335bd3099b-0', usage_metadata={'input_tokens': 27, 'output_tokens': 7, 'total_tokens': 34})"
"AIMessage(content='Who\\'s there?\\n\\n(You can then continue the \"knock knock\" joke by saying the name of the person or character who should be responding. For example, if I say \"Banana,\" you could respond with \"Banana who?\" and I would say \"Banana bunch! Get it? Because a group of bananas is called a \\'bunch\\'!\" and then we would both laugh and have a great time. But really, you can put anything you want in the spot where I put \"Banana\" and it will still technically be a \"knock knock\" joke. The possibilities are endless!)')"
]
},
"execution_count": 4,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await llm.ainvoke(messages)"
"await chat.ainvoke(messages)"
]
},
{
@@ -213,7 +150,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 11,
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
"metadata": {
"tags": []
@@ -223,12 +160,32 @@
"name": "stdout",
"output_type": "stream",
"text": [
"J'adore programmer."
"Who's there?\n",
"\n",
"(After this, the conversation can continue as a call and response \"who's there\" joke. Here is an example of how it could go:\n",
"\n",
"You say: Orange.\n",
"I say: Orange who?\n",
"You say: Orange you glad I didn't say banana!?)\n",
"\n",
"But since you asked for a knock knock joke specifically, here's one for you:\n",
"\n",
"Knock knock.\n",
"\n",
"Me: Who's there?\n",
"\n",
"You: Lettuce.\n",
"\n",
"Me: Lettuce who?\n",
"\n",
"You: Lettuce in, it's too cold out here!\n",
"\n",
"I hope this brings a smile to your face! Do you have a favorite knock knock joke you'd like to share? I'd love to hear it."
]
}
],
"source": [
"for chunk in llm.stream(messages):\n",
"for chunk in chat.stream(messages):\n",
" print(chunk.content, end=\"\")"
]
},
@@ -242,23 +199,23 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 12,
"id": "e63aebcb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[AIMessage(content=\"J'adore la programmation.\", response_metadata={'token_usage': {'prompt_tokens': 27, 'total_tokens': 36, 'completion_tokens': 9}, 'model': 'mistral-large-latest', 'finish_reason': 'stop'}, id='run-2aa2a189-c405-4cf5-bd31-e9025e4c8536-0', usage_metadata={'input_tokens': 27, 'output_tokens': 9, 'total_tokens': 36})]"
"[AIMessage(content=\"Who's there? I was just about to ask the same thing! Go ahead and tell me who's there. I love a good knock-knock joke.\")]"
]
},
"execution_count": 6,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm.batch([messages])"
"chat.batch([messages])"
]
},
{
@@ -273,52 +230,36 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 13,
"id": "ee43a1ae",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_template(\"Tell me a joke about {topic}\")\n",
"chain = prompt | chat"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "0dc49212",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Ich liebe Programmieren.', response_metadata={'token_usage': {'prompt_tokens': 21, 'total_tokens': 28, 'completion_tokens': 7}, 'model': 'mistral-large-latest', 'finish_reason': 'stop'}, id='run-409ebc9a-b4a0-4734-ab6f-e11f6b4f808f-0', usage_metadata={'input_tokens': 21, 'output_tokens': 7, 'total_tokens': 28})"
"AIMessage(content='Why do bears hate shoes so much? They like to run around in their bear feet.')"
]
},
"execution_count": 7,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"id": "eb7e01fb-a433-48b1-a4c2-e6009523a896",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all ChatMistralAI features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_mistralai.chat_models.ChatMistralAI.html"
"chain.invoke({\"topic\": \"bears\"})"
]
}
],
@@ -338,7 +279,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -2,24 +2,13 @@
"cells": [
{
"cell_type": "markdown",
"id": "1f666798-8635-4bc0-a515-04d318588d67",
"metadata": {},
"id": "cc6caafa",
"metadata": {
"id": "cc6caafa"
},
"source": [
"---\n",
"sidebar_label: NVIDIA AI Endpoints\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "fa8eb20e-4db8-45e3-9e79-c595f4f274da",
"metadata": {},
"source": [
"# ChatNVIDIA\n",
"# NVIDIA NIMs\n",
"\n",
"This will help you getting started with NVIDIA [chat models](/docs/concepts/#chat-models). For detailed documentation of all `ChatNVIDIA` features and configurations head to the [API reference](https://api.python.langchain.com/en/latest/chat_models/langchain_nvidia_ai_endpoints.chat_models.ChatNVIDIA.html).\n",
"\n",
"## Overview\n",
"The `langchain-nvidia-ai-endpoints` package contains LangChain integrations building applications with models on \n",
"NVIDIA NIM inference microservice. NIM supports models across domains like chat, embedding, and re-ranking models \n",
"from the community as well as NVIDIA. These models are optimized by NVIDIA to deliver the best performance on NVIDIA \n",
@@ -35,66 +24,7 @@
"\n",
"This example goes over how to use LangChain to interact with NVIDIA supported via the `ChatNVIDIA` class.\n",
"\n",
"For more information on accessing the chat models through this api, check out the [ChatNVIDIA](https://python.langchain.com/docs/integrations/chat/nvidia_ai_endpoints/) documentation.\n",
"\n",
"### Integration details\n",
"\n",
"| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
"| [ChatNVIDIA](https://api.python.langchain.com/en/latest/chat_models/langchain_nvidia_ai_endpoints.chat_models.ChatNVIDIA.html) | [langchain_nvidia_ai_endpoints](https://api.python.langchain.com/en/latest/nvidia_ai_endpoints_api_reference.html) | ✅ | beta | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain_nvidia_ai_endpoints?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain_nvidia_ai_endpoints?style=flat-square&label=%20) |\n",
"\n",
"### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | \n",
"\n",
"## Setup\n",
"\n",
"**To get started:**\n",
"\n",
"1. Create a free account with [NVIDIA](https://build.nvidia.com/), which hosts NVIDIA AI Foundation models.\n",
"\n",
"2. Click on your model of choice.\n",
"\n",
"3. Under `Input` select the `Python` tab, and click `Get API Key`. Then click `Generate Key`.\n",
"\n",
"4. Copy and save the generated key as `NVIDIA_API_KEY`. From there, you should have access to the endpoints.\n",
"\n",
"### Credentials\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "208b72da-1535-4249-bbd3-2500028e25e9",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"if not os.getenv(\"NVIDIA_API_KEY\"):\n",
" # Note: the API key should start with \"nvapi-\"\n",
" os.environ[\"NVIDIA_API_KEY\"] = getpass.getpass(\"Enter your NVIDIA API key: \")"
]
},
{
"cell_type": "markdown",
"id": "52dc8dcb-0a48-4a4e-9947-764116d2ffd4",
"metadata": {},
"source": [
"If you want to get automated tracing of your model calls you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2cd9cb12-6ca5-432a-9e42-8a57da073c7e",
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")"
"For more information on accessing the chat models through this api, check out the [ChatNVIDIA](https://python.langchain.com/docs/integrations/chat/nvidia_ai_endpoints/) documentation."
]
},
{
@@ -102,9 +32,7 @@
"id": "f2be90a9",
"metadata": {},
"source": [
"### Installation\n",
"\n",
"The LangChain NVIDIA AI Endpoints integration lives in the `langchain_nvidia_ai_endpoints` package:"
"## Installation"
]
},
{
@@ -117,14 +45,51 @@
"%pip install --upgrade --quiet langchain-nvidia-ai-endpoints"
]
},
{
"cell_type": "markdown",
"id": "ccff689e",
"metadata": {
"id": "ccff689e"
},
"source": [
"## Setup\n",
"\n",
"**To get started:**\n",
"\n",
"1. Create a free account with [NVIDIA](https://build.nvidia.com/), which hosts NVIDIA AI Foundation models.\n",
"\n",
"2. Click on your model of choice.\n",
"\n",
"3. Under `Input` select the `Python` tab, and click `Get API Key`. Then click `Generate Key`.\n",
"\n",
"4. Copy and save the generated key as `NVIDIA_API_KEY`. From there, you should have access to the endpoints."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "686c4d2f",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"# del os.environ['NVIDIA_API_KEY'] ## delete key and reset\n",
"if os.environ.get(\"NVIDIA_API_KEY\", \"\").startswith(\"nvapi-\"):\n",
" print(\"Valid NVIDIA_API_KEY already in environment. Delete to reset\")\n",
"else:\n",
" nvapi_key = getpass.getpass(\"NVAPI Key (starts with nvapi-): \")\n",
" assert nvapi_key.startswith(\"nvapi-\"), f\"{nvapi_key[:5]}... is not a valid key\"\n",
" os.environ[\"NVIDIA_API_KEY\"] = nvapi_key"
]
},
{
"cell_type": "markdown",
"id": "af0ce26b",
"metadata": {},
"source": [
"## Instantiation\n",
"\n",
"Now we can access models in the NVIDIA API Catalog:"
"## Working with NVIDIA API Catalog"
]
},
{
@@ -143,24 +108,7 @@
"## Core LC Chat Interface\n",
"from langchain_nvidia_ai_endpoints import ChatNVIDIA\n",
"\n",
"llm = ChatNVIDIA(model=\"mistralai/mixtral-8x7b-instruct-v0.1\")"
]
},
{
"cell_type": "markdown",
"id": "469c8c7f-de62-457f-a30f-674763a8b717",
"metadata": {},
"source": [
"## Invocation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9512c81b-1f3a-4eca-9470-f52cedff5c74",
"metadata": {},
"outputs": [],
"source": [
"llm = ChatNVIDIA(model=\"mistralai/mixtral-8x7b-instruct-v0.1\")\n",
"result = llm.invoke(\"Write a ballad about LangChain.\")\n",
"print(result.content)"
]
@@ -682,55 +630,6 @@
"source": [
"See [How to use chat models to call tools](https://python.langchain.com/v0.2/docs/how_to/tool_calling/) for additional examples."
]
},
{
"cell_type": "markdown",
"id": "a9a3c438-121d-46eb-8fb5-b8d5a13cd4a4",
"metadata": {},
"source": [
"## Chaining\n",
"\n",
"We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af585c6b-fe0a-4833-9860-a4209a71b3c6",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"id": "f2f25dd3-0b4a-465f-a53e-95521cdc253c",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all `ChatNVIDIA` features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_nvidia_ai_endpoints.chat_models.ChatNVIDIA.html"
]
}
],
"metadata": {
@@ -752,7 +651,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
"version": "3.10.13"
}
},
"nbformat": 4,

View File

@@ -12,83 +12,14 @@
},
{
"cell_type": "markdown",
"id": "8f82e243-f4ee-44e2-b417-099b6401ae3e",
"id": "eb7e5679-aa06-47e4-a1a3-b6b70e604017",
"metadata": {},
"source": [
"# vLLM Chat\n",
"\n",
"vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API. This server can be queried in the same format as OpenAI API.\n",
"\n",
"## Overview\n",
"This will help you getting started with vLLM [chat models](/docs/concepts/#chat-models), which leverage the `langchain-openai` package. For detailed documentation of all `ChatOpenAI` features and configurations head to the [API reference](https://api.python.langchain.com/en/latest/chat_models/langchain_openai.chat_models.base.ChatOpenAI.html).\n",
"\n",
"### Integration details\n",
"\n",
"| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
"| [ChatOpenAI](https://api.python.langchain.com/en/latest/chat_models/langchain_openai.chat_models.base.ChatOpenAI.html) | [langchain_openai](https://api.python.langchain.com/en/latest/langchain_openai.html) | ✅ | beta | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain_openai?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain_openai?style=flat-square&label=%20) |\n",
"\n",
"### Model features\n",
"Specific model features-- such as tool calling, support for multi-modal inputs, support for token-level streaming, etc.-- will depend on the hosted model.\n",
"\n",
"## Setup\n",
"\n",
"See the vLLM docs [here](https://docs.vllm.ai/en/latest/).\n",
"\n",
"To access vLLM models through LangChain, you'll need to install the `langchain-openai` integration package.\n",
"\n",
"### Credentials\n",
"\n",
"Authentication will depend on specifics of the inference server."
]
},
{
"cell_type": "markdown",
"id": "c3b1707a-cf2c-4367-94e3-436c43402503",
"metadata": {},
"source": [
"If you want to get automated tracing of your model calls you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e40bd5e-cbaa-41ef-aaf9-0858eb207184",
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")"
]
},
{
"cell_type": "markdown",
"id": "0739b647-609b-46d3-bdd3-e86fe4463288",
"metadata": {},
"source": [
"### Installation\n",
"\n",
"The LangChain vLLM integration can be accessed via the `langchain-openai` package:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7afcfbdc-56aa-4529-825a-8acbe7aa5241",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain-openai"
]
},
{
"cell_type": "markdown",
"id": "2cf576d6-7b67-4937-bf99-39071e85720c",
"metadata": {},
"source": [
"## Instantiation\n",
"\n",
"Now we can instantiate our model object and generate chat completions:"
"This notebook covers how to get started with vLLM chat models using langchain's `ChatOpenAI` **as it is**."
]
},
{
@@ -120,7 +51,7 @@
"source": [
"inference_server_url = \"http://localhost:8000/v1\"\n",
"\n",
"llm = ChatOpenAI(\n",
"chat = ChatOpenAI(\n",
" model=\"mosaicml/mpt-7b\",\n",
" openai_api_key=\"EMPTY\",\n",
" openai_api_base=inference_server_url,\n",
@@ -129,14 +60,6 @@
")"
]
},
{
"cell_type": "markdown",
"id": "34b18328-5e8b-4ff2-9b89-6fbb76b5c7f0",
"metadata": {},
"source": [
"## Invocation"
]
},
{
"cell_type": "code",
"execution_count": 15,
@@ -165,66 +88,82 @@
" content=\"Translate the following sentence from English to Italian: I love programming.\"\n",
" ),\n",
"]\n",
"llm.invoke(messages)"
"chat(messages)"
]
},
{
"cell_type": "markdown",
"id": "a580a1e4-11a3-4277-bfba-bfb414ac7201",
"id": "55fc7046-a6dc-4720-8c0c-24a6db76a4f4",
"metadata": {},
"source": [
"## Chaining\n",
"You can make use of templating by using a `MessagePromptTemplate`. You can build a `ChatPromptTemplate` from one or more `MessagePromptTemplates`. You can use ChatPromptTemplate's format_prompt -- this returns a `PromptValue`, which you can convert to a string or `Message` object, depending on whether you want to use the formatted value as input to an llm or chat model.\n",
"\n",
"We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:"
"For convenience, there is a `from_template` method exposed on the template. If you were to use this template, this is what it would look like:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "123980e9-0dee-4ce5-bde6-d964dd90129c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"template = (\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
")\n",
"system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n",
"human_template = \"{text}\"\n",
"human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "b2fb8c59-8892-4270-85a2-4f8ab276b75d",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=' I love programming too.', additional_kwargs={}, example=False)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat_prompt = ChatPromptTemplate.from_messages(\n",
" [system_message_prompt, human_message_prompt]\n",
")\n",
"\n",
"# get a chat completion from the formatted messages\n",
"chat(\n",
" chat_prompt.format_prompt(\n",
" input_language=\"English\", output_language=\"Italian\", text=\"I love programming.\"\n",
" ).to_messages()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd0f4043-48bd-4245-8bdb-e7669666a277",
"id": "0bbd9861-2b94-4920-8708-b690004f4c4d",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"id": "265f5d51-0a76-4808-8d13-ef598ee6e366",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all features and configurations exposed via `langchain-openai`, head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_openai.chat_models.base.ChatOpenAI.html\n",
"\n",
"Refer to the vLLM [documentation](https://docs.vllm.ai/en/latest/) as well."
]
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "conda_pytorch_p310",
"language": "python",
"name": "python3"
"name": "conda_pytorch_p310"
},
"language_info": {
"codemirror_mode": {
@@ -236,7 +175,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -108,7 +108,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = Cohere(max_tokens=256, temperature=0.75)"
"model = Cohere(model=\"command\", max_tokens=256, temperature=0.75)"
]
},
{

View File

@@ -46,55 +46,6 @@ print(llm.invoke("Come up with a pet name"))
```
Usage of the Cohere (legacy) [LLM model](/docs/integrations/llms/cohere)
### Tool calling
```python
from langchain_cohere import ChatCohere
from langchain_core.messages import (
HumanMessage,
ToolMessage,
)
from langchain_core.tools import tool
@tool
def magic_function(number: int) -> int:
"""Applies a magic operation to an integer
Args:
number: Number to have magic operation performed on
"""
return number + 10
def invoke_tools(tool_calls, messages):
for tool_call in tool_calls:
selected_tool = {"magic_function":magic_function}[
tool_call["name"].lower()
]
tool_output = selected_tool.invoke(tool_call["args"])
messages.append(ToolMessage(tool_output, tool_call_id=tool_call["id"]))
return messages
tools = [magic_function]
llm = ChatCohere()
llm_with_tools = llm.bind_tools(tools=tools)
messages = [
HumanMessage(
content="What is the value of magic_function(2)?"
)
]
res = llm_with_tools.invoke(messages)
while res.tool_calls:
messages.append(res)
messages = invoke_tools(res.tool_calls, messages)
res = llm_with_tools.invoke(messages)
print(res.content)
```
Tool calling with Cohere LLM can be done by binding the necessary tools to the llm as seen above.
An alternative, is to support multi hop tool calling with the ReAct agent as seen below.
### ReAct Agent
The agent is based on the paper
@@ -126,7 +77,6 @@ agent_executor.invoke({
"input": "In what year was the company that was founded as Sound of Music added to the S&P 500?",
})
```
The ReAct agent can be used to call multiple tools in sequence.
### RAG Retriever

View File

@@ -34,7 +34,8 @@
},
"outputs": [],
"source": [
"from langchain_cohere import ChatCohere, CohereRagRetriever\n",
"from langchain_cohere import ChatCohere\n",
"from langchain_community.retrievers import CohereRagRetriever\n",
"from langchain_core.documents import Document"
]
},
@@ -199,7 +200,7 @@
"source": [
"docs = rag.invoke(\n",
" \"Does langchain support cohere RAG?\",\n",
" documents=[\n",
" source_documents=[\n",
" Document(page_content=\"Langchain supports cohere RAG!\"),\n",
" Document(page_content=\"The sky is blue!\"),\n",
" ],\n",
@@ -207,14 +208,6 @@
"_pretty_print(docs)"
]
},
{
"cell_type": "markdown",
"id": "45a9470f",
"metadata": {},
"source": [
"Please note that connectors and documents cannot be used simultaneously. If you choose to provide documents in the `invoke` method, they will take precedence, and connectors will not be utilized for that particular request, as shown in the snippet above!"
]
},
{
"cell_type": "code",
"execution_count": null,

View File

@@ -1,27 +0,0 @@
---
sidebar_position: 0
sidebar_class_name: hidden
---
# Retrievers
A **retriever** is an interface that returns documents given an unstructured query.
It is more general than a vector store.
A retriever does not need to be able to store documents, only to return (or retrieve) them.
Retrievers can be created from vector stores, but are also broad enough to include [Wikipedia search](/docs/integrations/retrievers/wikipedia/) and [Amazon Kendra](/docs/integrations/retrievers/amazon_kendra_retriever/).
Retrievers accept a string query as input and return a list of Document's as output.
For specifics on how to use retrievers, see the [relevant how-to guides here](/docs/how_to/#retrievers).
This table lists common retrievers.
| Retriever | Namespace | Native async | Local |
|-----------|-----------|---------------|------|
| [AmazonKnowledgeBasesRetriever](https://api.python.langchain.com/en/latest/retrievers/langchain_aws.retrievers.bedrock.AmazonKnowledgeBasesRetriever.html) | langchain_aws.retrievers | ❌ | ❌ |
| [AzureAISearchRetriever](https://api.python.langchain.com/en/latest/retrievers/langchain_community.retrievers.azure_ai_search.AzureAISearchRetriever.html) | langchain_community.retrievers | ✅ | ❌ |
| [ElasticsearchRetriever](https://api.python.langchain.com/en/latest/retrievers/langchain_elasticsearch.retrievers.ElasticsearchRetriever.html) | langchain_elasticsearch | ❌ | ❌ |
| [MilvusCollectionHybridSearchRetriever](https://api.python.langchain.com/en/latest/retrievers/langchain_milvus.retrievers.milvus_hybrid_search.MilvusCollectionHybridSearchRetriever.html) | langchain_milvus | ❌ | ❌ |
| [TavilySearchAPIRetriever](https://api.python.langchain.com/en/latest/retrievers/langchain_community.retrievers.tavily_search_api.TavilySearchAPIRetriever.html) | langchain_community.retrievers | ❌ | ❌ |
| [VertexAISearchRetriever](https://api.python.langchain.com/en/latest/vertex_ai_search/langchain_google_community.vertex_ai_search.VertexAISearchRetriever.html) | langchain_google_community.vertex_ai_search | ❌ | ❌ |

View File

@@ -4,70 +4,20 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"sidebar_label: TavilySearchAPI\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TavilySearchAPIRetriever\n",
"# Tavily Search API\n",
"\n",
"## Overview\n",
">[Tavily's Search API](https://tavily.com) is a search engine built specifically for AI agents (LLMs), delivering real-time, accurate, and factual results at speed.\n",
"\n",
"We can use this as a [retriever](/docs/how_to#retrievers). It will show functionality specific to this integration. After going through, it may be useful to explore [relevant use-case pages](/docs/how_to#qa-with-rag) to learn how to use this vectorstore as part of a larger chain.\n",
"\n",
"### Integration details\n",
"## Setup\n",
"\n",
"| Retriever | Namespace | Native async | Local |\n",
"| :--- | :--- | :---: | :---: |\n",
"[TavilySearchAPIRetriever](https://api.python.langchain.com/en/latest/retrievers/langchain_community.retrievers.tavily_search_api.TavilySearchAPIRetriever.html) | langchain_community.retrievers | ❌ | ❌ |\n",
"The integration lives in the `langchain-community` package. We also need to install the `tavily-python` package itself.\n",
"\n",
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you want to get automated tracing from runs of individual tools, you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n",
"# os.environ[\"LANGSMITH_TRACING\"] = \"true\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Installation\n",
"```bash\n",
"pip install -U langchain-community tavily-python\n",
"```\n",
"\n",
"The integration lives in the `langchain-community` package. We also need to install the `tavily-python` package itself."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain-community tavily-python"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We also need to set our Tavily API key."
]
},
@@ -87,20 +37,17 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Instantiation\n",
"\n",
"Now we can instantiate our retriever:"
"It's also helpful (but not needed) to set up [LangSmith](https://smith.langchain.com/) for best-in-class observability"
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.retrievers import TavilySearchAPIRetriever\n",
"\n",
"retriever = TavilySearchAPIRetriever(k=3)"
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()"
]
},
{
@@ -112,40 +59,42 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(metadata={'title': 'The Legend of Zelda: Breath of the Wild - Nintendo Switch Wiki', 'source': 'https://nintendo-switch.fandom.com/wiki/The_Legend_of_Zelda:_Breath_of_the_Wild', 'score': 0.9961155, 'images': []}, page_content='The Legend of Zelda: Breath of the Wild is an open world action-adventure game published by Nintendo for the Wii U and as a launch title for the Nintendo Switch, and was released worldwide on March 3, 2017. It is the nineteenth installment of the The Legend of Zelda series and the first to be developed with a HD resolution. The game features a gigantic open world, with the player being able to ...'),\n",
" Document(metadata={'title': 'The Legend of Zelda: Breath of the Wild - Zelda Wiki', 'source': 'https://zelda.fandom.com/wiki/The_Legend_of_Zelda:_Breath_of_the_Wild', 'score': 0.9804313, 'images': []}, page_content='[]\\nReferences\\nThe Legend of Zelda \\xa0·\\nThe Adventure of Link \\xa0·\\nA Link to the Past (& Four Swords) \\xa0·\\nLink\\'s Awakening (DX; Nintendo Switch) \\xa0·\\nOcarina of Time (Master Quest; 3D) \\xa0·\\nMajora\\'s Mask (3D) \\xa0·\\nOracle of Ages \\xa0·\\nOracle of Seasons \\xa0·\\nFour Swords (Anniversary Edition) \\xa0·\\nThe Wind Waker (HD) \\xa0·\\nFour Swords Adventures \\xa0·\\nThe Minish Cap \\xa0·\\nTwilight Princess (HD) \\xa0·\\nPhantom Hourglass \\xa0·\\nSpirit Tracks \\xa0·\\nSkyward Sword (HD) \\xa0·\\nA Link Between Worlds \\xa0·\\nTri Force Heroes \\xa0·\\nBreath of the Wild \\xa0·\\nTears of the Kingdom\\nZelda (Game & Watch) \\xa0·\\nThe Legend of Zelda Game Watch \\xa0·\\nLink\\'s Crossbow Training \\xa0·\\nMy Nintendo Picross: Twilight Princess \\xa0·\\nCadence of Hyrule \\xa0·\\nGame & Watch: The Legend of Zelda\\nCD-i Games\\n Listings[]\\nCharacters[]\\nBosses[]\\nEnemies[]\\nDungeons[]\\nLocations[]\\nItems[]\\nTranslations[]\\nCredits[]\\nReception[]\\nSales[]\\nEiji Aonuma and Hidemaro Fujibayashi accepting the \"Game of the Year\" award for Breath of the Wild at The Game Awards 2017\\nBreath of the Wild was estimated to have sold approximately 1.3 million copies in its first three weeks and around 89% of Switch owners were estimated to have also purchased the game.[52] Sales of the game have remained strong and as of June 30, 2022, the Switch version has sold 27.14 million copies worldwide while the Wii U version has sold 1.69 million copies worldwide as of December 31, 2019,[53][54] giving Breath of the Wild a cumulative total of 28.83 million copies sold.\\n It also earned a Metacritic score of 97 from more than 100 critics, placing it among the highest-rated games of all time.[59][60] Notably, the game received the most perfect review scores for any game listed on Metacritic up to that point.[61]\\nIn 2022, Breath of the Wild was chosen as the best Legend of Zelda game of all time in their \"Top 10 Best Zelda Games\" list countdown; but was then placed as the \"second\" best Zelda game in their new revamped version of their \"Top 10 Best Zelda Games\" list in 2023, right behind it\\'s successor Tears of Video Game Canon ranks Breath of the Wild as one of the best video games of all time.[74] Metacritic ranked Breath of the Wild as the single best game of the 2010s.[75]\\nFan Reception[]\\nWatchMojo placed Breath of the Wild at the #2 spot in their \"Top 10 Legend of Zelda Games of All Time\" list countdown, right behind Ocarina of Time.[76] The Faces of Evil \\xa0·\\nThe Wand of Gamelon \\xa0·\\nZelda\\'s Adventure\\nHyrule Warriors Series\\nHyrule Warriors (Legends; Definitive Edition) \\xa0·\\nHyrule Warriors: Age of Calamity\\nSatellaview Games\\nBS The Legend of Zelda \\xa0·\\nAncient Stone Tablets\\nTingle Series\\nFreshly-Picked Tingle\\'s Rosy Rupeeland \\xa0·\\nTingle\\'s Balloon Fight DS \\xa0·\\n'),\n",
" Document(metadata={'title': 'The Legend of Zelda: Breath of the Wild - Zelda Wiki', 'source': 'https://zeldawiki.wiki/wiki/The_Legend_of_Zelda:_Breath_of_the_Wild', 'score': 0.9627432, 'images': []}, page_content='The Legend of Zelda\\xa0•\\nThe Adventure of Link\\xa0•\\nA Link to the Past (& Four Swords)\\xa0•\\nLink\\'s Awakening (DX; Nintendo Switch)\\xa0•\\nOcarina of Time (Master Quest; 3D)\\xa0•\\nMajora\\'s Mask (3D)\\xa0•\\nOracle of Ages\\xa0•\\nOracle of Seasons\\xa0•\\nFour Swords (Anniversary Edition)\\xa0•\\nThe Wind Waker (HD)\\xa0•\\nFour Swords Adventures\\xa0•\\nThe Minish Cap\\xa0•\\nTwilight Princess (HD)\\xa0•\\nPhantom Hourglass\\xa0•\\nSpirit Tracks\\xa0•\\nSkyward Sword (HD)\\xa0•\\nA Link Between Worlds\\xa0•\\nTri Force Heroes\\xa0•\\nBreath of the Wild\\xa0•\\nTears of the Kingdom\\nZelda (Game & Watch)\\xa0•\\nThe Legend of Zelda Game Watch\\xa0•\\nHeroes of Hyrule\\xa0•\\nLink\\'s Crossbow Training\\xa0•\\nMy Nintendo Picross: Twilight Princess\\xa0•\\nCadence of Hyrule\\xa0•\\nVermin\\nThe Faces of Evil\\xa0•\\nThe Wand of Gamelon\\xa0•\\nZelda\\'s Adventure\\nHyrule Warriors (Legends; Definitive Edition)\\xa0•\\nHyrule Warriors: Age of Calamity\\nBS The Legend of Zelda\\xa0•\\nAncient Stone Tablets\\nFreshly-Picked Tingle\\'s Rosy Rupeeland\\xa0•\\nTingle\\'s Balloon Fight DS\\xa0•\\nToo Much Tingle Pack\\xa0•\\nRipened Tingle\\'s Balloon Trip of Love\\nSoulcalibur II\\xa0•\\nWarioWare Series\\xa0•\\nCaptain Rainbow\\xa0•\\nNintendo Land\\xa0•\\nScribblenauts Unlimited\\xa0•\\nMario Kart 8\\xa0•\\nSplatoon 3\\nSuper Smash Bros (Series)\\nSuper Smash Bros.\\xa0•\\nSuper Smash Bros. Melee\\xa0•\\nSuper Smash Bros. Brawl\\xa0•\\nSuper Smash Bros. for Nintendo 3DS / Wii U\\xa0•\\n It also earned a Metacritic score of 97 from more than 100 critics, placing it among the highest-rated games of all time.[60][61] Notably, the game received the most perfect review scores for any game listed on Metacritic up to that point.[62]\\nAwards\\nThroughout 2016, Breath of the Wild won several awards as a highly anticipated game, including IGN\\'s and Destructoid\\'s Best of E3,[63][64] at the Game Critic Awards 2016,[65] and at The Game Awards 2016.[66] Following its release, Breath of the Wild received the title of \"Game of the Year\" from the Japan Game Awards 2017,[67] the Golden Joystick Awards 2017,<ref\"Our final award is for the Ultimate Game of the Year. Official website(s)\\nOfficial website(s)\\nCanonicity\\nCanonicity\\nCanon[citation needed]\\nPredecessor\\nPredecessor\\nTri Force Heroes\\nSuccessor\\nSuccessor\\nTears of the Kingdom\\nThe Legend of Zelda: Breath of the Wild guide at StrategyWiki\\nBreath of the Wild Guide at Zelda Universe\\nThe Legend of Zelda: Breath of the Wild is the nineteenth main installment of The Legend of Zelda series. Listings\\nCharacters\\nBosses\\nEnemies\\nDungeons\\nLocations\\nItems\\nTranslations\\nCredits\\nReception\\nSales\\nBreath of the Wild was estimated to have sold approximately 1.3 million copies in its first three weeks and around 89% of Switch owners were estimated to have also purchased the game.[53] Sales of the game have remained strong and as of September 30, 2023, the Switch version has sold 31.15 million copies worldwide while the Wii U version has sold 1.7 million copies worldwide as of December 31, 2021,[54][55] giving Breath of the Wild a cumulative total of 32.85 million copies sold.\\n The Legend of Zelda: Breath of the Wild\\nThe Legend of Zelda: Breath of the Wild\\nThe Legend of Zelda: Breath of the Wild\\nDeveloper(s)\\nDeveloper(s)\\nPublisher(s)\\nPublisher(s)\\nNintendo\\nDesigner(s)\\nDesigner(s)\\n')]"
"[Document(page_content='Trending topics\\nTrending topics\\nThe Legend of Zelda: Breath of the Wild\\nSelect a product\\nThe Legend of Zelda™: Breath of the Wild\\nThe Legend of Zelda™: Breath of the Wild\\nThe Legend of Zelda™: Breath of the Wild and The Legend of Zelda™: Breath of the Wild Expansion Pass Bundle\\nThis item will be sent to your system automatically after purchase or Nintendo Switch Game Voucher redemption. The Legend of Zelda: Breath of the Wild Expansion Pass\\nMore like this\\nSuper Mario Odyssey™\\nThe Legend of Zelda™: Tears of the Kingdom\\nMario + Rabbids® Kingdom Battle\\nThe Legend of Zelda™: Links Awakening\\nHollow Knight\\nThe Legend of Zelda™: Skyward Sword HD\\nStarlink: Battle for Atlas™ Digital Edition\\nDRAGON QUEST BUILDERS™ 2\\nDragon Quest Builders™\\nWARNING: If you have epilepsy or have had seizures or other unusual reactions to flashing lights or patterns, consult a doctor before playing video games. Saddle up with a herd of horse-filled games!\\nESRB rating\\nSupported play modes\\nTV\\nTabletop\\nHandheld\\nProduct information\\nRelease date\\nNo. of players\\nGenre\\nPublisher\\nESRB rating\\nSupported play modes\\nGame file size\\nSupported languages\\nPlay online, access classic NES™ and Super NES™ games, and more with a Nintendo Switch Online membership.\\n Two Game Boy games are now available for Nintendo Switch Online members\\n02/01/23\\nNintendo Switch Online member exclusive: Save on two digital games\\n09/13/22\\nOut of the Shadows … the Legend of Zelda: About Nintendo\\nShop\\nMy Nintendo Store orders\\nSupport\\nParents\\nCommunity\\nPrivacy\\n© Nintendo.', metadata={'title': 'The Legend of Zelda™: Breath of the Wild - Nintendo', 'source': 'https://www.nintendo.com/us/store/products/the-legend-of-zelda-breath-of-the-wild-switch/', 'score': 0.97451, 'images': None}),\n",
" Document(page_content='The Legend of Zelda: Breath of the Wild is a masterpiece of open-world design and exploration, released on March 3, 2017 for Nintendo Switch. Find out the latest news, reviews, guides, videos, and more for this award-winning game on IGN.', metadata={'title': 'The Legend of Zelda: Breath of the Wild - IGN', 'source': 'https://www.ign.com/games/the-legend-of-zelda-breath-of-the-wild', 'score': 0.94496, 'images': None}),\n",
" Document(page_content='Reviewers also commented on the unexpected permutations of interactions between Link, villagers, pets, and enemies,[129][130][131] many of which were shared widely on social media.[132] A tribute to former Nintendo president Satoru Iwata, who died during development, also attracted praise.[129][134]\\nJim Sterling was more critical than most, giving Breath of the Wild a 7/10 score, criticizing the difficulty, weapon durability, and level design, but praising the open world and variety of content.[135] Other criticism focused on the unstable frame rate and the low resolution of 900p;[136] updates addressed some of these problems.[137][138]\\nSales\\nBreath of the Wild broke sales records for a Nintendo launch game in multiple regions.[139][140] In Japan, the Switch and Wii U versions sold a combined 230,000 copies in the first week of release, with the Switch version becoming the top-selling game released that week.[141] Nintendo reported that Breath of the Wild sold more than one million copies in the US that month—925,000 of which were for Switch, outselling the Switch itself.[145][146][147][148] Nintendo president Tatsumi Kimishima said that the attach rate on the Switch was \"unprecedented\".[149] Breath of the Wild had sold 31.15 million copies on the Switch by September 2023 and 1.70 million copies on the Wii U by December 2020.[150][151]\\nAwards\\nFollowing its demonstration at E3 2016, Breath of the Wild received several accolades from the Game Critics Awards[152] and from publications such as IGN and Destructoid.[153][154] It was listed among the best games at E3 by Eurogamer,[81] The game, he continued, would challenge the series\\' conventions, such as the requirement that players complete dungeons in a set order.[2][73] The next year, Nintendo introduced the game\\'s high-definition, cel-shaded visual style with in-game footage at its E3 press event.[74][75] Once planned for release in 2015, the game was delayed early in the year and did not show at that year\\'s E3.[76][77] Zelda series creator Shigeru Miyamoto reaffirmed that the game would still release for the Wii U despite the development of Nintendo\\'s next console, the Nintendo Switch.[78] The Switch version also has higher-quality environmental sounds.[53][54] Certain ideas that were planned for the game, like flying and underground dungeons were not implemented due to the Wii Us limitations; they would eventually resurface in the game\\'s sequel.[55] Aonuma stated that the art design was inspired by gouache and en plein air art to help identify the vast world.[56] Takizawa has also cited the Jōmon period as an inspiration for the ancient Sheikah technology and architecture that is found in the game, due to the mystery surrounding the period.[57] Journalists commented on unexpected interactions between game elements,[129][130][131] with serendipitous moments proving popular on social media.[132] Chris Plante of The Verge predicted that whereas prior open-world games tended to feature prescribed challenges, Zelda would influence a new generation of games with open-ended problem-solving.[132] Digital Trends wrote that the game\\'s level of experimentation allowed players to interact with and exploit the environment in creative ways, resulting in various \"tricks\" still discovered years after release.[127]\\nReviewers lauded the sense of detail and immersion.[133][129] Kotaku recommended turning off UI elements in praise of the indirect cues that contextually indicate the same information, such as Link shivering in the cold or waypoints appearing when using the scope.[133]', metadata={'title': 'The Legend of Zelda: Breath of the Wild - Wikipedia', 'source': 'https://en.wikipedia.org/wiki/The_Legend_of_Zelda:_Breath_of_the_Wild', 'score': 0.93348, 'images': None})]"
]
},
"execution_count": 2,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"query = \"what year was breath of the wild released?\"\n",
"from langchain_community.retrievers import TavilySearchAPIRetriever\n",
"\n",
"retriever.invoke(query)"
"retriever = TavilySearchAPIRetriever(k=3)\n",
"\n",
"retriever.invoke(\"what year was breath of the wild released?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use within a chain\n",
"## Chaining\n",
"\n",
"We can easily combine this retriever in to a chain."
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -161,50 +110,40 @@
"\n",
"Question: {question}\"\"\"\n",
")\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\")\n",
"\n",
"\n",
"def format_docs(docs):\n",
" return \"\\n\\n\".join(doc.page_content for doc in docs)\n",
"\n",
"\n",
"chain = (\n",
" {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n",
" RunnablePassthrough.assign(context=(lambda x: x[\"question\"]) | retriever)\n",
" | prompt\n",
" | llm\n",
" | ChatOpenAI(model=\"gpt-4-1106-preview\")\n",
" | StrOutputParser()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'As of August 2020, The Legend of Zelda: Breath of the Wild had sold over 20.1 million copies worldwide on Nintendo Switch and Wii U.'"
"'As of the end of 2020, \"The Legend of Zelda: Breath of the Wild\" sold over 21.45 million copies worldwide.'"
]
},
"execution_count": 4,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke(\"how many units did bretch of the wild sell in 2020\")"
"chain.invoke({\"question\": \"how many units did bretch of the wild sell in 2020\"})"
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all `TavilySearchAPIRetriever` features and configurations head to the [API reference](https://api.python.langchain.com/en/latest/retrievers/langchain_community.retrievers.tavily_search_api.TavilySearchAPIRetriever.html)."
]
"outputs": [],
"source": []
}
],
"metadata": {
@@ -223,7 +162,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -68,13 +68,11 @@ module.exports = {
},
{
type: "category",
label: "Versions",
collapsed: false,
collapsible: false,
label: "Versioning",
collapsed: true,
items: [
"versions/overview",
"versions/release_policy",
"versions/packages",
{
type: 'doc',
id: "how_to/pydantic_compatibility",
@@ -82,26 +80,20 @@ module.exports = {
},
{
type: "category",
label: "v0.2",
label: "Upgrading to v0.2",
link: {type: 'doc', id: 'versions/v0_2/index'},
collapsible: false,
collapsed: false,
items: [{
type: 'autogenerated',
dirName: 'versions/v0_2',
className: 'hidden',
}],
},
{
type: "category",
label: "Migrating to LCEL",
link: {type: 'doc', id: 'versions/migrating_chains/index'},
collapsible: false,
collapsed: false,
items: [{
type: 'autogenerated',
dirName: 'versions/migrating_chains',
className: 'hidden',
}],
},
],

View File

@@ -8,7 +8,7 @@ import CodeBlock from "@theme-original/CodeBlock";
* @typedef {Object} ChatModelTabsProps - Component props.
* @property {string} [openaiParams] - Parameters for OpenAI chat model. Defaults to `model="gpt-3.5-turbo-0125"`
* @property {string} [anthropicParams] - Parameters for Anthropic chat model. Defaults to `model="claude-3-sonnet-20240229"`
* @property {string} [cohereParams] - Parameters for Cohere chat model. Defaults to `model="command-r-plus"`
* @property {string} [cohereParams] - Parameters for Cohere chat model. Defaults to `model="command-r"`
* @property {string} [fireworksParams] - Parameters for Fireworks chat model. Defaults to `model="accounts/fireworks/models/mixtral-8x7b-instruct"`
* @property {string} [groqParams] - Parameters for Groq chat model. Defaults to `model="llama3-8b-8192"`
* @property {string} [mistralParams] - Parameters for Mistral chat model. Defaults to `model="mistral-large-latest"`

View File

@@ -1,217 +0,0 @@
{
"cells": [
{
"cell_type": "raw",
"id": "afaf8039",
"metadata": {},
"source": [
"---\n",
"sidebar_label: __ModuleName__\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "e49f1e0d",
"metadata": {},
"source": [
"# __ModuleName__Retriever\n",
"\n",
"## Overview\n",
"- TODO: Make sure API reference link is correct.\n",
"\n",
"This will help you getting started with the __ModuleName__ [retriever](/docs/concepts/#retrievers). For detailed documentation of all __ModuleName__Retriever features and configurations head to the [API reference](https://api.python.langchain.com/en/latest/retrievers/__module_name__.retrievers.__ModuleName__.__ModuleName__Retriever.html).\n",
"\n",
"### Integration details\n",
"\n",
"| Retriever | Namespace | Native async | Local |\n",
"| :--- | :--- | :---: | :---: |\n",
"[__ModuleName__Retriever](https://api.python.langchain.com/en/latest/retrievers/__package_name__.retrievers.__module_name__.__ModuleName__Retriever.html) | __package_name__.retrievers | ❌ | ❌ |\n",
"\n",
"\n",
"## Setup\n",
"\n",
"- TODO: Update with relevant info."
]
},
{
"cell_type": "markdown",
"id": "72ee0c4b-9764-423a-9dbf-95129e185210",
"metadata": {},
"source": [
"If you want to get automated tracing from runs of individual tools, you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a15d341e-3e26-4ca3-830b-5aab30ed66de",
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n",
"# os.environ[\"LANGSMITH_TRACING\"] = \"true\""
]
},
{
"cell_type": "markdown",
"id": "0730d6a1-c893-4840-9817-5e5251676d5d",
"metadata": {},
"source": [
"### Installation\n",
"\n",
"This retriever lives in the `__package_name__` package:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "652d6238-1f87-422a-b135-f5abbb8652fc",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU __package_name__"
]
},
{
"cell_type": "markdown",
"id": "a38cde65-254d-4219-a441-068766c0d4b5",
"metadata": {},
"source": [
"## Instantiation\n",
"\n",
"Now we can instantiate our retriever:\n",
"\n",
"- TODO: Update model instantiation with relevant params."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "70cc8e65-2a02-408a-bbc6-8ef649057d82",
"metadata": {},
"outputs": [],
"source": [
"from __module_name__ import __ModuleName__Retriever\n",
"\n",
"retriever = __ModuleName__Retriever(\n",
" # ...\n",
")"
]
},
{
"cell_type": "markdown",
"id": "5c5f2839-4020-424e-9fc9-07777eede442",
"metadata": {},
"source": [
"## Usage"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51a60dbe-9f2e-4e04-bb62-23968f17164a",
"metadata": {},
"outputs": [],
"source": [
"query = \"...\"\n",
"\n",
"retriever.invoke(query)"
]
},
{
"cell_type": "markdown",
"id": "dfe8aad4-8626-4330-98a9-7ea1ca5d2e0e",
"metadata": {},
"source": [
"## Use within a chain"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "23e11cc9-abd6-4855-a7eb-799f45ca01ae",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"prompt = ChatPromptTemplate.from_template(\n",
" \"\"\"Answer the question based only on the context provided.\n",
"\n",
"Context: {context}\n",
"\n",
"Question: {question}\"\"\"\n",
")\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\")\n",
"\n",
"\n",
"def format_docs(docs):\n",
" return \"\\n\\n\".join(doc.page_content for doc in docs)\n",
"\n",
"\n",
"chain = (\n",
" {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n",
" | prompt\n",
" | llm\n",
" | StrOutputParser()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d47c37dd-5c11-416c-a3b6-bec413cd70e8",
"metadata": {},
"outputs": [],
"source": [
"chain.invoke(\"...\")"
]
},
{
"cell_type": "markdown",
"id": "d1ee55bc-ffc8-4cfa-801c-993953a08cfd",
"metadata": {},
"source": [
"## TODO: Any functionality or considerations specific to this retriever\n",
"\n",
"Fill in or delete if not relevant."
]
},
{
"cell_type": "markdown",
"id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all __ModuleName__Retriever features and configurations head to the [API reference](https://api.python.langchain.com/en/latest/retrievers/__module_name__.retrievers.__ModuleName__.__ModuleName__Retriever.html)."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -1,89 +0,0 @@
"""__ModuleName__ retrievers."""
from typing import List
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class __ModuleName__Retriever(BaseRetriever):
# TODO: Replace all TODOs in docstring. See example docstring:
# https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/retrievers/tavily_search_api.py#L17
"""__ModuleName__ retriever.
# TODO: Replace with relevant packages, env vars, etc.
Setup:
Install ``__package_name__`` and set environment variable ``__MODULE_NAME___API_KEY``.
.. code-block:: bash
pip install -U __package_name__
export __MODULE_NAME___API_KEY="your-api-key"
# TODO: Populate with relevant params.
Key init args:
arg 1: type
description
arg 2: type
description
# TODO: Replace with relevant init params.
Instantiate:
.. code-block:: python
from __package_name__ import __ModuleName__Retriever
retriever = __ModuleName__Retriever(
# ...
)
Usage:
.. code-block:: python
query = "..."
retriever.invoke(query)
.. code-block:: python
# TODO: Example output.
Use within a chain:
.. code-block:: python
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
prompt = ChatPromptTemplate.from_template(
\"\"\"Answer the question based only on the context provided.
Context: {context}
Question: {question}\"\"\"
)
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
chain.invoke("...")
.. code-block:: python
# TODO: Example output.
""" # noqa: E501
# TODO: This method must be implemented to retrieve documents.
def _get_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError()

View File

@@ -1,4 +1,4 @@
"""__ModuleName__ toolkits."""
"""__ModuleName__ chat models."""
from typing import List
from langchain_core.tools import BaseTool, BaseToolKit
@@ -6,7 +6,7 @@ from langchain_core.tools import BaseTool, BaseToolKit
class __ModuleName__Toolkit(BaseToolKit):
# TODO: Replace all TODOs in docstring. See example docstring:
# https://github.com/langchain-ai/langchain/blob/c123cb2b304f52ab65db4714eeec46af69a861ec/libs/community/langchain_community/agent_toolkits/sql/toolkit.py#L19
# https://github.com/langchain-ai/langchain/blob/a6d1fb4275801a4850e62b6209cfbf096a24f93f/libs/community/langchain_community/agent_toolkits/sql/toolkit.py#L20
"""__ModuleName__ toolkit.
# TODO: Replace with relevant packages, env vars, etc.
@@ -66,6 +66,6 @@ class __ModuleName__Toolkit(BaseToolKit):
""" # noqa: E501
# TODO: This method must be implemented to list tools.
# TODO: This method must be implemented to generate chat responses.
def get_tools(self) -> List[BaseTool]:
raise NotImplementedError()

View File

@@ -58,7 +58,7 @@ class UpstashRatelimitHandler(BaseCallbackHandler):
every time you invoke.
"""
raise_error: bool = True
raise_error = True
_checked: bool = False
def __init__(

View File

@@ -8,7 +8,7 @@ import inspect
import json
import logging
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
import requests # type: ignore
from langchain.chains.base import Chain
@@ -37,7 +37,6 @@ from langchain_community.chains.pebblo_retrieval.utilities import (
CLASSIFIER_URL,
PEBBLO_CLOUD_URL,
PLUGIN_VERSION,
PROMPT_GOV_URL,
PROMPT_URL,
get_runtime,
)
@@ -76,12 +75,10 @@ class PebbloRetrievalQA(Chain):
"""Classifier endpoint."""
classifier_location: str = "local" #: :meta private:
"""Classifier location. It could be either of 'local' or 'pebblo-cloud'."""
_discover_sent: bool = False #: :meta private:
_discover_sent = False #: :meta private:
"""Flag to check if discover payload has been sent."""
_prompt_sent: bool = False #: :meta private:
"""Flag to check if prompt payload has been sent."""
enable_prompt_gov: bool = True #: :meta private:
"""Flag to check if prompt governance is enabled or not"""
def _call(
self,
@@ -105,8 +102,6 @@ class PebbloRetrievalQA(Chain):
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key, {})
semantic_context = inputs.get(self.semantic_context_key, {})
_, prompt_entities = self._check_prompt_validity(question)
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
@@ -138,12 +133,7 @@ class PebbloRetrievalQA(Chain):
for doc in docs
if isinstance(doc, Document)
],
"prompt": {
"data": question,
"entities": prompt_entities.get("entities", {}),
"entityCount": prompt_entities.get("entityCount", 0),
"prompt_gov_enabled": self.enable_prompt_gov,
},
"prompt": {"data": question},
"response": {
"data": answer,
},
@@ -154,7 +144,6 @@ class PebbloRetrievalQA(Chain):
else [],
"classifier_location": self.classifier_location,
}
qa_payload = Qa(**qa)
self._send_prompt(qa_payload)
@@ -186,9 +175,6 @@ class PebbloRetrievalQA(Chain):
accepts_run_manager = (
"run_manager" in inspect.signature(self._aget_docs).parameters
)
_, prompt_entities = self._check_prompt_validity(question)
if accepts_run_manager:
docs = await self._aget_docs(
question, auth_context, semantic_context, run_manager=_run_manager
@@ -527,66 +513,6 @@ class PebbloRetrievalQA(Chain):
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
def _check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
"""
Check the validity of the given prompt using a remote classification service.
This method sends a prompt to a remote classifier service and return entities
present in prompt or not.
Args:
question (str): The prompt question to be validated.
Returns:
bool: True if the prompt is valid (does not contain deny list entities),
False otherwise.
dict: The entities present in the prompt
"""
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
prompt_payload = {"prompt": question}
is_valid_prompt: bool = True
prompt_gov_api_url = f"{self.classifier_url}{PROMPT_GOV_URL}"
pebblo_resp = None
prompt_entities: dict = {"entities": {}, "entityCount": 0}
if self.classifier_location == "local":
try:
pebblo_resp = requests.post(
prompt_gov_api_url,
headers=headers,
json=prompt_payload,
timeout=20,
)
logger.debug("prompt-payload: %s", prompt_payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}")
prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
prompt_entities["entityCount"] = pebblo_resp.json().get(
"entityCount", 0
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
return is_valid_prompt, prompt_entities
@classmethod
def get_chain_details(cls, llm: BaseLanguageModel, **kwargs): # type: ignore
llm_dict = llm.__dict__

View File

@@ -133,10 +133,7 @@ class Context(BaseModel):
class Prompt(BaseModel):
data: Optional[Union[list, str]]
entityCount: Optional[int]
entities: Optional[dict]
prompt_gov_enabled: Optional[bool]
data: str
class Qa(BaseModel):

View File

@@ -15,7 +15,6 @@ CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000")
PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai")
PROMPT_URL = "/v1/prompt"
PROMPT_GOV_URL = "/v1/prompt/governance"
APP_DISCOVER_URL = "/v1/app/discover"

View File

@@ -111,7 +111,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
from langchain_community.chat_models import ChatCohere
from langchain_core.messages import HumanMessage
chat = ChatCohere(max_tokens=256, temperature=0.75)
chat = ChatCohere(model="command", max_tokens=256, temperature=0.75)
messages = [HumanMessage(content="knock knock")]
chat.invoke(messages)

View File

@@ -1,7 +1,13 @@
import base64
import hashlib
import hmac
import json
import logging
import time
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
from urllib.parse import urlparse
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
BaseChatModel,
@@ -28,15 +34,18 @@ from langchain_core.utils import (
logger = logging.getLogger(__name__)
DEFAULT_API_BASE = "https://hunyuan.cloud.tencent.com"
DEFAULT_PATH = "/hyllm/v1/chat/completions"
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"Role": message.role, "Content": message.content}
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"Role": "user", "Content": message.content}
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"Role": "assistant", "Content": message.content}
message_dict = {"role": "assistant", "content": message.content}
else:
raise TypeError(f"Got unknown type {message}")
@@ -44,20 +53,20 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["Role"]
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["Content"])
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict.get("Content", "") or "")
return AIMessage(content=_dict.get("content", "") or "")
else:
return ChatMessage(content=_dict["Content"], role=role)
return ChatMessage(content=_dict["content"], role=role)
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("Role")
content = _dict.get("Content") or ""
role = _dict.get("role")
content = _dict.get("content") or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
@@ -69,13 +78,43 @@ def _convert_delta_to_message_chunk(
return default_class(content=content) # type: ignore[call-arg]
# signature generation
# https://cloud.tencent.com/document/product/1729/97732#532252ce-e960-48a7-8821-940a9ce2ccf3
def _signature(secret_key: SecretStr, url: str, payload: Dict[str, Any]) -> str:
sorted_keys = sorted(payload.keys())
url_info = urlparse(url)
sign_str = url_info.netloc + url_info.path + "?"
for key in sorted_keys:
value = payload[key]
if isinstance(value, list) or isinstance(value, dict):
value = json.dumps(value, separators=(",", ":"), ensure_ascii=False)
elif isinstance(value, float):
value = "%g" % value
sign_str = sign_str + key + "=" + str(value) + "&"
sign_str = sign_str[:-1]
hmacstr = hmac.new(
key=secret_key.get_secret_value().encode("utf-8"),
msg=sign_str.encode("utf-8"),
digestmod=hashlib.sha1,
).digest()
return base64.b64encode(hmacstr).decode("utf-8")
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
generations = []
for choice in response["Choices"]:
message = _convert_dict_to_message(choice["Message"])
for choice in response["choices"]:
message = _convert_dict_to_message(choice["messages"])
generations.append(ChatGeneration(message=message))
token_usage = response["Usage"]
token_usage = response["usage"]
llm_output = {"token_usage": token_usage}
return ChatResult(generations=generations, llm_output=llm_output)
@@ -98,6 +137,8 @@ class ChatHunyuan(BaseChatModel):
def lc_serializable(self) -> bool:
return True
hunyuan_api_base: str = Field(default=DEFAULT_API_BASE)
"""Hunyuan custom endpoints"""
hunyuan_app_id: Optional[int] = None
"""Hunyuan App ID"""
hunyuan_secret_id: Optional[str] = None
@@ -108,26 +149,13 @@ class ChatHunyuan(BaseChatModel):
"""Whether to stream the results or not."""
request_timeout: int = 60
"""Timeout for requests to Hunyuan API. Default is 60 seconds."""
query_id: Optional[str] = None
"""Query id for troubleshooting"""
temperature: float = 1.0
"""What sampling temperature to use."""
top_p: float = 1.0
"""What probability mass to use."""
model: str = "hunyuan-lite"
"""What Model to use.
Optional model:
- hunyuan-lite、
- hunyuan-standard
- hunyuan-standard-256K
- hunyuan-pro
- hunyuan-code
- hunyuan-role
- hunyuan-functioncall
- hunyuan-vision
"""
stream_moderation: bool = False
"""Whether to review the results or not when streaming is true."""
enable_enhancement: bool = True
"""Whether to enhancement the results or not."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for API call not explicitly specified."""
@@ -165,6 +193,12 @@ class ChatHunyuan(BaseChatModel):
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
values["hunyuan_api_base"] = get_from_dict_or_env(
values,
"hunyuan_api_base",
"HUNYUAN_API_BASE",
DEFAULT_API_BASE,
)
values["hunyuan_app_id"] = get_from_dict_or_env(
values,
"hunyuan_app_id",
@@ -182,19 +216,22 @@ class ChatHunyuan(BaseChatModel):
"HUNYUAN_SECRET_KEY",
)
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Hunyuan API."""
normal_params = {
"Temperature": self.temperature,
"TopP": self.top_p,
"Model": self.model,
"Stream": self.streaming,
"StreamModeration": self.stream_moderation,
"EnableEnhancement": self.enable_enhancement,
"app_id": self.hunyuan_app_id,
"secret_id": self.hunyuan_secret_id,
"temperature": self.temperature,
"top_p": self.top_p,
}
if self.query_id is not None:
normal_params["query_id"] = self.query_id
return {**normal_params, **self.model_kwargs}
def _generate(
@@ -211,7 +248,13 @@ class ChatHunyuan(BaseChatModel):
return generate_from_stream(stream_iter)
res = self._chat(messages, **kwargs)
return _create_chat_result(json.loads(res.to_json_string()))
response = res.json()
if "error" in response:
raise ValueError(f"Error from Hunyuan api response: {response}")
return _create_chat_result(response)
def _stream(
self,
@@ -223,17 +266,19 @@ class ChatHunyuan(BaseChatModel):
res = self._chat(messages, **kwargs)
default_chunk_class = AIMessageChunk
for chunk in res:
chunk = chunk.get("data", "")
for chunk in res.iter_lines():
chunk = chunk.decode(encoding="UTF-8", errors="strict").replace(
"data: ", ""
)
if len(chunk) == 0:
continue
response = json.loads(chunk)
if "error" in response:
raise ValueError(f"Error from Hunyuan api response: {response}")
for choice in response["Choices"]:
for choice in response["choices"]:
chunk = _convert_delta_to_message_chunk(
choice["Delta"], default_chunk_class
choice["delta"], default_chunk_class
)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
@@ -241,32 +286,42 @@ class ChatHunyuan(BaseChatModel):
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> Any:
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
if self.hunyuan_secret_key is None:
raise ValueError("Hunyuan secret key is not set.")
try:
from tencentcloud.common import credential
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
except ImportError:
raise ImportError(
"Could not import tencentcloud python package. "
"Please install it with `pip install tencentcloud-sdk-python`."
)
parameters = {**self._default_params, **kwargs}
cred = credential.Credential(
self.hunyuan_secret_id, str(self.hunyuan_secret_key.get_secret_value())
)
client = hunyuan_client.HunyuanClient(cred, "")
req = models.ChatCompletionsRequest()
params = {
"Messages": [_convert_message_to_dict(m) for m in messages],
headers = parameters.pop("headers", {})
timestamp = parameters.pop("timestamp", int(time.time()))
expired = parameters.pop("expired", timestamp + 24 * 60 * 60)
payload = {
"timestamp": timestamp,
"expired": expired,
"messages": [_convert_message_to_dict(m) for m in messages],
**parameters,
}
req.from_json_string(json.dumps(params))
resp = client.ChatCompletions(req)
return resp
if self.streaming:
payload["stream"] = 1
url = self.hunyuan_api_base + DEFAULT_PATH
res = requests.post(
url=url,
timeout=self.request_timeout,
headers={
"Content-Type": "application/json",
"Authorization": _signature(
secret_key=self.hunyuan_secret_key, url=url, payload=payload
),
**headers,
},
json=payload,
stream=self.streaming,
)
return res
@property
def _llm_type(self) -> str:

View File

@@ -9,7 +9,7 @@ import os
import re
from importlib.metadata import version
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
from langchain_core.utils import pre_init
@@ -164,7 +164,7 @@ class _KineticaLlmFileContextParser:
"""Parser for Kinetica LLM context datafiles."""
# parse line into a dict containing role and content
PARSER: Pattern = re.compile(r"^<\|(?P<role>\w+)\|>\W*(?P<content>.*)$", re.DOTALL)
PARSER = re.compile(r"^<\|(?P<role>\w+)\|>\W*(?P<content>.*)$", re.DOTALL)
@classmethod
def _removesuffix(cls, text: str, suffix: str) -> str:

View File

@@ -135,7 +135,7 @@ class Provider(ABC):
class CohereProvider(Provider):
stop_sequence_key: str = "stop_sequences"
stop_sequence_key = "stop_sequences"
def __init__(self) -> None:
from oci.generative_ai_inference import models
@@ -364,7 +364,7 @@ class CohereProvider(Provider):
class MetaProvider(Provider):
stop_sequence_key: str = "stop"
stop_sequence_key = "stop"
def __init__(self) -> None:
from oci.generative_ai_inference import models

View File

@@ -1,7 +1,7 @@
# LLM Lingua Document Compressor
import re
from typing import Any, Dict, List, Optional, Pattern, Sequence, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple
from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document
@@ -24,8 +24,8 @@ class LLMLinguaCompressor(BaseDocumentCompressor):
# Pattern to match ref tags at the beginning or end of the string,
# allowing for malformed tags
_pattern_beginning: Pattern = re.compile(r"\A(?:<#)?(?:ref)?(\d+)(?:#>?)?")
_pattern_ending: Pattern = re.compile(r"(?:<#)?(?:ref)?(\d+)(?:#>?)?\Z")
_pattern_beginning = re.compile(r"\A(?:<#)?(?:ref)?(\d+)(?:#>?)?")
_pattern_ending = re.compile(r"(?:<#)?(?:ref)?(\d+)(?:#>?)?\Z")
model_name: str = "NousResearch/Llama-2-7b-hf"
"""The hugging face model to use"""

View File

@@ -1,6 +1,6 @@
import re
from pathlib import Path
from typing import Iterator, Pattern, Union
from typing import Iterator, Union
from langchain_core.documents import Document
@@ -10,9 +10,7 @@ from langchain_community.document_loaders.base import BaseLoader
class AcreomLoader(BaseLoader):
"""Load `acreom` vault from a directory."""
FRONT_MATTER_REGEX: Pattern = re.compile(
r"^---\n(.*?)\n---\n", re.MULTILINE | re.DOTALL
)
FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.MULTILINE | re.DOTALL)
"""Regex to match front matter metadata in markdown files."""
def __init__(

View File

@@ -44,13 +44,13 @@ class DocugamiLoader(BaseLoader, BaseModel):
access_token: Optional[str] = os.environ.get("DOCUGAMI_API_KEY")
"""The Docugami API access token to use."""
max_text_length: int = 4096
max_text_length = 4096
"""Max length of chunk text returned."""
min_text_length: int = 32
"""Threshold under which chunks are appended to next to avoid over-chunking."""
max_metadata_length: int = 512
max_metadata_length = 512
"""Max length of metadata text returned."""
include_xml_tags: bool = False

View File

@@ -36,8 +36,8 @@ class HuggingFaceModelLoader(BaseLoader):
print(doc.metadata) # Metadata of the model
"""
BASE_URL: str = "https://huggingface.co/api/models"
README_BASE_URL: str = "https://huggingface.co/{model_id}/raw/main/README.md"
BASE_URL = "https://huggingface.co/api/models"
README_BASE_URL = "https://huggingface.co/{model_id}/raw/main/README.md"
def __init__(
self,

View File

@@ -2,7 +2,7 @@ import functools
import logging
import re
from pathlib import Path
from typing import Any, Dict, Iterator, Pattern, Union
from typing import Any, Dict, Iterator, Union
import yaml
from langchain_core.documents import Document
@@ -15,16 +15,12 @@ logger = logging.getLogger(__name__)
class ObsidianLoader(BaseLoader):
"""Load `Obsidian` files from directory."""
FRONT_MATTER_REGEX: Pattern = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL)
TEMPLATE_VARIABLE_REGEX: Pattern = re.compile(r"{{(.*?)}}", re.DOTALL)
TAG_REGEX: Pattern = re.compile(r"[^\S\/]#([a-zA-Z_]+[-_/\w]*)")
DATAVIEW_LINE_REGEX: Pattern = re.compile(r"^\s*(\w+)::\s*(.*)$", re.MULTILINE)
DATAVIEW_INLINE_BRACKET_REGEX: Pattern = re.compile(
r"\[(\w+)::\s*(.*)\]", re.MULTILINE
)
DATAVIEW_INLINE_PAREN_REGEX: Pattern = re.compile(
r"\((\w+)::\s*(.*)\)", re.MULTILINE
)
FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL)
TEMPLATE_VARIABLE_REGEX = re.compile(r"{{(.*?)}}", re.DOTALL)
TAG_REGEX = re.compile(r"[^\S\/]#([a-zA-Z_]+[-_/\w]*)")
DATAVIEW_LINE_REGEX = re.compile(r"^\s*(\w+)::\s*(.*)$", re.MULTILINE)
DATAVIEW_INLINE_BRACKET_REGEX = re.compile(r"\[(\w+)::\s*(.*)\]", re.MULTILINE)
DATAVIEW_INLINE_PAREN_REGEX = re.compile(r"\((\w+)::\s*(.*)\)", re.MULTILINE)
def __init__(
self,

View File

@@ -39,7 +39,7 @@ class OneNoteLoader(BaseLoader, BaseModel):
"""Personal access token"""
onenote_api_base_url: str = "https://graph.microsoft.com/v1.0/me/onenote"
"""URL of Microsoft Graph API for OneNote"""
authority_url: str = "https://login.microsoftonline.com/consumers/"
authority_url = "https://login.microsoftonline.com/consumers/"
"""A URL that identifies a token authority"""
token_path: FilePath = Path.home() / ".credentials" / "onenote_graph_token.txt"
"""Path to the file where the access token is stored"""

View File

@@ -1,5 +1,5 @@
import re
from typing import Callable, List, Pattern
from typing import Callable, List
from langchain_community.document_loaders.parsers.language.code_segmenter import (
CodeSegmenter,
@@ -9,11 +9,11 @@ from langchain_community.document_loaders.parsers.language.code_segmenter import
class CobolSegmenter(CodeSegmenter):
"""Code segmenter for `COBOL`."""
PARAGRAPH_PATTERN: Pattern = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE)
DIVISION_PATTERN: Pattern = re.compile(
PARAGRAPH_PATTERN = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE)
DIVISION_PATTERN = re.compile(
r"^\s*(IDENTIFICATION|DATA|PROCEDURE|ENVIRONMENT)\s+DIVISION.*$", re.IGNORECASE
)
SECTION_PATTERN: Pattern = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE)
SECTION_PATTERN = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE)
def __init__(self, code: str):
super().__init__(code)

View File

@@ -13,7 +13,6 @@ from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseLoader
from langchain_community.utilities.pebblo import (
APP_DISCOVER_URL,
BATCH_SIZE_BYTES,
CLASSIFIER_URL,
LOADER_DOC_URL,
PEBBLO_CLOUD_URL,
@@ -21,7 +20,6 @@ from langchain_community.utilities.pebblo import (
App,
Doc,
IndexedDocument,
generate_size_based_batches,
get_full_path,
get_loader_full_path,
get_loader_type,
@@ -70,7 +68,6 @@ class PebbloSafeLoader(BaseLoader):
self.source_aggregate_size = 0
self.classifier_url = classifier_url or CLASSIFIER_URL
self.classifier_location = classifier_location
self.batch_size = BATCH_SIZE_BYTES
self.loader_details = {
"loader": loader_name,
"source_path": self.source_path,
@@ -92,37 +89,15 @@ class PebbloSafeLoader(BaseLoader):
list: Documents fetched from load method of the wrapped `loader`.
"""
self.docs = self.loader.load()
# Classify docs in batches
self.classify_in_batches()
self.docs_with_id = self._index_docs()
classified_docs = self._classify_doc(loading_end=True)
self._add_pebblo_specific_metadata(classified_docs)
if self.load_semantic:
self.docs = self._add_semantic_to_docs(classified_docs)
else:
self.docs = self._unindex_docs() # type: ignore
return self.docs
def classify_in_batches(self) -> None:
"""
Classify documents in batches.
This is to avoid API timeouts when sending large number of documents.
Batches are generated based on the page_content size.
"""
batches: List[List[Document]] = generate_size_based_batches(
self.docs, self.batch_size
)
processed_docs: List[Document] = []
total_batches = len(batches)
for i, batch in enumerate(batches):
is_last_batch: bool = i == total_batches - 1
self.docs = batch
self.docs_with_id = self._index_docs()
classified_docs = self._classify_doc(loading_end=is_last_batch)
self._add_pebblo_specific_metadata(classified_docs)
if self.load_semantic:
batch_processed_docs = self._add_semantic_to_docs(classified_docs)
else:
batch_processed_docs = self._unindex_docs()
processed_docs.extend(batch_processed_docs)
self.docs = processed_docs
def lazy_load(self) -> Iterator[Document]:
"""Load documents in lazy fashion.
@@ -556,6 +531,7 @@ class PebbloSafeLoader(BaseLoader):
"full_path", doc_metadata.get("source", self.source_path)
)
)
doc_metadata["pb_id"] = doc.pb_id
doc_metadata["pb_checksum"] = classified_docs.get(doc.pb_id, {}).get(
"pb_checksum", None
)

View File

@@ -27,7 +27,7 @@ class ErnieEmbeddings(BaseModel, Embeddings):
chunk_size: int = 16
model_name: str = "ErnieBot-Embedding-V1"
model_name = "ErnieBot-Embedding-V1"
_lock = threading.Lock()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import json
import re
from hashlib import md5
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Pattern, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Tuple, Union
from langchain_community.graphs.graph_document import GraphDocument
from langchain_community.graphs.graph_store import GraphStore
@@ -63,7 +63,7 @@ class AGEGraph(GraphStore):
}
# precompiled regex for checking chars in graph labels
label_regex: Pattern = re.compile("[^0-9a-zA-Z]+")
label_regex = re.compile("[^0-9a-zA-Z]+")
def __init__(
self, graph_name: str, conf: Dict[str, Any], create: bool = True

View File

@@ -75,7 +75,7 @@ class Arcee(LLM):
model_name=self.model,
)
@root_validator(pre=True)
@root_validator(pre=False)
def validate_environments(cls, values: Dict) -> Dict:
"""Validate Arcee environment variables."""

View File

@@ -39,9 +39,9 @@ class MoonshotCommon(BaseModel):
"""Moonshot API key. Get it here: https://platform.moonshot.cn/console/api-keys"""
model_name: str = Field(default="moonshot-v1-8k", alias="model")
"""Model name. Available models listed here: https://platform.moonshot.cn/pricing"""
max_tokens: int = 1024
max_tokens = 1024
"""Maximum number of tokens to generate."""
temperature: float = 0.3
temperature = 0.3
"""Temperature parameter (higher values make the model more creative)."""
class Config:

View File

@@ -244,7 +244,7 @@ class OCIModelDeploymentTGI(OCIModelDeploymentLLM):
"""Watermarking with `A Watermark for Large Language Models <https://arxiv.org/abs/2301.10226>`_.
Defaults to True."""
return_full_text: bool = False
return_full_text = False
"""Whether to prepend the prompt to the generated text. Defaults to False."""
@property

View File

@@ -26,7 +26,7 @@ class Provider(ABC):
class CohereProvider(Provider):
stop_sequence_key: str = "stop_sequences"
stop_sequence_key = "stop_sequences"
def __init__(self) -> None:
from oci.generative_ai_inference import models
@@ -38,7 +38,7 @@ class CohereProvider(Provider):
class MetaProvider(Provider):
stop_sequence_key: str = "stop"
stop_sequence_key = "stop"
def __init__(self) -> None:
from oci.generative_ai_inference import models

View File

@@ -16,7 +16,7 @@ class SVEndpointHandler:
:param str host_url: Base URL of the DaaS API service
"""
API_BASE_PATH: str = "/api/predict"
API_BASE_PATH = "/api/predict"
def __init__(self, host_url: str):
"""

View File

@@ -41,7 +41,7 @@ class SolarCommon(BaseModel):
model_name: str = Field(default="solar-1-mini-chat", alias="model")
"""Model name. Available models listed here: https://console.upstage.ai/services/solar"""
max_tokens: int = Field(default=1024)
temperature: float = 0.3
temperature = 0.3
class Config:
allow_population_by_field_name = True

View File

@@ -27,7 +27,7 @@ class SupabaseVectorTranslator(Visitor):
]
"""Subset of allowed logical comparators."""
metadata_column: str = "metadata"
metadata_column = "metadata"
def _map_comparator(self, comparator: Comparator) -> str:
"""

View File

@@ -15,71 +15,7 @@ class SearchDepth(Enum):
class TavilySearchAPIRetriever(BaseRetriever):
"""Tavily Search API retriever.
Setup:
Install ``langchain-community`` and set environment variable ``TAVILY_API_KEY``.
.. code-block:: bash
pip install -U langchain-community
export TAVILY_API_KEY="your-api-key"
Key init args:
k: int
Number of results to include.
include_generated_answer: bool
Include a generated answer with results
include_raw_content: bool
Include raw content with results.
include_images: bool
Return images in addition to text.
Instantiate:
.. code-block:: python
from langchain_community.retrievers import TavilySearchAPIRetriever
retriever = TavilySearchAPIRetriever(k=3)
Usage:
.. code-block:: python
query = "what year was breath of the wild released?"
retriever.invoke(query)
Use within a chain:
.. code-block:: python
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
prompt = ChatPromptTemplate.from_template(
\"\"\"Answer the question based only on the context provided.
Context: {context}
Question: {question}\"\"\"
)
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
chain.invoke("how many units did bretch of the wild sell in 2020")
""" # noqa: E501
"""Tavily Search API retriever."""
k: int = 10
include_generated_answer: bool = False

View File

@@ -74,8 +74,8 @@ class BearlyInterpreterTool:
"""Tool for evaluating python code in a sandbox environment."""
api_key: str
endpoint: str = "https://exec.bearly.ai/v1/interpreter"
name: str = "bearly_interpreter"
endpoint = "https://exec.bearly.ai/v1/interpreter"
name = "bearly_interpreter"
args_schema: Type[BaseModel] = BearlyInterpreterToolArguments
files: Dict[str, FileInfo] = {}

View File

@@ -51,12 +51,12 @@ class ZenGuardTool(BaseTool):
"ZenGuard AI integration package. ZenGuard AI - the fastest GenAI guardrails."
)
args_schema = ZenGuardInput
return_direct: bool = True
return_direct = True
zenguard_api_key: Optional[str] = Field(default=None)
_ZENGUARD_API_URL_ROOT: str = "https://api.zenguard.ai/"
_ZENGUARD_API_KEY_ENV_NAME: str = "ZENGUARD_API_KEY"
_ZENGUARD_API_URL_ROOT = "https://api.zenguard.ai/"
_ZENGUARD_API_KEY_ENV_NAME = "ZENGUARD_API_KEY"
@validator("zenguard_api_key", pre=True, always=True, check_fields=False)
def set_api_key(cls, v: str) -> str:

View File

@@ -4,7 +4,7 @@ import logging
import os
import pathlib
import platform
from typing import List, Optional, Tuple
from typing import Optional, Tuple
from langchain_core.documents import Document
from langchain_core.env import get_runtime_environment
@@ -20,7 +20,6 @@ PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai")
LOADER_DOC_URL = "/v1/loader/doc"
APP_DISCOVER_URL = "/v1/app/discover"
BATCH_SIZE_BYTES = 100 * 1024 # 100 KB
# Supported loaders for Pebblo safe data loading
file_loader = [
@@ -302,43 +301,3 @@ def get_ip() -> str:
except Exception:
public_ip = socket.gethostbyname("localhost")
return public_ip
def generate_size_based_batches(
docs: List[Document], max_batch_size: int = 100 * 1024
) -> List[List[Document]]:
"""
Generate batches of documents based on page_content size.
Args:
docs: List of documents to be batched.
max_batch_size: Maximum size of each batch in bytes. Defaults to 100*1024(100KB)
Returns:
List[List[Document]]: List of batches of documents
"""
batches: List[List[Document]] = []
current_batch: List[Document] = []
current_batch_size: int = 0
for doc in docs:
# Calculate the size of the document in bytes
doc_size: int = len(doc.page_content.encode("utf-8"))
if doc_size > max_batch_size:
# If a single document exceeds the max batch size, send it as a single batch
batches.append([doc])
else:
if current_batch_size + doc_size > max_batch_size:
# If adding this document exceeds the max batch size, start a new batch
batches.append(current_batch)
current_batch = []
current_batch_size = 0
# Add document to the current batch
current_batch.append(doc)
current_batch_size += doc_size
# Add the last batch if it has documents
if current_batch:
batches.append(current_batch)
return batches

View File

@@ -11,7 +11,7 @@ class Portkey:
Default: "https://api.portkey.ai/v1/proxy"
"""
base: str = "https://api.portkey.ai/v1/proxy"
base = "https://api.portkey.ai/v1/proxy"
@staticmethod
def Config(

View File

@@ -28,7 +28,7 @@ class TokenEscaper:
# Characters that RediSearch requires us to escape during queries.
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
DEFAULT_ESCAPED_CHARS: str = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"
def __init__(self, escape_chars_re: Optional[Pattern] = None):
if escape_chars_re:

View File

@@ -29,7 +29,7 @@ class AtlasDB(VectorStore):
vectorstore = AtlasDB("my_project", embeddings.embed_query)
"""
_ATLAS_DEFAULT_ID_FIELD: str = "atlas_id"
_ATLAS_DEFAULT_ID_FIELD = "atlas_id"
def __init__(
self,

View File

@@ -21,7 +21,7 @@ DEFAULT_TOPN = 4
class AwaDB(VectorStore):
"""`AwaDB` vector store."""
_DEFAULT_TABLE_NAME: str = "langchain_awadb"
_DEFAULT_TABLE_NAME = "langchain_awadb"
def __init__(
self,

View File

@@ -53,7 +53,7 @@ class Bagel(VectorStore):
vectorstore = Bagel(cluster_name="langchain_store")
"""
_LANGCHAIN_DEFAULT_CLUSTER_NAME: str = "langchain"
_LANGCHAIN_DEFAULT_CLUSTER_NAME = "langchain"
def __init__(
self,

View File

@@ -66,7 +66,7 @@ class Chroma(VectorStore):
vectorstore = Chroma("langchain_store", embeddings)
"""
_LANGCHAIN_DEFAULT_COLLECTION_NAME: str = "langchain"
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
def __init__(
self,

View File

@@ -60,10 +60,10 @@ class CouchbaseVectorStore(VectorStore):
"""
# Default batch size
DEFAULT_BATCH_SIZE: int = 100
_metadata_key: str = "metadata"
_default_text_key: str = "text"
_default_embedding_key: str = "embedding"
DEFAULT_BATCH_SIZE = 100
_metadata_key = "metadata"
_default_text_key = "text"
_default_embedding_key = "embedding"
def _check_bucket_exists(self) -> bool:
"""Check if the bucket exists in the linked Couchbase cluster"""

View File

@@ -51,7 +51,7 @@ class DeepLake(VectorStore):
vectorstore = DeepLake("langchain_store", embeddings.embed_query)
"""
_LANGCHAIN_DEFAULT_DEEPLAKE_PATH: str = "./deeplake/"
_LANGCHAIN_DEFAULT_DEEPLAKE_PATH = "./deeplake/"
_valid_search_kwargs = ["lambda_mult"]
def __init__(

View File

@@ -45,9 +45,9 @@ class Epsilla(VectorStore):
epsilla = Epsilla(client, embeddings, db_path, db_name)
"""
_LANGCHAIN_DEFAULT_DB_NAME: str = "langchain_store"
_LANGCHAIN_DEFAULT_DB_PATH: str = "/tmp/langchain-epsilla"
_LANGCHAIN_DEFAULT_TABLE_NAME: str = "langchain_collection"
_LANGCHAIN_DEFAULT_DB_NAME = "langchain_store"
_LANGCHAIN_DEFAULT_DB_PATH = "/tmp/langchain-epsilla"
_LANGCHAIN_DEFAULT_TABLE_NAME = "langchain_collection"
def __init__(
self,

View File

@@ -13,7 +13,6 @@ from typing import (
Iterable,
List,
Optional,
Pattern,
Tuple,
Type,
)
@@ -224,7 +223,7 @@ class HanaDB(VectorStore):
return embedding
# Compile pattern only once, for better performance
_compiled_pattern: Pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$")
_compiled_pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$")
@staticmethod
def _sanitize_metadata_keys(metadata: dict) -> dict:

View File

@@ -48,7 +48,7 @@ class ManticoreSearchSettings(BaseSettings):
hnsw_m: int = 16 # The default is 16.
# An optional setting that defines a construction time/accuracy trade-off.
hnsw_ef_construction: int = 100
hnsw_ef_construction = 100
def get_connection_string(self) -> str:
return self.proto + "://" + self.host + ":" + str(self.port)

View File

@@ -85,8 +85,8 @@ class Qdrant(VectorStore):
qdrant = Qdrant(client, collection_name, embedding_function)
"""
CONTENT_KEY: str = "page_content"
METADATA_KEY: str = "metadata"
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
VECTOR_NAME = None
def __init__(

View File

@@ -25,7 +25,7 @@ class SemaDB(VectorStore):
"""
HOST: str = "semadb.p.rapidapi.com"
HOST = "semadb.p.rapidapi.com"
BASE_URL = "https://" + HOST
def __init__(

View File

@@ -53,7 +53,7 @@ fi
# Forbid vanilla usage of @root_validator
# This prevents the code from using either @root_validator or @root_validator()
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -En '(@root_validator\s*$)|(@root_validator\(\)|@root_validator\(pre=False\))' -- '*.py')
result=$(git -C "$repository_path" grep -En '(@root_validator\s*$)|(@root_validator\(\))' -- '*.py')
# Check if any matching lines were found
if [ -n "$result" ]; then

View File

@@ -1,10 +1,8 @@
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain_community.chat_models.hunyuan import ChatHunyuan
@pytest.mark.requires("tencentcloud-sdk-python")
def test_chat_hunyuan() -> None:
chat = ChatHunyuan()
message = HumanMessage(content="Hello")
@@ -13,7 +11,6 @@ def test_chat_hunyuan() -> None:
assert isinstance(response.content, str)
@pytest.mark.requires("tencentcloud-sdk-python")
def test_chat_hunyuan_with_temperature() -> None:
chat = ChatHunyuan(temperature=0.6)
message = HumanMessage(content="Hello")
@@ -22,24 +19,6 @@ def test_chat_hunyuan_with_temperature() -> None:
assert isinstance(response.content, str)
@pytest.mark.requires("tencentcloud-sdk-python")
def test_chat_hunyuan_with_model_name() -> None:
chat = ChatHunyuan(model="hunyuan-standard")
message = HumanMessage(content="Hello")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
@pytest.mark.requires("tencentcloud-sdk-python")
def test_chat_hunyuan_with_stream() -> None:
chat = ChatHunyuan(streaming=True)
message = HumanMessage(content="Hello")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_extra_kwargs() -> None:
chat = ChatHunyuan(temperature=0.88, top_p=0.7)
assert chat.temperature == 0.88

View File

@@ -23,9 +23,9 @@ def mock_quip(): # type: ignore
@pytest.mark.requires("quip_api")
class TestQuipLoader:
API_URL: str = "https://example-api.quip.com"
API_URL = "https://example-api.quip.com"
DOC_URL_PREFIX = ("https://example.quip.com",)
ACCESS_TOKEN: str = "api_token"
ACCESS_TOKEN = "api_token"
MOCK_FOLDER_IDS = ["ABC"]
MOCK_THREAD_IDS = ["ABC", "DEF"]

View File

@@ -59,8 +59,8 @@ def test_custom_formatter() -> None:
"""Test ability to create a custom content formatter."""
class CustomFormatter(ContentFormatterBase):
content_type: str = "application/json"
accepts: str = "application/json"
content_type = "application/json"
accepts = "application/json"
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
input_str = json.dumps(
@@ -101,8 +101,8 @@ def test_invalid_request_format() -> None:
"""Test invalid request format."""
class CustomContentFormatter(ContentFormatterBase):
content_type: str = "application/json"
accepts: str = "application/json"
content_type = "application/json"
accepts = "application/json"
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
input_str = json.dumps(

View File

@@ -13,19 +13,19 @@ README_PATH = Path(__file__).parents[4] / "README.md"
class FakeUploadResponse:
status_code: int = 200
text: str = "fake_uuid"
status_code = 200
text = "fake_uuid"
class FakePushResponse:
status_code: int = 200
status_code = 200
def json(self) -> Any:
return {"uuid": "fake_uuid"}
class FakePullResponse:
status_code: int = 200
status_code = 200
def json(self) -> Any:
return {

View File

@@ -8,25 +8,27 @@ from langchain_core.messages import (
HumanMessageChunk,
SystemMessage,
)
from langchain_core.pydantic_v1 import SecretStr
from langchain_community.chat_models.hunyuan import (
_convert_delta_to_message_chunk,
_convert_dict_to_message,
_convert_message_to_dict,
_signature,
)
def test__convert_message_to_dict_human() -> None:
message = HumanMessage(content="foo")
result = _convert_message_to_dict(message)
expected_output = {"Role": "user", "Content": "foo"}
expected_output = {"role": "user", "content": "foo"}
assert result == expected_output
def test__convert_message_to_dict_ai() -> None:
message = AIMessage(content="foo")
result = _convert_message_to_dict(message)
expected_output = {"Role": "assistant", "Content": "foo"}
expected_output = {"role": "assistant", "content": "foo"}
assert result == expected_output
@@ -45,35 +47,68 @@ def test__convert_message_to_dict_function() -> None:
def test__convert_dict_to_message_human() -> None:
message_dict = {"Role": "user", "Content": "foo"}
message_dict = {"role": "user", "content": "foo"}
result = _convert_dict_to_message(message_dict)
expected_output = HumanMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_ai() -> None:
message_dict = {"Role": "assistant", "Content": "foo"}
message_dict = {"role": "assistant", "content": "foo"}
result = _convert_dict_to_message(message_dict)
expected_output = AIMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_other_role() -> None:
message_dict = {"Role": "system", "Content": "foo"}
message_dict = {"role": "system", "content": "foo"}
result = _convert_dict_to_message(message_dict)
expected_output = ChatMessage(role="system", content="foo")
assert result == expected_output
def test__convert_delta_to_message_assistant() -> None:
delta = {"Role": "assistant", "Content": "foo"}
delta = {"role": "assistant", "content": "foo"}
result = _convert_delta_to_message_chunk(delta, AIMessageChunk)
expected_output = AIMessageChunk(content="foo")
assert result == expected_output
def test__convert_delta_to_message_human() -> None:
delta = {"Role": "user", "Content": "foo"}
delta = {"role": "user", "content": "foo"}
result = _convert_delta_to_message_chunk(delta, HumanMessageChunk)
expected_output = HumanMessageChunk(content="foo")
assert result == expected_output
def test__signature() -> None:
secret_key = SecretStr("YOUR_SECRET_KEY")
url = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
result = _signature(
secret_key=secret_key,
url=url,
payload={
"app_id": "YOUR_APP_ID",
"secret_id": "YOUR_SECRET_ID",
"query_id": "test_query_id_cb5d8156-0ce2-45af-86b4-d02f5c26a142",
"messages": [
{
"role": "user",
"content": "You are a helpful assistant that translates English"
" to French.Translate this sentence from English to"
" French. I love programming.",
}
],
"temperature": 0.0,
"top_p": 0.8,
"stream": 1,
"timestamp": 1697738378,
"expired": 1697824778,
},
)
# The signature was generated by the demo provided by Huanyuan.
# https://hunyuan-sdk-1256237915.cos.ap-guangzhou.myqcloud.com/python.zip
expected_output = "MXBvqNCXyxJWfEyBwk1pYBVnxzo="
assert result == expected_output

View File

@@ -11,7 +11,7 @@ LOG = logging.getLogger(__name__)
class TestChatKinetica:
test_ctx_json: str = """
test_ctx_json = """
{
"payload":{
"context":[

View File

@@ -20,10 +20,10 @@ def mock_confluence(): # type: ignore
@pytest.mark.requires("atlassian", "bs4", "lxml")
class TestConfluenceLoader:
CONFLUENCE_URL: str = "https://example.atlassian.com/wiki"
MOCK_USERNAME: str = "user@gmail.com"
MOCK_API_TOKEN: str = "api_token"
MOCK_SPACE_KEY: str = "spaceId123"
CONFLUENCE_URL = "https://example.atlassian.com/wiki"
MOCK_USERNAME = "user@gmail.com"
MOCK_API_TOKEN = "api_token"
MOCK_SPACE_KEY = "spaceId123"
def test_confluence_loader_initialization(self, mock_confluence: MagicMock) -> None:
ConfluenceLoader(

View File

@@ -57,12 +57,12 @@ def mock_lakefs_client_no_presign_local() -> Any:
class TestLakeFSLoader(unittest.TestCase):
lakefs_access_key: str = "lakefs_access_key"
lakefs_secret_key: str = "lakefs_secret_key"
endpoint: str = "endpoint"
repo: str = "repo"
ref: str = "ref"
path: str = "path"
lakefs_access_key = "lakefs_access_key"
lakefs_secret_key = "lakefs_secret_key"
endpoint = "endpoint"
repo = "repo"
ref = "ref"
path = "path"
@requests_mock.Mocker()
@pytest.mark.usefixtures("mock_lakefs_client_no_presign_not_local")

View File

@@ -69,6 +69,7 @@ def test_csv_loader_load_valid_data(mocker: MockerFixture) -> None:
"source": full_file_path,
"row": 0,
"full_path": full_file_path,
"pb_id": "0",
# For UT as here we are not calculating checksum
"pb_checksum": None,
},
@@ -79,6 +80,7 @@ def test_csv_loader_load_valid_data(mocker: MockerFixture) -> None:
"source": full_file_path,
"row": 1,
"full_path": full_file_path,
"pb_id": "1",
# For UT as here we are not calculating checksum
"pb_checksum": None,
},

View File

@@ -21,9 +21,9 @@ def mock_connector_id(): # type: ignore
@pytest.mark.requires("psychicapi")
class TestPsychicLoader:
MOCK_API_KEY: str = "api_key"
MOCK_CONNECTOR_ID: str = "notion"
MOCK_ACCOUNT_ID: str = "account_id"
MOCK_API_KEY = "api_key"
MOCK_CONNECTOR_ID = "notion"
MOCK_ACCOUNT_ID = "account_id"
def test_psychic_loader_initialization(
self, mock_psychic: MagicMock, mock_connector_id: MagicMock

View File

@@ -4,9 +4,9 @@ from langchain_community.document_loaders.rspace import RSpaceLoader
class TestRSpaceLoader(unittest.TestCase):
url: str = "https://community.researchspace.com"
api_key: str = "myapikey"
global_id: str = "SD12345"
url = "https://community.researchspace.com"
api_key = "myapikey"
global_id = "SD12345"
def test_valid_arguments(self) -> None:
loader = RSpaceLoader(

View File

@@ -70,7 +70,7 @@ class MockGradientaiPackage(MagicMock):
"""Mock Gradientai package."""
Gradient = MockGradient
__version__: str = "1.4.0"
__version__ = "1.4.0"
def test_gradient_llm_sync() -> None:

View File

@@ -5,7 +5,7 @@ from langchain_community.embeddings.ollama import OllamaEmbeddings
class MockResponse:
status_code: int = 200
status_code = 200
def json(self) -> dict:
return {"embedding": [1, 2, 3]}

View File

@@ -65,7 +65,7 @@ from langchain_core.runnables import RunnableMap, RunnablePassthrough
from langchain_core.runnables.config import ensure_config, run_in_executor
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING:
from langchain_core.output_parsers.base import OutputParserLike
@@ -1074,14 +1074,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
raise NotImplementedError()
def with_structured_output(
self,
schema: Union[Dict, Type],
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
**kwargs: Any,
@@ -1089,25 +1089,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema:
The output schema. Can be passed in as:
- an OpenAI function/tool schema,
- a JSON Schema,
- a TypedDict class (support added in 0.2.26),
- or a Pydantic class.
If ``schema`` is a Pydantic class then the model output will be a
Pydantic instance of that class, and the model-generated fields will be
validated by the Pydantic class. Otherwise the model output will be a
dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`
for more on how to properly specify types and descriptions of
schema fields when specifying a Pydantic or TypedDict class.
.. versionchanged:: 0.2.26
Added support for TypedDict class.
include_raw:
If False then only the parsed structured output is returned. If
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
then the model output will be an object of that class. If a dict then
the model output will be a dict. With a Pydantic class the returned
attributes will be validated, whereas with a dict they will not be. If
`method` is "function_calling" and `schema` is a dict, then the dict
must match the OpenAI function-calling spec.
include_raw: If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
@@ -1115,19 +1103,22 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
with keys "raw", "parsed", and "parsing_error".
Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
A Runnable that takes any ChatModel input and returns as output:
If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs
an instance of ``schema`` (i.e., a Pydantic object).
If include_raw is True then a dict with keys:
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
If include_raw is False then just _DictOrPydantic is returned,
where _DictOrPydantic depends on the schema:
If ``include_raw`` is True, then Runnable outputs a dict with keys:
- ``"raw"``: BaseMessage
- ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
- ``"parsing_error"``: Optional[BaseException]
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
class.
Example: Pydantic schema (include_raw=False):
If schema is a dict then _DictOrPydantic is a dict.
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=False):
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel
@@ -1147,7 +1138,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
# )
Example: Pydantic schema (include_raw=True):
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=True):
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel
@@ -1167,7 +1158,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
# 'parsing_error': None
# }
Example: Dict schema (include_raw=False):
Example: Function-calling, dict schema (method="function_calling", include_raw=False):
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel
@@ -1203,7 +1194,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
llm = self.bind_tools([schema], tool_choice="any")
if isinstance(schema, type) and is_basemodel_subclass(schema):
output_parser: OutputParserLike = PydanticToolsParser(
tools=[cast(TypeBaseModel, schema)], first_tool_only=True
tools=[schema], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]

View File

@@ -88,7 +88,6 @@ from langchain_core.runnables.config import (
run_in_executor,
)
from langchain_core.runnables.utils import accepts_context
from langchain_core.utils.function_calling import _parse_google_docstring
from langchain_core.utils.pydantic import (
TypeBaseModel,
_create_subset_model,
@@ -141,12 +140,50 @@ def _parse_python_function_docstring(
Assumes the function docstring follows Google Python style guide.
"""
docstring = inspect.getdoc(function)
return _parse_google_docstring(
docstring,
list(annotations),
error_on_invalid_docstring=error_on_invalid_docstring,
invalid_docstring_error = ValueError(
f"Found invalid Google-Style docstring for {function}."
)
docstring = inspect.getdoc(function)
if docstring:
docstring_blocks = docstring.split("\n\n")
if error_on_invalid_docstring:
filtered_annotations = {
arg for arg in annotations if arg not in (*(FILTERED_ARGS), "return")
}
if filtered_annotations and (
len(docstring_blocks) < 2 or not docstring_blocks[1].startswith("Args:")
):
raise (invalid_docstring_error)
descriptors = []
args_block = None
past_descriptors = False
for block in docstring_blocks:
if block.startswith("Args:"):
args_block = block
break
elif block.startswith("Returns:") or block.startswith("Example:"):
# Don't break in case Args come after
past_descriptors = True
elif not past_descriptors:
descriptors.append(block)
else:
continue
description = " ".join(descriptors)
else:
if error_on_invalid_docstring:
raise (invalid_docstring_error)
description = ""
args_block = None
arg_descriptions = {}
if args_block:
arg = None
for line in args_block.split("\n")[1:]:
if ":" in line:
arg, desc = line.split(":", maxsplit=1)
arg_descriptions[arg.strip()] = desc.strip()
elif arg:
arg_descriptions[arg.strip()] += " " + line.strip()
return description, arg_descriptions
def _validate_docstring_args_against_annotations(

View File

@@ -2,10 +2,7 @@
from __future__ import annotations
import collections
import inspect
import logging
import typing
import uuid
from typing import (
TYPE_CHECKING,
@@ -15,26 +12,27 @@ from typing import (
List,
Literal,
Optional,
Set,
Tuple,
Type,
Union,
cast,
)
from typing_extensions import Annotated, TypedDict, get_args, get_origin, is_typeddict
from typing_extensions import TypedDict
from langchain_core._api import deprecated
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
ToolMessage,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.json_schema import dereference_refs
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING:
from langchain_core.tools import BaseTool
logger = logging.getLogger(__name__)
PYTHON_TO_JSON_TYPES = {
"str": "string",
"int": "integer",
@@ -190,81 +188,6 @@ def convert_python_function_to_openai_function(
)
def _convert_typed_dict_to_openai_function(typed_dict: Type) -> FunctionDescription:
visited: Dict = {}
model = cast(
Type[BaseModel],
_convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited),
)
return convert_pydantic_to_openai_function(model)
_MAX_TYPED_DICT_RECURSION = 25
def _convert_any_typed_dicts_to_pydantic(
type_: Type,
*,
visited: Dict,
depth: int = 0,
) -> Type:
if type_ in visited:
return visited[type_]
elif depth >= _MAX_TYPED_DICT_RECURSION:
return type_
elif is_typeddict(type_):
typed_dict = type_
docstring = inspect.getdoc(typed_dict)
annotations_ = typed_dict.__annotations__
description, arg_descriptions = _parse_google_docstring(
docstring, list(annotations_)
)
fields: dict = {}
for arg, arg_type in annotations_.items():
if get_origin(arg_type) is Annotated:
annotated_args = get_args(arg_type)
new_arg_type = _convert_any_typed_dicts_to_pydantic(
annotated_args[0], depth=depth + 1, visited=visited
)
field_kwargs = {
k: v for k, v in zip(("default", "description"), annotated_args[1:])
}
if (field_desc := field_kwargs.get("description")) and not isinstance(
field_desc, str
):
raise ValueError(
f"Invalid annotation for field {arg}. Third argument to "
f"Annotated must be a string description, received value of "
f"type {type(field_desc)}."
)
elif arg_desc := arg_descriptions.get(arg):
field_kwargs["description"] = arg_desc
else:
pass
fields[arg] = (new_arg_type, Field(**field_kwargs))
else:
new_arg_type = _convert_any_typed_dicts_to_pydantic(
arg_type, depth=depth + 1, visited=visited
)
field_kwargs = {"default": ...}
if arg_desc := arg_descriptions.get(arg):
field_kwargs["description"] = arg_desc
fields[arg] = (new_arg_type, Field(**field_kwargs))
model = create_model(typed_dict.__name__, **fields)
model.__doc__ = description
visited[typed_dict] = model
return model
elif (origin := get_origin(type_)) and (type_args := get_args(type_)):
subscriptable_origin = _py_38_safe_origin(origin)
type_args = tuple(
_convert_any_typed_dicts_to_pydantic(arg, depth=depth + 1, visited=visited)
for arg in type_args
)
return subscriptable_origin[type_args]
else:
return type_
@deprecated(
"0.1.16",
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
@@ -321,22 +244,22 @@ def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
def convert_to_openai_function(
function: Union[Dict[str, Any], Type, Callable, BaseTool],
function: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
) -> Dict[str, Any]:
"""Convert a raw function/class to an OpenAI function.
Args:
function: A dictionary, Pydantic BaseModel class, TypedDict class, a LangChain
Tool object, or a Python function. If a dictionary is passed in, it is
assumed to already be a valid OpenAI function or a JSON schema with
top-level 'title' and 'description' keys specified.
function: Either a dictionary, a pydantic.BaseModel class, or a Python function.
If a dictionary is passed in, it is assumed to already be a valid OpenAI
function or a JSON schema with top-level 'title' and 'description' keys
specified.
Returns:
A dict version of the passed in function which is compatible with the OpenAI
function-calling API.
A dict version of the passed in function which is compatible with the
OpenAI function-calling API.
Raises:
ValueError: If function is not in a supported format.
ValueError: If the function is not in a supported format.
"""
from langchain_core.tools import BaseTool
@@ -357,8 +280,6 @@ def convert_to_openai_function(
}
elif isinstance(function, type) and is_basemodel_subclass(function):
return cast(Dict, convert_pydantic_to_openai_function(function))
elif is_typeddict(function):
return cast(Dict, _convert_typed_dict_to_openai_function(cast(Type, function)))
elif isinstance(function, BaseTool):
return cast(Dict, format_tool_to_openai_function(function))
elif callable(function):
@@ -484,69 +405,3 @@ def tool_example_to_messages(
for output, tool_call_dict in zip(tool_outputs, openai_tool_calls):
messages.append(ToolMessage(content=output, tool_call_id=tool_call_dict["id"])) # type: ignore
return messages
def _parse_google_docstring(
docstring: Optional[str],
args: List[str],
*,
error_on_invalid_docstring: bool = False,
) -> Tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide.
"""
if docstring:
docstring_blocks = docstring.split("\n\n")
if error_on_invalid_docstring:
filtered_annotations = {
arg for arg in args if arg not in ("run_manager", "callbacks", "return")
}
if filtered_annotations and (
len(docstring_blocks) < 2 or not docstring_blocks[1].startswith("Args:")
):
raise ValueError("Found invalid Google-Style docstring.")
descriptors = []
args_block = None
past_descriptors = False
for block in docstring_blocks:
if block.startswith("Args:"):
args_block = block
break
elif block.startswith("Returns:") or block.startswith("Example:"):
# Don't break in case Args come after
past_descriptors = True
elif not past_descriptors:
descriptors.append(block)
else:
continue
description = " ".join(descriptors)
else:
if error_on_invalid_docstring:
raise ValueError("Found invalid Google-Style docstring.")
description = ""
args_block = None
arg_descriptions = {}
if args_block:
arg = None
for line in args_block.split("\n")[1:]:
if ":" in line:
arg, desc = line.split(":", maxsplit=1)
arg_descriptions[arg.strip()] = desc.strip()
elif arg:
arg_descriptions[arg.strip()] += " " + line.strip()
return description, arg_descriptions
def _py_38_safe_origin(origin: Type) -> Type:
origin_map: Dict[Type, Any] = {
dict: Dict,
list: List,
tuple: Tuple,
set: Set,
collections.abc.Iterable: typing.Iterable,
collections.abc.Mapping: typing.Mapping,
collections.abc.Sequence: typing.Sequence,
collections.abc.MutableMapping: typing.MutableMapping,
}
return cast(Type, origin_map.get(origin, origin))

View File

@@ -9,7 +9,10 @@ from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
import pydantic # pydantic: ignore
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.pydantic_v1 import (
BaseModel,
root_validator,
)
def get_pydantic_major_version() -> int:

2
libs/core/poetry.lock generated
View File

@@ -3048,4 +3048,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "982544b8c6dee38e7b8e7a124794f731881f38556ed1eb6c0c4f4182f55d2ecb"
content-hash = "74434689496616068607618a26e7de242b0bb5c3123463566f4976ab28cc747f"

View File

@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "langchain-core"
version = "0.2.26"
version = "0.2.25"
description = "Building applications with LLMs through composability"
authors = []
license = "MIT"
@@ -13,7 +13,7 @@ repository = "https://github.com/langchain-ai/langchain"
[tool.mypy]
disallow_untyped_defs = "True"
exclude = [ "notebooks", "examples", "example_data", "langchain_core/pydantic", "tests/unit_tests/utils/test_function_calling.py",]
exclude = [ "notebooks", "examples", "example_data", "langchain_core/pydantic",]
[[tool.mypy.overrides]]
module = [ "numpy", "pytest",]
ignore_missing_imports = true
@@ -29,7 +29,6 @@ tenacity = "^8.1.0,!=8.4.0"
jsonpatch = "^1.33"
PyYAML = ">=5.3"
packaging = ">=23.2,<25"
typing-extensions = ">=4.7"
[[tool.poetry.dependencies.pydantic]]
version = ">=1,<3"
python = "<3.12.4"
@@ -74,17 +73,20 @@ optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.5"
[tool.poetry.group.typing.dependencies]
mypy = ">=1.10,<1.11"
types-pyyaml = "^6.0.12.2"
types-requests = "^2.28.11.5"
types-jinja2 = "^2.11.9"
[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"
setuptools = "^67.6.1"
grandalf = "^0.8"
[tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
freezegun = "^1.2.2"
@@ -103,12 +105,15 @@ python = "<3.12"
version = "^1.26.0"
python = ">=3.12"
[tool.poetry.group.test_integration.dependencies]
[tool.poetry.group.typing.dependencies.langchain-text-splitters]
path = "../text-splitters"
develop = true
[tool.poetry.group.test.dependencies.langchain-standard-tests]
path = "../standard-tests"
develop = true

View File

@@ -7,13 +7,9 @@ from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.stubs import (
AnyStr,
_AnyIdAIMessage,
_AnyIdAIMessageChunk,
_AnyIdHumanMessage,
)
from tests.unit_tests.stubs import AnyStr
def test_generic_fake_chat_model_invoke() -> None:
@@ -21,11 +17,11 @@ def test_generic_fake_chat_model_invoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = model.invoke("meow")
assert response == _AnyIdAIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = model.invoke("kitty")
assert response == _AnyIdAIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = model.invoke("meow")
assert response == _AnyIdAIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
async def test_generic_fake_chat_model_ainvoke() -> None:
@@ -33,11 +29,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = await model.ainvoke("meow")
assert response == _AnyIdAIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = await model.ainvoke("kitty")
assert response == _AnyIdAIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = await model.ainvoke("meow")
assert response == _AnyIdAIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
async def test_generic_fake_chat_model_stream() -> None:
@@ -50,17 +46,17 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=infinite_cycle)
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
assert len({chunk.id for chunk in chunks}) == 1
chunks = [chunk for chunk in model.stream("meow")]
assert chunks == [
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
assert len({chunk.id for chunk in chunks}) == 1
@@ -145,9 +141,9 @@ async def test_generic_fake_chat_model_astream_log() -> None:
]
final = log_patches[-1]
assert final.state["streamed_output"] == [
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1
@@ -196,9 +192,9 @@ async def test_callback_handlers() -> None:
# New model
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
assert results == [
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
assert tokens == ["hello", " ", "goodbye"]
assert len({chunk.id for chunk in results}) == 1
@@ -207,6 +203,8 @@ async def test_callback_handlers() -> None:
def test_chat_model_inputs() -> None:
fake = ParrotFakeChatModel()
assert fake.invoke("hello") == _AnyIdHumanMessage(content="hello")
assert fake.invoke([("ai", "blah")]) == _AnyIdAIMessage(content="blah")
assert fake.invoke([AIMessage(content="blah")]) == _AnyIdAIMessage(content="blah")
assert fake.invoke("hello") == HumanMessage(content="hello", id=AnyStr())
assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah", id=AnyStr())
assert fake.invoke([AIMessage(content="blah")]) == AIMessage(
content="blah", id=AnyStr()
)

View File

@@ -24,7 +24,7 @@ from tests.unit_tests.fake.callbacks import (
FakeAsyncCallbackHandler,
FakeCallbackHandler,
)
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
from tests.unit_tests.stubs import AnyStr
@pytest.fixture
@@ -144,10 +144,10 @@ async def test_astream_fallback_to_ainvoke() -> None:
model = ModelWithGenerate()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == [_AnyIdAIMessage(content="hello")]
assert chunks == [AIMessage(content="hello", id=AnyStr())]
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == [_AnyIdAIMessage(content="hello")]
assert chunks == [AIMessage(content="hello", id=AnyStr())]
async def test_astream_implementation_fallback_to_stream() -> None:
@@ -182,15 +182,15 @@ async def test_astream_implementation_fallback_to_stream() -> None:
model = ModelWithSyncStream()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == [
_AnyIdAIMessageChunk(content="a"),
_AnyIdAIMessageChunk(content="b"),
AIMessageChunk(content="a", id=AnyStr()),
AIMessageChunk(content="b", id=AnyStr()),
]
assert len({chunk.id for chunk in chunks}) == 1
assert type(model)._astream == BaseChatModel._astream
astream_chunks = [chunk async for chunk in model.astream("anything")]
assert astream_chunks == [
_AnyIdAIMessageChunk(content="a"),
_AnyIdAIMessageChunk(content="b"),
AIMessageChunk(content="a", id=AnyStr()),
AIMessageChunk(content="b", id=AnyStr()),
]
assert len({chunk.id for chunk in astream_chunks}) == 1
@@ -227,8 +227,8 @@ async def test_astream_implementation_uses_astream() -> None:
model = ModelWithAsyncStream()
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == [
_AnyIdAIMessageChunk(content="a"),
_AnyIdAIMessageChunk(content="b"),
AIMessageChunk(content="a", id=AnyStr()),
AIMessageChunk(content="b", id=AnyStr()),
]
assert len({chunk.id for chunk in chunks}) == 1

View File

@@ -37,6 +37,7 @@ from langchain_core.language_models import (
from langchain_core.load import dumpd, dumps
from langchain_core.load.load import loads
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
HumanMessage,
SystemMessage,
@@ -89,7 +90,7 @@ from langchain_core.tracers import (
RunLogPatch,
)
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk
from tests.unit_tests.stubs import AnyStr
class FakeTracer(BaseTracer):
@@ -1824,7 +1825,7 @@ def test_prompt_with_chat_model(
tracer = FakeTracer()
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == _AnyIdAIMessage(content="foo")
) == AIMessage(content="foo", id=AnyStr())
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
@@ -1849,8 +1850,8 @@ def test_prompt_with_chat_model(
],
dict(callbacks=[tracer]),
) == [
_AnyIdAIMessage(content="foo"),
_AnyIdAIMessage(content="foo"),
AIMessage(content="foo", id=AnyStr()),
AIMessage(content="foo", id=AnyStr()),
]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
@@ -1890,9 +1891,9 @@ def test_prompt_with_chat_model(
assert [
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
] == [
_AnyIdAIMessageChunk(content="f"),
_AnyIdAIMessageChunk(content="o"),
_AnyIdAIMessageChunk(content="o"),
AIMessageChunk(content="f", id=AnyStr()),
AIMessageChunk(content="o", id=AnyStr()),
AIMessageChunk(content="o", id=AnyStr()),
]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
@@ -1930,7 +1931,7 @@ async def test_prompt_with_chat_model_async(
tracer = FakeTracer()
assert await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == _AnyIdAIMessage(content="foo")
) == AIMessage(content="foo", id=AnyStr())
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
@@ -1955,8 +1956,8 @@ async def test_prompt_with_chat_model_async(
],
dict(callbacks=[tracer]),
) == [
_AnyIdAIMessage(content="foo"),
_AnyIdAIMessage(content="foo"),
AIMessage(content="foo", id=AnyStr()),
AIMessage(content="foo", id=AnyStr()),
]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
@@ -1999,9 +2000,9 @@ async def test_prompt_with_chat_model_async(
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
] == [
_AnyIdAIMessageChunk(content="f"),
_AnyIdAIMessageChunk(content="o"),
_AnyIdAIMessageChunk(content="o"),
AIMessageChunk(content="f", id=AnyStr()),
AIMessageChunk(content="o", id=AnyStr()),
AIMessageChunk(content="o", id=AnyStr()),
]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
@@ -2668,7 +2669,7 @@ def test_prompt_with_chat_model_and_parser(
HumanMessage(content="What is your name?"),
]
)
assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar")
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar", id=AnyStr())
assert tracer.runs == snapshot
@@ -2803,7 +2804,7 @@ What is your name?"""
),
]
)
assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar")
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar", id=AnyStr())
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 4
@@ -2849,7 +2850,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == {
"chat": _AnyIdAIMessage(content="i'm a chatbot"),
"chat": AIMessage(content="i'm a chatbot", id=AnyStr()),
"llm": "i'm a textbot",
}
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
@@ -3059,7 +3060,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == {
"chat": _AnyIdAIMessage(content="i'm a chatbot"),
"chat": AIMessage(content="i'm a chatbot", id=AnyStr()),
"llm": "i'm a textbot",
"passthrough": ChatPromptValue(
messages=[
@@ -3268,7 +3269,7 @@ async def test_map_astream() -> None:
assert streamed_chunks[0] in [
{"passthrough": prompt.invoke({"question": "What is your name?"})},
{"llm": "i"},
{"chat": _AnyIdAIMessageChunk(content="i")},
{"chat": AIMessageChunk(content="i", id=AnyStr())},
]
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
assert all(len(c.keys()) == 1 for c in streamed_chunks)

View File

@@ -30,7 +30,7 @@ from langchain_core.runnables import (
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import tool
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
from tests.unit_tests.stubs import AnyStr
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]:
@@ -461,7 +461,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@@ -470,7 +470,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@@ -479,7 +479,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@@ -488,7 +488,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"output": _AnyIdAIMessageChunk(content="hello world!")},
"data": {"output": AIMessageChunk(content="hello world!", id=AnyStr())},
"event": "on_chat_model_end",
"metadata": {"a": "b"},
"name": "my_model",
@@ -526,7 +526,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@@ -535,7 +535,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@@ -544,7 +544,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@@ -560,7 +560,9 @@ async def test_astream_events_from_model() -> None:
[
{
"generation_info": None,
"message": _AnyIdAIMessage(content="hello world!"),
"message": AIMessage(
content="hello world!", id=AnyStr()
),
"text": "hello world!",
"type": "ChatGeneration",
}
@@ -578,7 +580,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
"event": "on_chain_stream",
"metadata": {},
"name": "i_dont_stream",
@@ -587,7 +589,7 @@ async def test_astream_events_from_model() -> None:
"tags": [],
},
{
"data": {"output": _AnyIdAIMessage(content="hello world!")},
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
"event": "on_chain_end",
"metadata": {},
"name": "i_dont_stream",
@@ -625,7 +627,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@@ -634,7 +636,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@@ -643,7 +645,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@@ -659,7 +661,9 @@ async def test_astream_events_from_model() -> None:
[
{
"generation_info": None,
"message": _AnyIdAIMessage(content="hello world!"),
"message": AIMessage(
content="hello world!", id=AnyStr()
),
"text": "hello world!",
"type": "ChatGeneration",
}
@@ -677,7 +681,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
"event": "on_chain_stream",
"metadata": {},
"name": "ai_dont_stream",
@@ -686,7 +690,7 @@ async def test_astream_events_from_model() -> None:
"tags": [],
},
{
"data": {"output": _AnyIdAIMessage(content="hello world!")},
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
"event": "on_chain_end",
"metadata": {},
"name": "ai_dont_stream",

View File

@@ -50,7 +50,7 @@ from langchain_core.runnables.schema import StreamEvent
from langchain_core.runnables.utils import Input, Output
from langchain_core.tools import tool
from langchain_core.utils.aiter import aclosing
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
from tests.unit_tests.stubs import AnyStr
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]:
@@ -512,7 +512,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@@ -521,7 +521,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@@ -530,7 +530,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@@ -540,7 +540,7 @@ async def test_astream_events_from_model() -> None:
},
{
"data": {
"output": _AnyIdAIMessageChunk(content="hello world!"),
"output": AIMessageChunk(content="hello world!", id=AnyStr()),
},
"event": "on_chat_model_end",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
@@ -596,7 +596,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@@ -605,7 +605,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@@ -614,7 +614,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@@ -625,7 +625,7 @@ async def test_astream_with_model_in_chain() -> None:
{
"data": {
"input": {"messages": [[HumanMessage(content="hello")]]},
"output": _AnyIdAIMessage(content="hello world!"),
"output": AIMessage(content="hello world!", id=AnyStr()),
},
"event": "on_chat_model_end",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
@@ -635,7 +635,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
"event": "on_chain_stream",
"metadata": {},
"name": "i_dont_stream",
@@ -644,7 +644,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": [],
},
{
"data": {"output": _AnyIdAIMessage(content="hello world!")},
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
"event": "on_chain_end",
"metadata": {},
"name": "i_dont_stream",
@@ -682,7 +682,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="hello")},
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@@ -691,7 +691,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content=" ")},
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@@ -700,7 +700,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessageChunk(content="world!")},
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
@@ -711,7 +711,7 @@ async def test_astream_with_model_in_chain() -> None:
{
"data": {
"input": {"messages": [[HumanMessage(content="hello")]]},
"output": _AnyIdAIMessage(content="hello world!"),
"output": AIMessage(content="hello world!", id=AnyStr()),
},
"event": "on_chat_model_end",
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
@@ -721,7 +721,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": _AnyIdAIMessage(content="hello world!")},
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
"event": "on_chain_stream",
"metadata": {},
"name": "ai_dont_stream",
@@ -730,7 +730,7 @@ async def test_astream_with_model_in_chain() -> None:
"tags": [],
},
{
"data": {"output": _AnyIdAIMessage(content="hello world!")},
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
"event": "on_chain_end",
"metadata": {},
"name": "ai_dont_stream",

View File

@@ -1,44 +1,6 @@
from typing import Any
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
class AnyStr(str):
def __eq__(self, other: Any) -> bool:
return isinstance(other, str)
# The code below creates version of pydantic models
# that will work in unit tests with AnyStr as id field
# Please note that the `id` field is assigned AFTER the model is created
# to workaround an issue with pydantic ignoring the __eq__ method on
# subclassed strings.
def _AnyIdDocument(**kwargs: Any) -> Document:
"""Create a document with an id field."""
message = Document(**kwargs)
message.id = AnyStr()
return message
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage:
"""Create ai message with an any id field."""
message = AIMessage(**kwargs)
message.id = AnyStr()
return message
def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk:
"""Create ai message with an any id field."""
message = AIMessageChunk(**kwargs)
message.id = AnyStr()
return message
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
"""Create a human with an any id field."""
message = HumanMessage(**kwargs)
message.id = AnyStr()
return message

View File

@@ -1,44 +1,16 @@
# mypy: disable-error-code="annotation-unchecked"
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Mapping,
MutableMapping,
MutableSet,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)
from typing import TypedDict as TypingTypedDict
from typing import Any, Callable, Dict, List, Literal, Optional, Type
import pytest
from pydantic import BaseModel as BaseModelV2Maybe # pydantic: ignore
from pydantic import Field as FieldV2Maybe # pydantic: ignore
from typing_extensions import (
Annotated as ExtensionsAnnotated,
)
from typing_extensions import (
TypedDict as ExtensionsTypedDict,
)
try:
from typing import Annotated as TypingAnnotated # type: ignore[attr-defined]
except ImportError:
TypingAnnotated = ExtensionsAnnotated
from typing_extensions import Annotated, TypedDict
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_core.tools import BaseTool, tool
from langchain_core.utils.function_calling import (
_convert_typed_dict_to_openai_function,
convert_to_openai_function,
tool_example_to_messages,
)
@@ -56,10 +28,10 @@ def pydantic() -> Type[BaseModel]:
@pytest.fixture()
def Annotated_function() -> Callable:
def annotated_function() -> Callable:
def dummy_function(
arg1: ExtensionsAnnotated[int, "foo"],
arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"],
arg1: Annotated[int, "foo"],
arg2: Annotated[Literal["bar", "baz"], "one of 'bar', 'baz'"],
) -> None:
"""dummy function"""
pass
@@ -83,9 +55,9 @@ def function() -> Callable:
@pytest.fixture()
def runnable() -> Runnable:
class Args(ExtensionsTypedDict):
arg1: ExtensionsAnnotated[int, "foo"]
arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"]
class Args(TypedDict):
arg1: Annotated[int, "foo"]
arg2: Annotated[Literal["bar", "baz"], "one of 'bar', 'baz'"]
def dummy_function(input_dict: Args) -> None:
pass
@@ -134,60 +106,6 @@ def dummy_pydantic_v2() -> Type[BaseModelV2Maybe]:
return dummy_function
@pytest.fixture()
def dummy_typing_typed_dict() -> Type:
class dummy_function(TypingTypedDict):
"""dummy function"""
arg1: TypingAnnotated[int, ..., "foo"] # noqa: F821
arg2: TypingAnnotated[Literal["bar", "baz"], ..., "one of 'bar', 'baz'"] # noqa: F722
return dummy_function
@pytest.fixture()
def dummy_typing_typed_dict_docstring() -> Type:
class dummy_function(TypingTypedDict):
"""dummy function
Args:
arg1: foo
arg2: one of 'bar', 'baz'
"""
arg1: int
arg2: Literal["bar", "baz"]
return dummy_function
@pytest.fixture()
def dummy_extensions_typed_dict() -> Type:
class dummy_function(ExtensionsTypedDict):
"""dummy function"""
arg1: ExtensionsAnnotated[int, ..., "foo"]
arg2: ExtensionsAnnotated[Literal["bar", "baz"], ..., "one of 'bar', 'baz'"]
return dummy_function
@pytest.fixture()
def dummy_extensions_typed_dict_docstring() -> Type:
class dummy_function(ExtensionsTypedDict):
"""dummy function
Args:
arg1: foo
arg2: one of 'bar', 'baz'
"""
arg1: int
arg2: Literal["bar", "baz"]
return dummy_function
@pytest.fixture()
def json_schema() -> Dict:
return {
@@ -234,13 +152,9 @@ def test_convert_to_openai_function(
function: Callable,
dummy_tool: BaseTool,
json_schema: Dict,
Annotated_function: Callable,
annotated_function: Callable,
dummy_pydantic: Type[BaseModel],
runnable: Runnable,
dummy_typing_typed_dict: Type,
dummy_typing_typed_dict_docstring: Type,
dummy_extensions_typed_dict: Type,
dummy_extensions_typed_dict_docstring: Type,
) -> None:
expected = {
"name": "dummy_function",
@@ -267,12 +181,8 @@ def test_convert_to_openai_function(
expected,
Dummy.dummy_function,
DummyWithClassMethod.dummy_function,
Annotated_function,
annotated_function,
dummy_pydantic,
dummy_typing_typed_dict,
dummy_typing_typed_dict_docstring,
dummy_extensions_typed_dict,
dummy_extensions_typed_dict_docstring,
):
actual = convert_to_openai_function(fn) # type: ignore
assert actual == expected
@@ -446,259 +356,3 @@ def test_tool_outputs() -> None:
},
]
assert messages[2].content == "Output1"
@pytest.mark.parametrize("use_extension_typed_dict", [True, False])
@pytest.mark.parametrize("use_extension_annotated", [True, False])
def test__convert_typed_dict_to_openai_function(
use_extension_typed_dict: bool, use_extension_annotated: bool
) -> None:
if use_extension_typed_dict:
TypedDict = ExtensionsTypedDict
else:
TypedDict = TypingTypedDict
if use_extension_annotated:
Annotated = TypingAnnotated
else:
Annotated = TypingAnnotated
class SubTool(TypedDict):
"""Subtool docstring"""
args: Annotated[Dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore
class Tool(TypedDict):
"""Docstring
Args:
arg1: foo
"""
arg1: str
arg2: Union[int, str, bool]
arg3: Optional[List[SubTool]]
arg4: Annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722
arg5: Annotated[Optional[float], None]
arg6: Annotated[
Optional[Sequence[Mapping[str, Tuple[Iterable[Any], SubTool]]]], []
]
arg7: Annotated[List[SubTool], ...]
arg8: Annotated[Tuple[SubTool], ...]
arg9: Annotated[Sequence[SubTool], ...]
arg10: Annotated[Iterable[SubTool], ...]
arg11: Annotated[Set[SubTool], ...]
arg12: Annotated[Dict[str, SubTool], ...]
arg13: Annotated[Mapping[str, SubTool], ...]
arg14: Annotated[MutableMapping[str, SubTool], ...]
arg15: Annotated[bool, False, "flag"] # noqa: F821 # type: ignore
expected = {
"name": "Tool",
"description": "Docstring",
"parameters": {
"type": "object",
"properties": {
"arg1": {"description": "foo", "type": "string"},
"arg2": {
"anyOf": [
{"type": "integer"},
{"type": "string"},
{"type": "boolean"},
]
},
"arg3": {
"type": "array",
"items": {
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
},
"arg4": {
"description": "this does foo",
"enum": ["bar", "baz"],
"type": "string",
},
"arg5": {"type": "number"},
"arg6": {
"default": [],
"type": "array",
"items": {
"type": "object",
"additionalProperties": {
"type": "array",
"minItems": 2,
"maxItems": 2,
"items": [
{"type": "array", "items": {}},
{
"title": "SubTool",
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"title": "Args",
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
],
},
},
},
"arg7": {
"type": "array",
"items": {
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
},
"arg8": {
"type": "array",
"minItems": 1,
"maxItems": 1,
"items": [
{
"title": "SubTool",
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"title": "Args",
"description": "this does bar",
"default": {},
"type": "object",
}
},
}
],
},
"arg9": {
"type": "array",
"items": {
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
},
"arg10": {
"type": "array",
"items": {
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
},
"arg11": {
"type": "array",
"items": {
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
"uniqueItems": True,
},
"arg12": {
"type": "object",
"additionalProperties": {
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
},
"arg13": {
"type": "object",
"additionalProperties": {
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
},
"arg14": {
"type": "object",
"additionalProperties": {
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
},
"arg15": {"description": "flag", "default": False, "type": "boolean"},
},
"required": [
"arg1",
"arg2",
"arg3",
"arg4",
"arg7",
"arg8",
"arg9",
"arg10",
"arg11",
"arg12",
"arg13",
"arg14",
],
},
}
actual = _convert_typed_dict_to_openai_function(Tool)
assert actual == expected
@pytest.mark.parametrize("typed_dict", [ExtensionsTypedDict, TypingTypedDict])
def test__convert_typed_dict_to_openai_function_fail(typed_dict: Type) -> None:
class Tool(typed_dict):
arg1: MutableSet # Pydantic doesn't support
with pytest.raises(TypeError):
_convert_typed_dict_to_openai_function(Tool)

View File

@@ -10,7 +10,7 @@ from langchain_standard_tests.integration_tests.vectorstores import (
from langchain_core.documents import Document
from langchain_core.embeddings.fake import DeterministicFakeEmbedding
from langchain_core.vectorstores import InMemoryVectorStore
from tests.unit_tests.stubs import AnyStr, _AnyIdDocument
from tests.unit_tests.stubs import AnyStr
class TestInMemoryReadWriteTestSuite(ReadWriteTestSuite):
@@ -33,13 +33,13 @@ async def test_inmemory_similarity_search() -> None:
# Check sync version
output = store.similarity_search("foo", k=1)
assert output == [_AnyIdDocument(page_content="foo")]
assert output == [Document(page_content="foo", id=AnyStr())]
# Check async version
output = await store.asimilarity_search("bar", k=2)
assert output == [
_AnyIdDocument(page_content="bar"),
_AnyIdDocument(page_content="baz"),
Document(page_content="bar", id=AnyStr()),
Document(page_content="baz", id=AnyStr()),
]
@@ -80,16 +80,16 @@ async def test_inmemory_mmr() -> None:
# make sure we can k > docstore size
output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1)
assert len(output) == len(texts)
assert output[0] == _AnyIdDocument(page_content="foo")
assert output[1] == _AnyIdDocument(page_content="foy")
assert output[0] == Document(page_content="foo", id=AnyStr())
assert output[1] == Document(page_content="foy", id=AnyStr())
# Check async version
output = await docsearch.amax_marginal_relevance_search(
"foo", k=10, lambda_mult=0.1
)
assert len(output) == len(texts)
assert output[0] == _AnyIdDocument(page_content="foo")
assert output[1] == _AnyIdDocument(page_content="foy")
assert output[0] == Document(page_content="foo", id=AnyStr())
assert output[1] == Document(page_content="foy", id=AnyStr())
async def test_inmemory_dump_load(tmp_path: Path) -> None:

View File

@@ -1,6 +1,6 @@
import json
import re
from typing import Pattern, Union
from typing import Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException
@@ -17,7 +17,7 @@ class ChatOutputParser(AgentOutputParser):
format_instructions: str = FORMAT_INSTRUCTIONS
"""Default formatting instructions"""
pattern: Pattern = re.compile(r"^.*?`{3}(?:json)?\n(.*?)`{3}.*?$", re.DOTALL)
pattern = re.compile(r"^.*?`{3}(?:json)?\n(.*?)`{3}.*?$", re.DOTALL)
"""Regex pattern to parse the output."""
def get_format_instructions(self) -> str:

View File

@@ -1,6 +1,6 @@
import json
import re
from typing import Pattern, Union
from typing import Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException
@@ -42,7 +42,7 @@ class ReActJsonSingleInputOutputParser(AgentOutputParser):
"""
pattern: Pattern = re.compile(r"^.*?`{3}(?:json)?\n?(.*?)`{3}.*?$", re.DOTALL)
pattern = re.compile(r"^.*?`{3}(?:json)?\n?(.*?)`{3}.*?$", re.DOTALL)
"""Regex pattern to parse the output."""
def get_format_instructions(self) -> str:

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import json
import logging
import re
from typing import Optional, Pattern, Union
from typing import Optional, Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException
@@ -23,7 +23,7 @@ class StructuredChatOutputParser(AgentOutputParser):
format_instructions: str = FORMAT_INSTRUCTIONS
"""Default formatting instructions"""
pattern: Pattern = re.compile(r"```(?:json\s+)?(\W.*?)```", re.DOTALL)
pattern = re.compile(r"```(?:json\s+)?(\W.*?)```", re.DOTALL)
"""Regex pattern to parse the output."""
def get_format_instructions(self) -> str:

View File

@@ -10,8 +10,7 @@ from langchain_core.callbacks.manager import (
Callbacks,
)
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Field
from langchain_core.utils import pre_init
from langchain_core.pydantic_v1 import Field, root_validator
from langchain.chains.base import Chain
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
@@ -69,7 +68,7 @@ class _EmbeddingDistanceChainMixin(Chain):
embeddings: Embeddings = Field(default_factory=_embedding_factory)
distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE)
@pre_init
@root_validator(pre=False)
def _validate_tiktoken_installed(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that the TikTok library is installed.

View File

@@ -8,8 +8,7 @@ from langchain_core.callbacks.manager import (
CallbackManagerForChainRun,
Callbacks,
)
from langchain_core.pydantic_v1 import Field
from langchain_core.utils import pre_init
from langchain_core.pydantic_v1 import Field, root_validator
from langchain.chains.base import Chain
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
@@ -64,7 +63,7 @@ class _RapidFuzzChainMixin(Chain):
"""Whether to normalize the score to a value between 0 and 1.
Applies only to the Levenshtein and Damerau-Levenshtein distances."""
@pre_init
@root_validator
def validate_dependencies(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate that the rapidfuzz library is installed.

View File

@@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.utils import pre_init
from langchain_core.pydantic_v1 import root_validator
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
from langchain.memory.utils import get_prompt_input_key
@@ -82,7 +82,7 @@ class ConversationStringBufferMemory(BaseMemory):
input_key: Optional[str] = None
memory_key: str = "history" #: :meta private:
@pre_init
@root_validator()
def validate_chains(cls, values: Dict) -> Dict:
"""Validate that return messages is not True."""
if values.get("return_messages", False):

View File

@@ -6,8 +6,7 @@ from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils import pre_init
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
@@ -86,7 +85,7 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
buffer = self.buffer
return {self.memory_key: buffer}
@pre_init
@root_validator()
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
"""Validate that prompt input variables are consistent."""
prompt_variables = values["prompt"].input_variables

View File

@@ -1,7 +1,7 @@
from typing import Any, Dict, List, Union
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.utils import pre_init
from langchain_core.pydantic_v1 import root_validator
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.summary import SummarizerMixin
@@ -64,7 +64,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
)
return {self.memory_key: final_buffer}
@pre_init
@root_validator()
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
"""Validate that prompt input variables are consistent."""
prompt_variables = values["prompt"].input_variables

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any, Dict, List
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.utils import pre_init
from langchain_core.pydantic_v1 import root_validator
class CombiningOutputParser(BaseOutputParser[Dict[str, Any]]):
@@ -15,7 +15,7 @@ class CombiningOutputParser(BaseOutputParser[Dict[str, Any]]):
def is_lc_serializable(cls) -> bool:
return True
@pre_init
@root_validator()
def validate_parsers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate the parsers."""
parsers = values["parsers"]

View File

@@ -3,7 +3,7 @@ from typing import Dict, List, Type
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.utils import pre_init
from langchain_core.pydantic_v1 import root_validator
class EnumOutputParser(BaseOutputParser[Enum]):
@@ -12,7 +12,7 @@ class EnumOutputParser(BaseOutputParser[Enum]):
enum: Type[Enum]
"""The enum to parse. Its values must be strings."""
@pre_init
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
enum = values["enum"]
if not all(isinstance(e.value, str) for e in enum):

View File

@@ -4,8 +4,7 @@ import numpy as np
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Field
from langchain_core.utils import pre_init
from langchain_core.pydantic_v1 import Field, root_validator
from langchain.retrievers.document_compressors.base import (
BaseDocumentCompressor,
@@ -46,7 +45,7 @@ class EmbeddingsFilter(BaseDocumentCompressor):
arbitrary_types_allowed = True
@pre_init
@root_validator()
def validate_params(cls, values: Dict) -> Dict:
"""Validate similarity parameters."""
if values["k"] is None and values["similarity_threshold"] is None:

Some files were not shown because too many files have changed in this diff Show More