Compare commits

..

17 Commits

Author SHA1 Message Date
Bagatur
96d064e305 bump 260 (#9002) 2023-08-09 13:40:49 -07:00
Michael Shen
c2f46b2cdb Fixed wrong paper reference (#8970)
The ReAct reference references to MRKL paper. Corrected so that it
points to the actual ReAct paper #8964.
2023-08-09 16:17:46 -04:00
Nuno Campos
808248049d Implement a router for openai functions (#8589) 2023-08-09 21:17:04 +01:00
Eugene Yurtsev
a6e6e9bb86 Fix airbyte loader (#8998)
Fix airbyte loader

https://github.com/langchain-ai/langchain/issues/8996
2023-08-09 16:13:06 -04:00
William FH
90579021f8 Update Key Check (#8948)
In eval loop. It needn't be done unless you are creating the
corresponding evaluators
2023-08-09 12:33:00 -07:00
Jerzy Czopek
539672a7fd Feature/fix azureopenai model mappings (#8621)
This pull request aims to ensure that the `OpenAICallbackHandler` can
properly calculate the total cost for Azure OpenAI chat models. The
following changes have resolved this issue:

- The `model_name` has been added to the ChatResult llm_output. Without
this, the default values of `gpt-35-turbo` were applied. This was
causing the total cost for Azure OpenAI's GPT-4 to be significantly
inaccurate.
- A new parameter `model_version` has been added to `AzureChatOpenAI`.
Azure does not include the model version in the response. With the
addition of `model_name`, this is not a significant issue for GPT-4
models, but it's an issue for GPT-3.5-Turbo. Version 0301 (default) of
GPT-3.5-Turbo on Azure has a flat rate of 0.002 per 1k tokens for both
prompt and completion. However, version 0613 introduced a split in
pricing for prompt and completion tokens.
- The `OpenAICallbackHandler` implementation has been updated with the
proper model names, versions, and cost per 1k tokens.

Unit tests have been added to ensure the functionality works as
expected; the Azure ChatOpenAI notebook has been updated with examples.

Maintainers: @hwchase17, @baskaryan

Twitter handle: @jjczopek

---------

Co-authored-by: Jerzy Czopek <jerzy.czopek@avanade.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-08-09 10:56:15 -07:00
Bagatur
269f85b7b7 scheduled gha fix (#8977) 2023-08-09 09:44:25 -07:00
shibuiwilliam
3adb1e12ca make trajectory eval chain stricter and add unit tests (#8909)
- update trajectory eval logic to be stricter
- add tests to trajectory eval chain
2023-08-09 10:57:18 -04:00
Nuno Campos
b8df15cd64 Adds transform support for runnables (#8762)
<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure you're PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->

---------

Co-authored-by: jacoblee93 <jacoblee93@gmail.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-08-09 12:34:23 +01:00
Harrison Chase
4d72288487 async output parser (#8894)
Co-authored-by: Nuno Campos <nuno@boringbits.io>
2023-08-09 08:25:38 +01:00
Bagatur
3c6eccd701 bump 259 (#8951) 2023-08-09 00:07:47 -07:00
Harrison Chase
7de6a1b78e parent document retriever (#8941) 2023-08-08 22:39:08 -07:00
arjunbansal
a2681f950d add instructions on integrating Log10 (#8938)
- Description: Instruction for integration with Log10: an [open
source](https://github.com/log10-io/log10) proxiless LLM data management
and application development platform that lets you log, debug and tag
your Langchain calls
  - Tag maintainer: @baskaryan
  - Twitter handle: @log10io @coffeephoenix

Several examples showing the integration included
[here](https://github.com/log10-io/log10/tree/main/examples/logging) and
in the PR
2023-08-08 19:15:31 -07:00
Aarav Borthakur
3f64b8a761 Integrate Rockset as a chat history store (#8940)
Description: Adds Rockset as a chat history store
Dependencies: no changes
Tag maintainer: @hwchase17

This PR passes linting and testing. 

I added a test for the integration and an example notebook showing its
use.
2023-08-08 18:54:07 -07:00
Bagatur
0a1be1d501 document lcel fallbacks (#8942) 2023-08-08 18:49:33 -07:00
William FH
e3056340da Add id in error in tracer (#8944) 2023-08-08 18:25:27 -07:00
Molly Cantillon
99b5a7226c Weaviate: adding auth example + fixing spelling in ReadME (#8939)
Added basic auth example to Weaviate notebook @baskaryan
2023-08-08 16:24:17 -07:00
46 changed files with 2831 additions and 154 deletions

View File

@@ -1,7 +1,7 @@
name: Scheduled tests
on:
scheduled:
schedule:
- cron: '0 13 * * *'
env:

View File

@@ -12,7 +12,7 @@ Here are the agents available in LangChain.
### [Zero-shot ReAct](/docs/modules/agents/agent_types/react.html)
This agent uses the [ReAct](https://arxiv.org/pdf/2205.00445.pdf) framework to determine which tool to use
This agent uses the [ReAct](https://arxiv.org/pdf/2210.03629) framework to determine which tool to use
based solely on the tool's description. Any number of tools can be provided.
This agent requires that a description is provided for each tool.

View File

@@ -556,6 +556,14 @@
"source": "/docs/integrations/llamacpp",
"destination": "/docs/integrations/providers/llamacpp"
},
{
"source": "/en/latest/integrations/log10.html",
"destination": "/docs/integrations/providers/log10"
},
{
"source": "/docs/integrations/log10",
"destination": "/docs/integrations/providers/log10"
},
{
"source": "/en/latest/integrations/mediawikidump.html",
"destination": "/docs/integrations/providers/mediawikidump"

View File

@@ -1648,6 +1648,186 @@
"source": [
"## Conversational Retrieval With Memory"
]
},
{
"cell_type": "markdown",
"id": "92c87dd8-bb6f-4f32-a30d-8f5459ce6265",
"metadata": {},
"source": [
"## Fallbacks\n",
"\n",
"With LCEL you can easily introduce fallbacks for any Runnable component, like an LLM."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1b1cb744-31fc-4261-ab25-65fe1fcad559",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='To get to the other side.', additional_kwargs={}, example=False)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"\n",
"bad_llm = ChatOpenAI(model_name=\"gpt-fake\")\n",
"good_llm = ChatOpenAI(model_name=\"gpt-3.5-turbo\")\n",
"llm = bad_llm.with_fallbacks([good_llm])\n",
"\n",
"llm.invoke(\"Why did the the chicken cross the road?\")"
]
},
{
"cell_type": "markdown",
"id": "b8cf3982-03f6-49b3-8ff5-7cd12444f19c",
"metadata": {},
"source": [
"Looking at the trace, we can see that the first model failed but the second succeeded, so we still got an output: https://smith.langchain.com/public/dfaf0bf6-d86d-43e9-b084-dd16a56df15c/r\n",
"\n",
"We can add an arbitrary sequence of fallbacks, which will be executed in order:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "31819be0-7f40-4e67-b5ab-61340027b948",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='To get to the other side.', additional_kwargs={}, example=False)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm = bad_llm.with_fallbacks([bad_llm, bad_llm, good_llm])\n",
"\n",
"llm.invoke(\"Why did the the chicken cross the road?\")"
]
},
{
"cell_type": "markdown",
"id": "acad6e88-8046-450e-b005-db7e50f33b80",
"metadata": {},
"source": [
"Trace: https://smith.langchain.com/public/c09efd01-3184-4369-a225-c9da8efcaf47/r\n",
"\n",
"We can continue to use our Runnable with fallbacks the same way we use any Runnable, mean we can include it in sequences:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "bab114a1-bb93-4b7e-a639-e7e00f21aebc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='To show off its incredible jumping skills! Kangaroos are truly amazing creatures.', additional_kwargs={}, example=False)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", \"You're a nice assistant who always includes a compliment in your response\"),\n",
" (\"human\", \"Why did the {animal} cross the road\"),\n",
" ]\n",
")\n",
"chain = prompt | llm\n",
"chain.invoke({\"animal\": \"kangaroo\"})"
]
},
{
"cell_type": "markdown",
"id": "58340afa-8187-4ffe-9bd2-7912fb733a15",
"metadata": {},
"source": [
"Trace: https://smith.langchain.com/public/ba03895f-f8bd-4c70-81b7-8b930353eabd/r\n",
"\n",
"Note, since every sequence of Runnables is itself a Runnable, we can create fallbacks for whole Sequences. We can also continue using the full interface, including asynchronous calls, batched calls, and streams:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "45aa3170-b2e6-430d-887b-bd879048060a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[\"\\n\\nAnswer: The rabbit crossed the road to get to the other side. That's quite clever of him!\",\n",
" '\\n\\nAnswer: The turtle crossed the road to get to the other side. You must be pretty clever to come up with that riddle!']"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.llms import OpenAI\n",
"from langchain.prompts import PromptTemplate\n",
"\n",
"chat_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", \"You're a nice assistant who always includes a compliment in your response\"),\n",
" (\"human\", \"Why did the {animal} cross the road\"),\n",
" ]\n",
")\n",
"chat_model = ChatOpenAI(model_name=\"gpt-fake\")\n",
"\n",
"prompt_template = \"\"\"Instructions: You should always include a compliment in your response.\n",
"\n",
"Question: Why did the {animal} cross the road?\"\"\"\n",
"prompt = PromptTemplate.from_template(prompt_template)\n",
"llm = OpenAI()\n",
"\n",
"bad_chain = chat_prompt | chat_model\n",
"good_chain = prompt | llm\n",
"chain = bad_chain.with_fallbacks([good_chain])\n",
"await chain.abatch([{\"animal\": \"rabbit\"}, {\"animal\": \"turtle\"}])"
]
},
{
"cell_type": "markdown",
"id": "af6731c6-0c73-4b1d-a433-6e8f6ecce2bb",
"metadata": {},
"source": [
"Traces: \n",
"1. https://smith.langchain.com/public/ccd73236-9ae5-48a6-94b5-41210be18a46/r\n",
"2. https://smith.langchain.com/public/f43f608e-075c-45c7-bf73-b64e4d3f3082/r"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3d2fe1fe-506b-4ee5-8056-8b9df801765f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@@ -1666,7 +1846,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.1"
"version": "3.9.1"
}
},
"nbformat": 4,

View File

@@ -74,6 +74,124 @@
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "f27fa24d",
"metadata": {},
"source": [
"## Model Version\n",
"Azure OpenAI responses contain `model` property, which is name of the model used to generate the response. However unlike native OpenAI responses, it does not contain the version of the model, which is set on the deplyoment in Azure. This makes it tricky to know which version of the model was used to generate the response, which as result can lead to e.g. wrong total cost calculation with `OpenAICallbackHandler`.\n",
"\n",
"To solve this problem, you can pass `model_version` parameter to `AzureChatOpenAI` class, which will be added to the model name in the llm output. This way you can easily distinguish between different versions of the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0531798a",
"metadata": {},
"outputs": [],
"source": [
"from langchain.callbacks import get_openai_callback"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "3fd97dfc",
"metadata": {},
"outputs": [],
"source": [
"BASE_URL = \"https://{endpoint}.openai.azure.com\"\n",
"API_KEY = \"...\"\n",
"DEPLOYMENT_NAME = \"gpt-35-turbo\" # in Azure, this deployment has version 0613 - input and output tokens are counted separately"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "aceddb72",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total Cost (USD): $0.000054\n"
]
}
],
"source": [
"model = AzureChatOpenAI(\n",
" openai_api_base=BASE_URL,\n",
" openai_api_version=\"2023-05-15\",\n",
" deployment_name=DEPLOYMENT_NAME,\n",
" openai_api_key=API_KEY,\n",
" openai_api_type=\"azure\",\n",
")\n",
"with get_openai_callback() as cb:\n",
" model(\n",
" [\n",
" HumanMessage(\n",
" content=\"Translate this sentence from English to French. I love programming.\"\n",
" )\n",
" ]\n",
" )\n",
" print(f\"Total Cost (USD): ${format(cb.total_cost, '.6f')}\") # without specifying the model version, flat-rate 0.002 USD per 1k input and output tokens is used\n"
]
},
{
"cell_type": "markdown",
"id": "2e61eefd",
"metadata": {},
"source": [
"We can provide the model version to `AzureChatOpenAI` constructor. It will get appended to the model name returned by Azure OpenAI and cost will be counted correctly."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "8d5e54e9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total Cost (USD): $0.000044\n"
]
}
],
"source": [
"model0613 = AzureChatOpenAI(\n",
" openai_api_base=BASE_URL,\n",
" openai_api_version=\"2023-05-15\",\n",
" deployment_name=DEPLOYMENT_NAME,\n",
" openai_api_key=API_KEY,\n",
" openai_api_type=\"azure\",\n",
" model_version=\"0613\"\n",
")\n",
"with get_openai_callback() as cb:\n",
" model0613(\n",
" [\n",
" HumanMessage(\n",
" content=\"Translate this sentence from English to French. I love programming.\"\n",
" )\n",
" ]\n",
" )\n",
" print(f\"Total Cost (USD): ${format(cb.total_cost, '.6f')}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "99682534",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@@ -92,7 +210,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.8.10"
}
},
"nbformat": 4,

View File

@@ -0,0 +1,67 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Rockset Chat Message History\n",
"\n",
"This notebook goes over how to use [Rockset](https://rockset.com/docs) to store chat message history. \n",
"\n",
"To begin, with get your API key from the [Rockset console](https://console.rockset.com/apikeys). Find your API region for the Rockset [API reference](https://rockset.com/docs/rest-api#introduction)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"from langchain.memory.chat_message_histories import RocksetChatMessageHistory\n",
"from rockset import RocksetClient, Regions\n",
"\n",
"history = RocksetChatMessageHistory(\n",
" session_id=\"MySession\",\n",
" client=RocksetClient(\n",
" api_key=\"YOUR API KEY\", \n",
" host=Regions.usw2a1 # us-west-2 Oregon\n",
" ),\n",
" collection=\"langchain_demo\",\n",
" sync=True\n",
")\n",
"history.add_user_message(\"hi!\")\n",
"history.add_ai_message(\"whats up?\")\n",
"print(history.messages)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"The output should be something like:\n",
"\n",
"```python\n",
"[\n",
" HumanMessage(content='hi!', additional_kwargs={'id': '2e62f1c2-e9f7-465e-b551-49bae07fe9f0'}, example=False), \n",
" AIMessage(content='whats up?', additional_kwargs={'id': 'b9be8eda-4c18-4cf8-81c3-e91e876927d0'}, example=False)\n",
"]\n",
"\n",
"```"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,104 @@
# Log10
This page covers how to use the [Log10](https://log10.io) within LangChain.
## What is Log10?
Log10 is an [open source](https://github.com/log10-io/log10) proxiless LLM data management and application development platform that lets you log, debug and tag your Langchain calls.
## Quick start
1. Create your free account at [log10.io](https://log10.io)
2. Add your `LOG10_TOKEN` and `LOG10_ORG_ID` from the Settings and Organization tabs respectively as environment variables.
3. Also add `LOG10_URL=https://log10.io` and your usual LLM API key: for e.g. `OPENAI_API_KEY` or `ANTHROPIC_API_KEY` to your environment
## How to enable Log10 data management for Langchain
Integration with log10 is a simple one-line `log10_callback` integration as shown below:
```python
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
from log10.langchain import Log10Callback
from log10.llm import Log10Config
log10_callback = Log10Callback(log10_config=Log10Config())
messages = [
HumanMessage(content="You are a ping pong machine"),
HumanMessage(content="Ping?"),
]
llm = ChatOpenAI(model_name="gpt-3.5-turbo", callbacks=[log10_callback])
```
[Log10 + Langchain + Logs docs](https://github.com/log10-io/log10/blob/main/logging.md#langchain-logger)
[More details + screenshots](https://log10.io/docs/logs) including instructions for self-hosting logs
## How to use tags with Log10
```python
from langchain import OpenAI
from langchain.chat_models import ChatAnthropic
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
from log10.langchain import Log10Callback
from log10.llm import Log10Config
log10_callback = Log10Callback(log10_config=Log10Config())
messages = [
HumanMessage(content="You are a ping pong machine"),
HumanMessage(content="Ping?"),
]
llm = ChatOpenAI(model_name="gpt-3.5-turbo", callbacks=[log10_callback], temperature=0.5, tags=["test"])
completion = llm.predict_messages(messages, tags=["foobar"])
print(completion)
llm = ChatAnthropic(model="claude-2", callbacks=[log10_callback], temperature=0.7, tags=["baz"])
llm.predict_messages(messages)
print(completion)
llm = OpenAI(model_name="text-davinci-003", callbacks=[log10_callback], temperature=0.5)
completion = llm.predict("You are a ping pong machine.\nPing?\n")
print(completion)
```
You can also intermix direct OpenAI calls and Langchain LLM calls:
```python
import os
from log10.load import log10, log10_session
import openai
from langchain import OpenAI
log10(openai)
with log10_session(tags=["foo", "bar"]):
# Log a direct OpenAI call
response = openai.Completion.create(
model="text-ada-001",
prompt="Where is the Eiffel Tower?",
temperature=0,
max_tokens=1024,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
)
print(response)
# Log a call via Langchain
llm = OpenAI(model_name="text-ada-001", temperature=0.5)
response = llm.predict("You are a ping pong machine.\nPing?\n")
print(response)
```
## How to debug Langchain calls
[Example of debugging](https://log10.io/docs/prompt_chain_debugging)
[More Langchain examples](https://github.com/log10-io/log10/tree/main/examples#langchain)

View File

@@ -23,4 +23,11 @@ from langchain.vectorstores import Rockset
See a [usage example](/docs/integrations/document_loaders/rockset).
```python
from langchain.document_loaders import RocksetLoader
```
## Chat Message History
See a [usage example](/docs/integrations/memory/rockset_chat_message_history).
```python
from langchain.memory.chat_message_histories import RocksetChatMessageHistory
```

View File

@@ -81,17 +81,18 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 21,
"id": "53b7ce2d-3c09-4d1c-b66b-5769ce6746ae",
"metadata": {},
"outputs": [],
"source": [
"os.environ[\"WEAVIATE_API_KEY\"] = getpass.getpass(\"WEAVIATE_API_KEY:\")"
"os.environ[\"WEAVIATE_API_KEY\"] = getpass.getpass(\"WEAVIATE_API_KEY:\")\n",
"WEAVIATE_API_KEY = os.environ[\"WEAVIATE_API_KEY\"]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"id": "aac9563e",
"metadata": {
"tags": []
@@ -106,7 +107,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 12,
"id": "a3c3999a",
"metadata": {},
"outputs": [],
@@ -123,7 +124,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 14,
"id": "21e9e528",
"metadata": {},
"outputs": [],
@@ -133,7 +134,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 15,
"id": "b4170176",
"metadata": {},
"outputs": [],
@@ -144,7 +145,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 16,
"id": "ecf3b890",
"metadata": {},
"outputs": [
@@ -166,6 +167,53 @@
"print(docs[0].page_content)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "7826d0ea",
"metadata": {},
"source": [
"## Authentication"
]
},
{
"cell_type": "markdown",
"id": "13989a7c",
"metadata": {},
"source": [
"Weaviate instances have authentication enabled by default. You can use either a username/password combination or API key. "
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "f6604f1d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<langchain.vectorstores.weaviate.Weaviate object at 0x107f46550>\n"
]
}
],
"source": [
"import weaviate\n",
"\n",
"client = weaviate.Client(url=WEAVIATE_URL, auth_client_secret=weaviate.AuthApiKey(WEAVIATE_API_KEY))\n",
"\n",
"# client = weaviate.Client(\n",
"# url=WEAVIATE_URL,\n",
"# auth_client_secret=weaviate.AuthClientPassword(\n",
"# username = \"WCS_USERNAME\", # Replace w/ your WCS username\n",
"# password = \"WCS_PASSWORD\", # Replace w/ your WCS password\n",
"# ),\n",
"# )\n",
"\n",
"vectorstore = Weaviate.from_documents(documents, embeddings, client=client, by_text=False)"
]
},
{
"attachments": {},
"cell_type": "markdown",
@@ -187,7 +235,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 17,
"id": "102105a1",
"metadata": {},
"outputs": [
@@ -213,7 +261,7 @@
"id": "8fc3487b",
"metadata": {},
"source": [
"# Persistance"
"# Persistence"
]
},
{
@@ -249,7 +297,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"id": "8b7df7ae",
"metadata": {},
"outputs": [
@@ -287,7 +335,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"id": "5e824f3b",
"metadata": {},
"outputs": [],
@@ -298,7 +346,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"id": "61209cc3",
"metadata": {},
"outputs": [],
@@ -311,7 +359,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"id": "4abc3d37",
"metadata": {},
"outputs": [],
@@ -327,7 +375,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"id": "c7062393",
"metadata": {},
"outputs": [],
@@ -339,7 +387,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"id": "7e41b773",
"metadata": {},
"outputs": [

View File

@@ -0,0 +1,440 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "34883374",
"metadata": {},
"source": [
"# Parent Document Retriever\n",
"\n",
"When splitting documents for retrieval, there are often conflicting desires:\n",
"\n",
"1. You may want to have small documents, so that their embeddings can most\n",
" accurately reflect their meaning. If too long, then the embeddings can\n",
" lose meaning.\n",
"2. You want to have long enough documents that the context of each chunk is\n",
" retained.\n",
"\n",
"The ParentDocumentRetriever strikes that balance by splitting and storing\n",
"small chunks of data. During retrieval, it first fetches the small chunks\n",
"but then looks up the parent ids for those chunks and returns those larger\n",
"documents.\n",
"\n",
"Note that \"parent document\" refers to the document that a small chunk\n",
"originated from. This can either be the whole raw document OR a larger\n",
"chunk."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8b6e74b2",
"metadata": {},
"outputs": [],
"source": [
"from langchain.retrievers import ParentDocumentRetriever"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1d17af96",
"metadata": {},
"outputs": [],
"source": [
"from langchain.vectorstores import Chroma\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from langchain.storage import InMemoryStore\n",
"from langchain.document_loaders import TextLoader"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "604ff981",
"metadata": {},
"outputs": [],
"source": [
"loaders = [\n",
" TextLoader('../../paul_graham_essay.txt'),\n",
" TextLoader('../../state_of_the_union.txt'),\n",
"]\n",
"docs = []\n",
"for l in loaders:\n",
" docs.extend(l.load())"
]
},
{
"cell_type": "markdown",
"id": "d3943f72",
"metadata": {},
"source": [
"## Retrieving Full Documents\n",
"\n",
"In this mode, we want to retrieve the full documents. Therefor, we only specify a child splitter."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1a8b2e5f",
"metadata": {},
"outputs": [],
"source": [
"# This text splitter is used to create the child documents\n",
"\n",
"child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)\n",
"# The vectorstore to use to index the child chunks\n",
"vectorstore = Chroma(\n",
" collection_name=\"full_documents\",\n",
" embedding_function=OpenAIEmbeddings()\n",
")\n",
"# The storage layer for the parent documents\n",
"store = InMemoryStore()\n",
"retriever = ParentDocumentRetriever(\n",
" vectorstore=vectorstore, \n",
" docstore=store, \n",
" child_splitter=child_splitter,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2b107935",
"metadata": {},
"outputs": [],
"source": [
"retriever.add_documents(docs)"
]
},
{
"cell_type": "markdown",
"id": "d05b97b7",
"metadata": {},
"source": [
"This should yield two keys, because we added two documents."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "30e3812b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['05fe8d8a-bf60-4f87-b576-4351b23df266',\n",
" '571cc9e5-9ef7-4f6c-b800-835c83a1858b']"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(store.yield_keys())"
]
},
{
"cell_type": "markdown",
"id": "f895d62b",
"metadata": {},
"source": [
"Let's now call the vectorstore search functionality - we should see that it returns small chunks (since we're storing the small chunks"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b261c02c",
"metadata": {},
"outputs": [],
"source": [
"sub_docs = vectorstore.similarity_search(\"justice breyer\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "5108222f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
"\n",
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court.\n"
]
}
],
"source": [
"print(sub_docs[0].page_content)"
]
},
{
"cell_type": "markdown",
"id": "bda8ed5a",
"metadata": {},
"source": [
"Let's now retrieve from the overall retriever. This should return large documents - since it returns the documents where the smaller chunks are located."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "419a91c4",
"metadata": {},
"outputs": [],
"source": [
"retrieved_docs = retriever.get_relevant_documents(\"justice breyer\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "cf10d250",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"38539"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(retrieved_docs[0].page_content)"
]
},
{
"cell_type": "markdown",
"id": "14f813a5",
"metadata": {},
"source": [
"## Retrieving Larger Chunks\n",
"\n",
"Sometimes, the full documents can be too big to want to retrieve them as is. In that case, what we really want to do is to first split the raw documents into larger chunks, and then split it into smaller chunks. We then index the smaller chunks, but on retrieval we retrieve the larger chunks (but still not the full documents)."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b6f9a4f0",
"metadata": {},
"outputs": [],
"source": [
"# This text splitter is used to create the parent documents\n",
"parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000)\n",
"# This text splitter is used to create the child documents\n",
"# It should create documents smaller than the parent\n",
"child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)\n",
"# The vectorstore to use to index the child chunks\n",
"vectorstore = Chroma(collection_name=\"split_parents\", embedding_function=OpenAIEmbeddings())\n",
"# The storage layer for the parent documents\n",
"store = InMemoryStore()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "19478ff3",
"metadata": {},
"outputs": [],
"source": [
"retriever = ParentDocumentRetriever(\n",
" vectorstore=vectorstore, \n",
" docstore=store, \n",
" child_splitter=child_splitter,\n",
" parent_splitter=parent_splitter,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "fe16e620",
"metadata": {},
"outputs": [],
"source": [
"retriever.add_documents(docs)"
]
},
{
"cell_type": "markdown",
"id": "64ad3c8c",
"metadata": {},
"source": [
"We can see that there are much more than two documents now - these are the larger chunks"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "24d81886",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"66"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(list(store.yield_keys()))"
]
},
{
"cell_type": "markdown",
"id": "baaef673",
"metadata": {},
"source": [
"Let's make sure the underlying vectorstore still retrieves the small chunks."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "b1c859de",
"metadata": {},
"outputs": [],
"source": [
"sub_docs = vectorstore.similarity_search(\"justice breyer\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "6fffa2eb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
"\n",
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court.\n"
]
}
],
"source": [
"print(sub_docs[0].page_content)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "3a3202df",
"metadata": {},
"outputs": [],
"source": [
"retrieved_docs = retriever.get_relevant_documents(\"justice breyer\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "684fdb2c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1849"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(retrieved_docs[0].page_content)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "9f17f662",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. \n",
"\n",
"We cannot let this happen. \n",
"\n",
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
"\n",
"Tonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
"\n",
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
"\n",
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence. \n",
"\n",
"A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since shes been nominated, shes received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans. \n",
"\n",
"And if we are to advance liberty and justice, we need to secure the Border and fix the immigration system. \n",
"\n",
"We can do both. At our border, weve installed new technology like cutting-edge scanners to better detect drug smuggling. \n",
"\n",
"Weve set up joint patrols with Mexico and Guatemala to catch more human traffickers. \n",
"\n",
"Were putting in place dedicated immigration judges so families fleeing persecution and violence can have their cases heard faster. \n",
"\n",
"Were securing commitments and supporting partners in South and Central America to host more refugees and secure their own borders.\n"
]
}
],
"source": [
"print(retrieved_docs[0].page_content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "facfdacb",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -475,7 +475,8 @@ class Agent(BaseSingleActionAgent):
"""
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs)
return self.output_parser.parse(full_output)
agent_output = await self.output_parser.aparse(full_output)
return agent_output
def get_full_inputs(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any

View File

@@ -31,8 +31,19 @@ MODEL_COST_PER_1K_TOKENS = {
"gpt-3.5-turbo-0613-completion": 0.002,
"gpt-3.5-turbo-16k-completion": 0.004,
"gpt-3.5-turbo-16k-0613-completion": 0.004,
# Azure GPT-35 input
"gpt-35-turbo": 0.0015, # Azure OpenAI version of ChatGPT
"gpt-35-turbo-0301": 0.0015, # Azure OpenAI version of ChatGPT
"gpt-35-turbo-0613": 0.0015,
"gpt-35-turbo-16k": 0.003,
"gpt-35-turbo-16k-0613": 0.003,
# Azure GPT-35 output
"gpt-35-turbo-completion": 0.002, # Azure OpenAI version of ChatGPT
"gpt-35-turbo-0301-completion": 0.002, # Azure OpenAI version of ChatGPT
"gpt-35-turbo-0613-completion": 0.002,
"gpt-35-turbo-16k-completion": 0.004,
"gpt-35-turbo-16k-0613-completion": 0.004,
# Others
"gpt-35-turbo": 0.002, # Azure OpenAI version of ChatGPT
"text-ada-001": 0.0004,
"ada": 0.0004,
"text-babbage-001": 0.0005,
@@ -69,7 +80,9 @@ def standardize_model_name(
if "ft-" in model_name:
return model_name.split(":")[0] + "-finetuned"
elif is_completion and (
model_name.startswith("gpt-4") or model_name.startswith("gpt-3.5")
model_name.startswith("gpt-4")
or model_name.startswith("gpt-3.5")
or model_name.startswith("gpt-35")
):
return model_name + "-completion"
else:

View File

@@ -47,6 +47,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
parent_run = self.run_map[str(run.parent_run_id)]
if parent_run:
self._add_child_run(parent_run, run)
parent_run.child_execution_order = max(
parent_run.child_execution_order, run.child_execution_order
)
else:
logger.debug(f"Parent run with UUID {run.parent_run_id} not found.")
self.run_map[str(run.id)] = run
@@ -131,7 +134,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != "llm":
raise TracerException("No LLM Run found to be traced")
raise TracerException(f"No LLM Run found to be traced for {run_id}")
llm_run.events.append(
{
"name": "new_token",
@@ -183,7 +186,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != "llm":
raise TracerException("No LLM Run found to be traced")
raise TracerException(f"No LLM Run found to be traced for {run_id}")
llm_run.outputs = response.dict()
for i, generations in enumerate(response.generations):
for j, generation in enumerate(generations):
@@ -211,7 +214,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != "llm":
raise TracerException("No LLM Run found to be traced")
raise TracerException(f"No LLM Run found to be traced for {run_id}")
llm_run.error = repr(error)
llm_run.end_time = datetime.utcnow()
llm_run.events.append({"name": "error", "time": llm_run.end_time})
@@ -254,18 +257,25 @@ class BaseTracer(BaseCallbackHandler, ABC):
self._on_chain_start(chain_run)
def on_chain_end(
self, outputs: Dict[str, Any], *, run_id: UUID, **kwargs: Any
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""End a trace for a chain run."""
if not run_id:
raise TracerException("No run_id provided for on_chain_end callback.")
chain_run = self.run_map.get(str(run_id))
if chain_run is None:
raise TracerException("No chain Run found to be traced")
raise TracerException(f"No chain Run found to be traced for {run_id}")
chain_run.outputs = outputs
chain_run.end_time = datetime.utcnow()
chain_run.events.append({"name": "end", "time": chain_run.end_time})
if inputs is not None:
chain_run.inputs = inputs
self._end_trace(chain_run)
self._on_chain_end(chain_run)
@@ -273,6 +283,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
self,
error: Union[Exception, KeyboardInterrupt],
*,
inputs: Optional[Dict[str, Any]] = None,
run_id: UUID,
**kwargs: Any,
) -> None:
@@ -281,11 +292,13 @@ class BaseTracer(BaseCallbackHandler, ABC):
raise TracerException("No run_id provided for on_chain_error callback.")
chain_run = self.run_map.get(str(run_id))
if chain_run is None:
raise TracerException("No chain Run found to be traced")
raise TracerException(f"No chain Run found to be traced for {run_id}")
chain_run.error = repr(error)
chain_run.end_time = datetime.utcnow()
chain_run.events.append({"name": "error", "time": chain_run.end_time})
if inputs is not None:
chain_run.inputs = inputs
self._end_trace(chain_run)
self._on_chain_error(chain_run)
@@ -329,7 +342,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
raise TracerException("No run_id provided for on_tool_end callback.")
tool_run = self.run_map.get(str(run_id))
if tool_run is None or tool_run.run_type != "tool":
raise TracerException("No tool Run found to be traced")
raise TracerException(f"No tool Run found to be traced for {run_id}")
tool_run.outputs = {"output": output}
tool_run.end_time = datetime.utcnow()
@@ -349,7 +362,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
raise TracerException("No run_id provided for on_tool_error callback.")
tool_run = self.run_map.get(str(run_id))
if tool_run is None or tool_run.run_type != "tool":
raise TracerException("No tool Run found to be traced")
raise TracerException(f"No tool Run found to be traced for {run_id}")
tool_run.error = repr(error)
tool_run.end_time = datetime.utcnow()
@@ -404,7 +417,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
raise TracerException("No run_id provided for on_retriever_error callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != "retriever":
raise TracerException("No retriever Run found to be traced")
raise TracerException(f"No retriever Run found to be traced for {run_id}")
retrieval_run.error = repr(error)
retrieval_run.end_time = datetime.utcnow()
@@ -420,7 +433,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
raise TracerException("No run_id provided for on_retriever_end callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != "retriever":
raise TracerException("No retriever Run found to be traced")
raise TracerException(f"No retriever Run found to be traced for {run_id}")
retrieval_run.outputs = {"documents": documents}
retrieval_run.end_time = datetime.utcnow()
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})

View File

@@ -1,9 +1,11 @@
"""Base interface that all chains should implement."""
import asyncio
import inspect
import json
import logging
import warnings
from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
@@ -55,18 +57,40 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
"""
def invoke(
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return self(input, **(config or {}))
config = config or {}
return self(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
async def ainvoke(
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
if type(self)._acall == Chain._acall:
# If the chain does not implement async, fall back to default implementation
return await super().ainvoke(input, config)
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs)
)
return await self.acall(input, **(config or {}))
config = config or {}
return await self.acall(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None.

View File

@@ -3,6 +3,8 @@ import functools
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional
from pydantic import Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
@@ -27,9 +29,11 @@ class TransformChain(Chain):
"""The keys expected by the transform's input dictionary."""
output_variables: List[str]
"""The keys returned by the transform's output dictionary."""
transform: Callable[[Dict[str, str]], Dict[str, str]]
transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform")
"""The transform function."""
atransform: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = None
atransform_cb: Optional[
Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]
] = Field(None, alias="atransform")
"""The async coroutine transform function."""
@staticmethod
@@ -62,18 +66,18 @@ class TransformChain(Chain):
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
return self.transform(inputs)
return self.transform_cb(inputs)
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if self.atransform is not None:
return await self.atransform(inputs)
if self.atransform_cb is not None:
return await self.atransform_cb(inputs)
else:
self._log_once(
"TransformChain's atransform is not provided, falling"
" back to synchronous transform"
)
return self.transform(inputs)
return self.transform_cb(inputs)

View File

@@ -40,11 +40,20 @@ class AzureChatOpenAI(ChatOpenAI):
Be aware the API version may change.
You can also specify the version of the model using ``model_version`` constructor
parameter, as Azure OpenAI doesn't return model version with the response.
Default is empty. When you specify the version, it will be appended to the
model name in the response. Setting correct version will help you to calculate the
cost properly. Model version is not validated, so make sure you set it correctly
to get the correct cost.
Any parameters that are valid to be passed to the openai.create call can be passed
in, even if not explicitly saved on this class.
"""
deployment_name: str = ""
model_version: str = ""
openai_api_type: str = ""
openai_api_base: str = ""
openai_api_version: str = ""
@@ -137,7 +146,19 @@ class AzureChatOpenAI(ChatOpenAI):
for res in response["choices"]:
if res.get("finish_reason", None) == "content_filter":
raise ValueError(
"Azure has not provided the response due to a content"
" filter being triggered"
"Azure has not provided the response due to a content filter "
"being triggered"
)
return super()._create_chat_result(response)
chat_result = super()._create_chat_result(response)
if "model" in response:
model = response["model"]
if self.model_version:
model = f"{model}-{self.model_version}"
if chat_result.llm_output is not None and isinstance(
chat_result.llm_output, dict
):
chat_result.llm_output["model_name"] = model
return chat_result

View File

@@ -103,12 +103,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessageChunk:
config = config or {}
return cast(
BaseMessageChunk,
cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
).generations[0][0],
).message,
)
@@ -127,8 +133,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
None, partial(self.invoke, input, config, stop=stop, **kwargs)
)
config = config or {}
llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
return cast(
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message

View File

@@ -1,9 +1,13 @@
"""Fake ChatModel for testing purposes."""
from typing import Any, Dict, List, Optional
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import SimpleChatModel
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import AIMessageChunk, BaseMessage
from langchain.schema.output import ChatGenerationChunk
class FakeListChatModel(SimpleChatModel):
@@ -31,6 +35,36 @@ class FakeListChatModel(SimpleChatModel):
self.i = 0
return response
def _stream(
self,
messages: List[BaseMessage],
stop: Union[List[str], None] = None,
run_manager: Union[CallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
for c in response:
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
async def _astream(
self,
messages: List[BaseMessage],
stop: Union[List[str], None] = None,
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
for c in response:
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
@property
def _identifying_params(self) -> Dict[str, Any]:
return {"responses": self.responses}

View File

@@ -1,10 +1,9 @@
"""Loads local airbyte json files."""
from typing import Any, Callable, Iterator, List, Mapping, Optional
from libs.langchain.langchain.utils.utils import guard_import
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
from langchain.utils.utils import guard_import
RecordHandler = Callable[[Any, Optional[str]], Document]

View File

@@ -5,6 +5,7 @@ the sequence of actions taken and their outcomes. It uses a language model
chain (LLMChain) to generate the reasoning and scores.
"""
import re
from typing import (
Any,
Dict,
@@ -74,15 +75,24 @@ class TrajectoryOutputParser(BaseOutputParser):
reasoning, score_str = reasoning.strip(), score_str.strip()
score_str = next(
(char for char in score_str if char.isdigit()), "0"
) # Scan for first digit
if not 1 <= int(score_str) <= 5:
# Use regex to extract the score.
# This will get the number in the string, even if it is a float or more than 10.
# E.g. "Score: 1" will return 1, "Score: 3.5" will return 3.5, and
# "Score: 10" will return 10.
# The score should be an integer digit in the range 1-5.
_score = re.search(r"(\d+(\.\d+)?)", score_str)
# If the score is not found or is a float, raise an exception.
if _score is None or "." in _score.group(1):
raise OutputParserException(
f"Score is not an integer digit in the range 1-5: {text}"
)
score = int(_score.group(1))
# If the score is not in the range 1-5, raise an exception.
if not 1 <= score <= 5:
raise OutputParserException(
f"Score is not a digit in the range 1-5: {text}"
)
normalized_score = (int(score_str) - 1) / 4
normalized_score = (score - 1) / 4
return TrajectoryEval(score=normalized_score, reasoning=reasoning)

View File

@@ -1,6 +1,5 @@
"""Interfaces to be implemented by general evaluators."""
from __future__ import annotations
import asyncio
import logging
from abc import ABC, abstractmethod
@@ -169,13 +168,9 @@ class StringEvaluator(_EvalArgsMixin, ABC):
- value: the string value of the evaluation, if applicable.
- reasoning: the reasoning for the evaluation, if applicable.
""" # noqa: E501
return await asyncio.get_running_loop().run_in_executor(
None,
self._evaluate_strings,
prediction=prediction,
reference=reference,
input=input,
**kwargs,
raise NotImplementedError(
f"{self.__class__.__name__} hasn't implemented an async "
"aevaluate_strings method."
)
def evaluate_strings(
@@ -270,14 +265,9 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
Returns:
dict: A dictionary containing the preference, scores, and/or other information.
""" # noqa: E501
return await asyncio.get_running_loop().run_in_executor(
None,
self._evaluate_string_pairs,
prediction=prediction,
prediction_b=prediction_b,
reference=reference,
input=input,
**kwargs,
raise NotImplementedError(
f"{self.__class__.__name__} hasn't implemented an async "
"aevaluate_string_pairs method."
)
def evaluate_string_pairs(
@@ -391,14 +381,9 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
Returns:
dict: The evaluation result.
"""
raise asyncio.get_running_loop().run_in_executor(
None,
self._evaluate_agent_trajectory,
prediction=prediction,
agent_trajectory=agent_trajectory,
input=input,
reference=reference,
**kwargs,
raise NotImplementedError(
f"{self.__class__.__name__} hasn't implemented an async "
"aevaluate_agent_trajectory method."
)
def evaluate_agent_trajectory(

View File

@@ -219,9 +219,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
config = config or {}
return (
self.generate_prompt(
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
.generations[0][0]
.text
@@ -241,8 +247,14 @@ class BaseLLM(BaseLanguageModel[str], ABC):
None, partial(self.invoke, input, config, stop=stop, **kwargs)
)
config = config or {}
llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
return llm_result.generations[0][0].text

View File

@@ -1,10 +1,12 @@
from typing import Any, List, Mapping, Optional
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.schema.language_model import LanguageModelInput
from langchain.schema.runnable import RunnableConfig
class FakeListLLM(LLM):
@@ -51,3 +53,29 @@ class FakeListLLM(LLM):
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"responses": self.responses}
class FakeStreamingListLLM(FakeListLLM):
def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
result = self.invoke(input, config)
for c in result:
yield c
async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
result = await self.ainvoke(input, config)
for c in result:
yield c

View File

@@ -12,6 +12,7 @@ from langchain.memory.chat_message_histories.momento import MomentoChatMessageHi
from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHistory
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
from langchain.memory.chat_message_histories.rocksetdb import RocksetChatMessageHistory
from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory
from langchain.memory.chat_message_histories.streamlit import (
StreamlitChatMessageHistory,
@@ -29,6 +30,7 @@ __all__ = [
"MongoDBChatMessageHistory",
"PostgresChatMessageHistory",
"RedisChatMessageHistory",
"RocksetChatMessageHistory",
"SQLChatMessageHistory",
"StreamlitChatMessageHistory",
"ZepChatMessageHistory",

View File

@@ -0,0 +1,258 @@
from datetime import datetime
from time import sleep
from typing import Any, Callable, List, Union
from uuid import uuid4
from langchain.schema import BaseChatMessageHistory
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
class RocksetChatMessageHistory(BaseChatMessageHistory):
"""Uses Rockset to store chat messages.
To use, ensure that the `rockset` python package installed.
Example:
.. code-block:: python
from langchain.memory.chat_message_histories import (
RocksetChatMessageHistory
)
from rockset import RocksetClient
history = RocksetChatMessageHistory(
session_id="MySession",
client=RocksetClient(),
collection="langchain_demo",
sync=True
)
history.add_user_message("hi!")
history.add_ai_message("whats up?")
print(history.messages)
"""
# You should set these values based on your VI.
# These values are configured for the typical
# free VI. Read more about VIs here:
# https://rockset.com/docs/instances
SLEEP_INTERVAL_MS = 5
ADD_TIMEOUT_MS = 5000
CREATE_TIMEOUT_MS = 20000
def _wait_until(self, method: Callable, timeout: int, **method_params: Any) -> None:
"""Sleeps until meth() evaluates to true. Passes kwargs into
meth.
"""
start = datetime.now()
while not method(**method_params):
curr = datetime.now()
if (curr - start).total_seconds() * 1000 > timeout:
raise TimeoutError(f"{method} timed out at {timeout} ms")
sleep(RocksetChatMessageHistory.SLEEP_INTERVAL_MS / 1000)
def _query(self, query: str, **query_params: Any) -> List[Any]:
"""Executes an SQL statement and returns the result
Args:
- query: The SQL string
- **query_params: Parameters to pass into the query
"""
return self.client.sql(query, params=query_params).results
def _create_collection(self) -> None:
"""Creates a collection for this message history"""
self.client.Collections.create_s3_collection(
name=self.collection, workspace=self.workspace
)
def _collection_exists(self) -> bool:
"""Checks whether a collection exists for this message history"""
try:
self.client.Collections.get(collection=self.collection)
except self.rockset.exceptions.NotFoundException:
return False
return True
def _collection_is_ready(self) -> bool:
"""Checks whether the collection for this message history is ready
to be queried
"""
return (
self.client.Collections.get(collection=self.collection).data.status
== "READY"
)
def _document_exists(self) -> bool:
return (
len(
self._query(
f"""
SELECT 1
FROM {self.location}
WHERE _id=:session_id
LIMIT 1
""",
session_id=self.session_id,
)
)
!= 0
)
def _wait_until_collection_created(self) -> None:
"""Sleeps until the collection for this message history is ready
to be queried
"""
self._wait_until(
lambda: self._collection_is_ready(),
RocksetChatMessageHistory.CREATE_TIMEOUT_MS,
)
def _wait_until_message_added(self, message_id: str) -> None:
"""Sleeps until a message is added to the messages list"""
self._wait_until(
lambda message_id: len(
self._query(
f"""
SELECT *
FROM UNNEST((
SELECT {self.messages_key}
FROM {self.location}
WHERE _id = :session_id
)) AS message
WHERE message.data.additional_kwargs.id = :message_id
LIMIT 1
""",
session_id=self.session_id,
message_id=message_id,
),
)
!= 0,
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
message_id=message_id,
)
def _create_empty_doc(self) -> None:
"""Creates or replaces a document for this message history with no
messages"""
self.client.Documents.add_documents(
collection=self.collection,
workspace=self.workspace,
data=[{"_id": self.session_id, self.messages_key: []}],
)
def __init__(
self,
session_id: str,
client: Any,
collection: str,
workspace: str = "commons",
messages_key: str = "messages",
sync: bool = False,
message_uuid_method: Callable[[], Union[str, int]] = lambda: str(uuid4()),
) -> None:
"""Constructs a new RocksetChatMessageHistory.
Args:
- session_id: The ID of the chat session
- client: The RocksetClient object to use to query
- collection: The name of the collection to use to store chat
messages. If a collection with the given name
does not exist in the workspace, it is created.
- workspace: The workspace containing `collection`. Defaults
to `"commons"`
- messages_key: The DB column containing message history.
Defaults to `"messages"`
- sync: Whether to wait for messages to be added. Defaults
to `False`. NOTE: setting this to `True` will slow
down performance.
- message_uuid_method: The method that generates message IDs.
If set, all messages will have an `id` field within the
`additional_kwargs` property. If this param is not set
and `sync` is `False`, message IDs will not be created.
If this param is not set and `sync` is `True`, the
`uuid.uuid4` method will be used to create message IDs.
"""
try:
import rockset
except ImportError:
raise ImportError(
"Could not import rockset client python package. "
"Please install it with `pip install rockset`."
)
if not isinstance(client, rockset.RocksetClient):
raise ValueError(
f"client should be an instance of rockset.RocksetClient, "
f"got {type(client)}"
)
self.session_id = session_id
self.client = client
self.collection = collection
self.workspace = workspace
self.location = f'"{self.workspace}"."{self.collection}"'
self.rockset = rockset
self.messages_key = messages_key
self.message_uuid_method = message_uuid_method
self.sync = sync
if not self._collection_exists():
self._create_collection()
self._wait_until_collection_created()
self._create_empty_doc()
elif not self._document_exists():
self._create_empty_doc()
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Messages in this chat history."""
return messages_from_dict(
self._query(
f"""
SELECT *
FROM UNNEST ((
SELECT "{self.messages_key}"
FROM {self.location}
WHERE _id = :session_id
))
""",
session_id=self.session_id,
)
)
def add_message(self, message: BaseMessage) -> None:
"""Add a Message object to the history.
Args:
message: A BaseMessage object to store.
"""
if self.sync and "id" not in message.additional_kwargs:
message.additional_kwargs["id"] = self.message_uuid_method()
self.client.Documents.patch_documents(
collection=self.collection,
workspace=self.workspace,
data=[
self.rockset.model.patch_document.PatchDocument(
id=self.session_id,
patch=[
self.rockset.model.patch_operation.PatchOperation(
op="ADD",
path=f"/{self.messages_key}/-",
value=_message_to_dict(message),
)
],
)
],
)
if self.sync:
self._wait_until_message_added(message.additional_kwargs["id"])
def clear(self) -> None:
"""Removes all messages from the chat history"""
self._create_empty_doc()
if self.sync:
self._wait_until(
lambda: not self.messages,
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
)

View File

@@ -40,6 +40,7 @@ from langchain.retrievers.merger_retriever import MergerRetriever
from langchain.retrievers.metal import MetalRetriever
from langchain.retrievers.milvus import MilvusRetriever
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers.parent_document_retriever import ParentDocumentRetriever
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
from langchain.retrievers.pubmed import PubMedRetriever
from langchain.retrievers.re_phraser import RePhraseQueryRetriever
@@ -90,4 +91,5 @@ __all__ = [
"RePhraseQueryRetriever",
"WebResearchRetriever",
"EnsembleRetriever",
"ParentDocumentRetriever",
]

View File

@@ -0,0 +1,139 @@
import uuid
from typing import Any, Dict, List, Optional
from langchain.callbacks.base import Callbacks
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
from langchain.schema.storage import BaseStore
from langchain.text_splitter import TextSplitter
from langchain.vectorstores.base import VectorStore
class ParentDocumentRetriever(BaseRetriever):
"""Fetches small chunks, then fetches their parent documents.
When splitting documents for retrieval, there are often conflicting desires:
1. You may want to have small documents, so that their embeddings can most
accurately reflect their meaning. If too long, then the embeddings can
lose meaning.
2. You want to have long enough documents that the context of each chunk is
retained.
The ParentDocumentRetriever strikes that balance by splitting and storing
small chunks of data. During retrieval, it first fetches the small chunks
but then looks up the parent ids for those chunks and returns those larger
documents.
Note that "parent document" refers to the document that a small chunk
originated from. This can either be the whole raw document OR a larger
chunk.
Examples:
... code-block:: python
# Imports
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.storage import InMemoryStore
# This text splitter is used to create the parent documents
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000)
# This text splitter is used to create the child documents
# It should create documents smaller than the parent
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
# The vectorstore to use to index the child chunks
vectorstore = Chroma(embedding_function=OpenAIEmbeddings())
# The storage layer for the parent documents
store = InMemoryStore()
# Initialize the retriever
retriever = ParentDocumentRetriever(
vectorstore=vectorstore,
docstore=store,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
)
"""
vectorstore: VectorStore
"""The underlying vectorstore to use to store small chunks
and their embedding vectors"""
docstore: BaseStore[str, Document]
"""The storage layer for the parent documents"""
child_splitter: TextSplitter
"""The text splitter to use to create child documents."""
id_key: str = "doc_id"
"""The key to use to track the parent id. This will be stored in the
metadata of child documents."""
parent_splitter: Optional[TextSplitter] = None
"""The text splitter to use to create parent documents.
If none, then the parent documents will be the raw documents passed in."""
def get_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
sub_docs = self.vectorstore.similarity_search(query)
# We do this to maintain the order of the ids that are returned
ids = []
for d in sub_docs:
if d.metadata[self.id_key] not in ids:
ids.append(d.metadata[self.id_key])
docs = self.docstore.mget(ids)
return [d for d in docs if d is not None]
def add_documents(
self,
documents: List[Document],
ids: Optional[List[str]],
add_to_docstore: bool = True,
) -> None:
"""Adds documents to the docstore and vectorstores.
Args:
documents: List of documents to add
ids: Optional list of ids for documents. If provided should be the same
length as the list of documents. Can provided if parent documents
are already in the document store and you don't want to re-add
to the docstore. If not provided, random UUIDs will be used as
ids.
add_to_docstore: Boolean of whether to add documents to docstore.
This can be false if and only if `ids` are provided. You may want
to set this to False if the documents are already in the docstore
and you don't want to re-add them.
"""
if self.parent_splitter is not None:
documents = self.parent_splitter.split_documents(documents)
if ids is None:
doc_ids = [str(uuid.uuid4()) for _ in documents]
if not add_to_docstore:
raise ValueError(
"If ids are not passed in, `add_to_docstore` MUST be True"
)
else:
if len(documents) != len(ids):
raise ValueError(
"Got uneven list of documents and ids. "
"If `ids` is provided, should be same length as `documents`."
)
doc_ids = ids
docs = []
full_docs = []
for i, doc in enumerate(documents):
_id = doc_ids[i]
sub_docs = self.child_splitter.split_documents([doc])
for _doc in sub_docs:
_doc.metadata[self.id_key] = _id
docs.extend(sub_docs)
full_docs.append((_id, doc))
self.vectorstore.add_documents(docs)
if add_to_docstore:
self.docstore.mset(full_docs)

View File

@@ -0,0 +1,46 @@
from operator import itemgetter
from typing import Any, Callable, List, Mapping, Optional, Union
from typing_extensions import TypedDict
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.schema.output import ChatGeneration
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
class OpenAIFunction(TypedDict):
"""A function description for ChatOpenAI"""
name: str
"""The name of the function."""
description: str
"""The description of the function."""
parameters: dict
"""The parameters to the function."""
class OpenAIFunctionsRouter(RunnableBinding[ChatGeneration, Any]):
"""A runnable that routes to the selected function."""
functions: Optional[List[OpenAIFunction]]
def __init__(
self,
runnables: Mapping[
str,
Union[
Runnable[dict, Any],
Callable[[dict], Any],
],
],
functions: Optional[List[OpenAIFunction]] = None,
):
if functions is not None:
assert len(functions) == len(runnables)
assert all(func["name"] in runnables for func in functions)
router = (
JsonOutputFunctionsParser(args_only=False)
| {"key": itemgetter("name"), "input": itemgetter("arguments")}
| RouterRunnable(runnables)
)
super().__init__(bound=router, kwargs={}, functions=functions)

View File

@@ -1,7 +1,18 @@
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
from typing import (
Any,
AsyncIterator,
Dict,
Generic,
Iterator,
List,
Optional,
TypeVar,
Union,
)
from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage
@@ -27,12 +38,26 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
Structured output.
"""
async def aparse_result(self, result: List[Generation]) -> T:
"""Parse a list of candidate model Generations into a specific format.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.parse_result, result
)
class BaseGenerationOutputParser(
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
):
def invoke(
self, input: str | BaseMessage, config: RunnableConfig | None = None
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
) -> T:
if isinstance(input, BaseMessage):
return self._call_with_config(
@@ -51,6 +76,26 @@ class BaseGenerationOutputParser(
run_type="parser",
)
async def ainvoke(
self, input: str | BaseMessage, config: RunnableConfig | None = None
) -> T:
if isinstance(input, BaseMessage):
return await self._acall_with_config(
lambda inner_input: self.aparse_result(
[ChatGeneration(message=inner_input)]
),
input,
config,
run_type="parser",
)
else:
return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]):
"""Base class to parse the output of an LLM call.
@@ -80,7 +125,7 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
""" # noqa: E501
def invoke(
self, input: str | BaseMessage, config: RunnableConfig | None = None
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
) -> T:
if isinstance(input, BaseMessage):
return self._call_with_config(
@@ -99,6 +144,26 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
run_type="parser",
)
async def ainvoke(
self, input: str | BaseMessage, config: RunnableConfig | None = None
) -> T:
if isinstance(input, BaseMessage):
return await self._acall_with_config(
lambda inner_input: self.aparse_result(
[ChatGeneration(message=inner_input)]
),
input,
config,
run_type="parser",
)
else:
return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
def parse_result(self, result: List[Generation]) -> T:
"""Parse a list of candidate model Generations into a specific format.
@@ -125,6 +190,32 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
Structured output.
"""
async def aparse_result(self, result: List[Generation]) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
is assumed to be the highest-likelihood Generation.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
return await self.aparse(result[0].text)
async def aparse(self, text: str) -> T:
"""Parse a single string model output into some structure.
Args:
text: String output of a language model.
Returns:
Structured output.
"""
return await asyncio.get_running_loop().run_in_executor(None, self.parse, text)
# TODO: rename 'completion' -> 'text'.
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
"""Parse the output of an LLM call with the input prompt for context.
@@ -161,8 +252,47 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
return output_parser_dict
class StrOutputParser(BaseOutputParser[str]):
"""OutputParser that parses LLMResult into the top likely string.."""
class BaseTransformOutputParser(BaseOutputParser[T]):
"""Base class for an output parser that can handle streaming input."""
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]:
for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
else:
yield self.parse_result([Generation(text=chunk)])
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]:
async for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
else:
yield self.parse_result([Generation(text=chunk)])
def transform(
self,
input: Iterator[Union[str, BaseMessage]],
config: Optional[RunnableConfig] = None,
) -> Iterator[T]:
yield from self._transform_stream_with_config(
input, self._transform, config, run_type="parser"
)
async def atransform(
self,
input: AsyncIterator[Union[str, BaseMessage]],
config: Optional[RunnableConfig] = None,
) -> AsyncIterator[T]:
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, run_type="parser"
):
yield chunk
class StrOutputParser(BaseTransformOutputParser[str]):
"""OutputParser that parses LLMResult into the top likely string."""
@property
def lc_serializable(self) -> bool:

View File

@@ -107,7 +107,13 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
def invoke(
self, input: str, config: Optional[RunnableConfig] = None
) -> List[Document]:
return self.get_relevant_documents(input, **(config or {}))
config = config or {}
return self.get_relevant_documents(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
)
async def ainvoke(
self, input: str, config: Optional[RunnableConfig] = None
@@ -116,7 +122,13 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
# If the retriever doesn't implement async, use default implementation
return await super().ainvoke(input, config)
return await self.aget_relevant_documents(input, **(config or {}))
config = config or {}
return await self.aget_relevant_documents(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
)
@abstractmethod
def _get_relevant_documents(

View File

@@ -3,9 +3,11 @@ from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from itertools import tee
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Dict,
@@ -28,6 +30,7 @@ from pydantic import Field
from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.utils.aiter import atee, py_anext
async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
@@ -91,6 +94,8 @@ class Runnable(Generic[Input, Output], ABC):
) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
""" --- Public API --- """
@abstractmethod
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
...
@@ -98,6 +103,10 @@ class Runnable(Generic[Input, Output], ABC):
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
"""
Default implementation of ainvoke, which calls invoke in a thread pool.
Subclasses should override this method if they can run asynchronously.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.invoke, input, config
)
@@ -109,6 +118,10 @@ class Runnable(Generic[Input, Output], ABC):
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
"""
Default implementation of batch, which calls invoke N times.
Subclasses should override this method if they can batch more efficiently.
"""
configs = self._get_config_list(config, len(inputs))
# If there's only one input, don't bother with the executor
@@ -125,6 +138,10 @@ class Runnable(Generic[Input, Output], ABC):
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
"""
Default implementation of abatch, which calls ainvoke N times.
Subclasses should override this method if they can batch more efficiently.
"""
configs = self._get_config_list(config, len(inputs))
coros = map(self.ainvoke, inputs, configs)
@@ -133,22 +150,90 @@ class Runnable(Generic[Input, Output], ABC):
def stream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
"""
Default implementation of stream, which calls invoke.
Subclasses should override this method if they support streaming output.
"""
yield self.invoke(input, config)
async def astream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
"""
Default implementation of astream, which calls ainvoke.
Subclasses should override this method if they support streaming output.
"""
yield await self.ainvoke(input, config)
def transform(
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
"""
Default implementation of transform, which buffers input and then calls stream.
Subclasses should override this method if they can start producing output while
input is still being generated.
"""
final: Union[Input, None] = None
for chunk in input:
if final is None:
final = chunk
else:
# Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails.
final += chunk # type: ignore[operator]
if final:
yield from self.stream(final, config)
async def atransform(
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
"""
Default implementation of atransform, which buffers input and calls astream.
Subclasses should override this method if they can start producing output while
input is still being generated.
"""
final: Union[Input, None] = None
async for chunk in input:
if final is None:
final = chunk
else:
# Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails.
final += chunk # type: ignore[operator]
if final:
async for output in self.astream(final, config):
yield output
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
"""
Bind arguments to a Runnable, returning a new Runnable.
"""
return RunnableBinding(bound=self, kwargs=kwargs)
def with_fallbacks(
self,
fallbacks: Sequence[Runnable[Input, Output]],
*,
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,),
) -> RunnableWithFallbacks[Input, Output]:
return RunnableWithFallbacks(
runnable=self,
fallbacks=fallbacks,
exceptions_to_handle=exceptions_to_handle,
)
""" --- Helper methods for Subclasses --- """
def _get_config_list(
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
) -> List[RunnableConfig]:
"""
Helper method to get a list of configs from a single config or a list of
configs, useful for subclasses overriding batch() or abatch().
"""
if isinstance(config, list) and len(config) != length:
raise ValueError(
f"config must be a list of the same length as inputs, "
@@ -168,6 +253,8 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
from langchain.callbacks.manager import CallbackManager
config = config or {}
@@ -192,20 +279,187 @@ class Runnable(Generic[Input, Output], ABC):
)
return output
def with_fallbacks(
async def _acall_with_config(
self,
fallbacks: Sequence[Runnable[Input, Output]],
*,
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,),
) -> RunnableWithFallbacks[Input, Output]:
return RunnableWithFallbacks(
runnable=self,
fallbacks=fallbacks,
exceptions_to_handle=exceptions_to_handle,
func: Callable[[Input], Awaitable[Output]],
input: Input,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses."""
from langchain.callbacks.manager import AsyncCallbackManager
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
input if isinstance(input, dict) else {"input": input},
run_type=run_type,
)
try:
output = await func(input)
except Exception as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(
output if isinstance(output, dict) else {"output": output}
)
return output
def _transform_stream_with_config(
self,
input: Iterator[Input],
transformer: Callable[[Iterator[Input]], Iterator[Output]],
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
) -> Iterator[Output]:
"""Helper method to transform an Iterator of Input values into an Iterator of
Output values, with callbacks.
Use this to implement `stream()` or `transform()` in Runnable subclasses."""
from langchain.callbacks.manager import CallbackManager
# tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = tee(input, 2)
# Start the input iterator to ensure the input runnable starts before this one
final_input: Optional[Input] = next(input_for_tracing, None)
final_input_supported = True
final_output: Optional[Output] = None
final_output_supported = True
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
{"input": ""},
run_type=run_type,
)
try:
for chunk in transformer(input_for_transform):
yield chunk
if final_output_supported:
if final_output is None:
final_output = chunk
else:
try:
final_output += chunk # type: ignore[operator]
except TypeError:
final_output = None
final_output_supported = False
for ichunk in input_for_tracing:
if final_input_supported:
if final_input is None:
final_input = ichunk
else:
try:
final_input += ichunk # type: ignore[operator]
except TypeError:
final_input = None
final_input_supported = False
except Exception as e:
run_manager.on_chain_error(
e,
inputs=final_input
if isinstance(final_input, dict)
else {"input": final_input},
)
raise
else:
run_manager.on_chain_end(
final_output
if isinstance(final_output, dict)
else {"output": final_output},
inputs=final_input
if isinstance(final_input, dict)
else {"input": final_input},
)
async def _atransform_stream_with_config(
self,
input: AsyncIterator[Input],
transformer: Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
) -> AsyncIterator[Output]:
"""Helper method to transform an Async Iterator of Input values into an Async
Iterator of Output values, with callbacks.
Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
from langchain.callbacks.manager import AsyncCallbackManager
# tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = atee(input, 2)
# Start the input iterator to ensure the input runnable starts before this one
final_input: Optional[Input] = await py_anext(input_for_tracing, None)
final_input_supported = True
final_output: Optional[Output] = None
final_output_supported = True
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
{"input": ""},
run_type=run_type,
)
try:
async for chunk in transformer(input_for_transform):
yield chunk
if final_output_supported:
if final_output is None:
final_output = chunk
else:
try:
final_output += chunk # type: ignore[operator]
except TypeError:
final_output = None
final_output_supported = False
async for ichunk in input_for_tracing:
if final_input_supported:
if final_input is None:
final_input = ichunk
else:
try:
final_input += ichunk # type: ignore[operator]
except TypeError:
final_input = None
final_input_supported = False
except Exception as e:
await run_manager.on_chain_error(
e,
inputs=final_input
if isinstance(final_input, dict)
else {"input": final_input},
)
raise
else:
await run_manager.on_chain_end(
final_output
if isinstance(final_output, dict)
else {"output": final_output},
inputs=final_input
if isinstance(final_input, dict)
else {"input": final_input},
)
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
"""
A Runnable that can fallback to other Runnables if it fails.
"""
runnable: Runnable[Input, Output]
fallbacks: Sequence[Runnable[Input, Output]]
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,)
@@ -435,6 +689,10 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
class RunnableSequence(Serializable, Runnable[Input, Output]):
"""
A sequence of runnables, where the output of each is the input of the next.
"""
first: Runnable[Input, Any]
middle: List[Runnable[Any, Any]] = Field(default_factory=list)
last: Runnable[Any, Output]
@@ -706,9 +964,18 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
steps = [self.first] + self.middle + [self.last]
streaming_start_index = 0
for i in range(len(steps) - 1, 0, -1):
if type(steps[i]).transform != Runnable.transform:
streaming_start_index = i - 1
else:
break
# invoke the first steps
try:
for step in [self.first] + self.middle:
for step in steps[0:streaming_start_index]:
input = step.invoke(
input,
# mark each step as a child run
@@ -718,15 +985,20 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
run_manager.on_chain_error(e)
raise
# stream the last step
# stream the last steps
final: Union[Output, None] = None
final_supported = True
try:
for output in self.last.stream(
input,
# mark the last step as a child run
_patch_config(config, run_manager.get_child()),
):
# stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].stream(
input, _patch_config(config, run_manager.get_child())
)
# stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.transform(
final_pipeline, _patch_config(config, run_manager.get_child())
)
for output in final_pipeline:
yield output
# Accumulate output if possible, otherwise disable accumulation
if final_supported:
@@ -769,9 +1041,18 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
steps = [self.first] + self.middle + [self.last]
streaming_start_index = len(steps) - 1
for i in range(len(steps) - 1, 0, -1):
if type(steps[i]).transform != Runnable.transform:
streaming_start_index = i - 1
else:
break
# invoke the first steps
try:
for step in [self.first] + self.middle:
for step in steps[0:streaming_start_index]:
input = await step.ainvoke(
input,
# mark each step as a child run
@@ -781,15 +1062,20 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
await run_manager.on_chain_error(e)
raise
# stream the last step
# stream the last steps
final: Union[Output, None] = None
final_supported = True
try:
async for output in self.last.astream(
input,
# mark the last step as a child run
_patch_config(config, run_manager.get_child()),
):
# stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].astream(
input, _patch_config(config, run_manager.get_child())
)
# stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.atransform(
final_pipeline, _patch_config(config, run_manager.get_child())
)
async for output in final_pipeline:
yield output
# Accumulate output if possible, otherwise disable accumulation
if final_supported:
@@ -813,6 +1099,11 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
"""
A runnable that runs a mapping of runnables in parallel,
and returns a mapping of their outputs.
"""
steps: Mapping[str, Runnable[Input, Any]]
def __init__(
@@ -925,6 +1216,10 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
class RunnableLambda(Runnable[Input, Output]):
"""
A runnable that runs a callable.
"""
def __init__(self, func: Callable[[Input], Output]) -> None:
if callable(func):
self.func = func
@@ -945,6 +1240,10 @@ class RunnableLambda(Runnable[Input, Output]):
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
"""
A runnable that passes through the input.
"""
@property
def lc_serializable(self) -> bool:
return True
@@ -954,6 +1253,10 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
class RunnableBinding(Serializable, Runnable[Input, Output]):
"""
A runnable that delegates calls to another runnable with a set of kwargs.
"""
bound: Runnable[Input, Output]
kwargs: Mapping[str, Any]
@@ -1009,6 +1312,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
async for item in self.bound.astream(input, config, **self.kwargs):
yield item
def transform(
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
yield from self.bound.transform(input, config, **self.kwargs)
async def atransform(
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
async for item in self.bound.atransform(input, config, **self.kwargs):
yield item
class RouterInput(TypedDict):
key: str
@@ -1018,10 +1332,22 @@ class RouterInput(TypedDict):
class RouterRunnable(
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
):
"""
A runnable that routes to a set of runnables based on Input['key'].
Returns the output of the selected runnable.
"""
runnables: Mapping[str, Runnable[Input, Output]]
def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None:
super().__init__(runnables=runnables)
def __init__(
self,
runnables: Mapping[
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
],
) -> None:
super().__init__(
runnables={key: _coerce_to_runnable(r) for key, r in runnables.items()}
)
class Config:
arbitrary_types_allowed = True

View File

@@ -502,6 +502,18 @@ def _construct_run_evaluator(
return run_evaluator
def _get_keys(
config: RunEvalConfig,
run_inputs: Optional[List[str]],
run_outputs: Optional[List[str]],
example_outputs: Optional[List[str]],
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
input_key = _determine_input_key(config, run_inputs)
prediction_key = _determine_prediction_key(config, run_outputs)
reference_key = _determine_reference_key(config, example_outputs)
return input_key, prediction_key, reference_key
def _load_run_evaluators(
config: RunEvalConfig,
run_type: str,
@@ -521,9 +533,13 @@ def _load_run_evaluators(
"""
eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0)
run_evaluators = []
input_key = _determine_input_key(config, run_inputs)
prediction_key = _determine_prediction_key(config, run_outputs)
reference_key = _determine_reference_key(config, example_outputs)
input_key, prediction_key, reference_key = None, None, None
if config.evaluators or any(
[isinstance(e, EvaluatorType) for e in config.evaluators]
):
input_key, prediction_key, reference_key = _get_keys(
config, run_inputs, run_outputs, example_outputs
)
for eval_config in config.evaluators:
run_evaluator = _construct_run_evaluator(
eval_config,
@@ -1074,15 +1090,15 @@ def _run_on_examples(
A dictionary mapping example ids to the model outputs.
"""
results: Dict[str, Any] = {}
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory)
project_name = _get_project_name(project_name, llm_or_chain_factory)
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
project_name = _get_project_name(project_name, wrapped_model)
tracer = LangChainTracer(
project_name=project_name, client=client, use_threading=False
)
run_evaluators, examples = _setup_evaluation(
llm_or_chain_factory, examples, evaluation, data_type
wrapped_model, examples, evaluation, data_type
)
examples = _validate_example_inputs(examples, llm_or_chain_factory, input_mapper)
examples = _validate_example_inputs(examples, wrapped_model, input_mapper)
evalution_handler = EvaluatorCallbackHandler(
evaluators=run_evaluators or [],
client=client,
@@ -1091,7 +1107,7 @@ def _run_on_examples(
for i, example in enumerate(examples):
result = _run_llm_or_chain(
example,
llm_or_chain_factory,
wrapped_model,
num_repetitions,
tags=tags,
callbacks=callbacks,
@@ -1114,8 +1130,8 @@ def _prepare_eval_run(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
project_name: Optional[str],
) -> Tuple[MCF, str, Dataset, Iterator[Example]]:
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
project_name = _get_project_name(project_name, llm_or_chain_factory)
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
project_name = _get_project_name(project_name, wrapped_model)
try:
project = client.create_project(project_name)
except ValueError as e:
@@ -1130,7 +1146,7 @@ def _prepare_eval_run(
)
dataset = client.read_dataset(dataset_name=dataset_name)
examples = client.list_examples(dataset_id=str(dataset.id))
return llm_or_chain_factory, project_name, dataset, examples
return wrapped_model, project_name, dataset, examples
async def arun_on_dataset(
@@ -1256,13 +1272,13 @@ async def arun_on_dataset(
evaluation=evaluation_config,
)
""" # noqa: E501
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run(
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name
)
results = await _arun_on_examples(
client,
examples,
llm_or_chain_factory,
wrapped_model,
concurrency_level=concurrency_level,
num_repetitions=num_repetitions,
project_name=project_name,
@@ -1423,14 +1439,14 @@ def run_on_dataset(
evaluation=evaluation_config,
)
""" # noqa: E501
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run(
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name
)
if concurrency_level in (0, 1):
results = _run_on_examples(
client,
examples,
llm_or_chain_factory,
wrapped_model,
num_repetitions=num_repetitions,
project_name=project_name,
verbose=verbose,
@@ -1444,7 +1460,7 @@ def run_on_dataset(
coro = _arun_on_examples(
client,
examples,
llm_or_chain_factory,
wrapped_model,
concurrency_level=concurrency_level,
num_repetitions=num_repetitions,
project_name=project_name,

View File

@@ -203,7 +203,13 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla
**kwargs: Any,
) -> Any:
config = config or {}
return self.run(input, **config, **kwargs)
return self.run(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
async def ainvoke(
self,
@@ -216,7 +222,13 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla
return super().ainvoke(input, config, **kwargs)
config = config or {}
return await self.arun(input, **config, **kwargs)
return await self.arun(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
# --- Tool ---

View File

@@ -0,0 +1,191 @@
"""
Adapted from
https://github.com/maxfischer2781/asyncstdlib/blob/master/asyncstdlib/itertools.py
MIT License
"""
from collections import deque
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Deque,
Generic,
Iterator,
List,
Tuple,
TypeVar,
Union,
cast,
overload,
)
T = TypeVar("T")
_no_default = object()
# https://github.com/python/cpython/blob/main/Lib/test/test_asyncgen.py#L54
# before 3.10, the builtin anext() was not available
def py_anext(
iterator: AsyncIterator[T], default: Union[T, Any] = _no_default
) -> Awaitable[Union[T, None, Any]]:
"""Pure-Python implementation of anext() for testing purposes.
Closely matches the builtin anext() C implementation.
Can be used to compare the built-in implementation of the inner
coroutines machinery to C-implementation of __anext__() and send()
or throw() on the returned generator.
"""
try:
__anext__ = cast(
Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__
)
except AttributeError:
raise TypeError(f"{iterator!r} is not an async iterator")
if default is _no_default:
return __anext__(iterator)
async def anext_impl() -> Union[T, Any]:
try:
# The C code is way more low-level than this, as it implements
# all methods of the iterator protocol. In this implementation
# we're relying on higher-level coroutine concepts, but that's
# exactly what we want -- crosstest pure-Python high-level
# implementation and low-level C anext() iterators.
return await __anext__(iterator)
except StopAsyncIteration:
return default
return anext_impl()
async def tee_peer(
iterator: AsyncIterator[T],
# the buffer specific to this peer
buffer: Deque[T],
# the buffers of all peers, including our own
peers: List[Deque[T]],
) -> AsyncGenerator[T, None]:
"""An individual iterator of a :py:func:`~.tee`"""
try:
while True:
if not buffer:
# Another peer produced an item while we were waiting for the lock.
# Proceed with the next loop iteration to yield the item.
if buffer:
continue
try:
item = await iterator.__anext__()
except StopAsyncIteration:
break
else:
# Append to all buffers, including our own. We'll fetch our
# item from the buffer again, instead of yielding it directly.
# This ensures the proper item ordering if any of our peers
# are fetching items concurrently. They may have buffered their
# item already.
for peer_buffer in peers:
peer_buffer.append(item)
yield buffer.popleft()
finally:
# this peer is done remove its buffer
for idx, peer_buffer in enumerate(peers): # pragma: no branch
if peer_buffer is buffer:
peers.pop(idx)
break
# if we are the last peer, try and close the iterator
if not peers and hasattr(iterator, "aclose"):
await iterator.aclose()
class Tee(Generic[T]):
"""
Create ``n`` separate asynchronous iterators over ``iterable``
This splits a single ``iterable`` into multiple iterators, each providing
the same items in the same order.
All child iterators may advance separately but share the same items
from ``iterable`` -- when the most advanced iterator retrieves an item,
it is buffered until the least advanced iterator has yielded it as well.
A ``tee`` works lazily and can handle an infinite ``iterable``, provided
that all iterators advance.
.. code-block:: python3
async def derivative(sensor_data):
previous, current = a.tee(sensor_data, n=2)
await a.anext(previous) # advance one iterator
return a.map(operator.sub, previous, current)
Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
immediately closes all children, and it can be used in an ``async with`` context
for the same effect.
If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not*
provide these items. Also, ``tee`` must internally buffer each item until the
last iterator has yielded it; if the most and least advanced iterator differ
by most data, using a :py:class:`list` is more efficient (but not lazy).
If the underlying iterable is concurrency safe (``anext`` may be awaited
concurrently) the resulting iterators are concurrency safe as well. Otherwise,
the iterators are safe if there is only ever one single "most advanced" iterator.
To enforce sequential use of ``anext``, provide a ``lock``
- e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application -
and access is automatically synchronised.
"""
def __init__(
self,
iterable: AsyncIterator[T],
n: int = 2,
):
self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
self._children = tuple(
tee_peer(
iterator=self._iterator,
buffer=buffer,
peers=self._buffers,
)
for buffer in self._buffers
)
def __len__(self) -> int:
return len(self._children)
@overload
def __getitem__(self, item: int) -> AsyncIterator[T]:
...
@overload
def __getitem__(self, item: slice) -> Tuple[AsyncIterator[T], ...]:
...
def __getitem__(
self, item: Union[int, slice]
) -> Union[AsyncIterator[T], Tuple[AsyncIterator[T], ...]]:
return self._children[item]
def __iter__(self) -> Iterator[AsyncIterator[T]]:
yield from self._children
async def __aenter__(self) -> "Tee[T]":
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
await self.aclose()
return False
async def aclose(self) -> None:
for child in self._children:
await child.aclose()
atee = Tee

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain"
version = "0.0.258"
version = "0.0.260"
description = "Building applications with LLMs through composability"
authors = []
license = "MIT"

View File

@@ -0,0 +1,63 @@
"""Tests RocksetChatMessageHistory by creating a collection
for message history, adding to it, and clearing it.
To run these tests, make sure you have the ROCKSET_API_KEY
and ROCKSET_REGION environment variables set.
"""
import json
import os
from rockset import DevRegions, Regions, RocksetClient
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import RocksetChatMessageHistory
from langchain.schema.messages import _message_to_dict
collection_name = "langchain_demo"
session_id = "MySession"
class TestRockset:
memory: RocksetChatMessageHistory
@classmethod
def setup_class(cls) -> None:
assert os.environ.get("ROCKSET_API_KEY") is not None
assert os.environ.get("ROCKSET_REGION") is not None
api_key = os.environ.get("ROCKSET_API_KEY")
region = os.environ.get("ROCKSET_REGION")
if region == "use1a1":
host = Regions.use1a1
elif region == "usw2a1" or not region:
host = Regions.usw2a1
elif region == "euc1a1":
host = Regions.euc1a1
elif region == "dev":
host = DevRegions.usw2a1
else:
host = region
client = RocksetClient(host, api_key)
cls.memory = RocksetChatMessageHistory(
session_id, client, collection_name, sync=True
)
def test_memory_with_message_store(self) -> None:
memory = ConversationBufferMemory(
memory_key="messages", chat_memory=self.memory, return_messages=True
)
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
memory.chat_memory.clear()
assert memory.chat_memory.messages == []

View File

@@ -60,3 +60,67 @@ def test_on_llm_end_finetuned_model(handler: OpenAICallbackHandler) -> None:
)
handler.on_llm_end(response)
assert handler.total_cost > 0
@pytest.mark.parametrize(
"model_name,expected_cost",
[
("gpt-35-turbo", 0.0035),
("gpt-35-turbo-0301", 0.0035),
(
"gpt-35-turbo-0613",
0.0035,
),
(
"gpt-35-turbo-16k-0613",
0.007,
),
(
"gpt-35-turbo-16k",
0.007,
),
("gpt-4", 0.09),
("gpt-4-0314", 0.09),
("gpt-4-0613", 0.09),
("gpt-4-32k", 0.18),
("gpt-4-32k-0314", 0.18),
("gpt-4-32k-0613", 0.18),
],
)
def test_on_llm_end_azure_openai(
handler: OpenAICallbackHandler, model_name: str, expected_cost: float
) -> None:
response = LLMResult(
generations=[],
llm_output={
"token_usage": {
"prompt_tokens": 1000,
"completion_tokens": 1000,
"total_tokens": 2000,
},
"model_name": model_name,
},
)
handler.on_llm_end(response)
assert handler.total_cost == expected_cost
@pytest.mark.parametrize(
"model_name", ["gpt-35-turbo-16k-0301", "gpt-4-0301", "gpt-4-32k-0301"]
)
def test_on_llm_end_no_cost_invalid_model(
handler: OpenAICallbackHandler, model_name: str
) -> None:
response = LLMResult(
generations=[],
llm_output={
"token_usage": {
"prompt_tokens": 1000,
"completion_tokens": 1000,
"total_tokens": 2000,
},
"model_name": model_name,
},
)
handler.on_llm_end(response)
assert handler.total_cost == 0

View File

@@ -15,7 +15,7 @@ def dummy_transform(inputs: Dict[str, str]) -> Dict[str, str]:
return outputs
def test_tranform_chain() -> None:
def test_transform_chain() -> None:
"""Test basic transform chain."""
transform_chain = TransformChain(
input_variables=["first_name", "last_name"],

View File

@@ -0,0 +1,52 @@
import json
import os
from typing import Any, Mapping, cast
import pytest
from langchain.chat_models.azure_openai import AzureChatOpenAI
os.environ["OPENAI_API_KEY"] = "test"
os.environ["OPENAI_API_BASE"] = "https://oai.azure.com/"
os.environ["OPENAI_API_VERSION"] = "2023-05-01"
@pytest.mark.requires("openai")
@pytest.mark.parametrize(
"model_name", ["gpt-4", "gpt-4-32k", "gpt-35-turbo", "gpt-35-turbo-16k"]
)
def test_model_name_set_on_chat_result_when_present_in_response(
model_name: str,
) -> None:
sample_response_text = f"""
{{
"id": "chatcmpl-7ryweq7yc8463fas879t9hdkkdf",
"object": "chat.completion",
"created": 1690381189,
"model": "{model_name}",
"choices": [
{{
"index": 0,
"finish_reason": "stop",
"message": {{
"role": "assistant",
"content": "I'm an AI assistant that can help you."
}}
}}
],
"usage": {{
"completion_tokens": 28,
"prompt_tokens": 15,
"total_tokens": 43
}}
}}
"""
# convert sample_response_text to instance of Mapping[str, Any]
sample_response = json.loads(sample_response_text)
mock_response = cast(Mapping[str, Any], sample_response)
mock_chat = AzureChatOpenAI()
chat_result = mock_chat._create_chat_result(mock_response)
assert (
chat_result.llm_output is not None
and chat_result.llm_output["model_name"] == model_name
)

View File

@@ -0,0 +1,9 @@
"""Test the airbyte document loader.
Light test to ensure that the airbyte document loader can be imported.
"""
def test_airbyte_import() -> None:
"""Test that the airbyte document loader can be imported."""
from langchain.document_loaders import airbyte # noqa

View File

@@ -6,8 +6,12 @@ import pytest
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.evaluation.agents.trajectory_eval_chain import TrajectoryEvalChain
from langchain.schema import AgentAction, BaseMessage
from langchain.evaluation.agents.trajectory_eval_chain import (
TrajectoryEval,
TrajectoryEvalChain,
TrajectoryOutputParser,
)
from langchain.schema import AgentAction, BaseMessage, OutputParserException
from langchain.tools.base import tool
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
@@ -53,6 +57,61 @@ class _FakeTrajectoryChatModel(FakeChatModel):
return self.queries[prompt]
def test_trajectory_output_parser_parse() -> None:
trajectory_output_parser = TrajectoryOutputParser()
text = """Judgment: Given the good reasoning in the final answer
but otherwise poor performance, we give the model a score of 2.
Score: 2"""
got = trajectory_output_parser.parse(text)
want = TrajectoryEval(
score=0.25,
reasoning="""Judgment: Given the good reasoning in the final answer
but otherwise poor performance, we give the model a score of 2.""",
)
assert got["score"] == want["score"]
assert got["reasoning"] == want["reasoning"]
with pytest.raises(OutputParserException):
trajectory_output_parser.parse(
"""Judgment: Given the good reasoning in the final answer
but otherwise poor performance, we give the model a score of 2."""
)
with pytest.raises(OutputParserException):
trajectory_output_parser.parse(
"""Judgment: Given the good reasoning in the final answer
but otherwise poor performance, we give the model a score of 2.
Score: 9"""
)
with pytest.raises(OutputParserException):
trajectory_output_parser.parse(
"""Judgment: Given the good reasoning in the final answer
but otherwise poor performance, we give the model a score of 2.
Score: 10"""
)
with pytest.raises(OutputParserException):
trajectory_output_parser.parse(
"""Judgment: Given the good reasoning in the final answer
but otherwise poor performance, we give the model a score of 2.
Score: 0.1"""
)
with pytest.raises(OutputParserException):
trajectory_output_parser.parse(
"""Judgment: Given the good reasoning in the final answer
but otherwise poor performance, we give the model a score of 2.
Score: One"""
)
def test_trajectory_eval_chain(
intermediate_steps: List[Tuple[AgentAction, str]]
) -> None:

View File

@@ -0,0 +1,31 @@
# serializer version: 1
# name: test_openai_functions_router
list([
dict({
'description': 'Sends the draft for revision.',
'name': 'revise',
'parameters': dict({
'properties': dict({
'notes': dict({
'description': "The editor's notes to guide the revision.",
'type': 'string',
}),
}),
'type': 'object',
}),
}),
dict({
'description': 'Accepts the draft.',
'name': 'accept',
'parameters': dict({
'properties': dict({
'draft': dict({
'description': 'The draft to accept.',
'type': 'string',
}),
}),
'type': 'object',
}),
}),
])
# ---

View File

@@ -0,0 +1,95 @@
from typing import Any, List, Optional
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel
from langchain.runnables.openai_functions import OpenAIFunctionsRouter
from langchain.schema import ChatResult
from langchain.schema.messages import AIMessage, BaseMessage
from langchain.schema.output import ChatGeneration
class FakeChatOpenAI(BaseChatModel):
@property
def _llm_type(self) -> str:
return "fake-openai-chat-model"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": "accept",
"arguments": '{\n "draft": "turtles"\n}',
}
},
)
)
]
)
def test_openai_functions_router(
snapshot: SnapshotAssertion, mocker: MockerFixture
) -> None:
revise = mocker.Mock(
side_effect=lambda kw: f'Revised draft: no more {kw["notes"]}!'
)
accept = mocker.Mock(side_effect=lambda kw: f'Accepted draft: {kw["draft"]}!')
router = OpenAIFunctionsRouter(
{
"revise": revise,
"accept": accept,
},
functions=[
{
"name": "revise",
"description": "Sends the draft for revision.",
"parameters": {
"type": "object",
"properties": {
"notes": {
"type": "string",
"description": "The editor's notes to guide the revision.",
},
},
},
},
{
"name": "accept",
"description": "Accepts the draft.",
"parameters": {
"type": "object",
"properties": {
"draft": {
"type": "string",
"description": "The draft to accept.",
},
},
},
},
],
)
model = FakeChatOpenAI()
chain = model.bind(functions=router.functions) | router
assert router.functions == snapshot
assert chain.invoke("Something about turtles?") == "Accepted draft: turtles!"
revise.assert_not_called()
accept.assert_called_once_with({"draft": "turtles"})

File diff suppressed because one or more lines are too long

View File

@@ -11,7 +11,7 @@ from langchain.callbacks.manager import Callbacks
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
from langchain.chat_models.fake import FakeListChatModel
from langchain.llms.fake import FakeListLLM
from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM
from langchain.load.dump import dumpd, dumps
from langchain.output_parsers.list import CommaSeparatedListOutputParser
from langchain.prompts.chat import (
@@ -22,6 +22,7 @@ from langchain.prompts.chat import (
)
from langchain.schema.document import Document
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import (
RouterRunnable,
@@ -61,6 +62,8 @@ class FakeTracer(BaseTracer):
if run.parent_run_id
else None,
"child_runs": [self._copy_run(child) for child in run.child_runs],
"execution_order": None,
"child_execution_order": None,
}
)
@@ -302,7 +305,7 @@ async def test_prompt_with_chat_model(
tracer = FakeTracer()
assert [
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
] == [AIMessage(content="foo")]
] == [AIMessage(content="f"), AIMessage(content="o"), AIMessage(content="o")]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
@@ -678,7 +681,12 @@ async def test_router_runnable(
"key": "math",
"input": {"question": "2 + 2"},
}
assert tracer.runs == snapshot
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 2
router_run = parent_run.child_runs[1]
assert router_run.name == "RunnableSequence" # TODO: should be RunnableRouter
assert len(router_run.child_runs) == 2
@freeze_time("2023-01-01")
@@ -758,6 +766,45 @@ def test_bind_bind() -> None:
) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world"))
def test_deep_stream() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
llm = FakeStreamingListLLM(responses=["foo-lish"])
chain = prompt | llm | StrOutputParser()
stream = chain.stream({"question": "What up"})
chunks = []
for chunk in stream:
chunks.append(chunk)
assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"
@pytest.mark.asyncio
async def test_deep_astream() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
llm = FakeStreamingListLLM(responses=["foo-lish"])
chain = prompt | llm | StrOutputParser()
stream = chain.astream({"question": "What up"})
chunks = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"
@pytest.fixture()
def llm_with_fallbacks() -> RunnableWithFallbacks:
error_llm = FakeListLLM(responses=["foo"], i=1)