mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-05 16:50:03 +00:00
Compare commits
17 Commits
wfh/async_
...
v0.0.260
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
96d064e305 | ||
|
|
c2f46b2cdb | ||
|
|
808248049d | ||
|
|
a6e6e9bb86 | ||
|
|
90579021f8 | ||
|
|
539672a7fd | ||
|
|
269f85b7b7 | ||
|
|
3adb1e12ca | ||
|
|
b8df15cd64 | ||
|
|
4d72288487 | ||
|
|
3c6eccd701 | ||
|
|
7de6a1b78e | ||
|
|
a2681f950d | ||
|
|
3f64b8a761 | ||
|
|
0a1be1d501 | ||
|
|
e3056340da | ||
|
|
99b5a7226c |
2
.github/workflows/scheduled_test.yml
vendored
2
.github/workflows/scheduled_test.yml
vendored
@@ -1,7 +1,7 @@
|
||||
name: Scheduled tests
|
||||
|
||||
on:
|
||||
scheduled:
|
||||
schedule:
|
||||
- cron: '0 13 * * *'
|
||||
|
||||
env:
|
||||
|
||||
@@ -12,7 +12,7 @@ Here are the agents available in LangChain.
|
||||
|
||||
### [Zero-shot ReAct](/docs/modules/agents/agent_types/react.html)
|
||||
|
||||
This agent uses the [ReAct](https://arxiv.org/pdf/2205.00445.pdf) framework to determine which tool to use
|
||||
This agent uses the [ReAct](https://arxiv.org/pdf/2210.03629) framework to determine which tool to use
|
||||
based solely on the tool's description. Any number of tools can be provided.
|
||||
This agent requires that a description is provided for each tool.
|
||||
|
||||
|
||||
@@ -556,6 +556,14 @@
|
||||
"source": "/docs/integrations/llamacpp",
|
||||
"destination": "/docs/integrations/providers/llamacpp"
|
||||
},
|
||||
{
|
||||
"source": "/en/latest/integrations/log10.html",
|
||||
"destination": "/docs/integrations/providers/log10"
|
||||
},
|
||||
{
|
||||
"source": "/docs/integrations/log10",
|
||||
"destination": "/docs/integrations/providers/log10"
|
||||
},
|
||||
{
|
||||
"source": "/en/latest/integrations/mediawikidump.html",
|
||||
"destination": "/docs/integrations/providers/mediawikidump"
|
||||
|
||||
@@ -1648,6 +1648,186 @@
|
||||
"source": [
|
||||
"## Conversational Retrieval With Memory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "92c87dd8-bb6f-4f32-a30d-8f5459ce6265",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Fallbacks\n",
|
||||
"\n",
|
||||
"With LCEL you can easily introduce fallbacks for any Runnable component, like an LLM."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "1b1cb744-31fc-4261-ab25-65fe1fcad559",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='To get to the other side.', additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"\n",
|
||||
"bad_llm = ChatOpenAI(model_name=\"gpt-fake\")\n",
|
||||
"good_llm = ChatOpenAI(model_name=\"gpt-3.5-turbo\")\n",
|
||||
"llm = bad_llm.with_fallbacks([good_llm])\n",
|
||||
"\n",
|
||||
"llm.invoke(\"Why did the the chicken cross the road?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b8cf3982-03f6-49b3-8ff5-7cd12444f19c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Looking at the trace, we can see that the first model failed but the second succeeded, so we still got an output: https://smith.langchain.com/public/dfaf0bf6-d86d-43e9-b084-dd16a56df15c/r\n",
|
||||
"\n",
|
||||
"We can add an arbitrary sequence of fallbacks, which will be executed in order:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "31819be0-7f40-4e67-b5ab-61340027b948",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='To get to the other side.', additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm = bad_llm.with_fallbacks([bad_llm, bad_llm, good_llm])\n",
|
||||
"\n",
|
||||
"llm.invoke(\"Why did the the chicken cross the road?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "acad6e88-8046-450e-b005-db7e50f33b80",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Trace: https://smith.langchain.com/public/c09efd01-3184-4369-a225-c9da8efcaf47/r\n",
|
||||
"\n",
|
||||
"We can continue to use our Runnable with fallbacks the same way we use any Runnable, mean we can include it in sequences:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "bab114a1-bb93-4b7e-a639-e7e00f21aebc",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='To show off its incredible jumping skills! Kangaroos are truly amazing creatures.', additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.prompts import ChatPromptTemplate\n",
|
||||
"\n",
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\"system\", \"You're a nice assistant who always includes a compliment in your response\"),\n",
|
||||
" (\"human\", \"Why did the {animal} cross the road\"),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"chain = prompt | llm\n",
|
||||
"chain.invoke({\"animal\": \"kangaroo\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "58340afa-8187-4ffe-9bd2-7912fb733a15",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Trace: https://smith.langchain.com/public/ba03895f-f8bd-4c70-81b7-8b930353eabd/r\n",
|
||||
"\n",
|
||||
"Note, since every sequence of Runnables is itself a Runnable, we can create fallbacks for whole Sequences. We can also continue using the full interface, including asynchronous calls, batched calls, and streams:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "45aa3170-b2e6-430d-887b-bd879048060a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[\"\\n\\nAnswer: The rabbit crossed the road to get to the other side. That's quite clever of him!\",\n",
|
||||
" '\\n\\nAnswer: The turtle crossed the road to get to the other side. You must be pretty clever to come up with that riddle!']"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"\n",
|
||||
"chat_prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\"system\", \"You're a nice assistant who always includes a compliment in your response\"),\n",
|
||||
" (\"human\", \"Why did the {animal} cross the road\"),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"chat_model = ChatOpenAI(model_name=\"gpt-fake\")\n",
|
||||
"\n",
|
||||
"prompt_template = \"\"\"Instructions: You should always include a compliment in your response.\n",
|
||||
"\n",
|
||||
"Question: Why did the {animal} cross the road?\"\"\"\n",
|
||||
"prompt = PromptTemplate.from_template(prompt_template)\n",
|
||||
"llm = OpenAI()\n",
|
||||
"\n",
|
||||
"bad_chain = chat_prompt | chat_model\n",
|
||||
"good_chain = prompt | llm\n",
|
||||
"chain = bad_chain.with_fallbacks([good_chain])\n",
|
||||
"await chain.abatch([{\"animal\": \"rabbit\"}, {\"animal\": \"turtle\"}])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "af6731c6-0c73-4b1d-a433-6e8f6ecce2bb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Traces: \n",
|
||||
"1. https://smith.langchain.com/public/ccd73236-9ae5-48a6-94b5-41210be18a46/r\n",
|
||||
"2. https://smith.langchain.com/public/f43f608e-075c-45c7-bf73-b64e4d3f3082/r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3d2fe1fe-506b-4ee5-8056-8b9df801765f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -1666,7 +1846,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.1"
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -74,6 +74,124 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f27fa24d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Model Version\n",
|
||||
"Azure OpenAI responses contain `model` property, which is name of the model used to generate the response. However unlike native OpenAI responses, it does not contain the version of the model, which is set on the deplyoment in Azure. This makes it tricky to know which version of the model was used to generate the response, which as result can lead to e.g. wrong total cost calculation with `OpenAICallbackHandler`.\n",
|
||||
"\n",
|
||||
"To solve this problem, you can pass `model_version` parameter to `AzureChatOpenAI` class, which will be added to the model name in the llm output. This way you can easily distinguish between different versions of the model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0531798a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.callbacks import get_openai_callback"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "3fd97dfc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"BASE_URL = \"https://{endpoint}.openai.azure.com\"\n",
|
||||
"API_KEY = \"...\"\n",
|
||||
"DEPLOYMENT_NAME = \"gpt-35-turbo\" # in Azure, this deployment has version 0613 - input and output tokens are counted separately"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "aceddb72",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Total Cost (USD): $0.000054\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = AzureChatOpenAI(\n",
|
||||
" openai_api_base=BASE_URL,\n",
|
||||
" openai_api_version=\"2023-05-15\",\n",
|
||||
" deployment_name=DEPLOYMENT_NAME,\n",
|
||||
" openai_api_key=API_KEY,\n",
|
||||
" openai_api_type=\"azure\",\n",
|
||||
")\n",
|
||||
"with get_openai_callback() as cb:\n",
|
||||
" model(\n",
|
||||
" [\n",
|
||||
" HumanMessage(\n",
|
||||
" content=\"Translate this sentence from English to French. I love programming.\"\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
" )\n",
|
||||
" print(f\"Total Cost (USD): ${format(cb.total_cost, '.6f')}\") # without specifying the model version, flat-rate 0.002 USD per 1k input and output tokens is used\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2e61eefd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can provide the model version to `AzureChatOpenAI` constructor. It will get appended to the model name returned by Azure OpenAI and cost will be counted correctly."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "8d5e54e9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Total Cost (USD): $0.000044\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model0613 = AzureChatOpenAI(\n",
|
||||
" openai_api_base=BASE_URL,\n",
|
||||
" openai_api_version=\"2023-05-15\",\n",
|
||||
" deployment_name=DEPLOYMENT_NAME,\n",
|
||||
" openai_api_key=API_KEY,\n",
|
||||
" openai_api_type=\"azure\",\n",
|
||||
" model_version=\"0613\"\n",
|
||||
")\n",
|
||||
"with get_openai_callback() as cb:\n",
|
||||
" model0613(\n",
|
||||
" [\n",
|
||||
" HumanMessage(\n",
|
||||
" content=\"Translate this sentence from English to French. I love programming.\"\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
" )\n",
|
||||
" print(f\"Total Cost (USD): ${format(cb.total_cost, '.6f')}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "99682534",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -92,7 +210,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Rockset Chat Message History\n",
|
||||
"\n",
|
||||
"This notebook goes over how to use [Rockset](https://rockset.com/docs) to store chat message history. \n",
|
||||
"\n",
|
||||
"To begin, with get your API key from the [Rockset console](https://console.rockset.com/apikeys). Find your API region for the Rockset [API reference](https://rockset.com/docs/rest-api#introduction)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "plaintext"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.memory.chat_message_histories import RocksetChatMessageHistory\n",
|
||||
"from rockset import RocksetClient, Regions\n",
|
||||
"\n",
|
||||
"history = RocksetChatMessageHistory(\n",
|
||||
" session_id=\"MySession\",\n",
|
||||
" client=RocksetClient(\n",
|
||||
" api_key=\"YOUR API KEY\", \n",
|
||||
" host=Regions.usw2a1 # us-west-2 Oregon\n",
|
||||
" ),\n",
|
||||
" collection=\"langchain_demo\",\n",
|
||||
" sync=True\n",
|
||||
")\n",
|
||||
"history.add_user_message(\"hi!\")\n",
|
||||
"history.add_ai_message(\"whats up?\")\n",
|
||||
"print(history.messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The output should be something like:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"[\n",
|
||||
" HumanMessage(content='hi!', additional_kwargs={'id': '2e62f1c2-e9f7-465e-b551-49bae07fe9f0'}, example=False), \n",
|
||||
" AIMessage(content='whats up?', additional_kwargs={'id': 'b9be8eda-4c18-4cf8-81c3-e91e876927d0'}, example=False)\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"```"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
104
docs/extras/integrations/providers/log10.mdx
Normal file
104
docs/extras/integrations/providers/log10.mdx
Normal file
@@ -0,0 +1,104 @@
|
||||
# Log10
|
||||
|
||||
This page covers how to use the [Log10](https://log10.io) within LangChain.
|
||||
|
||||
## What is Log10?
|
||||
|
||||
Log10 is an [open source](https://github.com/log10-io/log10) proxiless LLM data management and application development platform that lets you log, debug and tag your Langchain calls.
|
||||
|
||||
## Quick start
|
||||
|
||||
1. Create your free account at [log10.io](https://log10.io)
|
||||
2. Add your `LOG10_TOKEN` and `LOG10_ORG_ID` from the Settings and Organization tabs respectively as environment variables.
|
||||
3. Also add `LOG10_URL=https://log10.io` and your usual LLM API key: for e.g. `OPENAI_API_KEY` or `ANTHROPIC_API_KEY` to your environment
|
||||
|
||||
## How to enable Log10 data management for Langchain
|
||||
|
||||
Integration with log10 is a simple one-line `log10_callback` integration as shown below:
|
||||
|
||||
```python
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from log10.langchain import Log10Callback
|
||||
from log10.llm import Log10Config
|
||||
|
||||
log10_callback = Log10Callback(log10_config=Log10Config())
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="You are a ping pong machine"),
|
||||
HumanMessage(content="Ping?"),
|
||||
]
|
||||
|
||||
llm = ChatOpenAI(model_name="gpt-3.5-turbo", callbacks=[log10_callback])
|
||||
```
|
||||
|
||||
[Log10 + Langchain + Logs docs](https://github.com/log10-io/log10/blob/main/logging.md#langchain-logger)
|
||||
|
||||
[More details + screenshots](https://log10.io/docs/logs) including instructions for self-hosting logs
|
||||
|
||||
## How to use tags with Log10
|
||||
|
||||
```python
|
||||
from langchain import OpenAI
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from log10.langchain import Log10Callback
|
||||
from log10.llm import Log10Config
|
||||
|
||||
log10_callback = Log10Callback(log10_config=Log10Config())
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="You are a ping pong machine"),
|
||||
HumanMessage(content="Ping?"),
|
||||
]
|
||||
|
||||
llm = ChatOpenAI(model_name="gpt-3.5-turbo", callbacks=[log10_callback], temperature=0.5, tags=["test"])
|
||||
completion = llm.predict_messages(messages, tags=["foobar"])
|
||||
print(completion)
|
||||
|
||||
llm = ChatAnthropic(model="claude-2", callbacks=[log10_callback], temperature=0.7, tags=["baz"])
|
||||
llm.predict_messages(messages)
|
||||
print(completion)
|
||||
|
||||
llm = OpenAI(model_name="text-davinci-003", callbacks=[log10_callback], temperature=0.5)
|
||||
completion = llm.predict("You are a ping pong machine.\nPing?\n")
|
||||
print(completion)
|
||||
```
|
||||
|
||||
You can also intermix direct OpenAI calls and Langchain LLM calls:
|
||||
|
||||
```python
|
||||
import os
|
||||
from log10.load import log10, log10_session
|
||||
import openai
|
||||
from langchain import OpenAI
|
||||
|
||||
log10(openai)
|
||||
|
||||
with log10_session(tags=["foo", "bar"]):
|
||||
# Log a direct OpenAI call
|
||||
response = openai.Completion.create(
|
||||
model="text-ada-001",
|
||||
prompt="Where is the Eiffel Tower?",
|
||||
temperature=0,
|
||||
max_tokens=1024,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
)
|
||||
print(response)
|
||||
|
||||
# Log a call via Langchain
|
||||
llm = OpenAI(model_name="text-ada-001", temperature=0.5)
|
||||
response = llm.predict("You are a ping pong machine.\nPing?\n")
|
||||
print(response)
|
||||
```
|
||||
|
||||
## How to debug Langchain calls
|
||||
|
||||
[Example of debugging](https://log10.io/docs/prompt_chain_debugging)
|
||||
|
||||
[More Langchain examples](https://github.com/log10-io/log10/tree/main/examples#langchain)
|
||||
@@ -23,4 +23,11 @@ from langchain.vectorstores import Rockset
|
||||
See a [usage example](/docs/integrations/document_loaders/rockset).
|
||||
```python
|
||||
from langchain.document_loaders import RocksetLoader
|
||||
```
|
||||
|
||||
## Chat Message History
|
||||
|
||||
See a [usage example](/docs/integrations/memory/rockset_chat_message_history).
|
||||
```python
|
||||
from langchain.memory.chat_message_histories import RocksetChatMessageHistory
|
||||
```
|
||||
@@ -81,17 +81,18 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 21,
|
||||
"id": "53b7ce2d-3c09-4d1c-b66b-5769ce6746ae",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"os.environ[\"WEAVIATE_API_KEY\"] = getpass.getpass(\"WEAVIATE_API_KEY:\")"
|
||||
"os.environ[\"WEAVIATE_API_KEY\"] = getpass.getpass(\"WEAVIATE_API_KEY:\")\n",
|
||||
"WEAVIATE_API_KEY = os.environ[\"WEAVIATE_API_KEY\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 7,
|
||||
"id": "aac9563e",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -106,7 +107,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 12,
|
||||
"id": "a3c3999a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -123,7 +124,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 14,
|
||||
"id": "21e9e528",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -133,7 +134,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 15,
|
||||
"id": "b4170176",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -144,7 +145,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 16,
|
||||
"id": "ecf3b890",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -166,6 +167,53 @@
|
||||
"print(docs[0].page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "7826d0ea",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Authentication"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "13989a7c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Weaviate instances have authentication enabled by default. You can use either a username/password combination or API key. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"id": "f6604f1d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<langchain.vectorstores.weaviate.Weaviate object at 0x107f46550>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import weaviate\n",
|
||||
"\n",
|
||||
"client = weaviate.Client(url=WEAVIATE_URL, auth_client_secret=weaviate.AuthApiKey(WEAVIATE_API_KEY))\n",
|
||||
"\n",
|
||||
"# client = weaviate.Client(\n",
|
||||
"# url=WEAVIATE_URL,\n",
|
||||
"# auth_client_secret=weaviate.AuthClientPassword(\n",
|
||||
"# username = \"WCS_USERNAME\", # Replace w/ your WCS username\n",
|
||||
"# password = \"WCS_PASSWORD\", # Replace w/ your WCS password\n",
|
||||
"# ),\n",
|
||||
"# )\n",
|
||||
"\n",
|
||||
"vectorstore = Weaviate.from_documents(documents, embeddings, client=client, by_text=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
@@ -187,7 +235,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 17,
|
||||
"id": "102105a1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -213,7 +261,7 @@
|
||||
"id": "8fc3487b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Persistance"
|
||||
"# Persistence"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -249,7 +297,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": null,
|
||||
"id": "8b7df7ae",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -287,7 +335,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": null,
|
||||
"id": "5e824f3b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -298,7 +346,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": null,
|
||||
"id": "61209cc3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -311,7 +359,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": null,
|
||||
"id": "4abc3d37",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -327,7 +375,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": null,
|
||||
"id": "c7062393",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -339,7 +387,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": null,
|
||||
"id": "7e41b773",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
|
||||
@@ -0,0 +1,440 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "34883374",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Parent Document Retriever\n",
|
||||
"\n",
|
||||
"When splitting documents for retrieval, there are often conflicting desires:\n",
|
||||
"\n",
|
||||
"1. You may want to have small documents, so that their embeddings can most\n",
|
||||
" accurately reflect their meaning. If too long, then the embeddings can\n",
|
||||
" lose meaning.\n",
|
||||
"2. You want to have long enough documents that the context of each chunk is\n",
|
||||
" retained.\n",
|
||||
"\n",
|
||||
"The ParentDocumentRetriever strikes that balance by splitting and storing\n",
|
||||
"small chunks of data. During retrieval, it first fetches the small chunks\n",
|
||||
"but then looks up the parent ids for those chunks and returns those larger\n",
|
||||
"documents.\n",
|
||||
"\n",
|
||||
"Note that \"parent document\" refers to the document that a small chunk\n",
|
||||
"originated from. This can either be the whole raw document OR a larger\n",
|
||||
"chunk."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "8b6e74b2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.retrievers import ParentDocumentRetriever"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "1d17af96",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.vectorstores import Chroma\n",
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
||||
"from langchain.storage import InMemoryStore\n",
|
||||
"from langchain.document_loaders import TextLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "604ff981",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loaders = [\n",
|
||||
" TextLoader('../../paul_graham_essay.txt'),\n",
|
||||
" TextLoader('../../state_of_the_union.txt'),\n",
|
||||
"]\n",
|
||||
"docs = []\n",
|
||||
"for l in loaders:\n",
|
||||
" docs.extend(l.load())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d3943f72",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Retrieving Full Documents\n",
|
||||
"\n",
|
||||
"In this mode, we want to retrieve the full documents. Therefor, we only specify a child splitter."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "1a8b2e5f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This text splitter is used to create the child documents\n",
|
||||
"\n",
|
||||
"child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)\n",
|
||||
"# The vectorstore to use to index the child chunks\n",
|
||||
"vectorstore = Chroma(\n",
|
||||
" collection_name=\"full_documents\",\n",
|
||||
" embedding_function=OpenAIEmbeddings()\n",
|
||||
")\n",
|
||||
"# The storage layer for the parent documents\n",
|
||||
"store = InMemoryStore()\n",
|
||||
"retriever = ParentDocumentRetriever(\n",
|
||||
" vectorstore=vectorstore, \n",
|
||||
" docstore=store, \n",
|
||||
" child_splitter=child_splitter,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "2b107935",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever.add_documents(docs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d05b97b7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This should yield two keys, because we added two documents."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "30e3812b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['05fe8d8a-bf60-4f87-b576-4351b23df266',\n",
|
||||
" '571cc9e5-9ef7-4f6c-b800-835c83a1858b']"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"list(store.yield_keys())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f895d62b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's now call the vectorstore search functionality - we should see that it returns small chunks (since we're storing the small chunks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "b261c02c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sub_docs = vectorstore.similarity_search(\"justice breyer\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "5108222f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
|
||||
"\n",
|
||||
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(sub_docs[0].page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bda8ed5a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's now retrieve from the overall retriever. This should return large documents - since it returns the documents where the smaller chunks are located."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "419a91c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retrieved_docs = retriever.get_relevant_documents(\"justice breyer\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "cf10d250",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"38539"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"len(retrieved_docs[0].page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "14f813a5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Retrieving Larger Chunks\n",
|
||||
"\n",
|
||||
"Sometimes, the full documents can be too big to want to retrieve them as is. In that case, what we really want to do is to first split the raw documents into larger chunks, and then split it into smaller chunks. We then index the smaller chunks, but on retrieval we retrieve the larger chunks (but still not the full documents)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "b6f9a4f0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This text splitter is used to create the parent documents\n",
|
||||
"parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000)\n",
|
||||
"# This text splitter is used to create the child documents\n",
|
||||
"# It should create documents smaller than the parent\n",
|
||||
"child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)\n",
|
||||
"# The vectorstore to use to index the child chunks\n",
|
||||
"vectorstore = Chroma(collection_name=\"split_parents\", embedding_function=OpenAIEmbeddings())\n",
|
||||
"# The storage layer for the parent documents\n",
|
||||
"store = InMemoryStore()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "19478ff3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = ParentDocumentRetriever(\n",
|
||||
" vectorstore=vectorstore, \n",
|
||||
" docstore=store, \n",
|
||||
" child_splitter=child_splitter,\n",
|
||||
" parent_splitter=parent_splitter,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "fe16e620",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever.add_documents(docs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "64ad3c8c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can see that there are much more than two documents now - these are the larger chunks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "24d81886",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"66"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"len(list(store.yield_keys()))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "baaef673",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's make sure the underlying vectorstore still retrieves the small chunks."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "b1c859de",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sub_docs = vectorstore.similarity_search(\"justice breyer\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "6fffa2eb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
|
||||
"\n",
|
||||
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(sub_docs[0].page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "3a3202df",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retrieved_docs = retriever.get_relevant_documents(\"justice breyer\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "684fdb2c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"1849"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"len(retrieved_docs[0].page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "9f17f662",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. \n",
|
||||
"\n",
|
||||
"We cannot let this happen. \n",
|
||||
"\n",
|
||||
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
|
||||
"\n",
|
||||
"Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
|
||||
"\n",
|
||||
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
|
||||
"\n",
|
||||
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence. \n",
|
||||
"\n",
|
||||
"A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since she’s been nominated, she’s received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans. \n",
|
||||
"\n",
|
||||
"And if we are to advance liberty and justice, we need to secure the Border and fix the immigration system. \n",
|
||||
"\n",
|
||||
"We can do both. At our border, we’ve installed new technology like cutting-edge scanners to better detect drug smuggling. \n",
|
||||
"\n",
|
||||
"We’ve set up joint patrols with Mexico and Guatemala to catch more human traffickers. \n",
|
||||
"\n",
|
||||
"We’re putting in place dedicated immigration judges so families fleeing persecution and violence can have their cases heard faster. \n",
|
||||
"\n",
|
||||
"We’re securing commitments and supporting partners in South and Central America to host more refugees and secure their own borders.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(retrieved_docs[0].page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "facfdacb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -475,7 +475,8 @@ class Agent(BaseSingleActionAgent):
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs)
|
||||
return self.output_parser.parse(full_output)
|
||||
agent_output = await self.output_parser.aparse(full_output)
|
||||
return agent_output
|
||||
|
||||
def get_full_inputs(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
|
||||
@@ -31,8 +31,19 @@ MODEL_COST_PER_1K_TOKENS = {
|
||||
"gpt-3.5-turbo-0613-completion": 0.002,
|
||||
"gpt-3.5-turbo-16k-completion": 0.004,
|
||||
"gpt-3.5-turbo-16k-0613-completion": 0.004,
|
||||
# Azure GPT-35 input
|
||||
"gpt-35-turbo": 0.0015, # Azure OpenAI version of ChatGPT
|
||||
"gpt-35-turbo-0301": 0.0015, # Azure OpenAI version of ChatGPT
|
||||
"gpt-35-turbo-0613": 0.0015,
|
||||
"gpt-35-turbo-16k": 0.003,
|
||||
"gpt-35-turbo-16k-0613": 0.003,
|
||||
# Azure GPT-35 output
|
||||
"gpt-35-turbo-completion": 0.002, # Azure OpenAI version of ChatGPT
|
||||
"gpt-35-turbo-0301-completion": 0.002, # Azure OpenAI version of ChatGPT
|
||||
"gpt-35-turbo-0613-completion": 0.002,
|
||||
"gpt-35-turbo-16k-completion": 0.004,
|
||||
"gpt-35-turbo-16k-0613-completion": 0.004,
|
||||
# Others
|
||||
"gpt-35-turbo": 0.002, # Azure OpenAI version of ChatGPT
|
||||
"text-ada-001": 0.0004,
|
||||
"ada": 0.0004,
|
||||
"text-babbage-001": 0.0005,
|
||||
@@ -69,7 +80,9 @@ def standardize_model_name(
|
||||
if "ft-" in model_name:
|
||||
return model_name.split(":")[0] + "-finetuned"
|
||||
elif is_completion and (
|
||||
model_name.startswith("gpt-4") or model_name.startswith("gpt-3.5")
|
||||
model_name.startswith("gpt-4")
|
||||
or model_name.startswith("gpt-3.5")
|
||||
or model_name.startswith("gpt-35")
|
||||
):
|
||||
return model_name + "-completion"
|
||||
else:
|
||||
|
||||
@@ -47,6 +47,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
parent_run = self.run_map[str(run.parent_run_id)]
|
||||
if parent_run:
|
||||
self._add_child_run(parent_run, run)
|
||||
parent_run.child_execution_order = max(
|
||||
parent_run.child_execution_order, run.child_execution_order
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Parent run with UUID {run.parent_run_id} not found.")
|
||||
self.run_map[str(run.id)] = run
|
||||
@@ -131,7 +134,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
run_id_ = str(run_id)
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None or llm_run.run_type != "llm":
|
||||
raise TracerException("No LLM Run found to be traced")
|
||||
raise TracerException(f"No LLM Run found to be traced for {run_id}")
|
||||
llm_run.events.append(
|
||||
{
|
||||
"name": "new_token",
|
||||
@@ -183,7 +186,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
run_id_ = str(run_id)
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None or llm_run.run_type != "llm":
|
||||
raise TracerException("No LLM Run found to be traced")
|
||||
raise TracerException(f"No LLM Run found to be traced for {run_id}")
|
||||
llm_run.outputs = response.dict()
|
||||
for i, generations in enumerate(response.generations):
|
||||
for j, generation in enumerate(generations):
|
||||
@@ -211,7 +214,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
run_id_ = str(run_id)
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None or llm_run.run_type != "llm":
|
||||
raise TracerException("No LLM Run found to be traced")
|
||||
raise TracerException(f"No LLM Run found to be traced for {run_id}")
|
||||
llm_run.error = repr(error)
|
||||
llm_run.end_time = datetime.utcnow()
|
||||
llm_run.events.append({"name": "error", "time": llm_run.end_time})
|
||||
@@ -254,18 +257,25 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
self._on_chain_start(chain_run)
|
||||
|
||||
def on_chain_end(
|
||||
self, outputs: Dict[str, Any], *, run_id: UUID, **kwargs: Any
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""End a trace for a chain run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_chain_end callback.")
|
||||
chain_run = self.run_map.get(str(run_id))
|
||||
if chain_run is None:
|
||||
raise TracerException("No chain Run found to be traced")
|
||||
raise TracerException(f"No chain Run found to be traced for {run_id}")
|
||||
|
||||
chain_run.outputs = outputs
|
||||
chain_run.end_time = datetime.utcnow()
|
||||
chain_run.events.append({"name": "end", "time": chain_run.end_time})
|
||||
if inputs is not None:
|
||||
chain_run.inputs = inputs
|
||||
self._end_trace(chain_run)
|
||||
self._on_chain_end(chain_run)
|
||||
|
||||
@@ -273,6 +283,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@@ -281,11 +292,13 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
raise TracerException("No run_id provided for on_chain_error callback.")
|
||||
chain_run = self.run_map.get(str(run_id))
|
||||
if chain_run is None:
|
||||
raise TracerException("No chain Run found to be traced")
|
||||
raise TracerException(f"No chain Run found to be traced for {run_id}")
|
||||
|
||||
chain_run.error = repr(error)
|
||||
chain_run.end_time = datetime.utcnow()
|
||||
chain_run.events.append({"name": "error", "time": chain_run.end_time})
|
||||
if inputs is not None:
|
||||
chain_run.inputs = inputs
|
||||
self._end_trace(chain_run)
|
||||
self._on_chain_error(chain_run)
|
||||
|
||||
@@ -329,7 +342,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
raise TracerException("No run_id provided for on_tool_end callback.")
|
||||
tool_run = self.run_map.get(str(run_id))
|
||||
if tool_run is None or tool_run.run_type != "tool":
|
||||
raise TracerException("No tool Run found to be traced")
|
||||
raise TracerException(f"No tool Run found to be traced for {run_id}")
|
||||
|
||||
tool_run.outputs = {"output": output}
|
||||
tool_run.end_time = datetime.utcnow()
|
||||
@@ -349,7 +362,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
raise TracerException("No run_id provided for on_tool_error callback.")
|
||||
tool_run = self.run_map.get(str(run_id))
|
||||
if tool_run is None or tool_run.run_type != "tool":
|
||||
raise TracerException("No tool Run found to be traced")
|
||||
raise TracerException(f"No tool Run found to be traced for {run_id}")
|
||||
|
||||
tool_run.error = repr(error)
|
||||
tool_run.end_time = datetime.utcnow()
|
||||
@@ -404,7 +417,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
raise TracerException("No run_id provided for on_retriever_error callback.")
|
||||
retrieval_run = self.run_map.get(str(run_id))
|
||||
if retrieval_run is None or retrieval_run.run_type != "retriever":
|
||||
raise TracerException("No retriever Run found to be traced")
|
||||
raise TracerException(f"No retriever Run found to be traced for {run_id}")
|
||||
|
||||
retrieval_run.error = repr(error)
|
||||
retrieval_run.end_time = datetime.utcnow()
|
||||
@@ -420,7 +433,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
raise TracerException("No run_id provided for on_retriever_end callback.")
|
||||
retrieval_run = self.run_map.get(str(run_id))
|
||||
if retrieval_run is None or retrieval_run.run_type != "retriever":
|
||||
raise TracerException("No retriever Run found to be traced")
|
||||
raise TracerException(f"No retriever Run found to be traced for {run_id}")
|
||||
retrieval_run.outputs = {"documents": documents}
|
||||
retrieval_run.end_time = datetime.utcnow()
|
||||
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Base interface that all chains should implement."""
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
@@ -55,18 +57,40 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
"""
|
||||
|
||||
def invoke(
|
||||
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
return self(input, **(config or {}))
|
||||
config = config or {}
|
||||
return self(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
if type(self)._acall == Chain._acall:
|
||||
# If the chain does not implement async, fall back to default implementation
|
||||
return await super().ainvoke(input, config)
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.invoke, input, config, **kwargs)
|
||||
)
|
||||
|
||||
return await self.acall(input, **(config or {}))
|
||||
config = config or {}
|
||||
return await self.acall(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
"""Optional memory object. Defaults to None.
|
||||
|
||||
@@ -3,6 +3,8 @@ import functools
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
@@ -27,9 +29,11 @@ class TransformChain(Chain):
|
||||
"""The keys expected by the transform's input dictionary."""
|
||||
output_variables: List[str]
|
||||
"""The keys returned by the transform's output dictionary."""
|
||||
transform: Callable[[Dict[str, str]], Dict[str, str]]
|
||||
transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform")
|
||||
"""The transform function."""
|
||||
atransform: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = None
|
||||
atransform_cb: Optional[
|
||||
Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||||
] = Field(None, alias="atransform")
|
||||
"""The async coroutine transform function."""
|
||||
|
||||
@staticmethod
|
||||
@@ -62,18 +66,18 @@ class TransformChain(Chain):
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
return self.transform(inputs)
|
||||
return self.transform_cb(inputs)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if self.atransform is not None:
|
||||
return await self.atransform(inputs)
|
||||
if self.atransform_cb is not None:
|
||||
return await self.atransform_cb(inputs)
|
||||
else:
|
||||
self._log_once(
|
||||
"TransformChain's atransform is not provided, falling"
|
||||
" back to synchronous transform"
|
||||
)
|
||||
return self.transform(inputs)
|
||||
return self.transform_cb(inputs)
|
||||
|
||||
@@ -40,11 +40,20 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
|
||||
Be aware the API version may change.
|
||||
|
||||
You can also specify the version of the model using ``model_version`` constructor
|
||||
parameter, as Azure OpenAI doesn't return model version with the response.
|
||||
|
||||
Default is empty. When you specify the version, it will be appended to the
|
||||
model name in the response. Setting correct version will help you to calculate the
|
||||
cost properly. Model version is not validated, so make sure you set it correctly
|
||||
to get the correct cost.
|
||||
|
||||
Any parameters that are valid to be passed to the openai.create call can be passed
|
||||
in, even if not explicitly saved on this class.
|
||||
"""
|
||||
|
||||
deployment_name: str = ""
|
||||
model_version: str = ""
|
||||
openai_api_type: str = ""
|
||||
openai_api_base: str = ""
|
||||
openai_api_version: str = ""
|
||||
@@ -137,7 +146,19 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
for res in response["choices"]:
|
||||
if res.get("finish_reason", None) == "content_filter":
|
||||
raise ValueError(
|
||||
"Azure has not provided the response due to a content"
|
||||
" filter being triggered"
|
||||
"Azure has not provided the response due to a content filter "
|
||||
"being triggered"
|
||||
)
|
||||
return super()._create_chat_result(response)
|
||||
chat_result = super()._create_chat_result(response)
|
||||
|
||||
if "model" in response:
|
||||
model = response["model"]
|
||||
if self.model_version:
|
||||
model = f"{model}-{self.model_version}"
|
||||
|
||||
if chat_result.llm_output is not None and isinstance(
|
||||
chat_result.llm_output, dict
|
||||
):
|
||||
chat_result.llm_output["model_name"] = model
|
||||
|
||||
return chat_result
|
||||
|
||||
@@ -103,12 +103,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessageChunk:
|
||||
config = config or {}
|
||||
return cast(
|
||||
BaseMessageChunk,
|
||||
cast(
|
||||
ChatGeneration,
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
).generations[0][0],
|
||||
).message,
|
||||
)
|
||||
@@ -127,8 +133,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
||||
)
|
||||
|
||||
config = config or {}
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
return cast(
|
||||
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import SimpleChatModel
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import AIMessageChunk, BaseMessage
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
|
||||
|
||||
class FakeListChatModel(SimpleChatModel):
|
||||
@@ -31,6 +35,36 @@ class FakeListChatModel(SimpleChatModel):
|
||||
self.i = 0
|
||||
return response
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Union[List[str], None] = None,
|
||||
run_manager: Union[CallbackManagerForLLMRun, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
response = self.responses[self.i]
|
||||
if self.i < len(self.responses) - 1:
|
||||
self.i += 1
|
||||
else:
|
||||
self.i = 0
|
||||
for c in response:
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Union[List[str], None] = None,
|
||||
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
response = self.responses[self.i]
|
||||
if self.i < len(self.responses) - 1:
|
||||
self.i += 1
|
||||
else:
|
||||
self.i = 0
|
||||
for c in response:
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return {"responses": self.responses}
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""Loads local airbyte json files."""
|
||||
from typing import Any, Callable, Iterator, List, Mapping, Optional
|
||||
|
||||
from libs.langchain.langchain.utils.utils import guard_import
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.utils.utils import guard_import
|
||||
|
||||
RecordHandler = Callable[[Any, Optional[str]], Document]
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ the sequence of actions taken and their outcomes. It uses a language model
|
||||
chain (LLMChain) to generate the reasoning and scores.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
@@ -74,15 +75,24 @@ class TrajectoryOutputParser(BaseOutputParser):
|
||||
|
||||
reasoning, score_str = reasoning.strip(), score_str.strip()
|
||||
|
||||
score_str = next(
|
||||
(char for char in score_str if char.isdigit()), "0"
|
||||
) # Scan for first digit
|
||||
|
||||
if not 1 <= int(score_str) <= 5:
|
||||
# Use regex to extract the score.
|
||||
# This will get the number in the string, even if it is a float or more than 10.
|
||||
# E.g. "Score: 1" will return 1, "Score: 3.5" will return 3.5, and
|
||||
# "Score: 10" will return 10.
|
||||
# The score should be an integer digit in the range 1-5.
|
||||
_score = re.search(r"(\d+(\.\d+)?)", score_str)
|
||||
# If the score is not found or is a float, raise an exception.
|
||||
if _score is None or "." in _score.group(1):
|
||||
raise OutputParserException(
|
||||
f"Score is not an integer digit in the range 1-5: {text}"
|
||||
)
|
||||
score = int(_score.group(1))
|
||||
# If the score is not in the range 1-5, raise an exception.
|
||||
if not 1 <= score <= 5:
|
||||
raise OutputParserException(
|
||||
f"Score is not a digit in the range 1-5: {text}"
|
||||
)
|
||||
normalized_score = (int(score_str) - 1) / 4
|
||||
normalized_score = (score - 1) / 4
|
||||
return TrajectoryEval(score=normalized_score, reasoning=reasoning)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Interfaces to be implemented by general evaluators."""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -169,13 +168,9 @@ class StringEvaluator(_EvalArgsMixin, ABC):
|
||||
- value: the string value of the evaluation, if applicable.
|
||||
- reasoning: the reasoning for the evaluation, if applicable.
|
||||
""" # noqa: E501
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
self._evaluate_strings,
|
||||
prediction=prediction,
|
||||
reference=reference,
|
||||
input=input,
|
||||
**kwargs,
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} hasn't implemented an async "
|
||||
"aevaluate_strings method."
|
||||
)
|
||||
|
||||
def evaluate_strings(
|
||||
@@ -270,14 +265,9 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
|
||||
Returns:
|
||||
dict: A dictionary containing the preference, scores, and/or other information.
|
||||
""" # noqa: E501
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
self._evaluate_string_pairs,
|
||||
prediction=prediction,
|
||||
prediction_b=prediction_b,
|
||||
reference=reference,
|
||||
input=input,
|
||||
**kwargs,
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} hasn't implemented an async "
|
||||
"aevaluate_string_pairs method."
|
||||
)
|
||||
|
||||
def evaluate_string_pairs(
|
||||
@@ -391,14 +381,9 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
|
||||
Returns:
|
||||
dict: The evaluation result.
|
||||
"""
|
||||
raise asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
self._evaluate_agent_trajectory,
|
||||
prediction=prediction,
|
||||
agent_trajectory=agent_trajectory,
|
||||
input=input,
|
||||
reference=reference,
|
||||
**kwargs,
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} hasn't implemented an async "
|
||||
"aevaluate_agent_trajectory method."
|
||||
)
|
||||
|
||||
def evaluate_agent_trajectory(
|
||||
|
||||
@@ -219,9 +219,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
config = config or {}
|
||||
return (
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
.generations[0][0]
|
||||
.text
|
||||
@@ -241,8 +247,14 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
||||
)
|
||||
|
||||
config = config or {}
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
return llm_result.generations[0][0].text
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from typing import Any, List, Mapping, Optional
|
||||
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain.schema.runnable import RunnableConfig
|
||||
|
||||
|
||||
class FakeListLLM(LLM):
|
||||
@@ -51,3 +53,29 @@ class FakeListLLM(LLM):
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {"responses": self.responses}
|
||||
|
||||
|
||||
class FakeStreamingListLLM(FakeListLLM):
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
result = self.invoke(input, config)
|
||||
for c in result:
|
||||
yield c
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[str]:
|
||||
result = await self.ainvoke(input, config)
|
||||
for c in result:
|
||||
yield c
|
||||
|
||||
@@ -12,6 +12,7 @@ from langchain.memory.chat_message_histories.momento import MomentoChatMessageHi
|
||||
from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.rocksetdb import RocksetChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.streamlit import (
|
||||
StreamlitChatMessageHistory,
|
||||
@@ -29,6 +30,7 @@ __all__ = [
|
||||
"MongoDBChatMessageHistory",
|
||||
"PostgresChatMessageHistory",
|
||||
"RedisChatMessageHistory",
|
||||
"RocksetChatMessageHistory",
|
||||
"SQLChatMessageHistory",
|
||||
"StreamlitChatMessageHistory",
|
||||
"ZepChatMessageHistory",
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
from datetime import datetime
|
||||
from time import sleep
|
||||
from typing import Any, Callable, List, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain.schema import BaseChatMessageHistory
|
||||
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
|
||||
|
||||
|
||||
class RocksetChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Uses Rockset to store chat messages.
|
||||
|
||||
To use, ensure that the `rockset` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.memory.chat_message_histories import (
|
||||
RocksetChatMessageHistory
|
||||
)
|
||||
from rockset import RocksetClient
|
||||
|
||||
history = RocksetChatMessageHistory(
|
||||
session_id="MySession",
|
||||
client=RocksetClient(),
|
||||
collection="langchain_demo",
|
||||
sync=True
|
||||
)
|
||||
|
||||
history.add_user_message("hi!")
|
||||
history.add_ai_message("whats up?")
|
||||
|
||||
print(history.messages)
|
||||
"""
|
||||
|
||||
# You should set these values based on your VI.
|
||||
# These values are configured for the typical
|
||||
# free VI. Read more about VIs here:
|
||||
# https://rockset.com/docs/instances
|
||||
SLEEP_INTERVAL_MS = 5
|
||||
ADD_TIMEOUT_MS = 5000
|
||||
CREATE_TIMEOUT_MS = 20000
|
||||
|
||||
def _wait_until(self, method: Callable, timeout: int, **method_params: Any) -> None:
|
||||
"""Sleeps until meth() evaluates to true. Passes kwargs into
|
||||
meth.
|
||||
"""
|
||||
start = datetime.now()
|
||||
while not method(**method_params):
|
||||
curr = datetime.now()
|
||||
if (curr - start).total_seconds() * 1000 > timeout:
|
||||
raise TimeoutError(f"{method} timed out at {timeout} ms")
|
||||
sleep(RocksetChatMessageHistory.SLEEP_INTERVAL_MS / 1000)
|
||||
|
||||
def _query(self, query: str, **query_params: Any) -> List[Any]:
|
||||
"""Executes an SQL statement and returns the result
|
||||
Args:
|
||||
- query: The SQL string
|
||||
- **query_params: Parameters to pass into the query
|
||||
"""
|
||||
return self.client.sql(query, params=query_params).results
|
||||
|
||||
def _create_collection(self) -> None:
|
||||
"""Creates a collection for this message history"""
|
||||
self.client.Collections.create_s3_collection(
|
||||
name=self.collection, workspace=self.workspace
|
||||
)
|
||||
|
||||
def _collection_exists(self) -> bool:
|
||||
"""Checks whether a collection exists for this message history"""
|
||||
try:
|
||||
self.client.Collections.get(collection=self.collection)
|
||||
except self.rockset.exceptions.NotFoundException:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _collection_is_ready(self) -> bool:
|
||||
"""Checks whether the collection for this message history is ready
|
||||
to be queried
|
||||
"""
|
||||
return (
|
||||
self.client.Collections.get(collection=self.collection).data.status
|
||||
== "READY"
|
||||
)
|
||||
|
||||
def _document_exists(self) -> bool:
|
||||
return (
|
||||
len(
|
||||
self._query(
|
||||
f"""
|
||||
SELECT 1
|
||||
FROM {self.location}
|
||||
WHERE _id=:session_id
|
||||
LIMIT 1
|
||||
""",
|
||||
session_id=self.session_id,
|
||||
)
|
||||
)
|
||||
!= 0
|
||||
)
|
||||
|
||||
def _wait_until_collection_created(self) -> None:
|
||||
"""Sleeps until the collection for this message history is ready
|
||||
to be queried
|
||||
"""
|
||||
self._wait_until(
|
||||
lambda: self._collection_is_ready(),
|
||||
RocksetChatMessageHistory.CREATE_TIMEOUT_MS,
|
||||
)
|
||||
|
||||
def _wait_until_message_added(self, message_id: str) -> None:
|
||||
"""Sleeps until a message is added to the messages list"""
|
||||
self._wait_until(
|
||||
lambda message_id: len(
|
||||
self._query(
|
||||
f"""
|
||||
SELECT *
|
||||
FROM UNNEST((
|
||||
SELECT {self.messages_key}
|
||||
FROM {self.location}
|
||||
WHERE _id = :session_id
|
||||
)) AS message
|
||||
WHERE message.data.additional_kwargs.id = :message_id
|
||||
LIMIT 1
|
||||
""",
|
||||
session_id=self.session_id,
|
||||
message_id=message_id,
|
||||
),
|
||||
)
|
||||
!= 0,
|
||||
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def _create_empty_doc(self) -> None:
|
||||
"""Creates or replaces a document for this message history with no
|
||||
messages"""
|
||||
self.client.Documents.add_documents(
|
||||
collection=self.collection,
|
||||
workspace=self.workspace,
|
||||
data=[{"_id": self.session_id, self.messages_key: []}],
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
client: Any,
|
||||
collection: str,
|
||||
workspace: str = "commons",
|
||||
messages_key: str = "messages",
|
||||
sync: bool = False,
|
||||
message_uuid_method: Callable[[], Union[str, int]] = lambda: str(uuid4()),
|
||||
) -> None:
|
||||
"""Constructs a new RocksetChatMessageHistory.
|
||||
|
||||
Args:
|
||||
- session_id: The ID of the chat session
|
||||
- client: The RocksetClient object to use to query
|
||||
- collection: The name of the collection to use to store chat
|
||||
messages. If a collection with the given name
|
||||
does not exist in the workspace, it is created.
|
||||
- workspace: The workspace containing `collection`. Defaults
|
||||
to `"commons"`
|
||||
- messages_key: The DB column containing message history.
|
||||
Defaults to `"messages"`
|
||||
- sync: Whether to wait for messages to be added. Defaults
|
||||
to `False`. NOTE: setting this to `True` will slow
|
||||
down performance.
|
||||
- message_uuid_method: The method that generates message IDs.
|
||||
If set, all messages will have an `id` field within the
|
||||
`additional_kwargs` property. If this param is not set
|
||||
and `sync` is `False`, message IDs will not be created.
|
||||
If this param is not set and `sync` is `True`, the
|
||||
`uuid.uuid4` method will be used to create message IDs.
|
||||
"""
|
||||
try:
|
||||
import rockset
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import rockset client python package. "
|
||||
"Please install it with `pip install rockset`."
|
||||
)
|
||||
|
||||
if not isinstance(client, rockset.RocksetClient):
|
||||
raise ValueError(
|
||||
f"client should be an instance of rockset.RocksetClient, "
|
||||
f"got {type(client)}"
|
||||
)
|
||||
|
||||
self.session_id = session_id
|
||||
self.client = client
|
||||
self.collection = collection
|
||||
self.workspace = workspace
|
||||
self.location = f'"{self.workspace}"."{self.collection}"'
|
||||
self.rockset = rockset
|
||||
self.messages_key = messages_key
|
||||
self.message_uuid_method = message_uuid_method
|
||||
self.sync = sync
|
||||
|
||||
if not self._collection_exists():
|
||||
self._create_collection()
|
||||
self._wait_until_collection_created()
|
||||
self._create_empty_doc()
|
||||
elif not self._document_exists():
|
||||
self._create_empty_doc()
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Messages in this chat history."""
|
||||
return messages_from_dict(
|
||||
self._query(
|
||||
f"""
|
||||
SELECT *
|
||||
FROM UNNEST ((
|
||||
SELECT "{self.messages_key}"
|
||||
FROM {self.location}
|
||||
WHERE _id = :session_id
|
||||
))
|
||||
""",
|
||||
session_id=self.session_id,
|
||||
)
|
||||
)
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a Message object to the history.
|
||||
|
||||
Args:
|
||||
message: A BaseMessage object to store.
|
||||
"""
|
||||
if self.sync and "id" not in message.additional_kwargs:
|
||||
message.additional_kwargs["id"] = self.message_uuid_method()
|
||||
self.client.Documents.patch_documents(
|
||||
collection=self.collection,
|
||||
workspace=self.workspace,
|
||||
data=[
|
||||
self.rockset.model.patch_document.PatchDocument(
|
||||
id=self.session_id,
|
||||
patch=[
|
||||
self.rockset.model.patch_operation.PatchOperation(
|
||||
op="ADD",
|
||||
path=f"/{self.messages_key}/-",
|
||||
value=_message_to_dict(message),
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
if self.sync:
|
||||
self._wait_until_message_added(message.additional_kwargs["id"])
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Removes all messages from the chat history"""
|
||||
self._create_empty_doc()
|
||||
if self.sync:
|
||||
self._wait_until(
|
||||
lambda: not self.messages,
|
||||
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
|
||||
)
|
||||
@@ -40,6 +40,7 @@ from langchain.retrievers.merger_retriever import MergerRetriever
|
||||
from langchain.retrievers.metal import MetalRetriever
|
||||
from langchain.retrievers.milvus import MilvusRetriever
|
||||
from langchain.retrievers.multi_query import MultiQueryRetriever
|
||||
from langchain.retrievers.parent_document_retriever import ParentDocumentRetriever
|
||||
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
|
||||
from langchain.retrievers.pubmed import PubMedRetriever
|
||||
from langchain.retrievers.re_phraser import RePhraseQueryRetriever
|
||||
@@ -90,4 +91,5 @@ __all__ = [
|
||||
"RePhraseQueryRetriever",
|
||||
"WebResearchRetriever",
|
||||
"EnsembleRetriever",
|
||||
"ParentDocumentRetriever",
|
||||
]
|
||||
|
||||
139
libs/langchain/langchain/retrievers/parent_document_retriever.py
Normal file
139
libs/langchain/langchain/retrievers/parent_document_retriever.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.base import Callbacks
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.schema.storage import BaseStore
|
||||
from langchain.text_splitter import TextSplitter
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
class ParentDocumentRetriever(BaseRetriever):
|
||||
"""Fetches small chunks, then fetches their parent documents.
|
||||
|
||||
When splitting documents for retrieval, there are often conflicting desires:
|
||||
|
||||
1. You may want to have small documents, so that their embeddings can most
|
||||
accurately reflect their meaning. If too long, then the embeddings can
|
||||
lose meaning.
|
||||
2. You want to have long enough documents that the context of each chunk is
|
||||
retained.
|
||||
|
||||
The ParentDocumentRetriever strikes that balance by splitting and storing
|
||||
small chunks of data. During retrieval, it first fetches the small chunks
|
||||
but then looks up the parent ids for those chunks and returns those larger
|
||||
documents.
|
||||
|
||||
Note that "parent document" refers to the document that a small chunk
|
||||
originated from. This can either be the whole raw document OR a larger
|
||||
chunk.
|
||||
|
||||
Examples:
|
||||
... code-block:: python
|
||||
|
||||
# Imports
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.storage import InMemoryStore
|
||||
|
||||
# This text splitter is used to create the parent documents
|
||||
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000)
|
||||
# This text splitter is used to create the child documents
|
||||
# It should create documents smaller than the parent
|
||||
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
|
||||
# The vectorstore to use to index the child chunks
|
||||
vectorstore = Chroma(embedding_function=OpenAIEmbeddings())
|
||||
# The storage layer for the parent documents
|
||||
store = InMemoryStore()
|
||||
|
||||
# Initialize the retriever
|
||||
retriever = ParentDocumentRetriever(
|
||||
vectorstore=vectorstore,
|
||||
docstore=store,
|
||||
child_splitter=child_splitter,
|
||||
parent_splitter=parent_splitter,
|
||||
)
|
||||
"""
|
||||
|
||||
vectorstore: VectorStore
|
||||
"""The underlying vectorstore to use to store small chunks
|
||||
and their embedding vectors"""
|
||||
docstore: BaseStore[str, Document]
|
||||
"""The storage layer for the parent documents"""
|
||||
child_splitter: TextSplitter
|
||||
"""The text splitter to use to create child documents."""
|
||||
id_key: str = "doc_id"
|
||||
"""The key to use to track the parent id. This will be stored in the
|
||||
metadata of child documents."""
|
||||
parent_splitter: Optional[TextSplitter] = None
|
||||
"""The text splitter to use to create parent documents.
|
||||
If none, then the parent documents will be the raw documents passed in."""
|
||||
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
sub_docs = self.vectorstore.similarity_search(query)
|
||||
# We do this to maintain the order of the ids that are returned
|
||||
ids = []
|
||||
for d in sub_docs:
|
||||
if d.metadata[self.id_key] not in ids:
|
||||
ids.append(d.metadata[self.id_key])
|
||||
docs = self.docstore.mget(ids)
|
||||
return [d for d in docs if d is not None]
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
documents: List[Document],
|
||||
ids: Optional[List[str]],
|
||||
add_to_docstore: bool = True,
|
||||
) -> None:
|
||||
"""Adds documents to the docstore and vectorstores.
|
||||
|
||||
Args:
|
||||
documents: List of documents to add
|
||||
ids: Optional list of ids for documents. If provided should be the same
|
||||
length as the list of documents. Can provided if parent documents
|
||||
are already in the document store and you don't want to re-add
|
||||
to the docstore. If not provided, random UUIDs will be used as
|
||||
ids.
|
||||
add_to_docstore: Boolean of whether to add documents to docstore.
|
||||
This can be false if and only if `ids` are provided. You may want
|
||||
to set this to False if the documents are already in the docstore
|
||||
and you don't want to re-add them.
|
||||
"""
|
||||
if self.parent_splitter is not None:
|
||||
documents = self.parent_splitter.split_documents(documents)
|
||||
if ids is None:
|
||||
doc_ids = [str(uuid.uuid4()) for _ in documents]
|
||||
if not add_to_docstore:
|
||||
raise ValueError(
|
||||
"If ids are not passed in, `add_to_docstore` MUST be True"
|
||||
)
|
||||
else:
|
||||
if len(documents) != len(ids):
|
||||
raise ValueError(
|
||||
"Got uneven list of documents and ids. "
|
||||
"If `ids` is provided, should be same length as `documents`."
|
||||
)
|
||||
doc_ids = ids
|
||||
|
||||
docs = []
|
||||
full_docs = []
|
||||
for i, doc in enumerate(documents):
|
||||
_id = doc_ids[i]
|
||||
sub_docs = self.child_splitter.split_documents([doc])
|
||||
for _doc in sub_docs:
|
||||
_doc.metadata[self.id_key] = _id
|
||||
docs.extend(sub_docs)
|
||||
full_docs.append((_id, doc))
|
||||
self.vectorstore.add_documents(docs)
|
||||
if add_to_docstore:
|
||||
self.docstore.mset(full_docs)
|
||||
0
libs/langchain/langchain/runnables/__init__.py
Normal file
0
libs/langchain/langchain/runnables/__init__.py
Normal file
46
libs/langchain/langchain/runnables/openai_functions.py
Normal file
46
libs/langchain/langchain/runnables/openai_functions.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from operator import itemgetter
|
||||
from typing import Any, Callable, List, Mapping, Optional, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||
from langchain.schema.output import ChatGeneration
|
||||
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
|
||||
|
||||
|
||||
class OpenAIFunction(TypedDict):
|
||||
"""A function description for ChatOpenAI"""
|
||||
|
||||
name: str
|
||||
"""The name of the function."""
|
||||
description: str
|
||||
"""The description of the function."""
|
||||
parameters: dict
|
||||
"""The parameters to the function."""
|
||||
|
||||
|
||||
class OpenAIFunctionsRouter(RunnableBinding[ChatGeneration, Any]):
|
||||
"""A runnable that routes to the selected function."""
|
||||
|
||||
functions: Optional[List[OpenAIFunction]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnables: Mapping[
|
||||
str,
|
||||
Union[
|
||||
Runnable[dict, Any],
|
||||
Callable[[dict], Any],
|
||||
],
|
||||
],
|
||||
functions: Optional[List[OpenAIFunction]] = None,
|
||||
):
|
||||
if functions is not None:
|
||||
assert len(functions) == len(runnables)
|
||||
assert all(func["name"] in runnables for func in functions)
|
||||
router = (
|
||||
JsonOutputFunctionsParser(args_only=False)
|
||||
| {"key": itemgetter("name"), "input": itemgetter("arguments")}
|
||||
| RouterRunnable(runnables)
|
||||
)
|
||||
super().__init__(bound=router, kwargs={}, functions=functions)
|
||||
@@ -1,7 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.messages import BaseMessage
|
||||
@@ -27,12 +38,26 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
|
||||
Structured output.
|
||||
"""
|
||||
|
||||
async def aparse_result(self, result: List[Generation]) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.parse_result, result
|
||||
)
|
||||
|
||||
|
||||
class BaseGenerationOutputParser(
|
||||
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||
):
|
||||
def invoke(
|
||||
self, input: str | BaseMessage, config: RunnableConfig | None = None
|
||||
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return self._call_with_config(
|
||||
@@ -51,6 +76,26 @@ class BaseGenerationOutputParser(
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: str | BaseMessage, config: RunnableConfig | None = None
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result(
|
||||
[ChatGeneration(message=inner_input)]
|
||||
),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
else:
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
|
||||
class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]):
|
||||
"""Base class to parse the output of an LLM call.
|
||||
@@ -80,7 +125,7 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||
""" # noqa: E501
|
||||
|
||||
def invoke(
|
||||
self, input: str | BaseMessage, config: RunnableConfig | None = None
|
||||
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return self._call_with_config(
|
||||
@@ -99,6 +144,26 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: str | BaseMessage, config: RunnableConfig | None = None
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result(
|
||||
[ChatGeneration(message=inner_input)]
|
||||
),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
else:
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
@@ -125,6 +190,32 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||
Structured output.
|
||||
"""
|
||||
|
||||
async def aparse_result(self, result: List[Generation]) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
The return value is parsed from only the first Generation in the result, which
|
||||
is assumed to be the highest-likelihood Generation.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return await self.aparse(result[0].text)
|
||||
|
||||
async def aparse(self, text: str) -> T:
|
||||
"""Parse a single string model output into some structure.
|
||||
|
||||
Args:
|
||||
text: String output of a language model.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(None, self.parse, text)
|
||||
|
||||
# TODO: rename 'completion' -> 'text'.
|
||||
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
|
||||
"""Parse the output of an LLM call with the input prompt for context.
|
||||
@@ -161,8 +252,47 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||
return output_parser_dict
|
||||
|
||||
|
||||
class StrOutputParser(BaseOutputParser[str]):
|
||||
"""OutputParser that parses LLMResult into the top likely string.."""
|
||||
class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||
"""Base class for an output parser that can handle streaming input."""
|
||||
|
||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]:
|
||||
for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
yield self.parse_result([ChatGeneration(message=chunk)])
|
||||
else:
|
||||
yield self.parse_result([Generation(text=chunk)])
|
||||
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
) -> AsyncIterator[T]:
|
||||
async for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
yield self.parse_result([ChatGeneration(message=chunk)])
|
||||
else:
|
||||
yield self.parse_result([Generation(text=chunk)])
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Union[str, BaseMessage]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> Iterator[T]:
|
||||
yield from self._transform_stream_with_config(
|
||||
input, self._transform, config, run_type="parser"
|
||||
)
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Union[str, BaseMessage]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> AsyncIterator[T]:
|
||||
async for chunk in self._atransform_stream_with_config(
|
||||
input, self._atransform, config, run_type="parser"
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
class StrOutputParser(BaseTransformOutputParser[str]):
|
||||
"""OutputParser that parses LLMResult into the top likely string."""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
|
||||
@@ -107,7 +107,13 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None
|
||||
) -> List[Document]:
|
||||
return self.get_relevant_documents(input, **(config or {}))
|
||||
config = config or {}
|
||||
return self.get_relevant_documents(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None
|
||||
@@ -116,7 +122,13 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
# If the retriever doesn't implement async, use default implementation
|
||||
return await super().ainvoke(input, config)
|
||||
|
||||
return await self.aget_relevant_documents(input, **(config or {}))
|
||||
config = config or {}
|
||||
return await self.aget_relevant_documents(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _get_relevant_documents(
|
||||
|
||||
@@ -3,9 +3,11 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import tee
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
@@ -28,6 +30,7 @@ from pydantic import Field
|
||||
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.utils.aiter import atee, py_anext
|
||||
|
||||
|
||||
async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||
@@ -91,6 +94,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
|
||||
|
||||
""" --- Public API --- """
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
...
|
||||
@@ -98,6 +103,10 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
"""
|
||||
Default implementation of ainvoke, which calls invoke in a thread pool.
|
||||
Subclasses should override this method if they can run asynchronously.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.invoke, input, config
|
||||
)
|
||||
@@ -109,6 +118,10 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
"""
|
||||
Default implementation of batch, which calls invoke N times.
|
||||
Subclasses should override this method if they can batch more efficiently.
|
||||
"""
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
|
||||
# If there's only one input, don't bother with the executor
|
||||
@@ -125,6 +138,10 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
"""
|
||||
Default implementation of abatch, which calls ainvoke N times.
|
||||
Subclasses should override this method if they can batch more efficiently.
|
||||
"""
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
coros = map(self.ainvoke, inputs, configs)
|
||||
|
||||
@@ -133,22 +150,90 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
def stream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
"""
|
||||
Default implementation of stream, which calls invoke.
|
||||
Subclasses should override this method if they support streaming output.
|
||||
"""
|
||||
yield self.invoke(input, config)
|
||||
|
||||
async def astream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
"""
|
||||
Default implementation of astream, which calls ainvoke.
|
||||
Subclasses should override this method if they support streaming output.
|
||||
"""
|
||||
yield await self.ainvoke(input, config)
|
||||
|
||||
def transform(
|
||||
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
"""
|
||||
Default implementation of transform, which buffers input and then calls stream.
|
||||
Subclasses should override this method if they can start producing output while
|
||||
input is still being generated.
|
||||
"""
|
||||
final: Union[Input, None] = None
|
||||
|
||||
for chunk in input:
|
||||
if final is None:
|
||||
final = chunk
|
||||
else:
|
||||
# Make a best effort to gather, for any type that supports `+`
|
||||
# This method should throw an error if gathering fails.
|
||||
final += chunk # type: ignore[operator]
|
||||
if final:
|
||||
yield from self.stream(final, config)
|
||||
|
||||
async def atransform(
|
||||
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
"""
|
||||
Default implementation of atransform, which buffers input and calls astream.
|
||||
Subclasses should override this method if they can start producing output while
|
||||
input is still being generated.
|
||||
"""
|
||||
final: Union[Input, None] = None
|
||||
|
||||
async for chunk in input:
|
||||
if final is None:
|
||||
final = chunk
|
||||
else:
|
||||
# Make a best effort to gather, for any type that supports `+`
|
||||
# This method should throw an error if gathering fails.
|
||||
final += chunk # type: ignore[operator]
|
||||
|
||||
if final:
|
||||
async for output in self.astream(final, config):
|
||||
yield output
|
||||
|
||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||
"""
|
||||
Bind arguments to a Runnable, returning a new Runnable.
|
||||
"""
|
||||
return RunnableBinding(bound=self, kwargs=kwargs)
|
||||
|
||||
def with_fallbacks(
|
||||
self,
|
||||
fallbacks: Sequence[Runnable[Input, Output]],
|
||||
*,
|
||||
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,),
|
||||
) -> RunnableWithFallbacks[Input, Output]:
|
||||
return RunnableWithFallbacks(
|
||||
runnable=self,
|
||||
fallbacks=fallbacks,
|
||||
exceptions_to_handle=exceptions_to_handle,
|
||||
)
|
||||
|
||||
""" --- Helper methods for Subclasses --- """
|
||||
|
||||
def _get_config_list(
|
||||
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
|
||||
) -> List[RunnableConfig]:
|
||||
"""
|
||||
Helper method to get a list of configs from a single config or a list of
|
||||
configs, useful for subclasses overriding batch() or abatch().
|
||||
"""
|
||||
if isinstance(config, list) and len(config) != length:
|
||||
raise ValueError(
|
||||
f"config must be a list of the same length as inputs, "
|
||||
@@ -168,6 +253,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
) -> Output:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
config = config or {}
|
||||
@@ -192,20 +279,187 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
return output
|
||||
|
||||
def with_fallbacks(
|
||||
async def _acall_with_config(
|
||||
self,
|
||||
fallbacks: Sequence[Runnable[Input, Output]],
|
||||
*,
|
||||
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,),
|
||||
) -> RunnableWithFallbacks[Input, Output]:
|
||||
return RunnableWithFallbacks(
|
||||
runnable=self,
|
||||
fallbacks=fallbacks,
|
||||
exceptions_to_handle=exceptions_to_handle,
|
||||
func: Callable[[Input], Awaitable[Output]],
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
) -> Output:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
config = config or {}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input if isinstance(input, dict) else {"input": input},
|
||||
run_type=run_type,
|
||||
)
|
||||
try:
|
||||
output = await func(input)
|
||||
except Exception as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(
|
||||
output if isinstance(output, dict) else {"output": output}
|
||||
)
|
||||
return output
|
||||
|
||||
def _transform_stream_with_config(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
transformer: Callable[[Iterator[Input]], Iterator[Output]],
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
) -> Iterator[Output]:
|
||||
"""Helper method to transform an Iterator of Input values into an Iterator of
|
||||
Output values, with callbacks.
|
||||
Use this to implement `stream()` or `transform()` in Runnable subclasses."""
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# tee the input so we can iterate over it twice
|
||||
input_for_tracing, input_for_transform = tee(input, 2)
|
||||
# Start the input iterator to ensure the input runnable starts before this one
|
||||
final_input: Optional[Input] = next(input_for_tracing, None)
|
||||
final_input_supported = True
|
||||
final_output: Optional[Output] = None
|
||||
final_output_supported = True
|
||||
|
||||
config = config or {}
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
{"input": ""},
|
||||
run_type=run_type,
|
||||
)
|
||||
try:
|
||||
for chunk in transformer(input_for_transform):
|
||||
yield chunk
|
||||
if final_output_supported:
|
||||
if final_output is None:
|
||||
final_output = chunk
|
||||
else:
|
||||
try:
|
||||
final_output += chunk # type: ignore[operator]
|
||||
except TypeError:
|
||||
final_output = None
|
||||
final_output_supported = False
|
||||
for ichunk in input_for_tracing:
|
||||
if final_input_supported:
|
||||
if final_input is None:
|
||||
final_input = ichunk
|
||||
else:
|
||||
try:
|
||||
final_input += ichunk # type: ignore[operator]
|
||||
except TypeError:
|
||||
final_input = None
|
||||
final_input_supported = False
|
||||
except Exception as e:
|
||||
run_manager.on_chain_error(
|
||||
e,
|
||||
inputs=final_input
|
||||
if isinstance(final_input, dict)
|
||||
else {"input": final_input},
|
||||
)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(
|
||||
final_output
|
||||
if isinstance(final_output, dict)
|
||||
else {"output": final_output},
|
||||
inputs=final_input
|
||||
if isinstance(final_input, dict)
|
||||
else {"input": final_input},
|
||||
)
|
||||
|
||||
async def _atransform_stream_with_config(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
transformer: Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
) -> AsyncIterator[Output]:
|
||||
"""Helper method to transform an Async Iterator of Input values into an Async
|
||||
Iterator of Output values, with callbacks.
|
||||
Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# tee the input so we can iterate over it twice
|
||||
input_for_tracing, input_for_transform = atee(input, 2)
|
||||
# Start the input iterator to ensure the input runnable starts before this one
|
||||
final_input: Optional[Input] = await py_anext(input_for_tracing, None)
|
||||
final_input_supported = True
|
||||
final_output: Optional[Output] = None
|
||||
final_output_supported = True
|
||||
|
||||
config = config or {}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
{"input": ""},
|
||||
run_type=run_type,
|
||||
)
|
||||
try:
|
||||
async for chunk in transformer(input_for_transform):
|
||||
yield chunk
|
||||
if final_output_supported:
|
||||
if final_output is None:
|
||||
final_output = chunk
|
||||
else:
|
||||
try:
|
||||
final_output += chunk # type: ignore[operator]
|
||||
except TypeError:
|
||||
final_output = None
|
||||
final_output_supported = False
|
||||
async for ichunk in input_for_tracing:
|
||||
if final_input_supported:
|
||||
if final_input is None:
|
||||
final_input = ichunk
|
||||
else:
|
||||
try:
|
||||
final_input += ichunk # type: ignore[operator]
|
||||
except TypeError:
|
||||
final_input = None
|
||||
final_input_supported = False
|
||||
except Exception as e:
|
||||
await run_manager.on_chain_error(
|
||||
e,
|
||||
inputs=final_input
|
||||
if isinstance(final_input, dict)
|
||||
else {"input": final_input},
|
||||
)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(
|
||||
final_output
|
||||
if isinstance(final_output, dict)
|
||||
else {"output": final_output},
|
||||
inputs=final_input
|
||||
if isinstance(final_input, dict)
|
||||
else {"input": final_input},
|
||||
)
|
||||
|
||||
|
||||
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
"""
|
||||
A Runnable that can fallback to other Runnables if it fails.
|
||||
"""
|
||||
|
||||
runnable: Runnable[Input, Output]
|
||||
fallbacks: Sequence[Runnable[Input, Output]]
|
||||
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,)
|
||||
@@ -435,6 +689,10 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
|
||||
|
||||
class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
"""
|
||||
A sequence of runnables, where the output of each is the input of the next.
|
||||
"""
|
||||
|
||||
first: Runnable[Input, Any]
|
||||
middle: List[Runnable[Any, Any]] = Field(default_factory=list)
|
||||
last: Runnable[Any, Output]
|
||||
@@ -706,9 +964,18 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
streaming_start_index = 0
|
||||
|
||||
for i in range(len(steps) - 1, 0, -1):
|
||||
if type(steps[i]).transform != Runnable.transform:
|
||||
streaming_start_index = i - 1
|
||||
else:
|
||||
break
|
||||
|
||||
# invoke the first steps
|
||||
try:
|
||||
for step in [self.first] + self.middle:
|
||||
for step in steps[0:streaming_start_index]:
|
||||
input = step.invoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
@@ -718,15 +985,20 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
|
||||
# stream the last step
|
||||
# stream the last steps
|
||||
final: Union[Output, None] = None
|
||||
final_supported = True
|
||||
try:
|
||||
for output in self.last.stream(
|
||||
input,
|
||||
# mark the last step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
):
|
||||
# stream the first of the last steps with non-streaming input
|
||||
final_pipeline = steps[streaming_start_index].stream(
|
||||
input, _patch_config(config, run_manager.get_child())
|
||||
)
|
||||
# stream the rest of the last steps with streaming input
|
||||
for step in steps[streaming_start_index + 1 :]:
|
||||
final_pipeline = step.transform(
|
||||
final_pipeline, _patch_config(config, run_manager.get_child())
|
||||
)
|
||||
for output in final_pipeline:
|
||||
yield output
|
||||
# Accumulate output if possible, otherwise disable accumulation
|
||||
if final_supported:
|
||||
@@ -769,9 +1041,18 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
streaming_start_index = len(steps) - 1
|
||||
|
||||
for i in range(len(steps) - 1, 0, -1):
|
||||
if type(steps[i]).transform != Runnable.transform:
|
||||
streaming_start_index = i - 1
|
||||
else:
|
||||
break
|
||||
|
||||
# invoke the first steps
|
||||
try:
|
||||
for step in [self.first] + self.middle:
|
||||
for step in steps[0:streaming_start_index]:
|
||||
input = await step.ainvoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
@@ -781,15 +1062,20 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
|
||||
# stream the last step
|
||||
# stream the last steps
|
||||
final: Union[Output, None] = None
|
||||
final_supported = True
|
||||
try:
|
||||
async for output in self.last.astream(
|
||||
input,
|
||||
# mark the last step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
):
|
||||
# stream the first of the last steps with non-streaming input
|
||||
final_pipeline = steps[streaming_start_index].astream(
|
||||
input, _patch_config(config, run_manager.get_child())
|
||||
)
|
||||
# stream the rest of the last steps with streaming input
|
||||
for step in steps[streaming_start_index + 1 :]:
|
||||
final_pipeline = step.atransform(
|
||||
final_pipeline, _patch_config(config, run_manager.get_child())
|
||||
)
|
||||
async for output in final_pipeline:
|
||||
yield output
|
||||
# Accumulate output if possible, otherwise disable accumulation
|
||||
if final_supported:
|
||||
@@ -813,6 +1099,11 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
|
||||
|
||||
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
"""
|
||||
A runnable that runs a mapping of runnables in parallel,
|
||||
and returns a mapping of their outputs.
|
||||
"""
|
||||
|
||||
steps: Mapping[str, Runnable[Input, Any]]
|
||||
|
||||
def __init__(
|
||||
@@ -925,6 +1216,10 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
|
||||
|
||||
class RunnableLambda(Runnable[Input, Output]):
|
||||
"""
|
||||
A runnable that runs a callable.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable[[Input], Output]) -> None:
|
||||
if callable(func):
|
||||
self.func = func
|
||||
@@ -945,6 +1240,10 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
|
||||
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
"""
|
||||
A runnable that passes through the input.
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
@@ -954,6 +1253,10 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
|
||||
|
||||
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
"""
|
||||
A runnable that delegates calls to another runnable with a set of kwargs.
|
||||
"""
|
||||
|
||||
bound: Runnable[Input, Output]
|
||||
|
||||
kwargs: Mapping[str, Any]
|
||||
@@ -1009,6 +1312,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
async for item in self.bound.astream(input, config, **self.kwargs):
|
||||
yield item
|
||||
|
||||
def transform(
|
||||
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
yield from self.bound.transform(input, config, **self.kwargs)
|
||||
|
||||
async def atransform(
|
||||
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
async for item in self.bound.atransform(input, config, **self.kwargs):
|
||||
yield item
|
||||
|
||||
|
||||
class RouterInput(TypedDict):
|
||||
key: str
|
||||
@@ -1018,10 +1332,22 @@ class RouterInput(TypedDict):
|
||||
class RouterRunnable(
|
||||
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
|
||||
):
|
||||
"""
|
||||
A runnable that routes to a set of runnables based on Input['key'].
|
||||
Returns the output of the selected runnable.
|
||||
"""
|
||||
|
||||
runnables: Mapping[str, Runnable[Input, Output]]
|
||||
|
||||
def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None:
|
||||
super().__init__(runnables=runnables)
|
||||
def __init__(
|
||||
self,
|
||||
runnables: Mapping[
|
||||
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
|
||||
],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
runnables={key: _coerce_to_runnable(r) for key, r in runnables.items()}
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -502,6 +502,18 @@ def _construct_run_evaluator(
|
||||
return run_evaluator
|
||||
|
||||
|
||||
def _get_keys(
|
||||
config: RunEvalConfig,
|
||||
run_inputs: Optional[List[str]],
|
||||
run_outputs: Optional[List[str]],
|
||||
example_outputs: Optional[List[str]],
|
||||
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
input_key = _determine_input_key(config, run_inputs)
|
||||
prediction_key = _determine_prediction_key(config, run_outputs)
|
||||
reference_key = _determine_reference_key(config, example_outputs)
|
||||
return input_key, prediction_key, reference_key
|
||||
|
||||
|
||||
def _load_run_evaluators(
|
||||
config: RunEvalConfig,
|
||||
run_type: str,
|
||||
@@ -521,9 +533,13 @@ def _load_run_evaluators(
|
||||
"""
|
||||
eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0)
|
||||
run_evaluators = []
|
||||
input_key = _determine_input_key(config, run_inputs)
|
||||
prediction_key = _determine_prediction_key(config, run_outputs)
|
||||
reference_key = _determine_reference_key(config, example_outputs)
|
||||
input_key, prediction_key, reference_key = None, None, None
|
||||
if config.evaluators or any(
|
||||
[isinstance(e, EvaluatorType) for e in config.evaluators]
|
||||
):
|
||||
input_key, prediction_key, reference_key = _get_keys(
|
||||
config, run_inputs, run_outputs, example_outputs
|
||||
)
|
||||
for eval_config in config.evaluators:
|
||||
run_evaluator = _construct_run_evaluator(
|
||||
eval_config,
|
||||
@@ -1074,15 +1090,15 @@ def _run_on_examples(
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
"""
|
||||
results: Dict[str, Any] = {}
|
||||
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory)
|
||||
project_name = _get_project_name(project_name, llm_or_chain_factory)
|
||||
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
|
||||
project_name = _get_project_name(project_name, wrapped_model)
|
||||
tracer = LangChainTracer(
|
||||
project_name=project_name, client=client, use_threading=False
|
||||
)
|
||||
run_evaluators, examples = _setup_evaluation(
|
||||
llm_or_chain_factory, examples, evaluation, data_type
|
||||
wrapped_model, examples, evaluation, data_type
|
||||
)
|
||||
examples = _validate_example_inputs(examples, llm_or_chain_factory, input_mapper)
|
||||
examples = _validate_example_inputs(examples, wrapped_model, input_mapper)
|
||||
evalution_handler = EvaluatorCallbackHandler(
|
||||
evaluators=run_evaluators or [],
|
||||
client=client,
|
||||
@@ -1091,7 +1107,7 @@ def _run_on_examples(
|
||||
for i, example in enumerate(examples):
|
||||
result = _run_llm_or_chain(
|
||||
example,
|
||||
llm_or_chain_factory,
|
||||
wrapped_model,
|
||||
num_repetitions,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
@@ -1114,8 +1130,8 @@ def _prepare_eval_run(
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
project_name: Optional[str],
|
||||
) -> Tuple[MCF, str, Dataset, Iterator[Example]]:
|
||||
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
|
||||
project_name = _get_project_name(project_name, llm_or_chain_factory)
|
||||
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
|
||||
project_name = _get_project_name(project_name, wrapped_model)
|
||||
try:
|
||||
project = client.create_project(project_name)
|
||||
except ValueError as e:
|
||||
@@ -1130,7 +1146,7 @@ def _prepare_eval_run(
|
||||
)
|
||||
dataset = client.read_dataset(dataset_name=dataset_name)
|
||||
examples = client.list_examples(dataset_id=str(dataset.id))
|
||||
return llm_or_chain_factory, project_name, dataset, examples
|
||||
return wrapped_model, project_name, dataset, examples
|
||||
|
||||
|
||||
async def arun_on_dataset(
|
||||
@@ -1256,13 +1272,13 @@ async def arun_on_dataset(
|
||||
evaluation=evaluation_config,
|
||||
)
|
||||
""" # noqa: E501
|
||||
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run(
|
||||
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
|
||||
client, dataset_name, llm_or_chain_factory, project_name
|
||||
)
|
||||
results = await _arun_on_examples(
|
||||
client,
|
||||
examples,
|
||||
llm_or_chain_factory,
|
||||
wrapped_model,
|
||||
concurrency_level=concurrency_level,
|
||||
num_repetitions=num_repetitions,
|
||||
project_name=project_name,
|
||||
@@ -1423,14 +1439,14 @@ def run_on_dataset(
|
||||
evaluation=evaluation_config,
|
||||
)
|
||||
""" # noqa: E501
|
||||
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run(
|
||||
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
|
||||
client, dataset_name, llm_or_chain_factory, project_name
|
||||
)
|
||||
if concurrency_level in (0, 1):
|
||||
results = _run_on_examples(
|
||||
client,
|
||||
examples,
|
||||
llm_or_chain_factory,
|
||||
wrapped_model,
|
||||
num_repetitions=num_repetitions,
|
||||
project_name=project_name,
|
||||
verbose=verbose,
|
||||
@@ -1444,7 +1460,7 @@ def run_on_dataset(
|
||||
coro = _arun_on_examples(
|
||||
client,
|
||||
examples,
|
||||
llm_or_chain_factory,
|
||||
wrapped_model,
|
||||
concurrency_level=concurrency_level,
|
||||
num_repetitions=num_repetitions,
|
||||
project_name=project_name,
|
||||
|
||||
@@ -203,7 +203,13 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
config = config or {}
|
||||
return self.run(input, **config, **kwargs)
|
||||
return self.run(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
@@ -216,7 +222,13 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla
|
||||
return super().ainvoke(input, config, **kwargs)
|
||||
|
||||
config = config or {}
|
||||
return await self.arun(input, **config, **kwargs)
|
||||
return await self.arun(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# --- Tool ---
|
||||
|
||||
|
||||
191
libs/langchain/langchain/utils/aiter.py
Normal file
191
libs/langchain/langchain/utils/aiter.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
Adapted from
|
||||
https://github.com/maxfischer2781/asyncstdlib/blob/master/asyncstdlib/itertools.py
|
||||
MIT License
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Deque,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
_no_default = object()
|
||||
|
||||
|
||||
# https://github.com/python/cpython/blob/main/Lib/test/test_asyncgen.py#L54
|
||||
# before 3.10, the builtin anext() was not available
|
||||
def py_anext(
|
||||
iterator: AsyncIterator[T], default: Union[T, Any] = _no_default
|
||||
) -> Awaitable[Union[T, None, Any]]:
|
||||
"""Pure-Python implementation of anext() for testing purposes.
|
||||
|
||||
Closely matches the builtin anext() C implementation.
|
||||
Can be used to compare the built-in implementation of the inner
|
||||
coroutines machinery to C-implementation of __anext__() and send()
|
||||
or throw() on the returned generator.
|
||||
"""
|
||||
|
||||
try:
|
||||
__anext__ = cast(
|
||||
Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__
|
||||
)
|
||||
except AttributeError:
|
||||
raise TypeError(f"{iterator!r} is not an async iterator")
|
||||
|
||||
if default is _no_default:
|
||||
return __anext__(iterator)
|
||||
|
||||
async def anext_impl() -> Union[T, Any]:
|
||||
try:
|
||||
# The C code is way more low-level than this, as it implements
|
||||
# all methods of the iterator protocol. In this implementation
|
||||
# we're relying on higher-level coroutine concepts, but that's
|
||||
# exactly what we want -- crosstest pure-Python high-level
|
||||
# implementation and low-level C anext() iterators.
|
||||
return await __anext__(iterator)
|
||||
except StopAsyncIteration:
|
||||
return default
|
||||
|
||||
return anext_impl()
|
||||
|
||||
|
||||
async def tee_peer(
|
||||
iterator: AsyncIterator[T],
|
||||
# the buffer specific to this peer
|
||||
buffer: Deque[T],
|
||||
# the buffers of all peers, including our own
|
||||
peers: List[Deque[T]],
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""An individual iterator of a :py:func:`~.tee`"""
|
||||
try:
|
||||
while True:
|
||||
if not buffer:
|
||||
# Another peer produced an item while we were waiting for the lock.
|
||||
# Proceed with the next loop iteration to yield the item.
|
||||
if buffer:
|
||||
continue
|
||||
try:
|
||||
item = await iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
else:
|
||||
# Append to all buffers, including our own. We'll fetch our
|
||||
# item from the buffer again, instead of yielding it directly.
|
||||
# This ensures the proper item ordering if any of our peers
|
||||
# are fetching items concurrently. They may have buffered their
|
||||
# item already.
|
||||
for peer_buffer in peers:
|
||||
peer_buffer.append(item)
|
||||
yield buffer.popleft()
|
||||
finally:
|
||||
# this peer is done – remove its buffer
|
||||
for idx, peer_buffer in enumerate(peers): # pragma: no branch
|
||||
if peer_buffer is buffer:
|
||||
peers.pop(idx)
|
||||
break
|
||||
# if we are the last peer, try and close the iterator
|
||||
if not peers and hasattr(iterator, "aclose"):
|
||||
await iterator.aclose()
|
||||
|
||||
|
||||
class Tee(Generic[T]):
|
||||
"""
|
||||
Create ``n`` separate asynchronous iterators over ``iterable``
|
||||
|
||||
This splits a single ``iterable`` into multiple iterators, each providing
|
||||
the same items in the same order.
|
||||
All child iterators may advance separately but share the same items
|
||||
from ``iterable`` -- when the most advanced iterator retrieves an item,
|
||||
it is buffered until the least advanced iterator has yielded it as well.
|
||||
A ``tee`` works lazily and can handle an infinite ``iterable``, provided
|
||||
that all iterators advance.
|
||||
|
||||
.. code-block:: python3
|
||||
|
||||
async def derivative(sensor_data):
|
||||
previous, current = a.tee(sensor_data, n=2)
|
||||
await a.anext(previous) # advance one iterator
|
||||
return a.map(operator.sub, previous, current)
|
||||
|
||||
Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
|
||||
of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
|
||||
to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
|
||||
immediately closes all children, and it can be used in an ``async with`` context
|
||||
for the same effect.
|
||||
|
||||
If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not*
|
||||
provide these items. Also, ``tee`` must internally buffer each item until the
|
||||
last iterator has yielded it; if the most and least advanced iterator differ
|
||||
by most data, using a :py:class:`list` is more efficient (but not lazy).
|
||||
|
||||
If the underlying iterable is concurrency safe (``anext`` may be awaited
|
||||
concurrently) the resulting iterators are concurrency safe as well. Otherwise,
|
||||
the iterators are safe if there is only ever one single "most advanced" iterator.
|
||||
To enforce sequential use of ``anext``, provide a ``lock``
|
||||
- e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application -
|
||||
and access is automatically synchronised.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
iterable: AsyncIterator[T],
|
||||
n: int = 2,
|
||||
):
|
||||
self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist
|
||||
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
|
||||
self._children = tuple(
|
||||
tee_peer(
|
||||
iterator=self._iterator,
|
||||
buffer=buffer,
|
||||
peers=self._buffers,
|
||||
)
|
||||
for buffer in self._buffers
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._children)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: int) -> AsyncIterator[T]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: slice) -> Tuple[AsyncIterator[T], ...]:
|
||||
...
|
||||
|
||||
def __getitem__(
|
||||
self, item: Union[int, slice]
|
||||
) -> Union[AsyncIterator[T], Tuple[AsyncIterator[T], ...]]:
|
||||
return self._children[item]
|
||||
|
||||
def __iter__(self) -> Iterator[AsyncIterator[T]]:
|
||||
yield from self._children
|
||||
|
||||
async def __aenter__(self) -> "Tee[T]":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
|
||||
await self.aclose()
|
||||
return False
|
||||
|
||||
async def aclose(self) -> None:
|
||||
for child in self._children:
|
||||
await child.aclose()
|
||||
|
||||
|
||||
atee = Tee
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain"
|
||||
version = "0.0.258"
|
||||
version = "0.0.260"
|
||||
description = "Building applications with LLMs through composability"
|
||||
authors = []
|
||||
license = "MIT"
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
"""Tests RocksetChatMessageHistory by creating a collection
|
||||
for message history, adding to it, and clearing it.
|
||||
|
||||
To run these tests, make sure you have the ROCKSET_API_KEY
|
||||
and ROCKSET_REGION environment variables set.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from rockset import DevRegions, Regions, RocksetClient
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.memory.chat_message_histories import RocksetChatMessageHistory
|
||||
from langchain.schema.messages import _message_to_dict
|
||||
|
||||
collection_name = "langchain_demo"
|
||||
session_id = "MySession"
|
||||
|
||||
|
||||
class TestRockset:
|
||||
memory: RocksetChatMessageHistory
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls) -> None:
|
||||
assert os.environ.get("ROCKSET_API_KEY") is not None
|
||||
assert os.environ.get("ROCKSET_REGION") is not None
|
||||
|
||||
api_key = os.environ.get("ROCKSET_API_KEY")
|
||||
region = os.environ.get("ROCKSET_REGION")
|
||||
if region == "use1a1":
|
||||
host = Regions.use1a1
|
||||
elif region == "usw2a1" or not region:
|
||||
host = Regions.usw2a1
|
||||
elif region == "euc1a1":
|
||||
host = Regions.euc1a1
|
||||
elif region == "dev":
|
||||
host = DevRegions.usw2a1
|
||||
else:
|
||||
host = region
|
||||
|
||||
client = RocksetClient(host, api_key)
|
||||
cls.memory = RocksetChatMessageHistory(
|
||||
session_id, client, collection_name, sync=True
|
||||
)
|
||||
|
||||
def test_memory_with_message_store(self) -> None:
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="messages", chat_memory=self.memory, return_messages=True
|
||||
)
|
||||
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
memory.chat_memory.clear()
|
||||
|
||||
assert memory.chat_memory.messages == []
|
||||
@@ -60,3 +60,67 @@ def test_on_llm_end_finetuned_model(handler: OpenAICallbackHandler) -> None:
|
||||
)
|
||||
handler.on_llm_end(response)
|
||||
assert handler.total_cost > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,expected_cost",
|
||||
[
|
||||
("gpt-35-turbo", 0.0035),
|
||||
("gpt-35-turbo-0301", 0.0035),
|
||||
(
|
||||
"gpt-35-turbo-0613",
|
||||
0.0035,
|
||||
),
|
||||
(
|
||||
"gpt-35-turbo-16k-0613",
|
||||
0.007,
|
||||
),
|
||||
(
|
||||
"gpt-35-turbo-16k",
|
||||
0.007,
|
||||
),
|
||||
("gpt-4", 0.09),
|
||||
("gpt-4-0314", 0.09),
|
||||
("gpt-4-0613", 0.09),
|
||||
("gpt-4-32k", 0.18),
|
||||
("gpt-4-32k-0314", 0.18),
|
||||
("gpt-4-32k-0613", 0.18),
|
||||
],
|
||||
)
|
||||
def test_on_llm_end_azure_openai(
|
||||
handler: OpenAICallbackHandler, model_name: str, expected_cost: float
|
||||
) -> None:
|
||||
response = LLMResult(
|
||||
generations=[],
|
||||
llm_output={
|
||||
"token_usage": {
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 1000,
|
||||
"total_tokens": 2000,
|
||||
},
|
||||
"model_name": model_name,
|
||||
},
|
||||
)
|
||||
handler.on_llm_end(response)
|
||||
assert handler.total_cost == expected_cost
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name", ["gpt-35-turbo-16k-0301", "gpt-4-0301", "gpt-4-32k-0301"]
|
||||
)
|
||||
def test_on_llm_end_no_cost_invalid_model(
|
||||
handler: OpenAICallbackHandler, model_name: str
|
||||
) -> None:
|
||||
response = LLMResult(
|
||||
generations=[],
|
||||
llm_output={
|
||||
"token_usage": {
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 1000,
|
||||
"total_tokens": 2000,
|
||||
},
|
||||
"model_name": model_name,
|
||||
},
|
||||
)
|
||||
handler.on_llm_end(response)
|
||||
assert handler.total_cost == 0
|
||||
|
||||
@@ -15,7 +15,7 @@ def dummy_transform(inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
return outputs
|
||||
|
||||
|
||||
def test_tranform_chain() -> None:
|
||||
def test_transform_chain() -> None:
|
||||
"""Test basic transform chain."""
|
||||
transform_chain = TransformChain(
|
||||
input_variables=["first_name", "last_name"],
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Mapping, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "test"
|
||||
os.environ["OPENAI_API_BASE"] = "https://oai.azure.com/"
|
||||
os.environ["OPENAI_API_VERSION"] = "2023-05-01"
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
@pytest.mark.parametrize(
|
||||
"model_name", ["gpt-4", "gpt-4-32k", "gpt-35-turbo", "gpt-35-turbo-16k"]
|
||||
)
|
||||
def test_model_name_set_on_chat_result_when_present_in_response(
|
||||
model_name: str,
|
||||
) -> None:
|
||||
sample_response_text = f"""
|
||||
{{
|
||||
"id": "chatcmpl-7ryweq7yc8463fas879t9hdkkdf",
|
||||
"object": "chat.completion",
|
||||
"created": 1690381189,
|
||||
"model": "{model_name}",
|
||||
"choices": [
|
||||
{{
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"message": {{
|
||||
"role": "assistant",
|
||||
"content": "I'm an AI assistant that can help you."
|
||||
}}
|
||||
}}
|
||||
],
|
||||
"usage": {{
|
||||
"completion_tokens": 28,
|
||||
"prompt_tokens": 15,
|
||||
"total_tokens": 43
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
# convert sample_response_text to instance of Mapping[str, Any]
|
||||
sample_response = json.loads(sample_response_text)
|
||||
mock_response = cast(Mapping[str, Any], sample_response)
|
||||
mock_chat = AzureChatOpenAI()
|
||||
chat_result = mock_chat._create_chat_result(mock_response)
|
||||
assert (
|
||||
chat_result.llm_output is not None
|
||||
and chat_result.llm_output["model_name"] == model_name
|
||||
)
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Test the airbyte document loader.
|
||||
|
||||
Light test to ensure that the airbyte document loader can be imported.
|
||||
"""
|
||||
|
||||
|
||||
def test_airbyte_import() -> None:
|
||||
"""Test that the airbyte document loader can be imported."""
|
||||
from langchain.document_loaders import airbyte # noqa
|
||||
@@ -6,8 +6,12 @@ import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.evaluation.agents.trajectory_eval_chain import TrajectoryEvalChain
|
||||
from langchain.schema import AgentAction, BaseMessage
|
||||
from langchain.evaluation.agents.trajectory_eval_chain import (
|
||||
TrajectoryEval,
|
||||
TrajectoryEvalChain,
|
||||
TrajectoryOutputParser,
|
||||
)
|
||||
from langchain.schema import AgentAction, BaseMessage, OutputParserException
|
||||
from langchain.tools.base import tool
|
||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||
|
||||
@@ -53,6 +57,61 @@ class _FakeTrajectoryChatModel(FakeChatModel):
|
||||
return self.queries[prompt]
|
||||
|
||||
|
||||
def test_trajectory_output_parser_parse() -> None:
|
||||
trajectory_output_parser = TrajectoryOutputParser()
|
||||
text = """Judgment: Given the good reasoning in the final answer
|
||||
but otherwise poor performance, we give the model a score of 2.
|
||||
|
||||
Score: 2"""
|
||||
got = trajectory_output_parser.parse(text)
|
||||
want = TrajectoryEval(
|
||||
score=0.25,
|
||||
reasoning="""Judgment: Given the good reasoning in the final answer
|
||||
but otherwise poor performance, we give the model a score of 2.""",
|
||||
)
|
||||
|
||||
assert got["score"] == want["score"]
|
||||
assert got["reasoning"] == want["reasoning"]
|
||||
|
||||
with pytest.raises(OutputParserException):
|
||||
trajectory_output_parser.parse(
|
||||
"""Judgment: Given the good reasoning in the final answer
|
||||
but otherwise poor performance, we give the model a score of 2."""
|
||||
)
|
||||
|
||||
with pytest.raises(OutputParserException):
|
||||
trajectory_output_parser.parse(
|
||||
"""Judgment: Given the good reasoning in the final answer
|
||||
but otherwise poor performance, we give the model a score of 2.
|
||||
|
||||
Score: 9"""
|
||||
)
|
||||
|
||||
with pytest.raises(OutputParserException):
|
||||
trajectory_output_parser.parse(
|
||||
"""Judgment: Given the good reasoning in the final answer
|
||||
but otherwise poor performance, we give the model a score of 2.
|
||||
|
||||
Score: 10"""
|
||||
)
|
||||
|
||||
with pytest.raises(OutputParserException):
|
||||
trajectory_output_parser.parse(
|
||||
"""Judgment: Given the good reasoning in the final answer
|
||||
but otherwise poor performance, we give the model a score of 2.
|
||||
|
||||
Score: 0.1"""
|
||||
)
|
||||
|
||||
with pytest.raises(OutputParserException):
|
||||
trajectory_output_parser.parse(
|
||||
"""Judgment: Given the good reasoning in the final answer
|
||||
but otherwise poor performance, we give the model a score of 2.
|
||||
|
||||
Score: One"""
|
||||
)
|
||||
|
||||
|
||||
def test_trajectory_eval_chain(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> None:
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
# serializer version: 1
|
||||
# name: test_openai_functions_router
|
||||
list([
|
||||
dict({
|
||||
'description': 'Sends the draft for revision.',
|
||||
'name': 'revise',
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'notes': dict({
|
||||
'description': "The editor's notes to guide the revision.",
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
dict({
|
||||
'description': 'Accepts the draft.',
|
||||
'name': 'accept',
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'draft': dict({
|
||||
'description': 'The draft to accept.',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
@@ -0,0 +1,95 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.runnables.openai_functions import OpenAIFunctionsRouter
|
||||
from langchain.schema import ChatResult
|
||||
from langchain.schema.messages import AIMessage, BaseMessage
|
||||
from langchain.schema.output import ChatGeneration
|
||||
|
||||
|
||||
class FakeChatOpenAI(BaseChatModel):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-openai-chat-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(
|
||||
message=AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "accept",
|
||||
"arguments": '{\n "draft": "turtles"\n}',
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_openai_functions_router(
|
||||
snapshot: SnapshotAssertion, mocker: MockerFixture
|
||||
) -> None:
|
||||
revise = mocker.Mock(
|
||||
side_effect=lambda kw: f'Revised draft: no more {kw["notes"]}!'
|
||||
)
|
||||
accept = mocker.Mock(side_effect=lambda kw: f'Accepted draft: {kw["draft"]}!')
|
||||
|
||||
router = OpenAIFunctionsRouter(
|
||||
{
|
||||
"revise": revise,
|
||||
"accept": accept,
|
||||
},
|
||||
functions=[
|
||||
{
|
||||
"name": "revise",
|
||||
"description": "Sends the draft for revision.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"notes": {
|
||||
"type": "string",
|
||||
"description": "The editor's notes to guide the revision.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "accept",
|
||||
"description": "Accepts the draft.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"draft": {
|
||||
"type": "string",
|
||||
"description": "The draft to accept.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
model = FakeChatOpenAI()
|
||||
|
||||
chain = model.bind(functions=router.functions) | router
|
||||
|
||||
assert router.functions == snapshot
|
||||
|
||||
assert chain.invoke("Something about turtles?") == "Accepted draft: turtles!"
|
||||
|
||||
revise.assert_not_called()
|
||||
accept.assert_called_once_with({"draft": "turtles"})
|
||||
File diff suppressed because one or more lines are too long
@@ -11,7 +11,7 @@ from langchain.callbacks.manager import Callbacks
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.chat_models.fake import FakeListChatModel
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM
|
||||
from langchain.load.dump import dumpd, dumps
|
||||
from langchain.output_parsers.list import CommaSeparatedListOutputParser
|
||||
from langchain.prompts.chat import (
|
||||
@@ -22,6 +22,7 @@ from langchain.prompts.chat import (
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.schema.runnable import (
|
||||
RouterRunnable,
|
||||
@@ -61,6 +62,8 @@ class FakeTracer(BaseTracer):
|
||||
if run.parent_run_id
|
||||
else None,
|
||||
"child_runs": [self._copy_run(child) for child in run.child_runs],
|
||||
"execution_order": None,
|
||||
"child_execution_order": None,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -302,7 +305,7 @@ async def test_prompt_with_chat_model(
|
||||
tracer = FakeTracer()
|
||||
assert [
|
||||
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
|
||||
] == [AIMessage(content="foo")]
|
||||
] == [AIMessage(content="f"), AIMessage(content="o"), AIMessage(content="o")]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
@@ -678,7 +681,12 @@ async def test_router_runnable(
|
||||
"key": "math",
|
||||
"input": {"question": "2 + 2"},
|
||||
}
|
||||
assert tracer.runs == snapshot
|
||||
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||
assert len(parent_run.child_runs) == 2
|
||||
router_run = parent_run.child_runs[1]
|
||||
assert router_run.name == "RunnableSequence" # TODO: should be RunnableRouter
|
||||
assert len(router_run.child_runs) == 2
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@@ -758,6 +766,45 @@ def test_bind_bind() -> None:
|
||||
) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world"))
|
||||
|
||||
|
||||
def test_deep_stream() -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||
|
||||
chain = prompt | llm | StrOutputParser()
|
||||
|
||||
stream = chain.stream({"question": "What up"})
|
||||
|
||||
chunks = []
|
||||
for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == len("foo-lish")
|
||||
assert "".join(chunks) == "foo-lish"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deep_astream() -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||
|
||||
chain = prompt | llm | StrOutputParser()
|
||||
|
||||
stream = chain.astream({"question": "What up"})
|
||||
|
||||
chunks = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == len("foo-lish")
|
||||
assert "".join(chunks) == "foo-lish"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm_with_fallbacks() -> RunnableWithFallbacks:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
|
||||
Reference in New Issue
Block a user