mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-05 16:50:03 +00:00
Compare commits
90 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
910da8518f | ||
|
|
2f27ef92fe | ||
|
|
75149d6d38 | ||
|
|
fab7994b74 | ||
|
|
eb80d6e0e4 | ||
|
|
b5667bed9e | ||
|
|
b3be83c750 | ||
|
|
50626a10ee | ||
|
|
6e1b5b8f7e | ||
|
|
eec9b1b306 | ||
|
|
ea142f6a32 | ||
|
|
12f868b292 | ||
|
|
31f9ecfc19 | ||
|
|
273e9bf296 | ||
|
|
f155d9d3ec | ||
|
|
d3d4503ce2 | ||
|
|
1f93c5cf69 | ||
|
|
15b5a08f4b | ||
|
|
ff4a25b841 | ||
|
|
2212520a6c | ||
|
|
d08f940336 | ||
|
|
2280a2cb2f | ||
|
|
ce5d97bcb3 | ||
|
|
8fa1764c60 | ||
|
|
f299bd1416 | ||
|
|
064be93edf | ||
|
|
86822d1cc2 | ||
|
|
a581bce379 | ||
|
|
2ffc643086 | ||
|
|
2136dc94bb | ||
|
|
a92344f476 | ||
|
|
b706966ebc | ||
|
|
1c22657256 | ||
|
|
6f02286805 | ||
|
|
3674074eb0 | ||
|
|
a7e09d46c5 | ||
|
|
fa2e546b76 | ||
|
|
c592b12043 | ||
|
|
9555bbd5bb | ||
|
|
0ca1641b14 | ||
|
|
d5b4393bb2 | ||
|
|
7b6ff7fe00 | ||
|
|
76c7b1f677 | ||
|
|
5aa8ece211 | ||
|
|
f6d24d5740 | ||
|
|
b1c4480d7c | ||
|
|
b6ba989f2f | ||
|
|
04acda55ec | ||
|
|
8e5c4ac867 | ||
|
|
df8702fead | ||
|
|
d5d50c39e6 | ||
|
|
1f18698b2a | ||
|
|
ef4945af6b | ||
|
|
7de2ada3ea | ||
|
|
262d4cb9a8 | ||
|
|
951c158106 | ||
|
|
85e4dd7fc3 | ||
|
|
b1b4a4065a | ||
|
|
08f23c95d9 | ||
|
|
3cf493b089 | ||
|
|
e635c86145 | ||
|
|
779790167e | ||
|
|
3161ced4bc | ||
|
|
3d6fcb85dc | ||
|
|
3701b2901e | ||
|
|
280cb4160d | ||
|
|
80d8db5f60 | ||
|
|
1a8790d808 | ||
|
|
34840f3aee | ||
|
|
8685d53adc | ||
|
|
2f6833d433 | ||
|
|
dd90fd02d5 | ||
|
|
07766a69f3 | ||
|
|
aa854988bf | ||
|
|
96ebe98dc2 | ||
|
|
45f05fc939 | ||
|
|
cf9c3f54f7 | ||
|
|
fbc0c85b90 | ||
|
|
276940fd9b | ||
|
|
cdff6c8181 | ||
|
|
cd45adbea2 | ||
|
|
aff44d0a98 | ||
|
|
8a95fdaee1 | ||
|
|
5d8dc83ede | ||
|
|
b157e0c1c3 | ||
|
|
40e9488055 | ||
|
|
55efbb8a7e | ||
|
|
d6bbf395af | ||
|
|
606605925d | ||
|
|
f93c011456 |
2
.github/CONTRIBUTING.md
vendored
2
.github/CONTRIBUTING.md
vendored
@@ -73,6 +73,8 @@ poetry install -E all
|
||||
|
||||
This will install all requirements for running the package, examples, linting, formatting, tests, and coverage. Note the `-E all` flag will install all optional dependencies necessary for integration testing.
|
||||
|
||||
❗Note: If you're running Poetry 1.4.1 and receive a `WheelFileValidationError` for `debugpy` during installation, you can try either downgrading to Poetry 1.4.0 or disabling "modern installation" (`poetry config installer.modern-installation false`) and re-install requirements. See [this `debugpy` issue](https://github.com/microsoft/debugpy/issues/1246) for more details.
|
||||
|
||||
Now, you should be able to run the common tasks in the following section.
|
||||
|
||||
## ✅Common Tasks
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -136,5 +136,8 @@ dmypy.json
|
||||
# macOS display setting files
|
||||
.DS_Store
|
||||
|
||||
# Wandb directory
|
||||
wandb/
|
||||
|
||||
# asdf tool versions
|
||||
.tool-versions
|
||||
.tool-versions
|
||||
|
||||
@@ -23,7 +23,7 @@ with open("../pyproject.toml") as f:
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "🦜🔗 LangChain"
|
||||
copyright = "2022, Harrison Chase"
|
||||
copyright = "2023, Harrison Chase"
|
||||
author = "Harrison Chase"
|
||||
|
||||
version = data["tool"]["poetry"]["version"]
|
||||
|
||||
@@ -34,7 +34,8 @@ search = GoogleSerperAPIWrapper()
|
||||
tools = [
|
||||
Tool(
|
||||
name="Intermediate Answer",
|
||||
func=search.run
|
||||
func=search.run,
|
||||
description="useful for when you need to ask with search"
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -25,9 +25,25 @@ from langchain.llms import PromptLayerOpenAI
|
||||
llm = PromptLayerOpenAI(pl_tags=["langchain-requests", "chatbot"])
|
||||
```
|
||||
|
||||
To get the PromptLayer request id, use the argument `return_pl_id` when instanializing the LLM
|
||||
```python
|
||||
from langchain.llms import PromptLayerOpenAI
|
||||
llm = PromptLayerOpenAI(return_pl_id=True)
|
||||
```
|
||||
This will add the PromptLayer request ID in the `generation_info` field of the `Generation` returned when using `.generate` or `.agenerate`
|
||||
|
||||
For example:
|
||||
```python
|
||||
llm_results = llm.generate(["hello world"])
|
||||
for res in llm_results.generations:
|
||||
print("pl request id: ", res[0].generation_info["pl_request_id"])
|
||||
```
|
||||
You can use the PromptLayer request ID to add a prompt, score, or other metadata to your request. [Read more about it here](https://magniv.notion.site/Track-4deee1b1f7a34c1680d085f82567dab9).
|
||||
|
||||
This LLM is identical to the [OpenAI LLM](./openai), except that
|
||||
- all your requests will be logged to your PromptLayer account
|
||||
- you can add `pl_tags` when instantializing to tag your requests on PromptLayer
|
||||
- you can add `return_pl_id` when instantializing to return a PromptLayer request id to use [while tracking requests](https://magniv.notion.site/Track-4deee1b1f7a34c1680d085f82567dab9).
|
||||
|
||||
|
||||
PromptLayer also provides native wrappers for [`PromptLayerChatOpenAI`](../modules/chat/examples/promptlayer_chat_openai.ipynb)
|
||||
PromptLayer also provides native wrappers for [`PromptLayerChatOpenAI`](../modules/chat/examples/promptlayer_chat_openai.ipynb) and `PromptLayerOpenAIChat`
|
||||
|
||||
20
docs/ecosystem/qdrant.md
Normal file
20
docs/ecosystem/qdrant.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# Qdrant
|
||||
|
||||
This page covers how to use the Qdrant ecosystem within LangChain.
|
||||
It is broken into two parts: installation and setup, and then references to specific Qdrant wrappers.
|
||||
|
||||
## Installation and Setup
|
||||
- Install the Python SDK with `pip install qdrant-client`
|
||||
## Wrappers
|
||||
|
||||
### VectorStore
|
||||
|
||||
There exists a wrapper around Qdrant indexes, allowing you to use it as a vectorstore,
|
||||
whether for semantic search or example selection.
|
||||
|
||||
To import this vectorstore:
|
||||
```python
|
||||
from langchain.vectorstores import Qdrant
|
||||
```
|
||||
|
||||
For a more detailed walkthrough of the Qdrant wrapper, see [this notebook](../modules/indexes/vectorstore_examples/qdrant.ipynb)
|
||||
625
docs/ecosystem/wandb_tracking.ipynb
Normal file
625
docs/ecosystem/wandb_tracking.ipynb
Normal file
@@ -0,0 +1,625 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Weights & Biases\n",
|
||||
"\n",
|
||||
"This notebook goes over how to track your LangChain experiments into one centralized Weights and Biases dashboard. To learn more about prompt engineering and the callback please refer to this Report which explains both alongside the resultant dashboards you can expect to see.\n",
|
||||
"\n",
|
||||
"Run in Colab: https://colab.research.google.com/drive/1DXH4beT4HFaRKy_Vm4PoxhXVDRf7Ym8L?usp=sharing\n",
|
||||
"\n",
|
||||
"View Report: https://wandb.ai/a-sh0ts/langchain_callback_demo/reports/Prompt-Engineering-LLMs-with-LangChain-and-W-B--VmlldzozNjk1NTUw#👋-how-to-build-a-callback-in-langchain-for-better-prompt-engineering"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install wandb\n",
|
||||
"!pip install pandas\n",
|
||||
"!pip install textstat\n",
|
||||
"!pip install spacy\n",
|
||||
"!python -m spacy download en_core_web_sm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"id": "T1bSmKd6V2If"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.environ[\"WANDB_API_KEY\"] = \"\"\n",
|
||||
"# os.environ[\"OPENAI_API_KEY\"] = \"\"\n",
|
||||
"# os.environ[\"SERPAPI_API_KEY\"] = \"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"id": "8WAGnTWpUUnD"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datetime import datetime\n",
|
||||
"from langchain.callbacks import WandbCallbackHandler, StdOutCallbackHandler\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.llms import OpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```\n",
|
||||
"Callback Handler that logs to Weights and Biases.\n",
|
||||
"\n",
|
||||
"Parameters:\n",
|
||||
" job_type (str): The type of job.\n",
|
||||
" project (str): The project to log to.\n",
|
||||
" entity (str): The entity to log to.\n",
|
||||
" tags (list): The tags to log.\n",
|
||||
" group (str): The group to log to.\n",
|
||||
" name (str): The name of the run.\n",
|
||||
" notes (str): The notes to log.\n",
|
||||
" visualize (bool): Whether to visualize the run.\n",
|
||||
" complexity_metrics (bool): Whether to log complexity metrics.\n",
|
||||
" stream_logs (bool): Whether to stream callback actions to W&B\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "cxBFfZR8d9FC"
|
||||
},
|
||||
"source": [
|
||||
"```\n",
|
||||
"Default values for WandbCallbackHandler(...)\n",
|
||||
"\n",
|
||||
"visualize: bool = False,\n",
|
||||
"complexity_metrics: bool = False,\n",
|
||||
"stream_logs: bool = False,\n",
|
||||
"```\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"NOTE: For beta workflows we have made the default analysis based on textstat and the visualizations based on spacy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"id": "KAz8weWuUeXF"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mharrison-chase\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"Tracking run with wandb version 0.14.0"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"Run data is saved locally in <code>/Users/harrisonchase/workplace/langchain/docs/ecosystem/wandb/run-20230318_150408-e47j1914</code>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"Syncing run <strong><a href='https://wandb.ai/harrison-chase/langchain_callback_demo/runs/e47j1914' target=\"_blank\">llm</a></strong> to <a href='https://wandb.ai/harrison-chase/langchain_callback_demo' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
" View project at <a href='https://wandb.ai/harrison-chase/langchain_callback_demo' target=\"_blank\">https://wandb.ai/harrison-chase/langchain_callback_demo</a>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
" View run at <a href='https://wandb.ai/harrison-chase/langchain_callback_demo/runs/e47j1914' target=\"_blank\">https://wandb.ai/harrison-chase/langchain_callback_demo/runs/e47j1914</a>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m The wandb callback is currently in beta and is subject to change based on updates to `langchain`. Please report any issues to https://github.com/wandb/wandb/issues with the tag `langchain`.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\"\"\"Main function.\n",
|
||||
"\n",
|
||||
"This function is used to try the callback handler.\n",
|
||||
"Scenarios:\n",
|
||||
"1. OpenAI LLM\n",
|
||||
"2. Chain with multiple SubChains on multiple generations\n",
|
||||
"3. Agent with Tools\n",
|
||||
"\"\"\"\n",
|
||||
"session_group = datetime.now().strftime(\"%m.%d.%Y_%H.%M.%S\")\n",
|
||||
"wandb_callback = WandbCallbackHandler(\n",
|
||||
" job_type=\"inference\",\n",
|
||||
" project=\"langchain_callback_demo\",\n",
|
||||
" group=f\"minimal_{session_group}\",\n",
|
||||
" name=\"llm\",\n",
|
||||
" tags=[\"test\"],\n",
|
||||
")\n",
|
||||
"manager = CallbackManager([StdOutCallbackHandler(), wandb_callback])\n",
|
||||
"llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Q-65jwrDeK6w"
|
||||
},
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"# Defaults for WandbCallbackHandler.flush_tracker(...)\n",
|
||||
"\n",
|
||||
"reset: bool = True,\n",
|
||||
"finish: bool = False,\n",
|
||||
"```\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The `flush_tracker` function is used to log LangChain sessions to Weights & Biases. It takes in the LangChain module or agent, and logs at minimum the prompts and generations alongside the serialized form of the LangChain module to the specified Weights & Biases project. By default we reset the session as opposed to concluding the session outright."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"id": "o_VmneyIUyx8"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
" View run <strong style=\"color:#cdcd00\">llm</strong> at: <a href='https://wandb.ai/harrison-chase/langchain_callback_demo/runs/e47j1914' target=\"_blank\">https://wandb.ai/harrison-chase/langchain_callback_demo/runs/e47j1914</a><br/>Synced 5 W&B file(s), 2 media file(s), 5 artifact file(s) and 0 other file(s)"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"Find logs at: <code>./wandb/run-20230318_150408-e47j1914/logs</code>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "0d7b4307ccdb450ea631497174fca2d1",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.016745895149999985, max=1.0…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"Tracking run with wandb version 0.14.0"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"Run data is saved locally in <code>/Users/harrisonchase/workplace/langchain/docs/ecosystem/wandb/run-20230318_150534-jyxma7hu</code>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"Syncing run <strong><a href='https://wandb.ai/harrison-chase/langchain_callback_demo/runs/jyxma7hu' target=\"_blank\">simple_sequential</a></strong> to <a href='https://wandb.ai/harrison-chase/langchain_callback_demo' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
" View project at <a href='https://wandb.ai/harrison-chase/langchain_callback_demo' target=\"_blank\">https://wandb.ai/harrison-chase/langchain_callback_demo</a>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
" View run at <a href='https://wandb.ai/harrison-chase/langchain_callback_demo/runs/jyxma7hu' target=\"_blank\">https://wandb.ai/harrison-chase/langchain_callback_demo/runs/jyxma7hu</a>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# SCENARIO 1 - LLM\n",
|
||||
"llm_result = llm.generate([\"Tell me a joke\", \"Tell me a poem\"] * 3)\n",
|
||||
"wandb_callback.flush_tracker(llm, name=\"simple_sequential\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"id": "trxslyb1U28Y"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"from langchain.chains import LLMChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"id": "uauQk10SUzF6"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
" View run <strong style=\"color:#cdcd00\">simple_sequential</strong> at: <a href='https://wandb.ai/harrison-chase/langchain_callback_demo/runs/jyxma7hu' target=\"_blank\">https://wandb.ai/harrison-chase/langchain_callback_demo/runs/jyxma7hu</a><br/>Synced 4 W&B file(s), 2 media file(s), 6 artifact file(s) and 0 other file(s)"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"Find logs at: <code>./wandb/run-20230318_150534-jyxma7hu/logs</code>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "dbdbf28fb8ed40a3a60218d2e6d1a987",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.016736786816666675, max=1.0…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"Tracking run with wandb version 0.14.0"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"Run data is saved locally in <code>/Users/harrisonchase/workplace/langchain/docs/ecosystem/wandb/run-20230318_150550-wzy59zjq</code>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"Syncing run <strong><a href='https://wandb.ai/harrison-chase/langchain_callback_demo/runs/wzy59zjq' target=\"_blank\">agent</a></strong> to <a href='https://wandb.ai/harrison-chase/langchain_callback_demo' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
" View project at <a href='https://wandb.ai/harrison-chase/langchain_callback_demo' target=\"_blank\">https://wandb.ai/harrison-chase/langchain_callback_demo</a>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
" View run at <a href='https://wandb.ai/harrison-chase/langchain_callback_demo/runs/wzy59zjq' target=\"_blank\">https://wandb.ai/harrison-chase/langchain_callback_demo/runs/wzy59zjq</a>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# SCENARIO 2 - Chain\n",
|
||||
"template = \"\"\"You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n",
|
||||
"Title: {title}\n",
|
||||
"Playwright: This is a synopsis for the above play:\"\"\"\n",
|
||||
"prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n",
|
||||
"\n",
|
||||
"test_prompts = [\n",
|
||||
" {\n",
|
||||
" \"title\": \"documentary about good video games that push the boundary of game design\"\n",
|
||||
" },\n",
|
||||
" {\"title\": \"cocaine bear vs heroin wolf\"},\n",
|
||||
" {\"title\": \"the best in class mlops tooling\"},\n",
|
||||
"]\n",
|
||||
"synopsis_chain.apply(test_prompts)\n",
|
||||
"wandb_callback.flush_tracker(synopsis_chain, name=\"agent\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"id": "_jN73xcPVEpI"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.agents import initialize_agent, load_tools"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"id": "Gpq4rk6VT9cu"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m I need to find out who Leo DiCaprio's girlfriend is and then calculate her age raised to the 0.43 power.\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3mDiCaprio had a steady girlfriend in Camila Morrone. He had been with the model turned actress for nearly five years, as they were first said to be dating at the end of 2017. And the now 26-year-old Morrone is no stranger to Hollywood.\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate her age raised to the 0.43 power.\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: 26^0.43\u001b[0m\n",
|
||||
"Observation: \u001b[33;1m\u001b[1;3mAnswer: 4.059182145592686\n",
|
||||
"\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n",
|
||||
"Final Answer: Leo DiCaprio's girlfriend is Camila Morrone and her current age raised to the 0.43 power is 4.059182145592686.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
" View run <strong style=\"color:#cdcd00\">agent</strong> at: <a href='https://wandb.ai/harrison-chase/langchain_callback_demo/runs/wzy59zjq' target=\"_blank\">https://wandb.ai/harrison-chase/langchain_callback_demo/runs/wzy59zjq</a><br/>Synced 5 W&B file(s), 2 media file(s), 7 artifact file(s) and 0 other file(s)"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"Find logs at: <code>./wandb/run-20230318_150550-wzy59zjq/logs</code>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# SCENARIO 3 - Agent with Tools\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n",
|
||||
"agent = initialize_agent(\n",
|
||||
" tools,\n",
|
||||
" llm,\n",
|
||||
" agent=\"zero-shot-react-description\",\n",
|
||||
" callback_manager=manager,\n",
|
||||
" verbose=True,\n",
|
||||
")\n",
|
||||
"agent.run(\n",
|
||||
" \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\"\n",
|
||||
")\n",
|
||||
"wandb_callback.flush_tracker(agent, reset=False, finish=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
||||
@@ -158,14 +158,14 @@ Open Source
|
||||
|
||||
---
|
||||
|
||||
.. link-button:: https://github.com/jerryjliu/gpt_index
|
||||
.. link-button:: https://github.com/jerryjliu/llama_index
|
||||
:type: url
|
||||
:text: GPT Index
|
||||
:text: LlamaIndex
|
||||
:classes: stretched-link btn-lg
|
||||
|
||||
+++
|
||||
|
||||
GPT Index is a project consisting of a set of data structures that are created using GPT-3 and can be traversed using GPT-3 in order to answer queries.
|
||||
LlamaIndex (formerly GPT Index) is a project consisting of a set of data structures that are created using GPT-3 and can be traversed using GPT-3 in order to answer queries.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -97,6 +97,8 @@ The above modules can be used in a variety of ways. LangChain also provides guid
|
||||
|
||||
- `Summarization <./use_cases/summarization.html>`_: Summarizing longer documents into shorter, more condensed chunks of information. A type of Data Augmented Generation.
|
||||
|
||||
- `Querying Tabular Data <./use_cases/tabular.html>`_: If you want to understand how to use LLMs to query data that is stored in a tabular format (csvs, SQL, dataframes, etc) you should read this page.
|
||||
|
||||
- `Evaluation <./use_cases/evaluation.html>`_: Generative models are notoriously hard to evaluate with traditional metrics. One new way of evaluating them is using language models themselves to do the evaluation. LangChain provides some prompts/chains for assisting in this.
|
||||
|
||||
- `Generate similar examples <./use_cases/generate_examples.html>`_: Generating similar examples to a given input. This is a common use case for many applications, and LangChain provides some prompts/chains for assisting in this.
|
||||
@@ -117,6 +119,8 @@ The above modules can be used in a variety of ways. LangChain also provides guid
|
||||
./use_cases/combine_docs.md
|
||||
./use_cases/question_answering.md
|
||||
./use_cases/summarization.md
|
||||
./use_cases/tabular.rst
|
||||
./use_cases/extraction.md
|
||||
./use_cases/evaluation.rst
|
||||
./use_cases/model_laboratory.ipynb
|
||||
|
||||
|
||||
132
docs/modules/agents/examples/human_tools.ipynb
Normal file
132
docs/modules/agents/examples/human_tools.ipynb
Normal file
@@ -0,0 +1,132 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Human as a tool\n",
|
||||
"\n",
|
||||
"Human are AGI so they can certainly be used as a tool to help out AI agent \n",
|
||||
"when it is confused."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.agents import load_tools, initialize_agent\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(temperature=0.0)\n",
|
||||
"math_llm = OpenAI(temperature=0.0)\n",
|
||||
"tools = load_tools(\n",
|
||||
" [\"human\", \"llm-math\"], \n",
|
||||
" llm=math_llm,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"agent_chain = initialize_agent(\n",
|
||||
" tools,\n",
|
||||
" llm,\n",
|
||||
" agent=\"zero-shot-react-description\",\n",
|
||||
" verbose=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In the above code you can see the tool takes input directly from command line.\n",
|
||||
"You can customize `prompt_func` and `input_func` according to your need."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mI don't know Eric Zhu, so I should ask a human for guidance.\n",
|
||||
"Action: Human\n",
|
||||
"Action Input: \"Do you know when Eric Zhu's birthday is?\"\u001b[0m\n",
|
||||
"\n",
|
||||
"Do you know when Eric Zhu's birthday is?\n",
|
||||
"last week\n",
|
||||
"\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3mlast week\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3mThat's not very helpful. I should ask for more information.\n",
|
||||
"Action: Human\n",
|
||||
"Action Input: \"Do you know the specific date of Eric Zhu's birthday?\"\u001b[0m\n",
|
||||
"\n",
|
||||
"Do you know the specific date of Eric Zhu's birthday?\n",
|
||||
"august 1st\n",
|
||||
"\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3maugust 1st\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3mNow that I have the date, I can check if it's a leap year or not.\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: \"Is 2021 a leap year?\"\u001b[0m\n",
|
||||
"Observation: \u001b[33;1m\u001b[1;3mAnswer: False\n",
|
||||
"\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3mI have all the information I need to answer the original question.\n",
|
||||
"Final Answer: Eric Zhu's birthday is on August 1st and it is not a leap year in 2021.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"Eric Zhu's birthday is on August 1st and it is not a leap year in 2021.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"agent_chain.run(\"What is Eric Zhu's birthday?\")\n",
|
||||
"# Answer with \"last week\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -61,7 +61,8 @@
|
||||
"tools = [\n",
|
||||
" Tool(\n",
|
||||
" name=\"Intermediate Answer\",\n",
|
||||
" func=search.run\n",
|
||||
" func=search.run,\n",
|
||||
" description=\"useful for when you need to ask with search\"\n",
|
||||
" )\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
|
||||
@@ -24,11 +24,13 @@
|
||||
"tools = [\n",
|
||||
" Tool(\n",
|
||||
" name=\"Search\",\n",
|
||||
" func=docstore.search\n",
|
||||
" func=docstore.search,\n",
|
||||
" description=\"useful for when you need to ask with search\"\n",
|
||||
" ),\n",
|
||||
" Tool(\n",
|
||||
" name=\"Lookup\",\n",
|
||||
" func=docstore.lookup\n",
|
||||
" func=docstore.lookup,\n",
|
||||
" description=\"useful for when you need to ask with lookup\"\n",
|
||||
" )\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
|
||||
@@ -52,7 +52,8 @@
|
||||
"tools = [\n",
|
||||
" Tool(\n",
|
||||
" name=\"Intermediate Answer\",\n",
|
||||
" func=search.run\n",
|
||||
" func=search.run,\n",
|
||||
" description=\"useful for when you need to ask with search\"\n",
|
||||
" )\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
|
||||
@@ -13,3 +13,4 @@ For more detailed information on tools, and different types of tools in LangChai
|
||||
Toolkits are groups of tools that are best used together.
|
||||
They allow you to logically group and initialize a set of tools that share a particular resource (such as a database connection or json object).
|
||||
They can be used to construct an agent for a specific use-case.
|
||||
For more detailed information on toolkits and their use cases, see [this documentation](how_to_guides.rst#agent-toolkits) (the "Agent Toolkits" section).
|
||||
@@ -145,3 +145,10 @@ Below is a list of all supported tools and relevant information:
|
||||
- Requires LLM: No
|
||||
- Extra Parameters: `top_k_results`
|
||||
|
||||
**podcast-api**
|
||||
|
||||
- Tool Name: Podcast API
|
||||
- Tool Description: Use the Listen Notes Podcast API to search all podcasts or episodes. The input should be a question in natural language that this API can answer.
|
||||
- Notes: A natural language connection to the Listen Notes Podcast API (`https://www.PodcastAPI.com`), specifically the `/search/` endpoint.
|
||||
- Requires LLM: Yes
|
||||
- Extra Parameters: `listen_api_key` (your api key to access this endpoint)
|
||||
|
||||
@@ -149,6 +149,33 @@
|
||||
"chain.run(\"Search for 'Avatar'\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Listen API Example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.chains.api import podcast_docs\n",
|
||||
"from langchain.chains import APIChain\n",
|
||||
"\n",
|
||||
"# Get api key here: https://www.listennotes.com/api/pricing/\n",
|
||||
"listen_api_key = 'xxx'\n",
|
||||
"\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"headers = {\"X-ListenAPI-Key\": listen_api_key}\n",
|
||||
"chain = APIChain.from_llm_and_api_docs(llm, podcast_docs.PODCAST_DOCS, headers=headers, verbose=True)\n",
|
||||
"chain.run(\"Search for 'silicon valley bank' podcast episodes, audio length is more than 30 minutes, return only 1 results\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -173,7 +200,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -532,7 +532,7 @@
|
||||
"id": "5fc6f507",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note how our custom table definition and sample rows for `Track` overrides the `sample_rows_in_table_info` parameter. Tables that are not overriden by `custom_table_info`, in this example `Playlist`, will have their table info gathered automatically as usual."
|
||||
"Note how our custom table definition and sample rows for `Track` overrides the `sample_rows_in_table_info` parameter. Tables that are not overridden by `custom_table_info`, in this example `Playlist`, will have their table info gathered automatically as usual."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -123,6 +123,40 @@
|
||||
"id": "05e9e2fe",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "c43803d1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using PromptLayer Track\n",
|
||||
"If you would like to use any of the [PromptLayer tracking features](https://magniv.notion.site/Track-4deee1b1f7a34c1680d085f82567dab9), you need to pass the argument `return_pl_id` when instantializing the PromptLayer LLM to get the request id. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b7d4db01",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat = PromptLayerChatOpenAI(return_pl_id=True)\n",
|
||||
"chat_results = chat.generate([[HumanMessage(content=\"I am a cat and I want\")]])\n",
|
||||
"\n",
|
||||
"for res in chat_results.generations:\n",
|
||||
" pl_request_id = res[0].generation_info[\"pl_request_id\"]\n",
|
||||
" promptlayer.track.score(request_id=pl_request_id, score=100)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "13e56507",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Using this allows you to track the performance of your model in the PromptLayer dashboard. If you are using a prompt template, you can attach a template to a request as well.\n",
|
||||
"Overall, this gives you the opportunity to track the performance of different templates and models in the PromptLayer dashboard."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -141,11 +175,11 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.8"
|
||||
"version": "3.8.8 (default, Apr 13 2021, 12:59:45) \n[Clang 10.0.0 ]"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "c4fe2cd85a8d9e8baaec5340ce66faff1c77581a9f43e6c45e85e09b6fced008"
|
||||
"hash": "8a5edab282632443219e051e4ade2d1d5bbc671c781051bf1437897cbdfea0f1"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 1,
|
||||
"id": "522686de",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -36,7 +36,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 2,
|
||||
"id": "62e0dbc3",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -56,7 +56,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 3,
|
||||
"id": "76a6e7b0-e927-4bfb-a414-1332a4149106",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -68,7 +68,7 @@
|
||||
"AIMessage(content=\"J'aime programmer.\", additional_kwargs={})"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -87,7 +87,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 4,
|
||||
"id": "ce16ad78-8e6f-48cd-954e-98be75eb5836",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -99,7 +99,7 @@
|
||||
"AIMessage(content=\"J'aime programmer.\", additional_kwargs={})"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -122,7 +122,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 5,
|
||||
"id": "2b21fc52-74b6-4950-ab78-45d12c68fb4d",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -131,10 +131,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"LLMResult(generations=[[ChatGeneration(text=\"J'aime programmer.\", generation_info=None, message=AIMessage(content=\"J'aime programmer.\", additional_kwargs={}))], [ChatGeneration(text=\"J'aime l'intelligence artificielle.\", generation_info=None, message=AIMessage(content=\"J'aime l'intelligence artificielle.\", additional_kwargs={}))]], llm_output=None)"
|
||||
"LLMResult(generations=[[ChatGeneration(text=\"J'aime programmer.\", generation_info=None, message=AIMessage(content=\"J'aime programmer.\", additional_kwargs={}))], [ChatGeneration(text=\"J'aime l'intelligence artificielle.\", generation_info=None, message=AIMessage(content=\"J'aime l'intelligence artificielle.\", additional_kwargs={}))]], llm_output={'token_usage': {'prompt_tokens': 71, 'completion_tokens': 18, 'total_tokens': 89}})"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -150,7 +150,39 @@
|
||||
" HumanMessage(content=\"Translate this sentence from English to French. I love artificial intelligence.\")\n",
|
||||
" ],\n",
|
||||
"]\n",
|
||||
"chat.generate(batch_messages)"
|
||||
"result = chat.generate(batch_messages)\n",
|
||||
"result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2960f50f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can recover things like token usage from this LLMResult"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "a6186bee",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'token_usage': {'prompt_tokens': 71,\n",
|
||||
" 'completion_tokens': 18,\n",
|
||||
" 'total_tokens': 89}}"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result.llm_output"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
38
docs/modules/document_loaders/examples/blackboard.ipynb
Normal file
38
docs/modules/document_loaders/examples/blackboard.ipynb
Normal file
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Blackboard\n",
|
||||
"\n",
|
||||
"This covers how to load data from a Blackboard Learn instance."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import BlackboardLoader\n",
|
||||
"\n",
|
||||
"loader = BlackboardLoader(\n",
|
||||
" blackboard_course_url=\"https://blackboard.example.com/webapps/blackboard/execute/announcement?method=search&context=course_entry&course_id=_123456_1\",\n",
|
||||
" bbrouter=\"expires:12345...\",\n",
|
||||
" load_all_recursively=True,\n",
|
||||
")\n",
|
||||
"documents = loader.load()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
79
docs/modules/document_loaders/examples/figma.ipynb
Normal file
79
docs/modules/document_loaders/examples/figma.ipynb
Normal file
@@ -0,0 +1,79 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "33205b12",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Figma\n",
|
||||
"\n",
|
||||
"This notebook covers how to load data from the Figma REST API into a format that can be ingested into LangChain."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "90b69c94",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from langchain.document_loaders import FigmaFileLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "13deb0f5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = FigmaFileLoader(\n",
|
||||
" os.environ.get('ACCESS_TOKEN'),\n",
|
||||
" os.environ.get('NODE_IDS'),\n",
|
||||
" os.environ.get('FILE_KEY')\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9ccc1e2f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader.load()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3e64cac2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,145 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "34c90eed",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Microsoft Word\n",
|
||||
"\n",
|
||||
"This notebook shows how to load text from Microsoft word documents."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "28ded768",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import UnstructuredDocxLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "f1f26035",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = UnstructuredDocxLoader('example_data/fake.docx')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "2c87dde9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = loader.load()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "0e4a884c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='Lorem ipsum dolor sit amet.', lookup_str='', metadata={'source': 'example_data/fake.docx'}, lookup_index=0)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5d1472e9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Retain Elements\n",
|
||||
"\n",
|
||||
"Under the hood, Unstructured creates different \"elements\" for different chunks of text. By default we combine those together, but you can easily keep that separation by specifying `mode=\"elements\"`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "93abf60b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = UnstructuredDocxLoader('example_data/fake.docx', mode=\"elements\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "c35cdbcc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = loader.load()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "fae2d730",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='Lorem ipsum dolor sit amet.', lookup_str='', metadata={'source': 'example_data/fake.docx'}, lookup_index=0)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "961a7b1d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -27,7 +27,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = UnstructuredWordDocumentLoader(\"fake.docx\")"
|
||||
"loader = UnstructuredWordDocumentLoader(\"example_data/fake.docx\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -78,7 +78,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = UnstructuredWordDocumentLoader(\"fake.docx\", mode=\"elements\")"
|
||||
"loader = UnstructuredWordDocumentLoader(\"example_data/fake.docx\", mode=\"elements\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -21,8 +21,6 @@ There are a lot of different document loaders that LangChain supports. Below are
|
||||
|
||||
`GoogleDrive <./examples/googledrive.html>`_: A walkthrough of how to load data from Google drive.
|
||||
|
||||
`Microsoft Word <./examples/microsoft_word.html>`_: A walkthrough of how to load data from Microsoft Word files.
|
||||
|
||||
`Obsidian <./examples/obsidian.html>`_: A walkthrough of how to load data from an Obsidian file dump.
|
||||
|
||||
`Roam <./examples/roam.html>`_: A walkthrough of how to load data from a Roam file export.
|
||||
@@ -59,6 +57,28 @@ There are a lot of different document loaders that LangChain supports. Below are
|
||||
|
||||
`iFixit <./examples/ifixit.html>`_: A walkthrough of how to search and load data like guides, technical Q&A's, and device wikis from iFixit.com
|
||||
|
||||
`Notebook <./examples/notebook.html>`_: A walkthrough of how to load data from .ipynb notebook.
|
||||
|
||||
`Copypaste <./examples/copypaste.html>`_: A walkthrough of how to load a document object from something you just want to copy and paste.
|
||||
|
||||
`CSV <./examples/csv.html>`_: A walkthrough of how to load data from a .csv file.
|
||||
|
||||
`Facebook Chat <./examples/facebook_chat.html>`_: A walkthrough of how to load data from a Facebook Chat json file.
|
||||
|
||||
`Image <./examples/image.html>`_: A walkthrough of how to load images such as JPGs PNGs into a document format that can be used downstream.
|
||||
|
||||
`Markdown <./examples/markdown.html>`_: A walkthrough of how to load data from a markdown file.
|
||||
|
||||
`SRT <./examples/srt.html>`_: A walkthrough of how to load data from a subtitle (`.srt`) file.
|
||||
|
||||
`Telegram <./examples/telegram.html>`_: A walkthrough of how to load data from a Telegram Chat json file.
|
||||
|
||||
`URL <./examples/url.html>`_: A walkthrough of how to load HTML documents from a list of URLs into a document format that we can use downstream.
|
||||
|
||||
`Word Document <./examples/word_document.html>`_: A walkthrough of how to load data from Microsoft Word files.
|
||||
|
||||
`Blackboard <./examples/blackboard.html>`_: A walkthrough of how to load data from a Blackboard course.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:glob:
|
||||
|
||||
@@ -4,7 +4,7 @@ Indexes
|
||||
Indexes refer to ways to structure documents so that LLMs can best interact with them.
|
||||
This module contains utility functions for working with documents, different types of indexes, and then examples for using those indexes in chains.
|
||||
LangChain provides common indices for working with data (most prominently support for vector databases).
|
||||
For more complicated index structures, it is worth checking out `GPTIndex <https://gpt-index.readthedocs.io/en/latest/index.html>`_.
|
||||
For more complicated index structures, it is worth checking out `LlamaIndex <https://gpt-index.readthedocs.io/en/latest/index.html>`_.
|
||||
|
||||
The following sections of documentation are provided:
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
"from langchain.docstore.document import Document\n",
|
||||
"import requests\n",
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"from langchain.vectorstores import Chromama\n",
|
||||
"from langchain.vectorstores import Chroma\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"import pathlib\n",
|
||||
|
||||
@@ -76,6 +76,129 @@
|
||||
"doc_result = embeddings.embed_documents([text])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bb61bbeb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's load the OpenAI Embedding class with first generation models (e.g. text-search-ada-doc-001/text-search-ada-query-001). Note: These are not recommended models - see [here](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c0b072cc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a56b70f5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embeddings = OpenAIEmbeddings(model_name=\"ada\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "14aefb64",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"text = \"This is a test document.\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3c39ed33",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query_result = embeddings.embed_query(text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e3221db6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"doc_result = embeddings.embed_documents([text])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c3852491",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## AzureOpenAI\n",
|
||||
"\n",
|
||||
"Let's load the OpenAI Embedding class with environment variables set to indicate to use Azure endpoints."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1b40f827",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# set the environment variables needed for openai package to know to reach out to azure\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_TYPE\"] = \"azure\"\n",
|
||||
"os.environ[\"OPENAI_API_BASE\"] = \"https://<your-endpoint.openai.azure.com/\"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"your AzureOpenAI key\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bb36d16c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embeddings = OpenAIEmbeddings(model=\"your-embeddings-deployment-name\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "228abcbb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"text = \"This is a test document.\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "60dd7fad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query_result = embeddings.embed_query(text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "83bc1a72",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"doc_result = embeddings.embed_documents([text])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "42f76e43",
|
||||
@@ -86,6 +209,12 @@
|
||||
"Let's load the Cohere Embedding class."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ca9e2b3a",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
@@ -103,7 +232,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embeddings = CohereEmbeddings(cohere_api_key= cohere_api_key)"
|
||||
"embeddings = CohereEmbeddings(cohere_api_key=cohere_api_key)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -290,7 +419,9 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"embeddings = HuggingFaceInstructEmbeddings(query_instruction=\"Represent the query for retrieval: \")"
|
||||
"embeddings = HuggingFaceInstructEmbeddings(\n",
|
||||
" query_instruction=\"Represent the query for retrieval: \"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -332,9 +463,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings import (\n",
|
||||
" SelfHostedEmbeddings, \n",
|
||||
" SelfHostedHuggingFaceEmbeddings, \n",
|
||||
" SelfHostedHuggingFaceInstructEmbeddings\n",
|
||||
" SelfHostedEmbeddings,\n",
|
||||
" SelfHostedHuggingFaceEmbeddings,\n",
|
||||
" SelfHostedHuggingFaceInstructEmbeddings,\n",
|
||||
")\n",
|
||||
"import runhouse as rh"
|
||||
]
|
||||
@@ -353,7 +484,7 @@
|
||||
"# gpu = rh.cluster(name='rh-a10x', instance_type='g5.2xlarge', provider='aws')\n",
|
||||
"\n",
|
||||
"# For an existing cluster\n",
|
||||
"# gpu = rh.cluster(ips=['<ip of the cluster>'], \n",
|
||||
"# gpu = rh.cluster(ips=['<ip of the cluster>'],\n",
|
||||
"# ssh_creds={'ssh_user': '...', 'ssh_private_key':'<path_to_key>'},\n",
|
||||
"# name='my-cluster')"
|
||||
]
|
||||
@@ -424,16 +555,22 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_pipeline():\n",
|
||||
" from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline # Must be inside the function in notebooks\n",
|
||||
" from transformers import (\n",
|
||||
" AutoModelForCausalLM,\n",
|
||||
" AutoTokenizer,\n",
|
||||
" pipeline,\n",
|
||||
" ) # Must be inside the function in notebooks\n",
|
||||
"\n",
|
||||
" model_id = \"facebook/bart-base\"\n",
|
||||
" tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
|
||||
" model = AutoModelForCausalLM.from_pretrained(model_id)\n",
|
||||
" return pipeline(\"feature-extraction\", model=model, tokenizer=tokenizer)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def inference_fn(pipeline, prompt):\n",
|
||||
" # Return last hidden state of the model\n",
|
||||
" if isinstance(prompt, list):\n",
|
||||
" return [emb[0][-1] for emb in pipeline(prompt)] \n",
|
||||
" return [emb[0][-1] for emb in pipeline(prompt)]\n",
|
||||
" return pipeline(prompt)[0][-1]"
|
||||
]
|
||||
},
|
||||
@@ -445,10 +582,10 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embeddings = SelfHostedEmbeddings(\n",
|
||||
" model_load_fn=get_pipeline, \n",
|
||||
" model_load_fn=get_pipeline,\n",
|
||||
" hardware=gpu,\n",
|
||||
" model_reqs=[\"./\", \"torch\", \"transformers\"],\n",
|
||||
" inference_fn=inference_fn\n",
|
||||
" inference_fn=inference_fn,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@@ -514,12 +651,101 @@
|
||||
"doc_results = embeddings.embed_documents([\"foo\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1f83f273",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## SageMaker Endpoint Embeddings\n",
|
||||
"\n",
|
||||
"Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n",
|
||||
"\n",
|
||||
"For instrucstions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "88d366bd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip3 install langchain boto3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "1e9b926a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Dict\n",
|
||||
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
|
||||
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ContentHandler(ContentHandlerBase):\n",
|
||||
" content_type = \"application/json\"\n",
|
||||
" accepts = \"application/json\"\n",
|
||||
"\n",
|
||||
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
|
||||
" input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n",
|
||||
" return input_str.encode('utf-8')\n",
|
||||
" \n",
|
||||
" def transform_output(self, output: bytes) -> str:\n",
|
||||
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
||||
" return response_json[\"embeddings\"]\n",
|
||||
"\n",
|
||||
"content_handler = ContentHandler()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"embeddings = SagemakerEndpointEmbeddings(\n",
|
||||
" # endpoint_name=\"endpoint-name\", \n",
|
||||
" # credentials_profile_name=\"credentials-profile-name\", \n",
|
||||
" endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\", \n",
|
||||
" region_name=\"us-east-1\", \n",
|
||||
" content_handler=content_handler\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fe9797b8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query_result = embeddings.embed_query(\"foo\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "76f1b752",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"doc_results = embeddings.embed_documents([\"foo\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fff99b21",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"doc_results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aaad49f8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
@@ -543,7 +769,7 @@
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "ce6f9b0d7cdac41515b0e0c38d0e6e153a2edce81d579281cb1ab99da6e8ea6d"
|
||||
"hash": "7377c2ccc78bc62c2683122d48c8cd1fb85a53850a1b1fc29736ed39852c9885"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -176,6 +176,77 @@
|
||||
"docs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3a2f572e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Latex Text Splitter\n",
|
||||
"\n",
|
||||
"LatexTextSplitter splits text along Latex headings, headlines, enumerations and more. It's implemented as a simple subclass of RecursiveCharacterSplitter with Latex-specific separators. See the source code to see the Latex syntax expected by default.\n",
|
||||
"\n",
|
||||
"1. How the text is split: by list of latex specific tags\n",
|
||||
"2. How the chunk size is measured: by length function passed in (defaults to number of characters)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c2503917",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.text_splitter import LatexTextSplitter"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e46b753b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"latex_text = \"\"\"\n",
|
||||
"\\documentclass{article}\n",
|
||||
"\n",
|
||||
"\\begin{document}\n",
|
||||
"\n",
|
||||
"\\maketitle\n",
|
||||
"\n",
|
||||
"\\section{Introduction}\n",
|
||||
"Large language models (LLMs) are a type of machine learning model that can be trained on vast amounts of text data to generate human-like language. In recent years, LLMs have made significant advances in a variety of natural language processing tasks, including language translation, text generation, and sentiment analysis.\n",
|
||||
"\n",
|
||||
"\\subsection{History of LLMs}\n",
|
||||
"The earliest LLMs were developed in the 1980s and 1990s, but they were limited by the amount of data that could be processed and the computational power available at the time. In the past decade, however, advances in hardware and software have made it possible to train LLMs on massive datasets, leading to significant improvements in performance.\n",
|
||||
"\n",
|
||||
"\\subsection{Applications of LLMs}\n",
|
||||
"LLMs have many applications in industry, including chatbots, content creation, and virtual assistants. They can also be used in academia for research in linguistics, psychology, and computational linguistics.\n",
|
||||
"\n",
|
||||
"\\end{document}\n",
|
||||
"\"\"\"\n",
|
||||
"latex_splitter = LatexTextSplitter(chunk_size=400, chunk_overlap=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "73b5bd33",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs = latex_splitter.create_documents([latex_text])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e1c7fbd5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c350765d",
|
||||
|
||||
@@ -7,6 +7,12 @@
|
||||
"source": [
|
||||
"# Getting Started\n",
|
||||
"\n",
|
||||
"By default, LangChain uses [Chroma](../../ecosystem/chroma.md) as the vectorstore to index and search embeddings. To walk through this tutorial, we'll first need to install `chromadb`.\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"pip install chromadb\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"This example showcases question answering over documents.\n",
|
||||
"We have chosen this as the example for getting started because it nicely combines a lot of different elements (Text splitters, embeddings, vectorstores) and then also shows how to use them in a chain.\n",
|
||||
"\n",
|
||||
|
||||
@@ -46,6 +46,8 @@ In the below guides, we cover different types of vectorstores and how to use the
|
||||
|
||||
`Milvus <./vectorstore_examples/milvus.html>`_: A walkthrough of how to use the Milvus vectorstore wrapper.
|
||||
|
||||
`Open Search <./vectorstore_examples/opensearch.html>`_: A walkthrough of how to use the OpenSearch wrapper.
|
||||
|
||||
`Pinecone <./vectorstore_examples/pinecone.html>`_: A walkthrough of how to use the Pinecone vectorstore wrapper.
|
||||
|
||||
`Qdrant <./vectorstore_examples/qdrant.html>`_: A walkthrough of how to use the Qdrant vectorstore wrapper.
|
||||
@@ -96,4 +98,4 @@ The examples here are all end-to-end chains that use indexes or utils covered ab
|
||||
:name: chains
|
||||
:hidden:
|
||||
|
||||
./chain_examples/*
|
||||
./chain_examples/*
|
||||
|
||||
@@ -62,10 +62,6 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. \n",
|
||||
"\n",
|
||||
"We cannot let this happen. \n",
|
||||
"\n",
|
||||
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
|
||||
"\n",
|
||||
"Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
|
||||
@@ -200,10 +196,104 @@
|
||||
"docs[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "57da60d4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Merging\n",
|
||||
"You can also merge two FAISS vectorstores"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "6dfd2b78",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db1 = FAISS.from_texts([\"foo\"], embeddings)\n",
|
||||
"db2 = FAISS.from_texts([\"bar\"], embeddings)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "29960da7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'e0b74348-6c93-4893-8764-943139ec1d17': Document(page_content='foo', lookup_str='', metadata={}, lookup_index=0)}"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db1.docstore._dict"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "83392605",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'bdc50ae3-a1bb-4678-9260-1b0979578f40': Document(page_content='bar', lookup_str='', metadata={}, lookup_index=0)}"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db2.docstore._dict"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "a3fcc1c7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db1.merge_from(db2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "41c51f89",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'e0b74348-6c93-4893-8764-943139ec1d17': Document(page_content='foo', lookup_str='', metadata={}, lookup_index=0),\n",
|
||||
" 'd5211050-c777-493d-8825-4800e74cfdb6': Document(page_content='bar', lookup_str='', metadata={}, lookup_index=0)}"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db1.docstore._dict"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bc8b71f7",
|
||||
"id": "f80b60de",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
|
||||
@@ -8,10 +8,7 @@
|
||||
"This notebook shows how to use functionality related to the Redis database."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -24,10 +21,7 @@
|
||||
"from langchain.vectorstores.redis import Redis"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -44,10 +38,7 @@
|
||||
"embeddings = OpenAIEmbeddings()"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -55,13 +46,10 @@
|
||||
"execution_count": 4,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"rds = Redis.from_documents(docs, embeddings,redis_url=\"redis://localhost:6379\")"
|
||||
"rds = Redis.from_documents(docs, embeddings, redis_url=\"redis://localhost:6379\", index_name='link')"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -81,10 +69,14 @@
|
||||
"rds.index_name"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -115,10 +107,7 @@
|
||||
"print(results[0].page_content)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -137,10 +126,7 @@
|
||||
"print(rds.add_texts([\"Ankush went to Princeton\"]))"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -161,22 +147,23 @@
|
||||
"print(results[0].page_content)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [],
|
||||
"source": [
|
||||
"#Query\n",
|
||||
"rds = Redis.from_existing_index(embeddings, redis_url=\"redis://localhost:6379\", index_name='link')\n",
|
||||
"\n",
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"results = rds.similarity_search(query)\n",
|
||||
"print(results[0].page_content)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
"collapsed": false
|
||||
}
|
||||
}
|
||||
],
|
||||
@@ -201,4 +188,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,10 +119,39 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "05e9e2fe",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
"source": [
|
||||
"## Using PromptLayer Track\n",
|
||||
"If you would like to use any of the [PromptLayer tracking features](https://magniv.notion.site/Track-4deee1b1f7a34c1680d085f82567dab9), you need to pass the argument `return_pl_id` when instantializing the PromptLayer LLM to get the request id. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1a7315b9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = PromptLayerOpenAI(return_pl_id=True)\n",
|
||||
"llm_results = llm.generate([\"Tell me a joke\"])\n",
|
||||
"\n",
|
||||
"for res in llm_results.generations:\n",
|
||||
" pl_request_id = res[0].generation_info[\"pl_request_id\"]\n",
|
||||
" promptlayer.track.score(request_id=pl_request_id, score=100)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "7eb19139",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Using this allows you to track the performance of your model in the PromptLayer dashboard. If you are using a prompt template, you can attach a template to a request as well.\n",
|
||||
"Overall, this gives you the opportunity to track the performance of different templates and models in the PromptLayer dashboard."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -145,7 +174,7 @@
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "c4fe2cd85a8d9e8baaec5340ce66faff1c77581a9f43e6c45e85e09b6fced008"
|
||||
"hash": "8a5edab282632443219e051e4ade2d1d5bbc671c781051bf1437897cbdfea0f1"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
131
docs/modules/llms/integrations/sagemaker.ipynb
Normal file
131
docs/modules/llms/integrations/sagemaker.ipynb
Normal file
@@ -0,0 +1,131 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# SageMakerEndpoint\n",
|
||||
"\n",
|
||||
"This notebooks goes over how to use an LLM hosted on a SageMaker endpoint."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip3 install langchain boto3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.docstore.document import Document"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"example_doc_1 = \"\"\"\n",
|
||||
"Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.\n",
|
||||
"Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.\n",
|
||||
"Therefore, Peter stayed with her at the hospital for 3 days without leaving.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"docs = [\n",
|
||||
" Document(\n",
|
||||
" page_content=example_doc_1,\n",
|
||||
" )\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Dict\n",
|
||||
"\n",
|
||||
"from langchain import PromptTemplate, SagemakerEndpoint\n",
|
||||
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
|
||||
"from langchain.chains.question_answering import load_qa_chain\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"query = \"\"\"How long was Elizabeth hospitalized?\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"prompt_template = \"\"\"Use the following pieces of context to answer the question at the end.\n",
|
||||
"\n",
|
||||
"{context}\n",
|
||||
"\n",
|
||||
"Question: {question}\n",
|
||||
"Answer:\"\"\"\n",
|
||||
"PROMPT = PromptTemplate(\n",
|
||||
" template=prompt_template, input_variables=[\"context\", \"question\"]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"class ContentHandler(ContentHandlerBase):\n",
|
||||
" content_type = \"application/json\"\n",
|
||||
" accepts = \"application/json\"\n",
|
||||
"\n",
|
||||
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
|
||||
" input_str = json.dumps({prompt: prompt, **model_kwargs})\n",
|
||||
" return input_str.encode('utf-8')\n",
|
||||
" \n",
|
||||
" def transform_output(self, output: bytes) -> str:\n",
|
||||
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
||||
" return response_json[0][\"generated_text\"]\n",
|
||||
"\n",
|
||||
"content_handler = ContentHandler()\n",
|
||||
"\n",
|
||||
"chain = load_qa_chain(\n",
|
||||
" llm=SagemakerEndpoint(\n",
|
||||
" endpoint_name=\"endpoint-name\", \n",
|
||||
" credentials_profile_name=\"credentials-profile-name\", \n",
|
||||
" region_name=\"us-west-2\", \n",
|
||||
" model_kwargs={\"temperature\":1e-10},\n",
|
||||
" content_handler=content_handler\n",
|
||||
" ),\n",
|
||||
" prompt=PROMPT\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
288
docs/modules/memory/types/token_buffer.ipynb
Normal file
288
docs/modules/memory/types/token_buffer.ipynb
Normal file
@@ -0,0 +1,288 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ff4be5f3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ConversationTokenBufferMemory\n",
|
||||
"\n",
|
||||
"`ConversationTokenBufferMemory` keeps a buffer of recent interactions in memory, and uses token length rather than number of interactions to determine when to flush interactions.\n",
|
||||
"\n",
|
||||
"Let's first walk through how to use the utilities"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "da3384db",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.memory import ConversationTokenBufferMemory\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"llm = OpenAI()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "e00d4938",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"memory = ConversationTokenBufferMemory(llm=llm, max_token_limit=10)\n",
|
||||
"memory.save_context({\"input\": \"hi\"}, {\"ouput\": \"whats up\"})\n",
|
||||
"memory.save_context({\"input\": \"not much you\"}, {\"ouput\": \"not much\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "2fe28a28",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'history': 'Human: not much you\\nAI: not much'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"memory.load_memory_variables({})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cf57b97a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can also get the history as a list of messages (this is useful if you are using this with a chat model)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "3422a3a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"memory = ConversationTokenBufferMemory(llm=llm, max_token_limit=10, return_messages=True)\n",
|
||||
"memory.save_context({\"input\": \"hi\"}, {\"ouput\": \"whats up\"})\n",
|
||||
"memory.save_context({\"input\": \"not much you\"}, {\"ouput\": \"not much\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a6d2569f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using in a chain\n",
|
||||
"Let's walk through an example, again setting `verbose=True` so we can see the prompt."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "ebd68c10",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
|
||||
"\n",
|
||||
"Current conversation:\n",
|
||||
"\n",
|
||||
"Human: Hi, what's up?\n",
|
||||
"AI:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\" Hi there! I'm doing great, just enjoying the day. How about you?\""
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chains import ConversationChain\n",
|
||||
"conversation_with_summary = ConversationChain(\n",
|
||||
" llm=llm, \n",
|
||||
" # We set a very low max_token_limit for the purposes of testing.\n",
|
||||
" memory=ConversationTokenBufferMemory(llm=OpenAI(), max_token_limit=60),\n",
|
||||
" verbose=True\n",
|
||||
")\n",
|
||||
"conversation_with_summary.predict(input=\"Hi, what's up?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "86207a61",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
|
||||
"\n",
|
||||
"Current conversation:\n",
|
||||
"Human: Hi, what's up?\n",
|
||||
"AI: Hi there! I'm doing great, just enjoying the day. How about you?\n",
|
||||
"Human: Just working on writing some documentation!\n",
|
||||
"AI:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' Sounds like a productive day! What kind of documentation are you writing?'"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"conversation_with_summary.predict(input=\"Just working on writing some documentation!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "76a0ab39",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
|
||||
"\n",
|
||||
"Current conversation:\n",
|
||||
"Human: Hi, what's up?\n",
|
||||
"AI: Hi there! I'm doing great, just enjoying the day. How about you?\n",
|
||||
"Human: Just working on writing some documentation!\n",
|
||||
"AI: Sounds like a productive day! What kind of documentation are you writing?\n",
|
||||
"Human: For LangChain! Have you heard of it?\n",
|
||||
"AI:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\" Yes, I have heard of LangChain! It is a decentralized language-learning platform that connects native speakers and learners in real time. Is that the documentation you're writing about?\""
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"conversation_with_summary.predict(input=\"For LangChain! Have you heard of it?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "8c669db1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
|
||||
"\n",
|
||||
"Current conversation:\n",
|
||||
"Human: For LangChain! Have you heard of it?\n",
|
||||
"AI: Yes, I have heard of LangChain! It is a decentralized language-learning platform that connects native speakers and learners in real time. Is that the documentation you're writing about?\n",
|
||||
"Human: Haha nope, although a lot of people confuse it for that\n",
|
||||
"AI:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\" Oh, I see. Is there another language learning platform you're referring to?\""
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# We can see here that the buffer is updated\n",
|
||||
"conversation_with_summary.predict(input=\"Haha nope, although a lot of people confuse it for that\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8c09a239",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -21,16 +21,17 @@
|
||||
"id": "5d56ce86",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create a custom prompt template\n",
|
||||
"## Creating a Custom Prompt Template\n",
|
||||
"\n",
|
||||
"The only two requirements for all prompt templates are:\n",
|
||||
"There are essentially two distinct prompt templates available - string prompt templates and chat prompt templates. String prompt templates provides a simple prompt in string format, while chat prompt templates produces a more structured prompt to be used with a chat API.\n",
|
||||
"\n",
|
||||
"1. They have a input_variables attribute that exposes what input variables this prompt template expects.\n",
|
||||
"2. They expose a format method which takes in keyword arguments corresponding to the expected input_variables and returns the formatted prompt.\n",
|
||||
"In this guide, we will create a custom prompt using a string prompt template. \n",
|
||||
"\n",
|
||||
"Let's create a custom prompt template that takes in the function name as input, and formats the prompt template to provide the source code of the function.\n",
|
||||
"To create a custom string prompt template, there are two requirements:\n",
|
||||
"1. It has an input_variables attribute that exposes what input variables the prompt template expects.\n",
|
||||
"2. It exposes a format method that takes in keyword arguments corresponding to the expected input_variables and returns the formatted prompt.\n",
|
||||
"\n",
|
||||
"First, let's create a function that will return the source code of a function given its name."
|
||||
"We will create a custom prompt template that takes in the function name as input and formats the prompt to provide the source code of the function. To achieve this, let's first create a function that will return the source code of a function given its name."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -62,11 +63,11 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts import BasePromptTemplate\n",
|
||||
"from langchain.prompts import StringPromptTemplate\n",
|
||||
"from pydantic import BaseModel, validator\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class FunctionExplainerPromptTemplate(BasePromptTemplate, BaseModel):\n",
|
||||
"class FunctionExplainerPromptTemplate(StringPromptTemplate, BaseModel):\n",
|
||||
" \"\"\" A custom prompt template that takes in the function name as input, and formats the prompt template to provide the source code of the function. \"\"\"\n",
|
||||
"\n",
|
||||
" @validator(\"input_variables\")\n",
|
||||
|
||||
@@ -14,9 +14,459 @@
|
||||
"- `get_format_instructions() -> str`: A method which returns a string containing instructions for how the output of a language model should be formatted.\n",
|
||||
"- `parse(str) -> Any`: A method which takes in a string (assumed to be the response from a language model) and parses it into some structure.\n",
|
||||
"\n",
|
||||
"And then one optional one:\n",
|
||||
"\n",
|
||||
"- `parse_with_prompt(str) -> Any`: A method which takes in a string (assumed to be the response from a language model) and a prompt (assumed to the prompt that generated such a response) and parses it into some structure. The prompt is largely provided in the event the OutputParser wants to retry or fix the output in some way, and needs information from the prompt to do so.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Below we go over some examples of output parsers."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "5f0c8a33",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.chat_models import ChatOpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a1ae632a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## PydanticOutputParser\n",
|
||||
"This output parser allows users to specify an arbitrary JSON schema and query LLMs for JSON outputs that conform to that schema.\n",
|
||||
"\n",
|
||||
"Keep in mind that large language models are leaky abstractions! You'll have to use an LLM with sufficient capacity to generate well-formed JSON. In the OpenAI family, DaVinci can do reliably but Curie's ability already drops off dramatically. \n",
|
||||
"\n",
|
||||
"Use Pydantic to declare your data model. Pydantic's BaseModel like a Python dataclass, but with actual type checking + coercion."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "cba6d8e3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.output_parsers import PydanticOutputParser\n",
|
||||
"from pydantic import BaseModel, Field, validator\n",
|
||||
"from typing import List"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "0a203100",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_name = 'text-davinci-003'\n",
|
||||
"temperature = 0.0\n",
|
||||
"model = OpenAI(model_name=model_name, temperature=temperature)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "b3f16168",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Joke(setup='Why did the chicken cross the road?', punchline='To get to the other side!')"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Define your desired data structure.\n",
|
||||
"class Joke(BaseModel):\n",
|
||||
" setup: str = Field(description=\"question to set up a joke\")\n",
|
||||
" punchline: str = Field(description=\"answer to resolve the joke\")\n",
|
||||
" \n",
|
||||
" # You can add custom validation logic easily with Pydantic.\n",
|
||||
" @validator('setup')\n",
|
||||
" def question_ends_with_question_mark(cls, field):\n",
|
||||
" if field[-1] != '?':\n",
|
||||
" raise ValueError(\"Badly formed question!\")\n",
|
||||
" return field\n",
|
||||
"\n",
|
||||
"# And a query intented to prompt a language model to populate the data structure.\n",
|
||||
"joke_query = \"Tell me a joke.\"\n",
|
||||
"\n",
|
||||
"# Set up a parser + inject instructions into the prompt template.\n",
|
||||
"parser = PydanticOutputParser(pydantic_object=Joke)\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate(\n",
|
||||
" template=\"Answer the user query.\\n{format_instructions}\\n{query}\\n\",\n",
|
||||
" input_variables=[\"query\"],\n",
|
||||
" partial_variables={\"format_instructions\": parser.get_format_instructions()}\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"_input = prompt.format_prompt(query=joke_query)\n",
|
||||
"\n",
|
||||
"output = model(_input.to_string())\n",
|
||||
"\n",
|
||||
"parser.parse(output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "03049f88",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Actor(name='Tom Hanks', film_names=['Forrest Gump', 'Saving Private Ryan', 'The Green Mile', 'Cast Away', 'Toy Story'])"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Here's another example, but with a compound typed field.\n",
|
||||
"class Actor(BaseModel):\n",
|
||||
" name: str = Field(description=\"name of an actor\")\n",
|
||||
" film_names: List[str] = Field(description=\"list of names of films they starred in\")\n",
|
||||
" \n",
|
||||
"actor_query = \"Generate the filmography for a random actor.\"\n",
|
||||
"\n",
|
||||
"parser = PydanticOutputParser(pydantic_object=Actor)\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate(\n",
|
||||
" template=\"Answer the user query.\\n{format_instructions}\\n{query}\\n\",\n",
|
||||
" input_variables=[\"query\"],\n",
|
||||
" partial_variables={\"format_instructions\": parser.get_format_instructions()}\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"_input = prompt.format_prompt(query=actor_query)\n",
|
||||
"\n",
|
||||
"output = model(_input.to_string())\n",
|
||||
"\n",
|
||||
"parser.parse(output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4d6c0c86",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Fixing Output Parsing Mistakes\n",
|
||||
"\n",
|
||||
"The above guardrail simply tries to parse the LLM response. If it does not parse correctly, then it errors.\n",
|
||||
"\n",
|
||||
"But we can do other things besides throw errors. Specifically, we can pass the misformatted output, along with the formatted instructions, to the model and ask it to fix it.\n",
|
||||
"\n",
|
||||
"For this example, we'll use the above OutputParser. Here's what happens if we pass it a result that does not comply with the schema:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "73beb20d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"misformatted = \"{'name': 'Tom Hanks', 'film_names': ['Forrest Gump']}\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "f0e5ba80",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "OutputParserException",
|
||||
"evalue": "Failed to parse Actor from completion {'name': 'Tom Hanks', 'film_names': ['Forrest Gump']}. Got: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mJSONDecodeError\u001b[0m Traceback (most recent call last)",
|
||||
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:23\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 22\u001b[0m json_str \u001b[38;5;241m=\u001b[39m match\u001b[38;5;241m.\u001b[39mgroup()\n\u001b[0;32m---> 23\u001b[0m json_object \u001b[38;5;241m=\u001b[39m \u001b[43mjson\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloads\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjson_str\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpydantic_object\u001b[38;5;241m.\u001b[39mparse_obj(json_object)\n",
|
||||
"File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/json/__init__.py:346\u001b[0m, in \u001b[0;36mloads\u001b[0;34m(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)\u001b[0m\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 344\u001b[0m parse_int \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m parse_float \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 345\u001b[0m parse_constant \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_pairs_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m kw):\n\u001b[0;32m--> 346\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_default_decoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
||||
"File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/json/decoder.py:337\u001b[0m, in \u001b[0;36mJSONDecoder.decode\u001b[0;34m(self, s, _w)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Return the Python representation of ``s`` (a ``str`` instance\u001b[39;00m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;124;03mcontaining a JSON document).\u001b[39;00m\n\u001b[1;32m 335\u001b[0m \n\u001b[1;32m 336\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m--> 337\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mraw_decode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_w\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mend\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 338\u001b[0m end \u001b[38;5;241m=\u001b[39m _w(s, end)\u001b[38;5;241m.\u001b[39mend()\n",
|
||||
"File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/json/decoder.py:353\u001b[0m, in \u001b[0;36mJSONDecoder.raw_decode\u001b[0;34m(self, s, idx)\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 353\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscan_once\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
|
||||
"\u001b[0;31mJSONDecodeError\u001b[0m: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)",
|
||||
"\nDuring handling of the above exception, another exception occurred:\n",
|
||||
"\u001b[0;31mOutputParserException\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mparser\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmisformatted\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:29\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 27\u001b[0m name \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpydantic_object\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\n\u001b[1;32m 28\u001b[0m msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to parse \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m from completion \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtext\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m OutputParserException(msg)\n",
|
||||
"\u001b[0;31mOutputParserException\u001b[0m: Failed to parse Actor from completion {'name': 'Tom Hanks', 'film_names': ['Forrest Gump']}. Got: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"parser.parse(misformatted)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6c7c82b6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we can construct and use a `OutputFixingParser`. This output parser takes as an argument another output parser but also an LLM with which to try to correct any formatting mistakes."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "39b1a5ce",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.output_parsers import OutputFixingParser\n",
|
||||
"\n",
|
||||
"new_parser = OutputFixingParser.from_llm(parser=parser, llm=ChatOpenAI())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "0fd96d68",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Actor(name='Tom Hanks', film_names=['Forrest Gump'])"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"new_parser.parse(misformatted)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ea34eeaa",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Fixing Output Parsing Mistakes with the original prompt\n",
|
||||
"\n",
|
||||
"While in some cases it is possible to fix any parsing mistakes by only looking at the output, in other cases it can't. An example of this is when the output is not just in the incorrect format, but is partially complete. Consider the below example."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "67c5e1ac",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"template = \"\"\"Based on the user question, provide an Action and Action Input for what step should be taken.\n",
|
||||
"{format_instructions}\n",
|
||||
"Question: {query}\n",
|
||||
"Response:\"\"\"\n",
|
||||
"class Action(BaseModel):\n",
|
||||
" action: str = Field(description=\"action to take\")\n",
|
||||
" action_input: str = Field(description=\"input to the action\")\n",
|
||||
" \n",
|
||||
"parser = PydanticOutputParser(pydantic_object=Action)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "007aa87f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prompt = PromptTemplate(\n",
|
||||
" template=\"Answer the user query.\\n{format_instructions}\\n{query}\\n\",\n",
|
||||
" input_variables=[\"query\"],\n",
|
||||
" partial_variables={\"format_instructions\": parser.get_format_instructions()}\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "10d207ff",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prompt_value = prompt.format_prompt(query=\"who is leo di caprios gf?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "68622837",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"bad_response = '{\"action\": \"search\"}'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "25631465",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we try to parse this response as is, we will get an error"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "894967c1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "OutputParserException",
|
||||
"evalue": "Failed to parse Action from completion {\"action\": \"search\"}. Got: 1 validation error for Action\naction_input\n field required (type=value_error.missing)",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)",
|
||||
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:24\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 23\u001b[0m json_object \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(json_str)\n\u001b[0;32m---> 24\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpydantic_object\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse_obj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjson_object\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (json\u001b[38;5;241m.\u001b[39mJSONDecodeError, ValidationError) \u001b[38;5;28;01mas\u001b[39;00m e:\n",
|
||||
"File \u001b[0;32m~/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/pydantic/main.py:527\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.parse_obj\u001b[0;34m()\u001b[0m\n",
|
||||
"File \u001b[0;32m~/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/pydantic/main.py:342\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.__init__\u001b[0;34m()\u001b[0m\n",
|
||||
"\u001b[0;31mValidationError\u001b[0m: 1 validation error for Action\naction_input\n field required (type=value_error.missing)",
|
||||
"\nDuring handling of the above exception, another exception occurred:\n",
|
||||
"\u001b[0;31mOutputParserException\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mparser\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbad_response\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:29\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 27\u001b[0m name \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpydantic_object\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\n\u001b[1;32m 28\u001b[0m msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to parse \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m from completion \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtext\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m OutputParserException(msg)\n",
|
||||
"\u001b[0;31mOutputParserException\u001b[0m: Failed to parse Action from completion {\"action\": \"search\"}. Got: 1 validation error for Action\naction_input\n field required (type=value_error.missing)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"parser.parse(bad_response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f6b64696",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we try to use the `OutputFixingParser` to fix this error, it will be confused - namely, it doesn't know what to actually put for action input."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "78b2b40d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fix_parser = OutputFixingParser.from_llm(parser=parser, llm=ChatOpenAI())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "4fe1301d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Action(action='search', action_input='keyword')"
|
||||
]
|
||||
},
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"fix_parser.parse(bad_response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9bd9ea7d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Instead, we can use the RetryOutputParser, which passes in the prompt (as well as the original output) to try again to get a better response."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "7e8a8a28",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.output_parsers import RetryWithErrorOutputParser"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "5c86e141",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retry_parser = RetryWithErrorOutputParser.from_llm(parser=parser, llm=ChatOpenAI())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"id": "9c04f731",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Action(action='search', action_input='leo di caprios girlfriend')"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"retry_parser.parse_with_prompt(bad_response, prompt_value)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "61f67890",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<br>\n",
|
||||
"<br>\n",
|
||||
"<br>\n",
|
||||
"<br>\n",
|
||||
"<br>\n",
|
||||
"<br>\n",
|
||||
"<br>\n",
|
||||
"<br>\n",
|
||||
"\n",
|
||||
"---"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "64bf525a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Older, less powerful parsers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "91871002",
|
||||
@@ -24,12 +474,12 @@
|
||||
"source": [
|
||||
"## Structured Output Parser\n",
|
||||
"\n",
|
||||
"This output parser can be used when you want to return multiple fields."
|
||||
"While the Pydantic/JSON parser is more powerful, we initially experimented data structures having text fields only."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 16,
|
||||
"id": "b492997a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -37,18 +487,6 @@
|
||||
"from langchain.output_parsers import StructuredOutputParser, ResponseSchema"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "ffb7fc57",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.chat_models import ChatOpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "09473dce",
|
||||
@@ -59,7 +497,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 17,
|
||||
"id": "432ac44a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -81,7 +519,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 18,
|
||||
"id": "593cfc25",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -104,7 +542,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 19,
|
||||
"id": "106f1ba6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -114,7 +552,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 20,
|
||||
"id": "86d9d24f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -125,7 +563,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 21,
|
||||
"id": "956bdc99",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -135,7 +573,7 @@
|
||||
"{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -154,7 +592,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 22,
|
||||
"id": "8f483d7d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -164,7 +602,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 23,
|
||||
"id": "f761cbf1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -180,7 +618,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 24,
|
||||
"id": "edd73ae3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -191,7 +629,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 25,
|
||||
"id": "a3c8b91e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -201,7 +639,7 @@
|
||||
"{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -217,12 +655,12 @@
|
||||
"source": [
|
||||
"## CommaSeparatedListOutputParser\n",
|
||||
"\n",
|
||||
"This output parser can be used to get a list of items as output."
|
||||
"Here's another parser strictly less powerful than Pydantic/JSON parsing."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 26,
|
||||
"id": "872246d7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -232,7 +670,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 27,
|
||||
"id": "c3f9aee6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -242,7 +680,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 28,
|
||||
"id": "e77871b7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -257,7 +695,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 29,
|
||||
"id": "a71cb5d3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -267,7 +705,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 30,
|
||||
"id": "783d7d98",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -278,7 +716,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 31,
|
||||
"id": "fcb81344",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -292,7 +730,7 @@
|
||||
" 'Cookies and Cream']"
|
||||
]
|
||||
},
|
||||
"execution_count": 17,
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -300,14 +738,6 @@
|
||||
"source": [
|
||||
"output_parser.parse(output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cba6d8e3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -120,6 +120,25 @@
|
||||
"!cat simple_prompt.json"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "de75e959",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prompt = load_prompt(\"simple_prompt.json\")\n",
|
||||
"print(prompt.format(adjective=\"funny\", content=\"chickens\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d1d788f9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Tell me a funny joke about chickens."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d788a83c",
|
||||
|
||||
@@ -121,7 +121,8 @@
|
||||
"tools = [\n",
|
||||
" Tool(\n",
|
||||
" name=\"Intermediate Answer\",\n",
|
||||
" func=search.run\n",
|
||||
" func=search.run,\n",
|
||||
" description=\"useful for when you need to ask with search\"\n",
|
||||
" )\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"source": [
|
||||
"## Zapier Natural Language Actions API\n",
|
||||
"\\\n",
|
||||
"Full docs here: https://nla.zapier.com/api/v1/dynamic/docs\n",
|
||||
"Full docs here: https://nla.zapier.com/api/v1/docs\n",
|
||||
"\n",
|
||||
"**Zapier Natural Language Actions** gives you access to the 5k+ apps, 20k+ actions on Zapier's platform through a natural language API interface.\n",
|
||||
"\n",
|
||||
@@ -21,7 +21,7 @@
|
||||
"\n",
|
||||
"2. User-facing (Oauth): for production scenarios where you are deploying an end-user facing application and LangChain needs access to end-user's exposed actions and connected accounts on Zapier.com\n",
|
||||
"\n",
|
||||
"This quick start will focus on the server-side use case for brevity. Review [full docs](https://nla.zapier.com/api/v1/dynamic/docs) or reach out to nla@zapier.com for user-facing oauth developer support.\n",
|
||||
"This quick start will focus on the server-side use case for brevity. Review [full docs](https://nla.zapier.com/api/v1/docs) or reach out to nla@zapier.com for user-facing oauth developer support.\n",
|
||||
"\n",
|
||||
"This example goes over how to use the Zapier integration with a `SimpleSequentialChain`, then an `Agent`.\n",
|
||||
"In code, below:"
|
||||
|
||||
@@ -76,7 +76,7 @@ Examples of vector database companies include [Pinecone](https://www.pinecone.io
|
||||
|
||||
Although this is perhaps the most common way of document retrieval, people are starting to think about alternative
|
||||
data structures and indexing techniques specifically for working with language models. For a leading example of this,
|
||||
check out [GPT Index](https://github.com/jerryjliu/gpt_index) - a collection of data structures created by and optimized
|
||||
check out [LlamaIndex](https://github.com/jerryjliu/llama_index) - a collection of data structures created by and optimized
|
||||
for language models.
|
||||
|
||||
## Augmenting
|
||||
|
||||
306
docs/use_cases/evaluation/llm_math.ipynb
Normal file
306
docs/use_cases/evaluation/llm_math.ipynb
Normal file
@@ -0,0 +1,306 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a4734146",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# LLM Math\n",
|
||||
"\n",
|
||||
"Evaluating chains that know how to do math."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "fdd7afae",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Comment this out if you are NOT using tracing\n",
|
||||
"import os\n",
|
||||
"os.environ[\"LANGCHAIN_HANDLER\"] = \"langchain\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "ce05ffea",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "d028a511cede4de2b845b9a9954d6bea",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Downloading readme: 0%| | 0.00/21.0 [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Downloading and preparing dataset json/LangChainDatasets--llm-math to /Users/harrisonchase/.cache/huggingface/datasets/LangChainDatasets___json/LangChainDatasets--llm-math-509b11d101165afa/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "a71c8e5a21dd4da5a20a354b544f7a58",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Downloading data files: 0%| | 0/1 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "ae530ca624154a1a934075c47d1093a6",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Downloading data: 0%| | 0.00/631 [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "7a4968df05d84bc483aa2c5039aecafe",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Extracting data files: 0%| | 0/1 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Generating train split: 0 examples [00:00, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Dataset json downloaded and prepared to /Users/harrisonchase/.cache/huggingface/datasets/LangChainDatasets___json/LangChainDatasets--llm-math-509b11d101165afa/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51. Subsequent calls will reuse this data.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "9a2caed96225410fb1cc0f8f155eb766",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/1 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.evaluation.loading import load_dataset\n",
|
||||
"dataset = load_dataset(\"llm-math\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8a998d6f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setting up a chain\n",
|
||||
"Now we need to create some pipelines for doing math."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "7078f7f8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.chains import LLMMathChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "2bd70c46",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = OpenAI()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "954c3270",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = LLMMathChain(llm=llm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "f252027e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"predictions = chain.apply(dataset)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"id": "c8af7041",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"numeric_output = [float(p['answer'].strip().strip(\"Answer: \")) for p in predictions]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"id": "cc09ffe4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"correct = [example['answer'] == numeric_output[i] for i, example in enumerate(dataset)]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"id": "585244e4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"1.0"
|
||||
]
|
||||
},
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sum(correct) / len(correct)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"id": "0d14ac78",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"input: 5\n",
|
||||
"expected output : 5.0\n",
|
||||
"prediction: 5.0\n",
|
||||
"input: 5 + 3\n",
|
||||
"expected output : 8.0\n",
|
||||
"prediction: 8.0\n",
|
||||
"input: 2^3.171\n",
|
||||
"expected output : 9.006708689094099\n",
|
||||
"prediction: 9.006708689094099\n",
|
||||
"input: 2 ^3.171 \n",
|
||||
"expected output : 9.006708689094099\n",
|
||||
"prediction: 9.006708689094099\n",
|
||||
"input: two to the power of three point one hundred seventy one\n",
|
||||
"expected output : 9.006708689094099\n",
|
||||
"prediction: 9.006708689094099\n",
|
||||
"input: five + three squared minus 1\n",
|
||||
"expected output : 13.0\n",
|
||||
"prediction: 13.0\n",
|
||||
"input: 2097 times 27.31\n",
|
||||
"expected output : 57269.07\n",
|
||||
"prediction: 57269.07\n",
|
||||
"input: two thousand ninety seven times twenty seven point thirty one\n",
|
||||
"expected output : 57269.07\n",
|
||||
"prediction: 57269.07\n",
|
||||
"input: 209758 / 2714\n",
|
||||
"expected output : 77.28739867354459\n",
|
||||
"prediction: 77.28739867354459\n",
|
||||
"input: 209758.857 divided by 2714.31\n",
|
||||
"expected output : 77.27888745205964\n",
|
||||
"prediction: 77.27888745205964\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for i, example in enumerate(dataset):\n",
|
||||
" print(\"input: \", example[\"question\"])\n",
|
||||
" print(\"expected output :\", example[\"answer\"])\n",
|
||||
" print(\"prediction: \", numeric_output[i])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b9021ffd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
20
docs/use_cases/extraction.md
Normal file
20
docs/use_cases/extraction.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# Extraction
|
||||
|
||||
Most APIs and databases still deal with structured information.
|
||||
Therefore, in order to better work with those, it can be useful to extract structured information from text.
|
||||
Examples of this include:
|
||||
|
||||
- Extracting a structured row to insert into a database from a sentence
|
||||
- Extracting multiple rows to insert into a database from a long document
|
||||
- Extracting the correct API parameters from a user query
|
||||
|
||||
This work is extremely related to [output parsing](../modules/prompts/examples/output_parsers.ipynb).
|
||||
Output parsers are responsible for instructing the LLM to respond in a specific format.
|
||||
In this case, the output parsers specify the format of the data you would like to extract from the document.
|
||||
Then, in addition to the output format instructions, the prompt should also contain the data you would like to extract information from.
|
||||
|
||||
While normal output parsers are good enough for basic structuring of response data,
|
||||
when doing extraction you often want to extract more complicated or nested structures.
|
||||
For a deep dive on extraction, we recommend checking out [`kor`](https://eyurtsev.github.io/kor/),
|
||||
a library that uses the existing LangChain chain and OutputParser abstractions
|
||||
but deep dives on allowing extraction of more complicated schemas.
|
||||
31
docs/use_cases/tabular.md
Normal file
31
docs/use_cases/tabular.md
Normal file
@@ -0,0 +1,31 @@
|
||||
# Querying Tabular Data
|
||||
|
||||
Lots of data and information is stored in tabular data, whether it be csvs, excel sheets, or SQL tables.
|
||||
This page covers all resources available in LangChain for working with data in this format.
|
||||
|
||||
## Document Loading
|
||||
If you have text data stored in a tabular format, you may want to load the data into a Document and then index it as you would
|
||||
other text/unstructured data. For this, you should use a document loader like the [CSVLoader](../modules/document_loaders/examples/csv.ipynb)
|
||||
and then you should [create an index](../modules/indexes.rst) over that data, and [query it that way](../modules/indexes/chain_examples/vector_db_qa.ipynb).
|
||||
|
||||
## Querying
|
||||
If you have more numeric tabular data, or have a large amount of data and don't want to index it, you should get started
|
||||
by looking at various chains and agents we have for dealing with this data.
|
||||
|
||||
### Chains
|
||||
|
||||
If you are just getting started, and you have relatively small/simple tabular data, you should get started with chains.
|
||||
Chains are a sequence of predetermined steps, so they are good to get started with as they give you more control and let you
|
||||
understand what is happening better.
|
||||
|
||||
- [SQL Database Chain](../modules/chains/examples/sqlite.ipynb)
|
||||
|
||||
### Agents
|
||||
|
||||
Agents are more complex, and involve multiple queries to the LLM to understand what to do.
|
||||
The downside of agents are that you have less control. The upside is that they are more powerful,
|
||||
which allows you to use them on larger databases and more complex schemas.
|
||||
|
||||
- [SQL Agent](../modules/agents/agent_toolkits/sql_database.ipynb)
|
||||
- [Pandas Agent](../modules/agents/agent_toolkits/pandas.ipynb)
|
||||
- [CSV Agent](../modules/agents/agent_toolkits/csv.ipynb)
|
||||
@@ -33,6 +33,7 @@ from langchain.llms import (
|
||||
Modal,
|
||||
OpenAI,
|
||||
Petals,
|
||||
SagemakerEndpoint,
|
||||
StochasticAI,
|
||||
Writer,
|
||||
)
|
||||
@@ -90,6 +91,7 @@ __all__ = [
|
||||
"ReActChain",
|
||||
"Wikipedia",
|
||||
"HuggingFaceHub",
|
||||
"SagemakerEndpoint",
|
||||
"HuggingFacePipeline",
|
||||
"SQLDatabase",
|
||||
"SQLDatabaseChain",
|
||||
|
||||
@@ -453,9 +453,15 @@ class AgentExecutor(Chain, BaseModel):
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
return output
|
||||
self.callback_manager.on_agent_action(
|
||||
output, verbose=self.verbose, color="green"
|
||||
)
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_agent_action(
|
||||
output, verbose=self.verbose, color="green"
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_agent_action(
|
||||
output, verbose=self.verbose, color="green"
|
||||
)
|
||||
|
||||
# Otherwise we lookup the tool
|
||||
if output.tool in name_to_tool_map:
|
||||
tool = name_to_tool_map[output.tool]
|
||||
|
||||
@@ -13,7 +13,6 @@ from langchain.agents.conversational_chat.prompt import (
|
||||
)
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.output_parsers.base import BaseOutputParser
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
@@ -26,6 +25,7 @@ from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
BaseOutputParser,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain.tools.base import BaseTool
|
||||
@@ -39,6 +39,8 @@ class AgentOutputParser(BaseOutputParser):
|
||||
cleaned_output = text.strip()
|
||||
if "```json" in cleaned_output:
|
||||
_, cleaned_output = cleaned_output.split("```json")
|
||||
if "```" in cleaned_output:
|
||||
cleaned_output, _ = cleaned_output.split("```")
|
||||
if cleaned_output.startswith("```json"):
|
||||
cleaned_output = cleaned_output[len("```json") :]
|
||||
if cleaned_output.startswith("```"):
|
||||
|
||||
@@ -27,7 +27,9 @@ def initialize_agent(
|
||||
`react-docstore`
|
||||
`self-ask-with-search`
|
||||
`conversational-react-description`
|
||||
If None and agent_path is also None, will default to
|
||||
`chat-zero-shot-react-description`,
|
||||
`chat-conversational-react-description`,
|
||||
If None and agent_path is also None, will default to
|
||||
`zero-shot-react-description`.
|
||||
callback_manager: CallbackManager to use. Global callback manager is used if
|
||||
not provided. Defaults to None.
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, List, Optional
|
||||
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.api import news_docs, open_meteo_docs, tmdb_docs
|
||||
from langchain.chains.api import news_docs, open_meteo_docs, tmdb_docs, podcast_docs
|
||||
from langchain.chains.api.base import APIChain
|
||||
from langchain.chains.llm_math.base import LLMMathChain
|
||||
from langchain.chains.pal.base import PALChain
|
||||
@@ -13,6 +13,7 @@ from langchain.requests import RequestsWrapper
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.bing_search.tool import BingSearchRun
|
||||
from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearchRun
|
||||
from langchain.tools.human.tool import HumanInputRun
|
||||
from langchain.tools.python.tool import PythonREPLTool
|
||||
from langchain.tools.requests.tool import RequestsGetTool
|
||||
from langchain.tools.wikipedia.tool import WikipediaQueryRun
|
||||
@@ -118,6 +119,20 @@ def _get_tmdb_api(llm: BaseLLM, **kwargs: Any) -> BaseTool:
|
||||
)
|
||||
|
||||
|
||||
def _get_podcast_api(llm: BaseLLM, **kwargs: Any) -> BaseTool:
|
||||
listen_api_key = kwargs["listen_api_key"]
|
||||
chain = APIChain.from_llm_and_api_docs(
|
||||
llm,
|
||||
podcast_docs.PODCAST_DOCS,
|
||||
headers={"X-ListenAPI-Key": listen_api_key},
|
||||
)
|
||||
return Tool(
|
||||
name="Podcast API",
|
||||
description="Use the Listen Notes Podcast API to search all podcasts or episodes. The input should be a question in natural language that this API can answer.",
|
||||
func=chain.run,
|
||||
)
|
||||
|
||||
|
||||
def _get_wolfram_alpha(**kwargs: Any) -> BaseTool:
|
||||
return WolframAlphaQueryRun(api_wrapper=WolframAlphaAPIWrapper(**kwargs))
|
||||
|
||||
@@ -163,9 +178,14 @@ def _get_bing_search(**kwargs: Any) -> BaseTool:
|
||||
return BingSearchRun(api_wrapper=BingSearchAPIWrapper(**kwargs))
|
||||
|
||||
|
||||
def _get_human_tool(**kwargs: Any) -> BaseTool:
|
||||
return HumanInputRun(**kwargs)
|
||||
|
||||
|
||||
_EXTRA_LLM_TOOLS = {
|
||||
"news-api": (_get_news_api, ["news_api_key"]),
|
||||
"tmdb-api": (_get_tmdb_api, ["tmdb_bearer_token"]),
|
||||
"podcast-api": (_get_podcast_api, ["listen_api_key"]),
|
||||
}
|
||||
|
||||
_EXTRA_OPTIONAL_TOOLS = {
|
||||
@@ -180,6 +200,7 @@ _EXTRA_OPTIONAL_TOOLS = {
|
||||
"serpapi": (_get_serpapi, ["serpapi_api_key", "aiosession"]),
|
||||
"searx-search": (_get_searx_search, ["searx_host"]),
|
||||
"wikipedia": (_get_wikipedia, ["top_k_results"]),
|
||||
"human": (_get_human_tool, ["prompt_func", "input_func"]),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -3,11 +3,16 @@ import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager
|
||||
from langchain.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.tracers import SharedLangChainTracer
|
||||
from langchain.callbacks.wandb_callback import WandbCallbackHandler
|
||||
|
||||
|
||||
def get_callback_manager() -> BaseCallbackManager:
|
||||
@@ -58,3 +63,17 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
||||
manager.add_handler(handler)
|
||||
yield handler
|
||||
manager.remove_handler(handler)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CallbackManager",
|
||||
"OpenAICallbackHandler",
|
||||
"SharedCallbackManager",
|
||||
"StdOutCallbackHandler",
|
||||
"WandbCallbackHandler",
|
||||
"get_openai_callback",
|
||||
"set_tracing_callback_manager",
|
||||
"set_default_callback_manager",
|
||||
"set_handler",
|
||||
"get_callback_manager",
|
||||
]
|
||||
|
||||
819
langchain/callbacks/wandb_callback.py
Normal file
819
langchain/callbacks/wandb_callback.py
Normal file
@@ -0,0 +1,819 @@
|
||||
import hashlib
|
||||
import json
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
def import_wandb() -> Any:
|
||||
try:
|
||||
import wandb # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the wandb callback manager you need to have the `wandb` python "
|
||||
"package installed. Please install it with `pip install wandb`"
|
||||
)
|
||||
return wandb
|
||||
|
||||
|
||||
def import_spacy() -> Any:
|
||||
try:
|
||||
import spacy # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the wandb callback manager you need to have the `spacy` python "
|
||||
"package installed. Please install it with `pip install spacy`"
|
||||
)
|
||||
return spacy
|
||||
|
||||
|
||||
def import_pandas() -> Any:
|
||||
try:
|
||||
import pandas # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the wandb callback manager you need to have the `pandas` python "
|
||||
"package installed. Please install it with `pip install pandas`"
|
||||
)
|
||||
return pandas
|
||||
|
||||
|
||||
def import_textstat() -> Any:
|
||||
try:
|
||||
import textstat # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the wandb callback manager you need to have the `textstat` python "
|
||||
"package installed. Please install it with `pip install textstat`"
|
||||
)
|
||||
return textstat
|
||||
|
||||
|
||||
def _flatten_dict(
|
||||
nested_dict: Dict[str, Any], parent_key: str = "", sep: str = "_"
|
||||
) -> Iterable[Tuple[str, Any]]:
|
||||
"""
|
||||
Generator that yields flattened items from a nested dictionary for a flat dict.
|
||||
|
||||
Parameters:
|
||||
nested_dict (dict): The nested dictionary to flatten.
|
||||
parent_key (str): The prefix to prepend to the keys of the flattened dict.
|
||||
sep (str): The separator to use between the parent key and the key of the
|
||||
flattened dictionary.
|
||||
|
||||
Yields:
|
||||
(str, any): A key-value pair from the flattened dictionary.
|
||||
"""
|
||||
for key, value in nested_dict.items():
|
||||
new_key = parent_key + sep + key if parent_key else key
|
||||
if isinstance(value, dict):
|
||||
yield from _flatten_dict(value, new_key, sep)
|
||||
else:
|
||||
yield new_key, value
|
||||
|
||||
|
||||
def flatten_dict(
|
||||
nested_dict: Dict[str, Any], parent_key: str = "", sep: str = "_"
|
||||
) -> Dict[str, Any]:
|
||||
"""Flattens a nested dictionary into a flat dictionary.
|
||||
|
||||
Parameters:
|
||||
nested_dict (dict): The nested dictionary to flatten.
|
||||
parent_key (str): The prefix to prepend to the keys of the flattened dict.
|
||||
sep (str): The separator to use between the parent key and the key of the
|
||||
flattened dictionary.
|
||||
|
||||
Returns:
|
||||
(dict): A flat dictionary.
|
||||
|
||||
"""
|
||||
flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
|
||||
return flat_dict
|
||||
|
||||
|
||||
def hash_string(s: str) -> str:
|
||||
"""Hash a string using sha1.
|
||||
|
||||
Parameters:
|
||||
s (str): The string to hash.
|
||||
|
||||
Returns:
|
||||
(str): The hashed string.
|
||||
"""
|
||||
return hashlib.sha1(s.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def load_json_to_dict(json_path: Union[str, Path]) -> dict:
|
||||
"""Load json file to a dictionary.
|
||||
|
||||
Parameters:
|
||||
json_path (str): The path to the json file.
|
||||
|
||||
Returns:
|
||||
(dict): The dictionary representation of the json file.
|
||||
"""
|
||||
with open(json_path, "r") as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
def analyze_text(
|
||||
text: str,
|
||||
complexity_metrics: bool = True,
|
||||
visualize: bool = True,
|
||||
nlp: Any = None,
|
||||
output_dir: Optional[Union[str, Path]] = None,
|
||||
) -> dict:
|
||||
"""Analyze text using textstat and spacy.
|
||||
|
||||
Parameters:
|
||||
text (str): The text to analyze.
|
||||
complexity_metrics (bool): Whether to compute complexity metrics.
|
||||
visualize (bool): Whether to visualize the text.
|
||||
nlp (spacy.lang): The spacy language model to use for visualization.
|
||||
output_dir (str): The directory to save the visualization files to.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the complexity metrics and visualization
|
||||
files serialized in a wandb.Html element.
|
||||
"""
|
||||
resp = {}
|
||||
textstat = import_textstat()
|
||||
wandb = import_wandb()
|
||||
spacy = import_spacy()
|
||||
if complexity_metrics:
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
||||
"smog_index": textstat.smog_index(text),
|
||||
"coleman_liau_index": textstat.coleman_liau_index(text),
|
||||
"automated_readability_index": textstat.automated_readability_index(text),
|
||||
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
|
||||
"difficult_words": textstat.difficult_words(text),
|
||||
"linsear_write_formula": textstat.linsear_write_formula(text),
|
||||
"gunning_fog": textstat.gunning_fog(text),
|
||||
"text_standard": textstat.text_standard(text),
|
||||
"fernandez_huerta": textstat.fernandez_huerta(text),
|
||||
"szigriszt_pazos": textstat.szigriszt_pazos(text),
|
||||
"gutierrez_polini": textstat.gutierrez_polini(text),
|
||||
"crawford": textstat.crawford(text),
|
||||
"gulpease_index": textstat.gulpease_index(text),
|
||||
"osman": textstat.osman(text),
|
||||
}
|
||||
resp.update(text_complexity_metrics)
|
||||
|
||||
if visualize and nlp and output_dir is not None:
|
||||
doc = nlp(text)
|
||||
|
||||
dep_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="dep", jupyter=False, page=True
|
||||
)
|
||||
dep_output_path = Path(output_dir, hash_string(f"dep-{text}") + ".html")
|
||||
dep_output_path.open("w", encoding="utf-8").write(dep_out)
|
||||
|
||||
ent_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="ent", jupyter=False, page=True
|
||||
)
|
||||
ent_output_path = Path(output_dir, hash_string(f"ent-{text}") + ".html")
|
||||
ent_output_path.open("w", encoding="utf-8").write(ent_out)
|
||||
|
||||
text_visualizations = {
|
||||
"dependency_tree": wandb.Html(str(dep_output_path)),
|
||||
"entities": wandb.Html(str(ent_output_path)),
|
||||
}
|
||||
resp.update(text_visualizations)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> Any:
|
||||
"""Construct an html element from a prompt and a generation.
|
||||
|
||||
Parameters:
|
||||
prompt (str): The prompt.
|
||||
generation (str): The generation.
|
||||
|
||||
Returns:
|
||||
(wandb.Html): The html element."""
|
||||
wandb = import_wandb()
|
||||
formatted_prompt = prompt.replace("\n", "<br>")
|
||||
formatted_generation = generation.replace("\n", "<br>")
|
||||
|
||||
return wandb.Html(
|
||||
f"""
|
||||
<p style="color:black;">{formatted_prompt}:</p>
|
||||
<blockquote>
|
||||
<p style="color:green;">
|
||||
{formatted_generation}
|
||||
</p>
|
||||
</blockquote>
|
||||
""",
|
||||
inject=False,
|
||||
)
|
||||
|
||||
|
||||
class BaseMetadataCallbackHandler:
|
||||
"""This class handles the metadata and associated function states for callbacks.
|
||||
|
||||
Attributes:
|
||||
step (int): The current step.
|
||||
starts (int): The number of times the start method has been called.
|
||||
ends (int): The number of times the end method has been called.
|
||||
errors (int): The number of times the error method has been called.
|
||||
text_ctr (int): The number of times the text method has been called.
|
||||
ignore_llm_ (bool): Whether to ignore llm callbacks.
|
||||
ignore_chain_ (bool): Whether to ignore chain callbacks.
|
||||
ignore_agent_ (bool): Whether to ignore agent callbacks.
|
||||
always_verbose_ (bool): Whether to always be verbose.
|
||||
chain_starts (int): The number of times the chain start method has been called.
|
||||
chain_ends (int): The number of times the chain end method has been called.
|
||||
llm_starts (int): The number of times the llm start method has been called.
|
||||
llm_ends (int): The number of times the llm end method has been called.
|
||||
llm_streams (int): The number of times the text method has been called.
|
||||
tool_starts (int): The number of times the tool start method has been called.
|
||||
tool_ends (int): The number of times the tool end method has been called.
|
||||
agent_ends (int): The number of times the agent end method has been called.
|
||||
on_llm_start_records (list): A list of records of the on_llm_start method.
|
||||
on_llm_token_records (list): A list of records of the on_llm_token method.
|
||||
on_llm_end_records (list): A list of records of the on_llm_end method.
|
||||
on_chain_start_records (list): A list of records of the on_chain_start method.
|
||||
on_chain_end_records (list): A list of records of the on_chain_end method.
|
||||
on_tool_start_records (list): A list of records of the on_tool_start method.
|
||||
on_tool_end_records (list): A list of records of the on_tool_end method.
|
||||
on_agent_finish_records (list): A list of records of the on_agent_end method.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.step = 0
|
||||
|
||||
self.starts = 0
|
||||
self.ends = 0
|
||||
self.errors = 0
|
||||
self.text_ctr = 0
|
||||
|
||||
self.ignore_llm_ = False
|
||||
self.ignore_chain_ = False
|
||||
self.ignore_agent_ = False
|
||||
self.always_verbose_ = False
|
||||
|
||||
self.chain_starts = 0
|
||||
self.chain_ends = 0
|
||||
|
||||
self.llm_starts = 0
|
||||
self.llm_ends = 0
|
||||
self.llm_streams = 0
|
||||
|
||||
self.tool_starts = 0
|
||||
self.tool_ends = 0
|
||||
|
||||
self.agent_ends = 0
|
||||
|
||||
self.on_llm_start_records: list = []
|
||||
self.on_llm_token_records: list = []
|
||||
self.on_llm_end_records: list = []
|
||||
|
||||
self.on_chain_start_records: list = []
|
||||
self.on_chain_end_records: list = []
|
||||
|
||||
self.on_tool_start_records: list = []
|
||||
self.on_tool_end_records: list = []
|
||||
|
||||
self.on_text_records: list = []
|
||||
self.on_agent_finish_records: list = []
|
||||
self.on_agent_action_records: list = []
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return self.always_verbose_
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return self.ignore_llm_
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return self.ignore_chain_
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
def get_custom_callback_meta(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"step": self.step,
|
||||
"starts": self.starts,
|
||||
"ends": self.ends,
|
||||
"errors": self.errors,
|
||||
"text_ctr": self.text_ctr,
|
||||
"chain_starts": self.chain_starts,
|
||||
"chain_ends": self.chain_ends,
|
||||
"llm_starts": self.llm_starts,
|
||||
"llm_ends": self.llm_ends,
|
||||
"llm_streams": self.llm_streams,
|
||||
"tool_starts": self.tool_starts,
|
||||
"tool_ends": self.tool_ends,
|
||||
"agent_ends": self.agent_ends,
|
||||
}
|
||||
|
||||
def reset_callback_meta(self) -> None:
|
||||
"""Reset the callback metadata."""
|
||||
self.step = 0
|
||||
|
||||
self.starts = 0
|
||||
self.ends = 0
|
||||
self.errors = 0
|
||||
self.text_ctr = 0
|
||||
|
||||
self.ignore_llm_ = False
|
||||
self.ignore_chain_ = False
|
||||
self.ignore_agent_ = False
|
||||
self.always_verbose_ = False
|
||||
|
||||
self.chain_starts = 0
|
||||
self.chain_ends = 0
|
||||
|
||||
self.llm_starts = 0
|
||||
self.llm_ends = 0
|
||||
self.llm_streams = 0
|
||||
|
||||
self.tool_starts = 0
|
||||
self.tool_ends = 0
|
||||
|
||||
self.agent_ends = 0
|
||||
|
||||
self.on_llm_start_records = []
|
||||
self.on_llm_token_records = []
|
||||
self.on_llm_end_records = []
|
||||
|
||||
self.on_chain_start_records = []
|
||||
self.on_chain_end_records = []
|
||||
|
||||
self.on_tool_start_records = []
|
||||
self.on_tool_end_records = []
|
||||
|
||||
self.on_text_records = []
|
||||
self.on_agent_finish_records = []
|
||||
self.on_agent_action_records = []
|
||||
return None
|
||||
|
||||
|
||||
class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
"""Callback Handler that logs to Weights and Biases.
|
||||
|
||||
Parameters:
|
||||
job_type (str): The type of job.
|
||||
project (str): The project to log to.
|
||||
entity (str): The entity to log to.
|
||||
tags (list): The tags to log.
|
||||
group (str): The group to log to.
|
||||
name (str): The name of the run.
|
||||
notes (str): The notes to log.
|
||||
visualize (bool): Whether to visualize the run.
|
||||
complexity_metrics (bool): Whether to log complexity metrics.
|
||||
stream_logs (bool): Whether to stream callback actions to W&B
|
||||
|
||||
This handler will utilize the associated callback method called and formats
|
||||
the input of each callback function with metadata regarding the state of LLM run,
|
||||
and adds the response to the list of records for both the {method}_records and
|
||||
action. It then logs the response using the run.log() method to Weights and Biases.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
job_type: Optional[str] = None,
|
||||
project: Optional[str] = "langchain_callback_demo",
|
||||
entity: Optional[str] = None,
|
||||
tags: Optional[Sequence] = None,
|
||||
group: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
notes: Optional[str] = None,
|
||||
visualize: bool = False,
|
||||
complexity_metrics: bool = False,
|
||||
stream_logs: bool = False,
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
|
||||
wandb = import_wandb()
|
||||
import_pandas()
|
||||
import_textstat()
|
||||
spacy = import_spacy()
|
||||
super().__init__()
|
||||
|
||||
self.job_type = job_type
|
||||
self.project = project
|
||||
self.entity = entity
|
||||
self.tags = tags
|
||||
self.group = group
|
||||
self.name = name
|
||||
self.notes = notes
|
||||
self.visualize = visualize
|
||||
self.complexity_metrics = complexity_metrics
|
||||
self.stream_logs = stream_logs
|
||||
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.run: wandb.sdk.wandb_run.Run = wandb.init( # type: ignore
|
||||
job_type=self.job_type,
|
||||
project=self.project,
|
||||
entity=self.entity,
|
||||
tags=self.tags,
|
||||
group=self.group,
|
||||
name=self.name,
|
||||
notes=self.notes,
|
||||
)
|
||||
warning = (
|
||||
"The wandb callback is currently in beta and is subject to change "
|
||||
"based on updates to `langchain`. Please report any issues to "
|
||||
"https://github.com/wandb/wandb/issues with the tag `langchain`."
|
||||
)
|
||||
wandb.termwarn(
|
||||
warning,
|
||||
repeat=False,
|
||||
)
|
||||
self.callback_columns: list = []
|
||||
self.action_records: list = []
|
||||
self.complexity_metrics = complexity_metrics
|
||||
self.visualize = visualize
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
|
||||
def _init_resp(self) -> Dict:
|
||||
return {k: None for k in self.callback_columns}
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
self.step += 1
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_llm_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
for prompt in prompts:
|
||||
prompt_resp = deepcopy(resp)
|
||||
prompt_resp["prompts"] = prompt
|
||||
self.on_llm_start_records.append(prompt_resp)
|
||||
self.action_records.append(prompt_resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(prompt_resp)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.step += 1
|
||||
self.llm_streams += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_llm_new_token", "token": token})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_llm_token_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.step += 1
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_llm_end"})
|
||||
resp.update(flatten_dict(response.llm_output or {}))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
for generations in response.generations:
|
||||
for generation in generations:
|
||||
generation_resp = deepcopy(resp)
|
||||
generation_resp.update(flatten_dict(generation.dict()))
|
||||
generation_resp.update(
|
||||
analyze_text(
|
||||
generation.text,
|
||||
complexity_metrics=self.complexity_metrics,
|
||||
visualize=self.visualize,
|
||||
nlp=self.nlp,
|
||||
output_dir=self.temp_dir.name,
|
||||
)
|
||||
)
|
||||
self.on_llm_end_records.append(generation_resp)
|
||||
self.action_records.append(generation_resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(generation_resp)
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.step += 1
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_chain_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
chain_input = inputs["input"]
|
||||
|
||||
if isinstance(chain_input, str):
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp["input"] = chain_input
|
||||
self.on_chain_start_records.append(input_resp)
|
||||
self.action_records.append(input_resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(input_resp)
|
||||
elif isinstance(chain_input, list):
|
||||
for inp in chain_input:
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp.update(inp)
|
||||
self.on_chain_start_records.append(input_resp)
|
||||
self.action_records.append(input_resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(input_resp)
|
||||
else:
|
||||
raise ValueError("Unexpected data format provided!")
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.step += 1
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_chain_end", "outputs": outputs["output"]})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_chain_end_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_tool_start", "input_str": input_str})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_tool_start_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.step += 1
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_tool_end", "output": output})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_tool_end_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when agent is ending.
|
||||
"""
|
||||
self.step += 1
|
||||
self.text_ctr += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_text", "text": text})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_text_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.step += 1
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_finish",
|
||||
"output": finish.return_values["output"],
|
||||
"log": finish.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_agent_finish_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_action",
|
||||
"tool": action.tool,
|
||||
"tool_input": action.tool_input,
|
||||
"log": action.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
self.on_agent_action_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def _create_session_analysis_df(self) -> Any:
|
||||
"""Create a dataframe with all the information from the session."""
|
||||
pd = import_pandas()
|
||||
on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
|
||||
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
|
||||
|
||||
llm_input_prompts_df = (
|
||||
on_llm_start_records_df[["step", "prompts", "name"]]
|
||||
.dropna(axis=1)
|
||||
.rename({"step": "prompt_step"}, axis=1)
|
||||
)
|
||||
complexity_metrics_columns = []
|
||||
visualizations_columns = []
|
||||
|
||||
if self.complexity_metrics:
|
||||
complexity_metrics_columns = [
|
||||
"flesch_reading_ease",
|
||||
"flesch_kincaid_grade",
|
||||
"smog_index",
|
||||
"coleman_liau_index",
|
||||
"automated_readability_index",
|
||||
"dale_chall_readability_score",
|
||||
"difficult_words",
|
||||
"linsear_write_formula",
|
||||
"gunning_fog",
|
||||
"text_standard",
|
||||
"fernandez_huerta",
|
||||
"szigriszt_pazos",
|
||||
"gutierrez_polini",
|
||||
"crawford",
|
||||
"gulpease_index",
|
||||
"osman",
|
||||
]
|
||||
|
||||
if self.visualize:
|
||||
visualizations_columns = ["dependency_tree", "entities"]
|
||||
|
||||
llm_outputs_df = (
|
||||
on_llm_end_records_df[
|
||||
[
|
||||
"step",
|
||||
"text",
|
||||
"token_usage_total_tokens",
|
||||
"token_usage_prompt_tokens",
|
||||
"token_usage_completion_tokens",
|
||||
]
|
||||
+ complexity_metrics_columns
|
||||
+ visualizations_columns
|
||||
]
|
||||
.dropna(axis=1)
|
||||
.rename({"step": "output_step", "text": "output"}, axis=1)
|
||||
)
|
||||
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
|
||||
session_analysis_df["chat_html"] = session_analysis_df[
|
||||
["prompts", "output"]
|
||||
].apply(
|
||||
lambda row: construct_html_from_prompt_and_generation(
|
||||
row["prompts"], row["output"]
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
return session_analysis_df
|
||||
|
||||
def flush_tracker(
|
||||
self,
|
||||
langchain_asset: Any = None,
|
||||
reset: bool = True,
|
||||
finish: bool = False,
|
||||
job_type: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
entity: Optional[str] = None,
|
||||
tags: Optional[Sequence] = None,
|
||||
group: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
notes: Optional[str] = None,
|
||||
visualize: Optional[bool] = None,
|
||||
complexity_metrics: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""Flush the tracker and reset the session.
|
||||
|
||||
Args:
|
||||
langchain_asset: The langchain asset to save.
|
||||
reset: Whether to reset the session.
|
||||
finish: Whether to finish the run.
|
||||
job_type: The job type.
|
||||
project: The project.
|
||||
entity: The entity.
|
||||
tags: The tags.
|
||||
group: The group.
|
||||
name: The name.
|
||||
notes: The notes.
|
||||
visualize: Whether to visualize.
|
||||
complexity_metrics: Whether to compute complexity metrics.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pd = import_pandas()
|
||||
wandb = import_wandb()
|
||||
action_records_table = wandb.Table(dataframe=pd.DataFrame(self.action_records))
|
||||
session_analysis_table = wandb.Table(
|
||||
dataframe=self._create_session_analysis_df()
|
||||
)
|
||||
self.run.log(
|
||||
{
|
||||
"action_records": action_records_table,
|
||||
"session_analysis": session_analysis_table,
|
||||
}
|
||||
)
|
||||
|
||||
if langchain_asset:
|
||||
langchain_asset_path = Path(self.temp_dir.name, "model.json")
|
||||
model_artifact = wandb.Artifact(name="model", type="model")
|
||||
model_artifact.add(action_records_table, name="action_records")
|
||||
model_artifact.add(session_analysis_table, name="session_analysis")
|
||||
try:
|
||||
langchain_asset.save(langchain_asset_path)
|
||||
model_artifact.add_file(str(langchain_asset_path))
|
||||
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
|
||||
except ValueError:
|
||||
langchain_asset.save_agent(langchain_asset_path)
|
||||
model_artifact.add_file(str(langchain_asset_path))
|
||||
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
|
||||
except NotImplementedError as e:
|
||||
print("Could not save model.")
|
||||
print(repr(e))
|
||||
pass
|
||||
self.run.log_artifact(model_artifact)
|
||||
|
||||
if finish or reset:
|
||||
self.run.finish()
|
||||
self.temp_dir.cleanup()
|
||||
self.reset_callback_meta()
|
||||
if reset:
|
||||
self.__init__( # type: ignore
|
||||
job_type=job_type if job_type else self.job_type,
|
||||
project=project if project else self.project,
|
||||
entity=entity if entity else self.entity,
|
||||
tags=tags if tags else self.tags,
|
||||
group=group if group else self.group,
|
||||
name=name if name else self.name,
|
||||
notes=notes if notes else self.notes,
|
||||
visualize=visualize if visualize else self.visualize,
|
||||
complexity_metrics=complexity_metrics
|
||||
if complexity_metrics
|
||||
else self.complexity_metrics,
|
||||
)
|
||||
@@ -1,9 +1,12 @@
|
||||
"""Chains are easily reusable components which can be linked together."""
|
||||
from langchain.chains.api.base import APIChain
|
||||
from langchain.chains.chat_vector_db.base import ChatVectorDBChain
|
||||
from langchain.chains.combine_documents.base import AnalyzeDocumentChain
|
||||
from langchain.chains.constitutional_ai.base import ConstitutionalChain
|
||||
from langchain.chains.conversation.base import ConversationChain
|
||||
from langchain.chains.conversational_retrieval.base import (
|
||||
ChatVectorDBChain,
|
||||
ConversationalRetrievalChain,
|
||||
)
|
||||
from langchain.chains.graph_qa.base import GraphQAChain
|
||||
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
|
||||
from langchain.chains.llm import LLMChain
|
||||
@@ -18,14 +21,15 @@ from langchain.chains.moderation import OpenAIModerationChain
|
||||
from langchain.chains.pal.base import PALChain
|
||||
from langchain.chains.qa_generation.base import QAGenerationChain
|
||||
from langchain.chains.qa_with_sources.base import QAWithSourcesChain
|
||||
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
|
||||
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
|
||||
from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA
|
||||
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
||||
from langchain.chains.sql_database.base import (
|
||||
SQLDatabaseChain,
|
||||
SQLDatabaseSequentialChain,
|
||||
)
|
||||
from langchain.chains.transform import TransformChain
|
||||
from langchain.chains.vector_db_qa.base import VectorDBQA
|
||||
|
||||
__all__ = [
|
||||
"ConversationChain",
|
||||
@@ -54,4 +58,7 @@ __all__ = [
|
||||
"GraphQAChain",
|
||||
"ConstitutionalChain",
|
||||
"QAGenerationChain",
|
||||
"RetrievalQA",
|
||||
"RetrievalQAWithSourcesChain",
|
||||
"ConversationalRetrievalChain",
|
||||
]
|
||||
|
||||
@@ -8,9 +8,9 @@ from pydantic import BaseModel, Field, root_validator
|
||||
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.requests import RequestsWrapper
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class APIChain(Chain, BaseModel):
|
||||
@@ -84,7 +84,7 @@ class APIChain(Chain, BaseModel):
|
||||
@classmethod
|
||||
def from_llm_and_api_docs(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
api_docs: str,
|
||||
headers: Optional[dict] = None,
|
||||
api_url_prompt: BasePromptTemplate = API_URL_PROMPT,
|
||||
|
||||
28
langchain/chains/api/podcast_docs.py
Normal file
28
langchain/chains/api/podcast_docs.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# flake8: noqa
|
||||
PODCAST_DOCS = """API documentation:
|
||||
Endpoint: https://listen-api.listennotes.com/api/v2
|
||||
GET /search
|
||||
|
||||
This API is for searching podcasts or episodes.
|
||||
|
||||
Query parameters table:
|
||||
q | string | Search term, e.g., person, place, topic... You can use double quotes to do verbatim match, e.g., "game of thrones". Otherwise, it's fuzzy search. | required
|
||||
type | string | What type of contents do you want to search for? Available values: episode, podcast, curated. default: episode | optional
|
||||
page_size | integer | The maximum number of search results per page. A valid value should be an integer between 1 and 10 (inclusive). default: 3 | optional
|
||||
language | string | Limit search results to a specific language, e.g., English, Chinese ... If not specified, it'll be any language. It works only when type is episode or podcast. | optional
|
||||
region | string | Limit search results to a specific region (e.g., us, gb, in...). If not specified, it'll be any region. It works only when type is episode or podcast. | optional
|
||||
len_min | integer | Minimum audio length in minutes. Applicable only when type parameter is episode or podcast. If type parameter is episode, it's for audio length of an episode. If type parameter is podcast, it's for average audio length of all episodes in a podcast. | optional
|
||||
len_max | integer | Maximum audio length in minutes. Applicable only when type parameter is episode or podcast. If type parameter is episode, it's for audio length of an episode. If type parameter is podcast, it's for average audio length of all episodes in a podcast. | optional
|
||||
|
||||
Response schema (JSON object):
|
||||
next_offset | integer | optional
|
||||
total | integer | optional
|
||||
results | array[object] (Episode / Podcast List Result Object)
|
||||
|
||||
Each object in the "results" key has the following schema:
|
||||
listennotes_url | string | optional
|
||||
id | integer | optional
|
||||
title_highlighted | string | optional
|
||||
|
||||
Use page_size: 3
|
||||
"""
|
||||
@@ -1,12 +1,13 @@
|
||||
"""Chain for applying constitutional principles to the outputs of another chain."""
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||
from langchain.chains.constitutional_ai.principles import PRINCIPLES
|
||||
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class ConstitutionalChain(Chain):
|
||||
@@ -42,10 +43,19 @@ class ConstitutionalChain(Chain):
|
||||
critique_chain: LLMChain
|
||||
revision_chain: LLMChain
|
||||
|
||||
@classmethod
|
||||
def get_principles(
|
||||
cls, names: Optional[List[str]] = None
|
||||
) -> List[ConstitutionalPrinciple]:
|
||||
if names is None:
|
||||
return list(PRINCIPLES.values())
|
||||
else:
|
||||
return [PRINCIPLES[name] for name in names]
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
chain: LLMChain,
|
||||
critique_prompt: BasePromptTemplate = CRITIQUE_PROMPT,
|
||||
revision_prompt: BasePromptTemplate = REVISION_PROMPT,
|
||||
|
||||
5
langchain/chains/constitutional_ai/principles.py
Normal file
5
langchain/chains/constitutional_ai/principles.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# flake8: noqa
|
||||
from typing import Dict
|
||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||
|
||||
PRINCIPLES: Dict[str, ConstitutionalPrinciple] = {}
|
||||
@@ -1,18 +1,19 @@
|
||||
"""Chain for chatting with a vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.schema import BaseLanguageModel, BaseRetriever, Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
@@ -25,21 +26,22 @@ def _get_chat_history(chat_history: List[Tuple[str, str]]) -> str:
|
||||
return buffer
|
||||
|
||||
|
||||
class ChatVectorDBChain(Chain, BaseModel):
|
||||
"""Chain for chatting with a vector database."""
|
||||
class BaseConversationalRetrievalChain(Chain, BaseModel):
|
||||
"""Chain for chatting with an index."""
|
||||
|
||||
vectorstore: VectorStore
|
||||
combine_docs_chain: BaseCombineDocumentsChain
|
||||
question_generator: LLMChain
|
||||
output_key: str = "answer"
|
||||
return_source_documents: bool = False
|
||||
top_k_docs_for_context: int = 4
|
||||
get_chat_history: Optional[Callable[[Tuple[str, str]], str]] = None
|
||||
"""Return the source documents."""
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "chat-vector-db"
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
@@ -57,44 +59,22 @@ class ChatVectorDBChain(Chain, BaseModel):
|
||||
_output_keys = _output_keys + ["source_documents"]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
vectorstore: VectorStore,
|
||||
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
|
||||
qa_prompt: Optional[BasePromptTemplate] = None,
|
||||
chain_type: str = "stuff",
|
||||
**kwargs: Any,
|
||||
) -> ChatVectorDBChain:
|
||||
"""Load chain from LLM."""
|
||||
doc_chain = load_qa_chain(
|
||||
llm,
|
||||
chain_type=chain_type,
|
||||
prompt=qa_prompt,
|
||||
)
|
||||
condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt)
|
||||
return cls(
|
||||
vectorstore=vectorstore,
|
||||
combine_docs_chain=doc_chain,
|
||||
question_generator=condense_question_chain,
|
||||
**kwargs,
|
||||
)
|
||||
@abstractmethod
|
||||
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""Get docs."""
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
question = inputs["question"]
|
||||
get_chat_history = self.get_chat_history or _get_chat_history
|
||||
chat_history_str = get_chat_history(inputs["chat_history"])
|
||||
vectordbkwargs = inputs.get("vectordbkwargs", {})
|
||||
|
||||
if chat_history_str:
|
||||
new_question = self.question_generator.run(
|
||||
question=question, chat_history=chat_history_str
|
||||
)
|
||||
else:
|
||||
new_question = question
|
||||
docs = self.vectorstore.similarity_search(
|
||||
new_question, k=self.top_k_docs_for_context, **vectordbkwargs
|
||||
)
|
||||
docs = self._get_docs(new_question, inputs)
|
||||
new_inputs = inputs.copy()
|
||||
new_inputs["question"] = new_question
|
||||
new_inputs["chat_history"] = chat_history_str
|
||||
@@ -108,7 +88,6 @@ class ChatVectorDBChain(Chain, BaseModel):
|
||||
question = inputs["question"]
|
||||
get_chat_history = self.get_chat_history or _get_chat_history
|
||||
chat_history_str = get_chat_history(inputs["chat_history"])
|
||||
vectordbkwargs = inputs.get("vectordbkwargs", {})
|
||||
if chat_history_str:
|
||||
new_question = await self.question_generator.arun(
|
||||
question=question, chat_history=chat_history_str
|
||||
@@ -116,9 +95,7 @@ class ChatVectorDBChain(Chain, BaseModel):
|
||||
else:
|
||||
new_question = question
|
||||
# TODO: This blocks the event loop, but it's not clear how to avoid it.
|
||||
docs = self.vectorstore.similarity_search(
|
||||
new_question, k=self.top_k_docs_for_context, **vectordbkwargs
|
||||
)
|
||||
docs = self._get_docs(new_question, inputs)
|
||||
new_inputs = inputs.copy()
|
||||
new_inputs["question"] = new_question
|
||||
new_inputs["chat_history"] = chat_history_str
|
||||
@@ -132,3 +109,79 @@ class ChatVectorDBChain(Chain, BaseModel):
|
||||
if self.get_chat_history:
|
||||
raise ValueError("Chain not savable when `get_chat_history` is not None.")
|
||||
super().save(file_path)
|
||||
|
||||
|
||||
class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel):
|
||||
"""Chain for chatting with an index."""
|
||||
|
||||
retriever: BaseRetriever
|
||||
|
||||
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
return self.retriever.get_relevant_texts(question)
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
retriever: BaseRetriever,
|
||||
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
|
||||
qa_prompt: Optional[BasePromptTemplate] = None,
|
||||
chain_type: str = "stuff",
|
||||
**kwargs: Any,
|
||||
) -> BaseConversationalRetrievalChain:
|
||||
"""Load chain from LLM."""
|
||||
doc_chain = load_qa_chain(
|
||||
llm,
|
||||
chain_type=chain_type,
|
||||
prompt=qa_prompt,
|
||||
)
|
||||
condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt)
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
combine_docs_chain=doc_chain,
|
||||
question_generator=condense_question_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class ChatVectorDBChain(BaseConversationalRetrievalChain, BaseModel):
|
||||
"""Chain for chatting with a vector database."""
|
||||
|
||||
vectorstore: VectorStore = Field(alias="vectorstore")
|
||||
top_k_docs_for_context: int = 4
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "chat-vector-db"
|
||||
|
||||
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
vectordbkwargs = inputs.get("vectordbkwargs", {})
|
||||
full_kwargs = {**self.search_kwargs, **vectordbkwargs}
|
||||
return self.vectorstore.similarity_search(
|
||||
question, k=self.top_k_docs_for_context, **full_kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
vectorstore: VectorStore,
|
||||
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
|
||||
qa_prompt: Optional[BasePromptTemplate] = None,
|
||||
chain_type: str = "stuff",
|
||||
**kwargs: Any,
|
||||
) -> BaseConversationalRetrievalChain:
|
||||
"""Load chain from LLM."""
|
||||
doc_chain = load_qa_chain(
|
||||
llm,
|
||||
chain_type=chain_type,
|
||||
prompt=qa_prompt,
|
||||
)
|
||||
condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt)
|
||||
return cls(
|
||||
vectorstore=vectorstore,
|
||||
combine_docs_chain=doc_chain,
|
||||
question_generator=condense_question_chain,
|
||||
**kwargs,
|
||||
)
|
||||
20
langchain/chains/conversational_retrieval/prompts.py
Normal file
20
langchain/chains/conversational_retrieval/prompts.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# flake8: noqa
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
|
||||
|
||||
Chat History:
|
||||
{chat_history}
|
||||
Follow Up Input: {question}
|
||||
Standalone question:"""
|
||||
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
|
||||
|
||||
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
|
||||
{context}
|
||||
|
||||
Question: {question}
|
||||
Helpful Answer:"""
|
||||
QA_PROMPT = PromptTemplate(
|
||||
template=prompt_template, input_variables=["context", "question"]
|
||||
)
|
||||
@@ -6,8 +6,8 @@ from pydantic import BaseModel, Extra
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_bash.prompt import PROMPT
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.utilities.bash import BashProcess
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class LLMBashChain(Chain, BaseModel):
|
||||
llm_bash = LLMBashChain(llm=OpenAI())
|
||||
"""
|
||||
|
||||
llm: BaseLLM
|
||||
llm: BaseLanguageModel
|
||||
"""LLM wrapper to use."""
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
@@ -1,26 +1,17 @@
|
||||
# flake8: noqa
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
_PROMPT_TEMPLATE = """You are GPT-3, and you can't do math.
|
||||
_PROMPT_TEMPLATE = """Translate a math problem into Python code that can be executed in Python 3 REPL. Use the output of running this code to answer the question.
|
||||
|
||||
You can do basic math, and your memorization abilities are impressive, but you can't do any complex calculations that a human could not do in their head. You also have an annoying tendency to just make up highly specific, but wrong, answers.
|
||||
|
||||
So we hooked you up to a Python 3 kernel, and now you can execute code. If anyone gives you a hard math problem, just use this format and we’ll take care of the rest:
|
||||
|
||||
Question: ${{Question with hard calculation.}}
|
||||
Question: ${{Question with math problem.}}
|
||||
```python
|
||||
${{Code that prints what you need to know}}
|
||||
${{Code that solves the problem and prints the solution}}
|
||||
```
|
||||
```output
|
||||
${{Output of your code}}
|
||||
${{Output of running the code}}
|
||||
```
|
||||
Answer: ${{Answer}}
|
||||
|
||||
Otherwise, use this simpler format:
|
||||
|
||||
Question: ${{Question without hard calculation}}
|
||||
Answer: ${{Answer}}
|
||||
|
||||
Begin.
|
||||
|
||||
Question: What is 37593 * 67?
|
||||
|
||||
@@ -20,8 +20,8 @@ from langchain.chains.llm_requests import LLMRequestsChain
|
||||
from langchain.chains.pal.base import PALChain
|
||||
from langchain.chains.qa_with_sources.base import QAWithSourcesChain
|
||||
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
|
||||
from langchain.chains.retrieval_qa.base import VectorDBQA
|
||||
from langchain.chains.sql_database.base import SQLDatabaseChain
|
||||
from langchain.chains.vector_db_qa.base import VectorDBQA
|
||||
from langchain.llms.loading import load_llm, load_llm_from_config
|
||||
from langchain.prompts.loading import load_prompt, load_prompt_from_config
|
||||
from langchain.utilities.loading import try_load_from_hub
|
||||
|
||||
@@ -12,15 +12,15 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
|
||||
from langchain.chains.pal.math_prompt import MATH_PROMPT
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.python import PythonREPL
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class PALChain(Chain, BaseModel):
|
||||
"""Implements Program-Aided Language Models."""
|
||||
|
||||
llm: BaseLLM
|
||||
llm: BaseLanguageModel
|
||||
prompt: BasePromptTemplate
|
||||
stop: str = "\n\n"
|
||||
get_answer_expr: str = "print(solution())"
|
||||
@@ -68,7 +68,7 @@ class PALChain(Chain, BaseModel):
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def from_math_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain:
|
||||
def from_math_prompt(cls, llm: BaseLanguageModel, **kwargs: Any) -> PALChain:
|
||||
"""Load PAL from math prompt."""
|
||||
return cls(
|
||||
llm=llm,
|
||||
@@ -79,7 +79,9 @@ class PALChain(Chain, BaseModel):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_colored_object_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain:
|
||||
def from_colored_object_prompt(
|
||||
cls, llm: BaseLanguageModel, **kwargs: Any
|
||||
) -> PALChain:
|
||||
"""Load PAL from colored object prompt."""
|
||||
return cls(
|
||||
llm=llm,
|
||||
|
||||
@@ -19,8 +19,8 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
|
||||
QUESTION_PROMPT,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
||||
@@ -38,7 +38,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
document_prompt: BasePromptTemplate = EXAMPLE_PROMPT,
|
||||
question_prompt: BasePromptTemplate = QUESTION_PROMPT,
|
||||
combine_prompt: BasePromptTemplate = COMBINE_PROMPT,
|
||||
@@ -65,7 +65,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
||||
@classmethod
|
||||
def from_chain_type(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str = "stuff",
|
||||
chain_type_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
|
||||
@@ -13,19 +13,21 @@ from langchain.chains.qa_with_sources import (
|
||||
stuff_prompt,
|
||||
)
|
||||
from langchain.chains.question_answering import map_rerank_prompt
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
"""Interface for loading the combine documents chain."""
|
||||
|
||||
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
||||
def __call__(
|
||||
self, llm: BaseLanguageModel, **kwargs: Any
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Callable to load the combine documents chain."""
|
||||
|
||||
|
||||
def _load_map_rerank_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
|
||||
verbose: bool = False,
|
||||
document_variable_name: str = "context",
|
||||
@@ -44,7 +46,7 @@ def _load_map_rerank_chain(
|
||||
|
||||
|
||||
def _load_stuff_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||
document_prompt: BasePromptTemplate = stuff_prompt.EXAMPLE_PROMPT,
|
||||
document_variable_name: str = "summaries",
|
||||
@@ -62,15 +64,15 @@ def _load_stuff_chain(
|
||||
|
||||
|
||||
def _load_map_reduce_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
|
||||
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
|
||||
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
|
||||
combine_document_variable_name: str = "summaries",
|
||||
map_reduce_document_variable_name: str = "context",
|
||||
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||
reduce_llm: Optional[BaseLLM] = None,
|
||||
collapse_llm: Optional[BaseLLM] = None,
|
||||
reduce_llm: Optional[BaseLanguageModel] = None,
|
||||
collapse_llm: Optional[BaseLanguageModel] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
@@ -112,13 +114,13 @@ def _load_map_reduce_chain(
|
||||
|
||||
|
||||
def _load_refine_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
|
||||
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
|
||||
document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT,
|
||||
document_variable_name: str = "context_str",
|
||||
initial_response_name: str = "existing_answer",
|
||||
refine_llm: Optional[BaseLLM] = None,
|
||||
refine_llm: Optional[BaseLanguageModel] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> RefineDocumentsChain:
|
||||
@@ -137,7 +139,7 @@ def _load_refine_chain(
|
||||
|
||||
|
||||
def load_qa_with_sources_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str = "stuff",
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
|
||||
46
langchain/chains/qa_with_sources/retrieval.py
Normal file
46
langchain/chains/qa_with_sources/retrieval.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Question-answering with sources over an index."""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
|
||||
class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
|
||||
"""Question-answering with sources over an index."""
|
||||
|
||||
retriever: BaseRetriever = Field(exclude=True)
|
||||
"""Index to connect to."""
|
||||
reduce_k_below_max_tokens: bool = False
|
||||
"""Reduce the number of results to return from store based on tokens limit"""
|
||||
max_tokens_limit: int = 3375
|
||||
"""Restrict the docs to return from store based on tokens,
|
||||
enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""
|
||||
|
||||
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
|
||||
num_docs = len(docs)
|
||||
|
||||
if self.reduce_k_below_max_tokens and isinstance(
|
||||
self.combine_documents_chain, StuffDocumentsChain
|
||||
):
|
||||
tokens = [
|
||||
self.combine_documents_chain.llm_chain.llm.get_num_tokens(
|
||||
doc.page_content
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
token_count = sum(tokens[:num_docs])
|
||||
while token_count > self.max_tokens_limit:
|
||||
num_docs -= 1
|
||||
token_count -= tokens[num_docs]
|
||||
|
||||
return docs[:num_docs]
|
||||
|
||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
question = inputs[self.question_key]
|
||||
docs = self.retriever.get_relevant_texts(question)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
@@ -155,7 +155,7 @@ def _load_refine_chain(
|
||||
**kwargs: Any,
|
||||
) -> RefineDocumentsChain:
|
||||
_question_prompt = (
|
||||
question_prompt or refine_prompts.REFINE_PROMPT_SELECTOR.get_prompt(llm)
|
||||
question_prompt or refine_prompts.QUESTION_PROMPT_SELECTOR.get_prompt(llm)
|
||||
)
|
||||
_refine_prompt = refine_prompt or refine_prompts.REFINE_PROMPT_SELECTOR.get_prompt(
|
||||
llm
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Chain for question-answering against a vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
@@ -11,44 +12,25 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, BaseRetriever, Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
class VectorDBQA(Chain, BaseModel):
|
||||
"""Chain for question-answering against a vector database.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import OpenAI, VectorDBQA
|
||||
from langchain.faiss import FAISS
|
||||
vectordb = FAISS(...)
|
||||
vectordbQA = VectorDBQA(llm=OpenAI(), vectorstore=vectordb)
|
||||
|
||||
"""
|
||||
|
||||
vectorstore: VectorStore = Field(exclude=True)
|
||||
"""Vector Database to connect to."""
|
||||
k: int = 4
|
||||
"""Number of documents to query for."""
|
||||
class BaseRetrievalQA(Chain, BaseModel):
|
||||
combine_documents_chain: BaseCombineDocumentsChain
|
||||
"""Chain to use to combine the documents."""
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
return_source_documents: bool = False
|
||||
"""Return the source documents."""
|
||||
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Extra search args."""
|
||||
search_type: str = "similarity"
|
||||
"""Search type to use over vectorstore. `similarity` or `mmr`."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
@@ -69,42 +51,13 @@ class VectorDBQA(Chain, BaseModel):
|
||||
_output_keys = _output_keys + ["source_documents"]
|
||||
return _output_keys
|
||||
|
||||
# TODO: deprecate this
|
||||
@root_validator(pre=True)
|
||||
def load_combine_documents_chain(cls, values: Dict) -> Dict:
|
||||
"""Validate question chain."""
|
||||
if "combine_documents_chain" not in values:
|
||||
if "llm" not in values:
|
||||
raise ValueError(
|
||||
"If `combine_documents_chain` not provided, `llm` should be."
|
||||
)
|
||||
llm = values.pop("llm")
|
||||
prompt = values.pop("prompt", PROMPT_SELECTOR.get_prompt(llm))
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
document_prompt = PromptTemplate(
|
||||
input_variables=["page_content"], template="Context:\n{page_content}"
|
||||
)
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name="context",
|
||||
document_prompt=document_prompt,
|
||||
)
|
||||
values["combine_documents_chain"] = combine_documents_chain
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_search_type(cls, values: Dict) -> Dict:
|
||||
"""Validate search type."""
|
||||
if "search_type" in values:
|
||||
search_type = values["search_type"]
|
||||
if search_type not in ("similarity", "mmr"):
|
||||
raise ValueError(f"search_type of {search_type} not allowed.")
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls, llm: BaseLLM, prompt: Optional[PromptTemplate] = None, **kwargs: Any
|
||||
) -> VectorDBQA:
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[PromptTemplate] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseRetrievalQA:
|
||||
"""Initialize from LLM."""
|
||||
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
|
||||
llm_chain = LLMChain(llm=llm, prompt=_prompt)
|
||||
@@ -122,11 +75,11 @@ class VectorDBQA(Chain, BaseModel):
|
||||
@classmethod
|
||||
def from_chain_type(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str = "stuff",
|
||||
chain_type_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> VectorDBQA:
|
||||
) -> BaseRetrievalQA:
|
||||
"""Load chain from chain type."""
|
||||
_chain_type_kwargs = chain_type_kwargs or {}
|
||||
combine_documents_chain = load_qa_chain(
|
||||
@@ -134,8 +87,12 @@ class VectorDBQA(Chain, BaseModel):
|
||||
)
|
||||
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _get_docs(self, question: str) -> List[Document]:
|
||||
"""Get documents to do question answering over."""
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""Run similarity search and llm on input query.
|
||||
"""Run get_relevant_text and llm on input query.
|
||||
|
||||
If chain has 'return_source_documents' as 'True', returns
|
||||
the retrieved documents as well under the key 'source_documents'.
|
||||
@@ -143,11 +100,62 @@ class VectorDBQA(Chain, BaseModel):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
res = vectordbqa({'query': 'This is my query'})
|
||||
res = indexqa({'query': 'This is my query'})
|
||||
answer, docs = res['result'], res['source_documents']
|
||||
"""
|
||||
question = inputs[self.input_key]
|
||||
|
||||
docs = self._get_docs(question)
|
||||
answer, _ = self.combine_documents_chain.combine_docs(docs, question=question)
|
||||
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
return {self.output_key: answer}
|
||||
|
||||
|
||||
class RetrievalQA(BaseRetrievalQA, BaseModel):
|
||||
"""Chain for question-answering against an index.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.chains import RetrievalQA
|
||||
from langchain.faiss import FAISS
|
||||
vectordb = FAISS(...)
|
||||
retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=vectordb)
|
||||
|
||||
"""
|
||||
|
||||
retriever: BaseRetriever = Field(exclude=True)
|
||||
|
||||
def _get_docs(self, question: str) -> List[Document]:
|
||||
return self.retriever.get_relevant_texts(question)
|
||||
|
||||
|
||||
class VectorDBQA(BaseRetrievalQA, BaseModel):
|
||||
"""Chain for question-answering against a vector database."""
|
||||
|
||||
vectorstore: VectorStore = Field(exclude=True, alias="vectorstore")
|
||||
"""Vector Database to connect to."""
|
||||
k: int = 4
|
||||
"""Number of documents to query for."""
|
||||
search_type: str = "similarity"
|
||||
"""Search type to use over vectorstore. `similarity` or `mmr`."""
|
||||
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Extra search args."""
|
||||
|
||||
@root_validator()
|
||||
def validate_search_type(cls, values: Dict) -> Dict:
|
||||
"""Validate search type."""
|
||||
if "search_type" in values:
|
||||
search_type = values["search_type"]
|
||||
if search_type not in ("similarity", "mmr"):
|
||||
raise ValueError(f"search_type of {search_type} not allowed.")
|
||||
return values
|
||||
|
||||
def _get_docs(self, question: str) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(
|
||||
question, k=self.k, **self.search_kwargs
|
||||
@@ -158,12 +166,7 @@ class VectorDBQA(Chain, BaseModel):
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
answer, _ = self.combine_documents_chain.combine_docs(docs, question=question)
|
||||
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
return {self.output_key: answer}
|
||||
return docs
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
@@ -8,8 +8,8 @@ from pydantic import BaseModel, Extra, Field
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.sql_database import SQLDatabase
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class SQLDatabaseChain(Chain, BaseModel):
|
||||
db_chain = SQLDatabaseChain(llm=OpenAI(), database=db)
|
||||
"""
|
||||
|
||||
llm: BaseLLM
|
||||
llm: BaseLanguageModel
|
||||
"""LLM wrapper to use."""
|
||||
database: SQLDatabase = Field(exclude=True)
|
||||
"""SQL Database to connect to."""
|
||||
@@ -122,7 +122,7 @@ class SQLDatabaseSequentialChain(Chain, BaseModel):
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
database: SQLDatabase,
|
||||
query_prompt: BasePromptTemplate = PROMPT,
|
||||
decider_prompt: BasePromptTemplate = DECIDER_PROMPT,
|
||||
|
||||
@@ -7,19 +7,21 @@ from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
"""Interface for loading the combine documents chain."""
|
||||
|
||||
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
||||
def __call__(
|
||||
self, llm: BaseLanguageModel, **kwargs: Any
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Callable to load the combine documents chain."""
|
||||
|
||||
|
||||
def _load_stuff_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||
document_variable_name: str = "text",
|
||||
verbose: Optional[bool] = None,
|
||||
@@ -36,14 +38,14 @@ def _load_stuff_chain(
|
||||
|
||||
|
||||
def _load_map_reduce_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
|
||||
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
|
||||
combine_document_variable_name: str = "text",
|
||||
map_reduce_document_variable_name: str = "text",
|
||||
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||
reduce_llm: Optional[BaseLLM] = None,
|
||||
collapse_llm: Optional[BaseLLM] = None,
|
||||
reduce_llm: Optional[BaseLanguageModel] = None,
|
||||
collapse_llm: Optional[BaseLanguageModel] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
@@ -84,12 +86,12 @@ def _load_map_reduce_chain(
|
||||
|
||||
|
||||
def _load_refine_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
question_prompt: BasePromptTemplate = refine_prompts.PROMPT,
|
||||
refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT,
|
||||
document_variable_name: str = "text",
|
||||
initial_response_name: str = "existing_answer",
|
||||
refine_llm: Optional[BaseLLM] = None,
|
||||
refine_llm: Optional[BaseLanguageModel] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> RefineDocumentsChain:
|
||||
@@ -107,7 +109,7 @@ def _load_refine_chain(
|
||||
|
||||
|
||||
def load_summarize_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str = "stuff",
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
|
||||
|
||||
__all__ = ["ChatOpenAI", "PromptLayerChatOpenAI"]
|
||||
__all__ = ["ChatOpenAI", "AzureChatOpenAI", "PromptLayerChatOpenAI"]
|
||||
|
||||
105
langchain/chat_models/azure_openai.py
Normal file
105
langchain/chat_models/azure_openai.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Azure OpenAI chat wrapper."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.chat_models.openai import (
|
||||
ChatOpenAI,
|
||||
)
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class AzureChatOpenAI(ChatOpenAI):
|
||||
"""Wrapper around Azure OpenAI Chat Completion API. To use this class you
|
||||
must have a deployed model on Azure OpenAI. Use `deployment_name` in the
|
||||
constructor to refer to the "Model deployment name" in the Azure portal.
|
||||
|
||||
In addition, you should have the ``openai`` python package installed, and the
|
||||
following environment variables set or passed in constructor in lower case:
|
||||
- ``OPENAI_API_TYPE`` (default: ``azure``)
|
||||
- ``OPENAI_API_KEY``
|
||||
- ``OPENAI_API_BASE``
|
||||
- ``OPENAI_API_VERSION``
|
||||
|
||||
For exmaple, if you have `gpt-35-turbo` deployed, with the deployment name
|
||||
`35-turbo-dev`, the constructor should look like:
|
||||
|
||||
.. code-block:: python
|
||||
AzureChatOpenAI(
|
||||
deployment_name="35-turbo-dev",
|
||||
openai_api_version="2023-03-15-preview",
|
||||
)
|
||||
|
||||
Be aware the API version may change.
|
||||
|
||||
Any parameters that are valid to be passed to the openai.create call can be passed
|
||||
in, even if not explicitly saved on this class.
|
||||
"""
|
||||
|
||||
deployment_name: str = ""
|
||||
openai_api_type: str = "azure"
|
||||
openai_api_base: str = ""
|
||||
openai_api_version: str = ""
|
||||
openai_api_key: str = ""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
openai_api_key = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_key",
|
||||
"OPENAI_API_KEY",
|
||||
)
|
||||
openai_api_base = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_base",
|
||||
"OPENAI_API_BASE",
|
||||
)
|
||||
openai_api_version = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_version",
|
||||
"OPENAI_API_VERSION",
|
||||
)
|
||||
openai_api_type = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_type",
|
||||
"OPENAI_API_TYPE",
|
||||
)
|
||||
try:
|
||||
import openai
|
||||
|
||||
openai.api_type = openai_api_type
|
||||
openai.api_base = openai_api_base
|
||||
openai.api_version = openai_api_version
|
||||
openai.api_key = openai_api_key
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please it install it with `pip install openai`."
|
||||
)
|
||||
try:
|
||||
values["client"] = openai.ChatCompletion
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
"due to an old version of the openai package. Try upgrading it "
|
||||
"with `pip install --upgrade openai`."
|
||||
)
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
**super()._default_params,
|
||||
"engine": self.deployment_name,
|
||||
}
|
||||
@@ -43,19 +43,26 @@ class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
return {}
|
||||
|
||||
def generate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
results = [self._generate(m, stop=stop) for m in messages]
|
||||
return LLMResult(generations=[res.generations for res in results])
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||
generations = [res.generations for res in results]
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
async def agenerate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
results = [await self._agenerate(m, stop=stop) for m in messages]
|
||||
return LLMResult(generations=[res.generations for res in results])
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||
generations = [res.generations for res in results]
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def generate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
|
||||
@@ -97,7 +97,8 @@ def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
gen = ChatGeneration(message=message)
|
||||
generations.append(gen)
|
||||
return ChatResult(generations=generations)
|
||||
llm_output = {"token_usage": response["usage"]}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
|
||||
class ChatOpenAI(BaseChatModel, BaseModel):
|
||||
@@ -122,13 +123,15 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
openai_api_key: Optional[str] = None
|
||||
request_timeout: int = 60
|
||||
"""Timeout in seconds for the OpenAPI request."""
|
||||
max_retries: int = 6
|
||||
"""Maximum number of retries to make when generating."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
n: int = 1
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
max_tokens: int = 256
|
||||
max_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
class Config:
|
||||
@@ -184,6 +187,7 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
"model": self.model_name,
|
||||
"request_timeout": self.request_timeout,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": self.streaming,
|
||||
"n": self.n,
|
||||
@@ -221,6 +225,20 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
# Happens in streaming
|
||||
continue
|
||||
token_usage = output["token_usage"]
|
||||
for k, v in token_usage.items():
|
||||
if k in overall_token_usage:
|
||||
overall_token_usage[k] += v
|
||||
else:
|
||||
overall_token_usage[k] = v
|
||||
return {"token_usage": overall_token_usage}
|
||||
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
@@ -317,3 +335,41 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_num_tokens_from_messages(
|
||||
self, messages: List[BaseMessage], model: str = "gpt-3.5-turbo-0301"
|
||||
) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo with tiktoken package."""
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to calculate get_num_tokens. "
|
||||
"Please it install it with `pip install tiktoken`."
|
||||
)
|
||||
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
if model == "gpt-3.5-turbo-0301": # note: future models may deviate from this
|
||||
num_tokens = 0
|
||||
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
num_tokens += 4
|
||||
for key, value in message.items():
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name": # if there's a name, the role is omitted
|
||||
num_tokens += -1 # role is always required and always 1 token
|
||||
num_tokens += 2 # every reply is primed with <im_start>assistant
|
||||
return num_tokens
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
|
||||
@@ -17,8 +17,12 @@ class PromptLayerChatOpenAI(ChatOpenAI, BaseModel):
|
||||
promptlayer key respectively.
|
||||
|
||||
All parameters that can be passed to the OpenAI LLM can also
|
||||
be passed here. The PromptLayerChatOpenAI LLM adds an extra
|
||||
``pl_tags`` parameter that can be used to tag the request.
|
||||
be passed here. The PromptLayerChatOpenAI adds to optional
|
||||
parameters:
|
||||
``pl_tags``: List of strings to tag the request with.
|
||||
``return_pl_id``: If True, the PromptLayer request ID will be
|
||||
returned in the ``generation_info`` field of the
|
||||
``Generation`` object.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@@ -28,6 +32,7 @@ class PromptLayerChatOpenAI(ChatOpenAI, BaseModel):
|
||||
"""
|
||||
|
||||
pl_tags: Optional[List[str]]
|
||||
return_pl_id: Optional[bool] = False
|
||||
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
@@ -43,7 +48,7 @@ class PromptLayerChatOpenAI(ChatOpenAI, BaseModel):
|
||||
response_dict, params = super()._create_message_dicts(
|
||||
[generation.message], stop
|
||||
)
|
||||
promptlayer_api_request(
|
||||
pl_request_id = promptlayer_api_request(
|
||||
"langchain.PromptLayerChatOpenAI",
|
||||
"langchain",
|
||||
message_dicts,
|
||||
@@ -53,7 +58,14 @@ class PromptLayerChatOpenAI(ChatOpenAI, BaseModel):
|
||||
request_start_time,
|
||||
request_end_time,
|
||||
get_api_key(),
|
||||
return_pl_id=self.return_pl_id,
|
||||
)
|
||||
if self.return_pl_id:
|
||||
if generation.generation_info is None or not isinstance(
|
||||
generation.generation_info, dict
|
||||
):
|
||||
generation.generation_info = {}
|
||||
generation.generation_info["pl_request_id"] = pl_request_id
|
||||
return generated_responses
|
||||
|
||||
async def _agenerate(
|
||||
@@ -70,7 +82,7 @@ class PromptLayerChatOpenAI(ChatOpenAI, BaseModel):
|
||||
response_dict, params = super()._create_message_dicts(
|
||||
[generation.message], stop
|
||||
)
|
||||
promptlayer_api_request(
|
||||
pl_request_id = promptlayer_api_request(
|
||||
"langchain.PromptLayerChatOpenAI.async",
|
||||
"langchain",
|
||||
message_dicts,
|
||||
@@ -80,5 +92,12 @@ class PromptLayerChatOpenAI(ChatOpenAI, BaseModel):
|
||||
request_start_time,
|
||||
request_end_time,
|
||||
get_api_key(),
|
||||
return_pl_id=self.return_pl_id,
|
||||
)
|
||||
if self.return_pl_id:
|
||||
if generation.generation_info is None or not isinstance(
|
||||
generation.generation_info, dict
|
||||
):
|
||||
generation.generation_info = {}
|
||||
generation.generation_info["pl_request_id"] = pl_request_id
|
||||
return generated_responses
|
||||
|
||||
@@ -1,39 +1,3 @@
|
||||
"""Interface for interacting with a document."""
|
||||
from typing import List
|
||||
from langchain.schema import Document
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""Interface for interacting with a document."""
|
||||
|
||||
page_content: str
|
||||
lookup_str: str = ""
|
||||
lookup_index = 0
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def paragraphs(self) -> List[str]:
|
||||
"""Paragraphs of the page."""
|
||||
return self.page_content.split("\n\n")
|
||||
|
||||
@property
|
||||
def summary(self) -> str:
|
||||
"""Summary of the page (the first paragraph)."""
|
||||
return self.paragraphs[0]
|
||||
|
||||
def lookup(self, string: str) -> str:
|
||||
"""Lookup a term in the page, imitating cmd-F functionality."""
|
||||
if string.lower() != self.lookup_str:
|
||||
self.lookup_str = string.lower()
|
||||
self.lookup_index = 0
|
||||
else:
|
||||
self.lookup_index += 1
|
||||
lookups = [p for p in self.paragraphs if self.lookup_str in p.lower()]
|
||||
if len(lookups) == 0:
|
||||
return "No Results"
|
||||
elif self.lookup_index >= len(lookups):
|
||||
return "No More Results"
|
||||
else:
|
||||
result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})"
|
||||
return f"{result_prefix} {lookups[self.lookup_index]}"
|
||||
__all__ = ["Document"]
|
||||
|
||||
@@ -17,7 +17,7 @@ class Wikipedia(Docstore):
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import wikipedia python package. "
|
||||
"Please it install it with `pip install wikipedia`."
|
||||
"Please install it with `pip install wikipedia`."
|
||||
)
|
||||
|
||||
def search(self, search: str) -> Union[str, Document]:
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
from langchain.document_loaders.airbyte_json import AirbyteJSONLoader
|
||||
from langchain.document_loaders.azlyrics import AZLyricsLoader
|
||||
from langchain.document_loaders.blackboard import BlackboardLoader
|
||||
from langchain.document_loaders.college_confidential import CollegeConfidentialLoader
|
||||
from langchain.document_loaders.conllu import CoNLLULoader
|
||||
from langchain.document_loaders.csv import CSVLoader
|
||||
from langchain.document_loaders.csv_loader import CSVLoader
|
||||
from langchain.document_loaders.directory import DirectoryLoader
|
||||
from langchain.document_loaders.docx import UnstructuredDocxLoader
|
||||
from langchain.document_loaders.email import UnstructuredEmailLoader
|
||||
from langchain.document_loaders.evernote import EverNoteLoader
|
||||
from langchain.document_loaders.facebook_chat import FacebookChatLoader
|
||||
@@ -17,6 +17,7 @@ from langchain.document_loaders.googledrive import GoogleDriveLoader
|
||||
from langchain.document_loaders.gutenberg import GutenbergLoader
|
||||
from langchain.document_loaders.hn import HNLoader
|
||||
from langchain.document_loaders.html import UnstructuredHTMLLoader
|
||||
from langchain.document_loaders.html_bs import BSHTMLLoader
|
||||
from langchain.document_loaders.ifixit import IFixitLoader
|
||||
from langchain.document_loaders.image import UnstructuredImageLoader
|
||||
from langchain.document_loaders.imsdb import IMSDbLoader
|
||||
@@ -64,12 +65,12 @@ __all__ = [
|
||||
"ReadTheDocsLoader",
|
||||
"GoogleDriveLoader",
|
||||
"UnstructuredHTMLLoader",
|
||||
"BSHTMLLoader",
|
||||
"UnstructuredPowerPointLoader",
|
||||
"UnstructuredWordDocumentLoader",
|
||||
"UnstructuredPDFLoader",
|
||||
"UnstructuredImageLoader",
|
||||
"ObsidianLoader",
|
||||
"UnstructuredDocxLoader",
|
||||
"UnstructuredEmailLoader",
|
||||
"UnstructuredMarkdownLoader",
|
||||
"RoamLoader",
|
||||
@@ -102,4 +103,5 @@ __all__ = [
|
||||
"GoogleApiYoutubeLoader",
|
||||
"GoogleApiClient",
|
||||
"CSVLoader",
|
||||
"BlackboardLoader",
|
||||
]
|
||||
|
||||
293
langchain/document_loaders/blackboard.py
Normal file
293
langchain/document_loaders/blackboard.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""Loader that loads all documents from a blackboard course."""
|
||||
import contextlib
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from urllib.parse import unquote
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.directory import DirectoryLoader
|
||||
from langchain.document_loaders.pdf import PyPDFLoader
|
||||
from langchain.document_loaders.web_base import WebBaseLoader
|
||||
|
||||
|
||||
class BlackboardLoader(WebBaseLoader):
|
||||
"""Loader that loads all documents from a Blackboard course.
|
||||
|
||||
This loader is not compatible with all Blackboard courses. It is only
|
||||
compatible with courses that use the new Blackboard interface.
|
||||
To use this loader, you must have the BbRouter cookie. You can get this
|
||||
cookie by logging into the course and then copying the value of the
|
||||
BbRouter cookie from the browser's developer tools.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.document_loaders import BlackboardLoader
|
||||
|
||||
loader = BlackboardLoader(
|
||||
blackboard_course_url="https://blackboard.example.com/webapps/blackboard/execute/announcement?method=search&context=course_entry&course_id=_123456_1",
|
||||
bbrouter="expires:12345...",
|
||||
)
|
||||
documents = loader.load()
|
||||
|
||||
"""
|
||||
|
||||
base_url: str
|
||||
folder_path: str
|
||||
load_all_recursively: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
blackboard_course_url: str,
|
||||
bbrouter: str,
|
||||
load_all_recursively: bool = True,
|
||||
basic_auth: Optional[Tuple[str, str]] = None,
|
||||
cookies: Optional[dict] = None,
|
||||
):
|
||||
"""Initialize with blackboard course url.
|
||||
|
||||
The BbRouter cookie is required for most blackboard courses.
|
||||
|
||||
Args:
|
||||
blackboard_course_url: Blackboard course url.
|
||||
bbrouter: BbRouter cookie.
|
||||
load_all_recursively: If True, load all documents recursively.
|
||||
basic_auth: Basic auth credentials.
|
||||
cookies: Cookies.
|
||||
|
||||
Raises:
|
||||
ValueError: If blackboard course url is invalid.
|
||||
"""
|
||||
super().__init__(blackboard_course_url)
|
||||
# Get base url
|
||||
try:
|
||||
self.base_url = blackboard_course_url.split("/webapps/blackboard")[0]
|
||||
except IndexError:
|
||||
raise ValueError(
|
||||
"Invalid blackboard course url. "
|
||||
"Please provide a url that starts with "
|
||||
"https://<blackboard_url>/webapps/blackboard"
|
||||
)
|
||||
if basic_auth is not None:
|
||||
self.session.auth = basic_auth
|
||||
# Combine cookies
|
||||
if cookies is None:
|
||||
cookies = {}
|
||||
cookies.update({"BbRouter": bbrouter})
|
||||
self.session.cookies.update(cookies)
|
||||
self.load_all_recursively = load_all_recursively
|
||||
self.check_bs4()
|
||||
|
||||
def check_bs4(self) -> None:
|
||||
"""Check if BeautifulSoup4 is installed.
|
||||
|
||||
Raises:
|
||||
ImportError: If BeautifulSoup4 is not installed.
|
||||
"""
|
||||
try:
|
||||
import bs4 # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"BeautifulSoup4 is required for BlackboardLoader. "
|
||||
"Please install it with `pip install beautifulsoup4`."
|
||||
)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load data into document objects.
|
||||
|
||||
Returns:
|
||||
List of documents.
|
||||
"""
|
||||
if self.load_all_recursively:
|
||||
soup_info = self.scrape()
|
||||
self.folder_path = self._get_folder_path(soup_info)
|
||||
relative_paths = self._get_paths(soup_info)
|
||||
documents = []
|
||||
for path in relative_paths:
|
||||
url = self.base_url + path
|
||||
print(f"Fetching documents from {url}")
|
||||
soup_info = self._scrape(url)
|
||||
with contextlib.suppress(ValueError):
|
||||
documents.extend(self._get_documents(soup_info))
|
||||
return documents
|
||||
else:
|
||||
print(f"Fetching documents from {self.web_path}")
|
||||
soup_info = self.scrape()
|
||||
self.folder_path = self._get_folder_path(soup_info)
|
||||
return self._get_documents(soup_info)
|
||||
|
||||
def _get_folder_path(self, soup: Any) -> str:
|
||||
"""Get the folder path to save the documents in.
|
||||
|
||||
Args:
|
||||
soup: BeautifulSoup4 soup object.
|
||||
|
||||
Returns:
|
||||
Folder path.
|
||||
"""
|
||||
# Get the course name
|
||||
course_name = soup.find("span", {"id": "crumb_1"})
|
||||
if course_name is None:
|
||||
raise ValueError("No course name found.")
|
||||
course_name = course_name.text.strip()
|
||||
# Prepare the folder path
|
||||
course_name_clean = (
|
||||
unquote(course_name)
|
||||
.replace(" ", "_")
|
||||
.replace("/", "_")
|
||||
.replace(":", "_")
|
||||
.replace(",", "_")
|
||||
.replace("?", "_")
|
||||
.replace("'", "_")
|
||||
.replace("!", "_")
|
||||
.replace('"', "_")
|
||||
)
|
||||
# Get the folder path
|
||||
folder_path = Path(".") / course_name_clean
|
||||
return str(folder_path)
|
||||
|
||||
def _get_documents(self, soup: Any) -> List[Document]:
|
||||
"""Fetch content from page and return Documents.
|
||||
|
||||
Args:
|
||||
soup: BeautifulSoup4 soup object.
|
||||
|
||||
Returns:
|
||||
List of documents.
|
||||
"""
|
||||
attachments = self._get_attachments(soup)
|
||||
self._download_attachments(attachments)
|
||||
documents = self._load_documents()
|
||||
return documents
|
||||
|
||||
def _get_attachments(self, soup: Any) -> List[str]:
|
||||
"""Get all attachments from a page.
|
||||
|
||||
Args:
|
||||
soup: BeautifulSoup4 soup object.
|
||||
|
||||
Returns:
|
||||
List of attachments.
|
||||
"""
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
|
||||
# Get content list
|
||||
content_list = soup.find("ul", {"class": "contentList"})
|
||||
if content_list is None:
|
||||
raise ValueError("No content list found.")
|
||||
content_list: BeautifulSoup # type: ignore
|
||||
# Get all attachments
|
||||
attachments = []
|
||||
for attachment in content_list.find_all("ul", {"class": "attachments"}):
|
||||
attachment: Tag # type: ignore
|
||||
for link in attachment.find_all("a"):
|
||||
link: Tag # type: ignore
|
||||
href = link.get("href")
|
||||
# Only add if href is not None and does not start with #
|
||||
if href is not None and not href.startswith("#"):
|
||||
attachments.append(href)
|
||||
return attachments
|
||||
|
||||
def _download_attachments(self, attachments: List[str]) -> None:
|
||||
"""Download all attachments.
|
||||
|
||||
Args:
|
||||
attachments: List of attachments.
|
||||
"""
|
||||
# Make sure the folder exists
|
||||
Path(self.folder_path).mkdir(parents=True, exist_ok=True)
|
||||
# Download all attachments
|
||||
for attachment in attachments:
|
||||
self.download(attachment)
|
||||
|
||||
def _load_documents(self) -> List[Document]:
|
||||
"""Load all documents in the folder.
|
||||
|
||||
Returns:
|
||||
List of documents.
|
||||
"""
|
||||
# Create the document loader
|
||||
loader = DirectoryLoader(
|
||||
path=self.folder_path, glob="*.pdf", loader_cls=PyPDFLoader # type: ignore
|
||||
)
|
||||
# Load the documents
|
||||
documents = loader.load()
|
||||
# Return all documents
|
||||
return documents
|
||||
|
||||
def _get_paths(self, soup: Any) -> List[str]:
|
||||
"""Get all relative paths in the navbar."""
|
||||
relative_paths = []
|
||||
course_menu = soup.find("ul", {"class": "courseMenu"})
|
||||
if course_menu is None:
|
||||
raise ValueError("No course menu found.")
|
||||
for link in course_menu.find_all("a"):
|
||||
href = link.get("href")
|
||||
if href is not None and href.startswith("/"):
|
||||
relative_paths.append(href)
|
||||
return relative_paths
|
||||
|
||||
def download(self, path: str) -> None:
|
||||
"""Download a file from a url.
|
||||
|
||||
Args:
|
||||
path: Path to the file.
|
||||
"""
|
||||
# Get the file content
|
||||
response = self.session.get(self.base_url + path, allow_redirects=True)
|
||||
# Get the filename
|
||||
filename = self.parse_filename(response.url)
|
||||
# Write the file to disk
|
||||
with open(Path(self.folder_path) / filename, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
def parse_filename(self, url: str) -> str:
|
||||
"""Parse the filename from a url.
|
||||
|
||||
Args:
|
||||
url: Url to parse the filename from.
|
||||
|
||||
Returns:
|
||||
The filename.
|
||||
"""
|
||||
if (url_path := Path(url)) and url_path.suffix == ".pdf":
|
||||
return url_path.name
|
||||
else:
|
||||
return self._parse_filename_from_url(url)
|
||||
|
||||
def _parse_filename_from_url(self, url: str) -> str:
|
||||
"""Parse the filename from a url.
|
||||
|
||||
Args:
|
||||
url: Url to parse the filename from.
|
||||
|
||||
Returns:
|
||||
The filename.
|
||||
|
||||
Raises:
|
||||
ValueError: If the filename could not be parsed.
|
||||
"""
|
||||
filename_matches = re.search(r"filename%2A%3DUTF-8%27%27(.+)", url)
|
||||
if filename_matches:
|
||||
filename = filename_matches.group(1)
|
||||
else:
|
||||
raise ValueError(f"Could not parse filename from {url}")
|
||||
if ".pdf" not in filename:
|
||||
raise ValueError(f"Incorrect file type: {filename}")
|
||||
filename = filename.split(".pdf")[0] + ".pdf"
|
||||
filename = unquote(filename)
|
||||
filename = filename.replace("%20", " ")
|
||||
return filename
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loader = BlackboardLoader(
|
||||
"https://<YOUR BLACKBOARD URL"
|
||||
" HERE>/webapps/blackboard/content/listContent.jsp?course_id=_<YOUR COURSE ID"
|
||||
" HERE>_1&content_id=_<YOUR CONTENT ID HERE>_1&mode=reset",
|
||||
"<YOUR BBROUTER COOKIE HERE>",
|
||||
load_all_recursively=True,
|
||||
)
|
||||
documents = loader.load()
|
||||
print(f"Loaded {len(documents)} pages of PDFs from {loader.web_path}")
|
||||
@@ -1,47 +0,0 @@
|
||||
from csv import DictReader
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
class CSVLoader(BaseLoader):
|
||||
"""Loads a CSV file into a list of documents.
|
||||
|
||||
Each document represents one row of the CSV file. Every row is converted into a
|
||||
key/value pair and outputted to a new line in the document's page_content.
|
||||
|
||||
Output Example:
|
||||
.. code-block:: txt
|
||||
|
||||
column1: value1
|
||||
column2: value2
|
||||
column3: value3
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, csv_args: Optional[Dict] = None):
|
||||
self.file_path = file_path
|
||||
if csv_args is None:
|
||||
self.csv_args = {
|
||||
"delimiter": ",",
|
||||
"quotechar": '"',
|
||||
}
|
||||
else:
|
||||
self.csv_args = csv_args
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
docs = []
|
||||
|
||||
with open(self.file_path, newline="") as csvfile:
|
||||
csv = DictReader(csvfile, **self.csv_args) # type: ignore
|
||||
for i, row in enumerate(csv):
|
||||
docs.append(
|
||||
Document(
|
||||
page_content="\n".join(
|
||||
f"{k.strip()}: {v.strip()}" for k, v in row.items()
|
||||
),
|
||||
metadata={"source": self.file_path, "row": i},
|
||||
)
|
||||
)
|
||||
|
||||
return docs
|
||||
62
langchain/document_loaders/csv_loader.py
Normal file
62
langchain/document_loaders/csv_loader.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from csv import DictReader
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
class CSVLoader(BaseLoader):
|
||||
"""Loads a CSV file into a list of documents.
|
||||
|
||||
Each document represents one row of the CSV file. Every row is converted into a
|
||||
key/value pair and outputted to a new line in the document's page_content.
|
||||
|
||||
The source for each document loaded from csv is set to the value of the
|
||||
`file_path` argument for all doucments by default.
|
||||
You can override this by setting the `source_column` argument to the
|
||||
name of a column in the CSV file.
|
||||
The source of each document will then be set to the value of the column
|
||||
with the name specified in `source_column`.
|
||||
|
||||
Output Example:
|
||||
.. code-block:: txt
|
||||
|
||||
column1: value1
|
||||
column2: value2
|
||||
column3: value3
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
source_column: Optional[str] = None,
|
||||
csv_args: Optional[Dict] = None,
|
||||
encoding: Optional[str] = None,
|
||||
):
|
||||
self.file_path = file_path
|
||||
self.source_column = source_column
|
||||
self.encoding = encoding
|
||||
if csv_args is None:
|
||||
self.csv_args = {
|
||||
"delimiter": ",",
|
||||
"quotechar": '"',
|
||||
}
|
||||
else:
|
||||
self.csv_args = csv_args
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
docs = []
|
||||
|
||||
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
|
||||
csv = DictReader(csvfile, **self.csv_args) # type: ignore
|
||||
for i, row in enumerate(csv):
|
||||
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
|
||||
if self.source_column is not None:
|
||||
source = row[self.source_column]
|
||||
else:
|
||||
source = self.file_path
|
||||
metadata = {"source": source, "row": i}
|
||||
doc = Document(page_content=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
@@ -5,10 +5,13 @@ from typing import List, Type, Union
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.html_bs import BSHTMLLoader
|
||||
from langchain.document_loaders.text import TextLoader
|
||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
|
||||
FILE_LOADER_TYPE = Union[Type[UnstructuredFileLoader], Type[TextLoader]]
|
||||
FILE_LOADER_TYPE = Union[
|
||||
Type[UnstructuredFileLoader], Type[TextLoader], Type[BSHTMLLoader]
|
||||
]
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
"""Loader that loads Microsoft Word files."""
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
|
||||
|
||||
class UnstructuredDocxLoader(UnstructuredFileLoader):
|
||||
"""Loader that uses unstructured to load Microsoft Word files."""
|
||||
|
||||
def _get_elements(self) -> List:
|
||||
from unstructured.partition.docx import partition_docx
|
||||
|
||||
return partition_docx(filename=self.file_path)
|
||||
59
langchain/document_loaders/figma.py
Normal file
59
langchain/document_loaders/figma.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Loader that loads Figma files json dump."""
|
||||
import json
|
||||
import urllib.request
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
def _stringify_value(val: Any) -> str:
|
||||
if isinstance(val, str):
|
||||
return val
|
||||
elif isinstance(val, dict):
|
||||
return "\n" + _stringify_dict(val)
|
||||
elif isinstance(val, list):
|
||||
return "\n".join(_stringify_value(v) for v in val)
|
||||
else:
|
||||
return str(val)
|
||||
|
||||
|
||||
def _stringify_dict(data: dict) -> str:
|
||||
text = ""
|
||||
for key, value in data.items():
|
||||
text += key + ": " + _stringify_value(data[key]) + "\n"
|
||||
return text
|
||||
|
||||
|
||||
class FigmaFileLoader(BaseLoader):
|
||||
"""Loader that loads Figma file json."""
|
||||
|
||||
def __init__(self, access_token: str, ids: str, key: str):
|
||||
"""Initialize with access token, ids, and key."""
|
||||
self.access_token = access_token
|
||||
self.ids = ids
|
||||
self.key = key
|
||||
|
||||
def _construct_figma_api_url(self) -> str:
|
||||
api_url = "https://api.figma.com/v1/files/%s/nodes?ids=%s" % (
|
||||
self.key,
|
||||
self.ids,
|
||||
)
|
||||
return api_url
|
||||
|
||||
def _get_figma_file(self) -> Any:
|
||||
"""Get Figma file from Figma REST API."""
|
||||
headers = {"X-Figma-Token": self.access_token}
|
||||
request = urllib.request.Request(
|
||||
self._construct_figma_api_url(), headers=headers
|
||||
)
|
||||
with urllib.request.urlopen(request) as response:
|
||||
json_data = json.loads(response.read().decode())
|
||||
return json_data
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load file"""
|
||||
data = self._get_figma_file()
|
||||
text = _stringify_dict(data)
|
||||
metadata = {"source": self._construct_figma_api_url()}
|
||||
return [Document(page_content=text, metadata=metadata)]
|
||||
@@ -6,7 +6,8 @@
|
||||
# https://console.cloud.google.com/flows/enableapi?apiid=drive.googleapis.com
|
||||
# 3. Authorize credentials for desktop app:
|
||||
# https://developers.google.com/drive/api/quickstart/python#authorize_credentials_for_a_desktop_application # noqa: E501
|
||||
|
||||
# 4. For service accounts visit
|
||||
# https://cloud.google.com/iam/docs/service-accounts-create
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -22,6 +23,7 @@ SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
|
||||
class GoogleDriveLoader(BaseLoader, BaseModel):
|
||||
"""Loader that loads Google Docs from Google Drive."""
|
||||
|
||||
service_account_key: Path = Path.home() / ".credentials" / "keys.json"
|
||||
credentials_path: Path = Path.home() / ".credentials" / "credentials.json"
|
||||
token_path: Path = Path.home() / ".credentials" / "token.json"
|
||||
folder_id: Optional[str] = None
|
||||
@@ -60,6 +62,7 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
|
||||
# Adapted from https://developers.google.com/drive/api/v3/quickstart/python
|
||||
try:
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2 import service_account
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow
|
||||
except ImportError:
|
||||
@@ -72,6 +75,11 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
|
||||
)
|
||||
|
||||
creds = None
|
||||
if self.service_account_key.exists():
|
||||
return service_account.Credentials.from_service_account_file(
|
||||
str(self.service_account_key), scopes=SCOPES
|
||||
)
|
||||
|
||||
if self.token_path.exists():
|
||||
creds = Credentials.from_authorized_user_file(str(self.token_path), SCOPES)
|
||||
|
||||
|
||||
36
langchain/document_loaders/html_bs.py
Normal file
36
langchain/document_loaders/html_bs.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Loader that uses bs4 to load HTML files, enriching metadata with page title."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class BSHTMLLoader(BaseLoader):
|
||||
"""Loader that uses beautiful soup to parse HTML files."""
|
||||
|
||||
def __init__(self, file_path: str) -> None:
|
||||
self.file_path = file_path
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load HTML document into document objects."""
|
||||
with open(self.file_path, "r") as f:
|
||||
soup = BeautifulSoup(f, features="lxml")
|
||||
|
||||
text = soup.get_text()
|
||||
|
||||
if soup.title:
|
||||
title = str(soup.title.string)
|
||||
else:
|
||||
title = ""
|
||||
|
||||
metadata: Dict[str, Union[str, None]] = {
|
||||
"source": self.file_path,
|
||||
"title": title,
|
||||
}
|
||||
return [Document(page_content=text, metadata=metadata)]
|
||||
@@ -9,16 +9,17 @@ from langchain.document_loaders.base import BaseLoader
|
||||
class ObsidianLoader(BaseLoader):
|
||||
"""Loader that loads Obsidian files from disk."""
|
||||
|
||||
def __init__(self, path: str):
|
||||
def __init__(self, path: str, encoding: str = "UTF-8"):
|
||||
"""Initialize with path."""
|
||||
self.file_path = path
|
||||
self.encoding = encoding
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load documents."""
|
||||
ps = list(Path(self.file_path).glob("**/*.md"))
|
||||
docs = []
|
||||
for p in ps:
|
||||
with open(p) as f:
|
||||
with open(p, encoding=self.encoding) as f:
|
||||
text = f.read()
|
||||
metadata = {"source": str(p)}
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
@@ -114,7 +114,7 @@ class YoutubeLoader(BaseLoader):
|
||||
def load(self) -> List[Document]:
|
||||
"""Load documents."""
|
||||
try:
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
from youtube_transcript_api import NoTranscriptFound, YouTubeTranscriptApi
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import youtube_transcript_api python package. "
|
||||
@@ -129,9 +129,15 @@ class YoutubeLoader(BaseLoader):
|
||||
video_info = self._get_video_info()
|
||||
metadata.update(video_info)
|
||||
|
||||
transcript_pieces = YouTubeTranscriptApi.get_transcript(
|
||||
self.video_id, languages=[self.language]
|
||||
)
|
||||
transcript_list = YouTubeTranscriptApi.list_transcripts(self.video_id)
|
||||
try:
|
||||
transcript = transcript_list.find_transcript([self.language])
|
||||
except NoTranscriptFound:
|
||||
en_transcript = transcript_list.find_transcript(["en"])
|
||||
transcript = en_transcript.translate(self.language)
|
||||
|
||||
transcript_pieces = transcript.fetch()
|
||||
|
||||
transcript = " ".join([t["text"].strip(" ") for t in transcript_pieces])
|
||||
|
||||
return [Document(page_content=transcript, metadata=metadata)]
|
||||
@@ -177,7 +183,7 @@ class GoogleApiYoutubeLoader(BaseLoader):
|
||||
As the service needs a google_api_client, you first have to initialize
|
||||
the GoogleApiClient.
|
||||
|
||||
Additonali you have to either provide a channel name or a list of videoids
|
||||
Additionally you have to either provide a channel name or a list of videoids
|
||||
"https://developers.google.com/docs/api/quickstart/python"
|
||||
|
||||
|
||||
@@ -233,9 +239,16 @@ class GoogleApiYoutubeLoader(BaseLoader):
|
||||
return values
|
||||
|
||||
def _get_transcripe_for_video_id(self, video_id: str) -> str:
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
from youtube_transcript_api import NoTranscriptFound, YouTubeTranscriptApi
|
||||
|
||||
transcript_pieces = YouTubeTranscriptApi.get_transcript(video_id)
|
||||
transcript_list = YouTubeTranscriptApi.list_transcripts(self.video_ids)
|
||||
try:
|
||||
transcript = transcript_list.find_transcript([self.captions_language])
|
||||
except NoTranscriptFound:
|
||||
en_transcript = transcript_list.find_transcript(["en"])
|
||||
transcript = en_transcript.translate(self.captions_language)
|
||||
|
||||
transcript_pieces = transcript.fetch()
|
||||
return " ".join([t["text"].strip(" ") for t in transcript_pieces])
|
||||
|
||||
def _get_document_for_video_id(self, video_id: str, **kwargs: Any) -> Document:
|
||||
|
||||
@@ -10,6 +10,7 @@ from langchain.embeddings.huggingface import (
|
||||
)
|
||||
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.embeddings.sagemaker_endpoint import SagemakerEndpointEmbeddings
|
||||
from langchain.embeddings.self_hosted import SelfHostedEmbeddings
|
||||
from langchain.embeddings.self_hosted_hugging_face import (
|
||||
SelfHostedHuggingFaceEmbeddings,
|
||||
@@ -25,6 +26,7 @@ __all__ = [
|
||||
"CohereEmbeddings",
|
||||
"HuggingFaceHubEmbeddings",
|
||||
"TensorflowHubEmbeddings",
|
||||
"SagemakerEndpointEmbeddings",
|
||||
"HuggingFaceInstructEmbeddings",
|
||||
"SelfHostedEmbeddings",
|
||||
"SelfHostedHuggingFaceEmbeddings",
|
||||
|
||||
@@ -65,9 +65,35 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
openai = OpenAIEmbeddings(openai_api_key="my-api-key")
|
||||
|
||||
In order to use the library with Microsoft Azure endpoints, you need to set
|
||||
the OPENAI_API_TYPE, OPENAI_API_BASE, OPENAI_API_KEY and optionally and
|
||||
API_VERSION.
|
||||
The OPENAI_API_TYPE must be set to 'azure' and the others correspond to
|
||||
the properties of your endpoint.
|
||||
In addition, the deployment name must be passed as the model parameter.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
import os
|
||||
os.environ["OPENAI_API_TYPE"] = "azure"
|
||||
os.environ["OPENAI_API_BASE"] = "https://<your-endpoint.openai.azure.com/"
|
||||
os.environ["OPENAI_API_KEY"] = "your AzureOpenAI key"
|
||||
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings(model="your-embeddings-deployment-name")
|
||||
text = "This is a test query."
|
||||
query_result = embeddings.embed_query(text)
|
||||
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model: str = "text-embedding-ada-002"
|
||||
|
||||
# TODO: deprecate these two in favor of model
|
||||
# https://community.openai.com/t/api-update-engines-models/18597
|
||||
# https://github.com/openai/openai-python/issues/132
|
||||
document_model_name: str = "text-embedding-ada-002"
|
||||
query_model_name: str = "text-embedding-ada-002"
|
||||
embedding_ctx_length: int = -1
|
||||
@@ -85,6 +111,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
# TODO: deprecate this
|
||||
@root_validator(pre=True)
|
||||
def get_model_names(cls, values: Dict) -> Dict:
|
||||
# model_name is for first generation, and model is for second generation.
|
||||
# Both are not allowed together.
|
||||
if "model_name" in values and "model" in values:
|
||||
raise ValueError(
|
||||
"Both `model_name` and `model` were provided, "
|
||||
"but only one should be."
|
||||
)
|
||||
|
||||
"""Get model names from just old model name."""
|
||||
if "model_name" in values:
|
||||
if "document_model_name" in values:
|
||||
@@ -100,6 +134,23 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
model_name = values.pop("model_name")
|
||||
values["document_model_name"] = f"text-search-{model_name}-doc-001"
|
||||
values["query_model_name"] = f"text-search-{model_name}-query-001"
|
||||
|
||||
# Set document/query model names from model parameter.
|
||||
if "model" in values:
|
||||
if "document_model_name" in values:
|
||||
raise ValueError(
|
||||
"Both `model` and `document_model_name` were provided, "
|
||||
"but only one should be."
|
||||
)
|
||||
if "query_model_name" in values:
|
||||
raise ValueError(
|
||||
"Both `model` and `query_model_name` were provided, "
|
||||
"but only one should be."
|
||||
)
|
||||
model = values.get("model")
|
||||
values["document_model_name"] = model
|
||||
values["query_model_name"] = model
|
||||
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
|
||||
194
langchain/embeddings/sagemaker_endpoint.py
Normal file
194
langchain/embeddings/sagemaker_endpoint.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Wrapper around Sagemaker InvokeEndpoint API."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
|
||||
|
||||
|
||||
class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around custom Sagemaker Inference Endpoints.
|
||||
|
||||
To use, you must supply the endpoint name from your deployed
|
||||
Sagemaker model & the region where it is deployed.
|
||||
|
||||
To authenticate, the AWS client uses the following methods to
|
||||
automatically load credentials:
|
||||
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
|
||||
If a specific credential profile should be used, you must pass
|
||||
the name of the profile from the ~/.aws/credentials file that is to be used.
|
||||
|
||||
Make sure the credentials / roles used have the required policies to
|
||||
access the Sagemaker endpoint.
|
||||
See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
|
||||
"""
|
||||
|
||||
"""
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import SagemakerEndpointEmbeddings
|
||||
endpoint_name = (
|
||||
"my-endpoint-name"
|
||||
)
|
||||
region_name = (
|
||||
"us-west-2"
|
||||
)
|
||||
credentials_profile_name = (
|
||||
"default"
|
||||
)
|
||||
se = SagemakerEndpointEmbeddings(
|
||||
endpoint_name=endpoint_name,
|
||||
region_name=region_name,
|
||||
credentials_profile_name=credentials_profile_name
|
||||
)
|
||||
"""
|
||||
client: Any #: :meta private:
|
||||
|
||||
endpoint_name: str = ""
|
||||
"""The name of the endpoint from the deployed Sagemaker model.
|
||||
Must be unique within an AWS Region."""
|
||||
|
||||
region_name: str = ""
|
||||
"""The aws region where the Sagemaker model is deployed, eg. `us-west-2`."""
|
||||
|
||||
credentials_profile_name: Optional[str] = None
|
||||
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
|
||||
has either access keys or role information specified.
|
||||
If not specified, the default credential profile or, if on an EC2 instance,
|
||||
credentials from IMDS will be used.
|
||||
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
"""
|
||||
|
||||
content_handler: ContentHandlerBase
|
||||
"""The content handler class that provides an input and
|
||||
output transform functions to handle formats between LLM
|
||||
and the endpoint.
|
||||
"""
|
||||
|
||||
"""
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
|
||||
|
||||
class ContentHandler(ContentHandlerBase):
|
||||
content_type = "application/json"
|
||||
accepts = "application/json"
|
||||
|
||||
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
input_str = json.dumps({prompt: prompt, **model_kwargs})
|
||||
return input_str.encode('utf-8')
|
||||
|
||||
def transform_output(self, output: bytes) -> str:
|
||||
response_json = json.loads(output.read().decode("utf-8"))
|
||||
return response_json[0]["generated_text"]
|
||||
"""
|
||||
|
||||
model_kwargs: Optional[Dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
endpoint_kwargs: Optional[Dict] = None
|
||||
"""Optional attributes passed to the invoke_endpoint
|
||||
function. See `boto3`_. docs for more info.
|
||||
.. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that AWS credentials to and python package exists in environment."""
|
||||
try:
|
||||
import boto3
|
||||
|
||||
try:
|
||||
if values["credentials_profile_name"] is not None:
|
||||
session = boto3.Session(
|
||||
profile_name=values["credentials_profile_name"]
|
||||
)
|
||||
else:
|
||||
# use default credentials
|
||||
session = boto3.Session()
|
||||
|
||||
values["client"] = session.client(
|
||||
"sagemaker-runtime", region_name=values["region_name"]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not load credentials to authenticate with AWS client. "
|
||||
"Please check that credentials in the specified "
|
||||
"profile name are valid."
|
||||
) from e
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import boto3 python package. "
|
||||
"Please it install it with `pip install boto3`."
|
||||
)
|
||||
return values
|
||||
|
||||
def _embedding_func(self, texts: List[str]) -> List[float]:
|
||||
"""Call out to SageMaker Inference embedding endpoint."""
|
||||
# replace newlines, which can negatively affect performance.
|
||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_endpoint_kwargs = self.endpoint_kwargs or {}
|
||||
|
||||
body = self.content_handler.transform_input(texts, _model_kwargs)
|
||||
content_type = self.content_handler.content_type
|
||||
accepts = self.content_handler.accepts
|
||||
|
||||
# send request
|
||||
try:
|
||||
response = self.client.invoke_endpoint(
|
||||
EndpointName=self.endpoint_name,
|
||||
Body=body,
|
||||
ContentType=content_type,
|
||||
Accept=accepts,
|
||||
**_endpoint_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
|
||||
return self.content_handler.transform_output(response["Body"])
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], chunk_size: int = 64
|
||||
) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a SageMaker Inference Endpoint.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
chunk_size: The chunk size defines how many input texts will
|
||||
be grouped together as request. If None, will use the
|
||||
chunk size specified by the class.
|
||||
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
results = []
|
||||
_chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
|
||||
for i in range(0, len(texts), _chunk_size):
|
||||
response = self._embedding_func(texts[i : i + _chunk_size])
|
||||
results.append(response)
|
||||
return results
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a SageMaker inference endpoint.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self._embedding_func([text])
|
||||
@@ -2,8 +2,8 @@ from typing import Any, List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
|
||||
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
|
||||
from langchain.chains.vector_db_qa.base import VectorDBQA
|
||||
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
|
||||
from langchain.chains.retrieval_qa.base import RetrievalQA
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
@@ -32,7 +32,9 @@ class VectorStoreIndexWrapper(BaseModel):
|
||||
def query(self, question: str, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str:
|
||||
"""Query the vectorstore."""
|
||||
llm = llm or OpenAI(temperature=0)
|
||||
chain = VectorDBQA.from_chain_type(llm, vectorstore=self.vectorstore, **kwargs)
|
||||
chain = RetrievalQA.from_chain_type(
|
||||
llm, retriever=self.vectorstore.as_retriever(), **kwargs
|
||||
)
|
||||
return chain.run(question)
|
||||
|
||||
def query_with_sources(
|
||||
@@ -40,8 +42,8 @@ class VectorStoreIndexWrapper(BaseModel):
|
||||
) -> dict:
|
||||
"""Query the vectorstore and get back sources."""
|
||||
llm = llm or OpenAI(temperature=0)
|
||||
chain = VectorDBQAWithSourcesChain.from_chain_type(
|
||||
llm, vectorstore=self.vectorstore, **kwargs
|
||||
chain = RetrievalQAWithSourcesChain.from_chain_type(
|
||||
llm, retriever=self.vectorstore.as_retriever(), **kwargs
|
||||
)
|
||||
return chain({chain.question_key: question})
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from langchain.llms.nlpcloud import NLPCloud
|
||||
from langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat
|
||||
from langchain.llms.petals import Petals
|
||||
from langchain.llms.promptlayer_openai import PromptLayerOpenAI, PromptLayerOpenAIChat
|
||||
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
|
||||
from langchain.llms.self_hosted import SelfHostedPipeline
|
||||
from langchain.llms.self_hosted_hugging_face import SelfHostedHuggingFaceLLM
|
||||
from langchain.llms.stochasticai import StochasticAI
|
||||
@@ -40,6 +41,7 @@ __all__ = [
|
||||
"Petals",
|
||||
"HuggingFaceEndpoint",
|
||||
"HuggingFaceHub",
|
||||
"SagemakerEndpoint",
|
||||
"HuggingFacePipeline",
|
||||
"AI21",
|
||||
"AzureOpenAI",
|
||||
@@ -64,6 +66,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
"huggingface_hub": HuggingFaceHub,
|
||||
"huggingface_endpoint": HuggingFaceEndpoint,
|
||||
"modal": Modal,
|
||||
"sagemaker_endpoint": SagemakerEndpoint,
|
||||
"nlpcloud": NLPCloud,
|
||||
"openai": OpenAI,
|
||||
"petals": Petals,
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@@ -163,7 +164,13 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
||||
|
||||
def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
|
||||
"""Initialize the OpenAI object."""
|
||||
if data.get("model_name", "").startswith("gpt-3.5-turbo"):
|
||||
model_name = data.get("model_name", "")
|
||||
if model_name.startswith("gpt-3.5-turbo") or model_name.startswith("gpt-4"):
|
||||
warnings.warn(
|
||||
"You are trying to use a chat model. This way of initializing it is "
|
||||
"no longer supported. Instead, please use: "
|
||||
"`from langchain.chat_models import ChatOpenAI`"
|
||||
)
|
||||
return OpenAIChat(**data)
|
||||
return super().__new__(cls)
|
||||
|
||||
@@ -362,9 +369,8 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
||||
for choice in sub_choices
|
||||
]
|
||||
)
|
||||
return LLMResult(
|
||||
generations=generations, llm_output={"token_usage": token_usage}
|
||||
)
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
|
||||
"""Call OpenAI with streaming flag and return the resulting generator.
|
||||
@@ -599,6 +605,11 @@ class OpenAIChat(BaseLLM, BaseModel):
|
||||
"due to an old version of the openai package. Try upgrading it "
|
||||
"with `pip install --upgrade openai`."
|
||||
)
|
||||
warnings.warn(
|
||||
"You are trying to use a chat model. This way of initializing it is "
|
||||
"no longer supported. Instead, please use: "
|
||||
"`from langchain.chat_models import ChatOpenAI`"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
@@ -643,11 +654,15 @@ class OpenAIChat(BaseLLM, BaseModel):
|
||||
)
|
||||
else:
|
||||
full_response = completion_with_retry(self, messages=messages, **params)
|
||||
llm_output = {
|
||||
"token_usage": full_response["usage"],
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
return LLMResult(
|
||||
generations=[
|
||||
[Generation(text=full_response["choices"][0]["message"]["content"])]
|
||||
],
|
||||
llm_output={"token_usage": full_response["usage"]},
|
||||
llm_output=llm_output,
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
@@ -679,11 +694,15 @@ class OpenAIChat(BaseLLM, BaseModel):
|
||||
full_response = await acompletion_with_retry(
|
||||
self, messages=messages, **params
|
||||
)
|
||||
llm_output = {
|
||||
"token_usage": full_response["usage"],
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
return LLMResult(
|
||||
generations=[
|
||||
[Generation(text=full_response["choices"][0]["message"]["content"])]
|
||||
],
|
||||
llm_output={"token_usage": full_response["usage"]},
|
||||
llm_output=llm_output,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -17,8 +17,12 @@ class PromptLayerOpenAI(OpenAI, BaseModel):
|
||||
promptlayer key respectively.
|
||||
|
||||
All parameters that can be passed to the OpenAI LLM can also
|
||||
be passed here. The PromptLayerOpenAI LLM adds an extra
|
||||
``pl_tags`` parameter that can be used to tag the request.
|
||||
be passed here. The PromptLayerOpenAI LLM adds two optional
|
||||
parameters:
|
||||
``pl_tags``: List of strings to tag the request with.
|
||||
``return_pl_id``: If True, the PromptLayer request ID will be
|
||||
returned in the ``generation_info`` field of the
|
||||
``Generation`` object.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@@ -28,6 +32,7 @@ class PromptLayerOpenAI(OpenAI, BaseModel):
|
||||
"""
|
||||
|
||||
pl_tags: Optional[List[str]]
|
||||
return_pl_id: Optional[bool] = False
|
||||
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
@@ -40,11 +45,12 @@ class PromptLayerOpenAI(OpenAI, BaseModel):
|
||||
request_end_time = datetime.datetime.now().timestamp()
|
||||
for i in range(len(prompts)):
|
||||
prompt = prompts[i]
|
||||
generation = generated_responses.generations[i][0]
|
||||
resp = {
|
||||
"text": generated_responses.generations[i][0].text,
|
||||
"text": generation.text,
|
||||
"llm_output": generated_responses.llm_output,
|
||||
}
|
||||
promptlayer_api_request(
|
||||
pl_request_id = promptlayer_api_request(
|
||||
"langchain.PromptLayerOpenAI",
|
||||
"langchain",
|
||||
[prompt],
|
||||
@@ -54,7 +60,14 @@ class PromptLayerOpenAI(OpenAI, BaseModel):
|
||||
request_start_time,
|
||||
request_end_time,
|
||||
get_api_key(),
|
||||
return_pl_id=self.return_pl_id,
|
||||
)
|
||||
if self.return_pl_id:
|
||||
if generation.generation_info is None or not isinstance(
|
||||
generation.generation_info, dict
|
||||
):
|
||||
generation.generation_info = {}
|
||||
generation.generation_info["pl_request_id"] = pl_request_id
|
||||
return generated_responses
|
||||
|
||||
async def _agenerate(
|
||||
@@ -67,11 +80,12 @@ class PromptLayerOpenAI(OpenAI, BaseModel):
|
||||
request_end_time = datetime.datetime.now().timestamp()
|
||||
for i in range(len(prompts)):
|
||||
prompt = prompts[i]
|
||||
generation = generated_responses.generations[i][0]
|
||||
resp = {
|
||||
"text": generated_responses.generations[i][0].text,
|
||||
"text": generation.text,
|
||||
"llm_output": generated_responses.llm_output,
|
||||
}
|
||||
promptlayer_api_request(
|
||||
pl_request_id = promptlayer_api_request(
|
||||
"langchain.PromptLayerOpenAI.async",
|
||||
"langchain",
|
||||
[prompt],
|
||||
@@ -81,7 +95,14 @@ class PromptLayerOpenAI(OpenAI, BaseModel):
|
||||
request_start_time,
|
||||
request_end_time,
|
||||
get_api_key(),
|
||||
return_pl_id=self.return_pl_id,
|
||||
)
|
||||
if self.return_pl_id:
|
||||
if generation.generation_info is None or not isinstance(
|
||||
generation.generation_info, dict
|
||||
):
|
||||
generation.generation_info = {}
|
||||
generation.generation_info["pl_request_id"] = pl_request_id
|
||||
return generated_responses
|
||||
|
||||
|
||||
@@ -94,8 +115,12 @@ class PromptLayerOpenAIChat(OpenAIChat, BaseModel):
|
||||
promptlayer key respectively.
|
||||
|
||||
All parameters that can be passed to the OpenAIChat LLM can also
|
||||
be passed here. The PromptLayerOpenAIChat LLM adds an extra
|
||||
``pl_tags`` parameter that can be used to tag the request.
|
||||
be passed here. The PromptLayerOpenAIChat adds two optional
|
||||
parameters:
|
||||
``pl_tags``: List of strings to tag the request with.
|
||||
``return_pl_id``: If True, the PromptLayer request ID will be
|
||||
returned in the ``generation_info`` field of the
|
||||
``Generation`` object.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@@ -105,6 +130,7 @@ class PromptLayerOpenAIChat(OpenAIChat, BaseModel):
|
||||
"""
|
||||
|
||||
pl_tags: Optional[List[str]]
|
||||
return_pl_id: Optional[bool] = False
|
||||
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
@@ -117,11 +143,12 @@ class PromptLayerOpenAIChat(OpenAIChat, BaseModel):
|
||||
request_end_time = datetime.datetime.now().timestamp()
|
||||
for i in range(len(prompts)):
|
||||
prompt = prompts[i]
|
||||
generation = generated_responses.generations[i][0]
|
||||
resp = {
|
||||
"text": generated_responses.generations[i][0].text,
|
||||
"text": generation.text,
|
||||
"llm_output": generated_responses.llm_output,
|
||||
}
|
||||
promptlayer_api_request(
|
||||
pl_request_id = promptlayer_api_request(
|
||||
"langchain.PromptLayerOpenAIChat",
|
||||
"langchain",
|
||||
[prompt],
|
||||
@@ -131,7 +158,14 @@ class PromptLayerOpenAIChat(OpenAIChat, BaseModel):
|
||||
request_start_time,
|
||||
request_end_time,
|
||||
get_api_key(),
|
||||
return_pl_id=self.return_pl_id,
|
||||
)
|
||||
if self.return_pl_id:
|
||||
if generation.generation_info is None or not isinstance(
|
||||
generation.generation_info, dict
|
||||
):
|
||||
generation.generation_info = {}
|
||||
generation.generation_info["pl_request_id"] = pl_request_id
|
||||
return generated_responses
|
||||
|
||||
async def _agenerate(
|
||||
@@ -144,16 +178,27 @@ class PromptLayerOpenAIChat(OpenAIChat, BaseModel):
|
||||
request_end_time = datetime.datetime.now().timestamp()
|
||||
for i in range(len(prompts)):
|
||||
prompt = prompts[i]
|
||||
resp = generated_responses.generations[i]
|
||||
promptlayer_api_request(
|
||||
generation = generated_responses.generations[i][0]
|
||||
resp = {
|
||||
"text": generation.text,
|
||||
"llm_output": generated_responses.llm_output,
|
||||
}
|
||||
pl_request_id = promptlayer_api_request(
|
||||
"langchain.PromptLayerOpenAIChat.async",
|
||||
"langchain",
|
||||
[prompt],
|
||||
self._identifying_params,
|
||||
self.pl_tags,
|
||||
resp[0].text,
|
||||
resp,
|
||||
request_start_time,
|
||||
request_end_time,
|
||||
get_api_key(),
|
||||
return_pl_id=self.return_pl_id,
|
||||
)
|
||||
if self.return_pl_id:
|
||||
if generation.generation_info is None or not isinstance(
|
||||
generation.generation_info, dict
|
||||
):
|
||||
generation.generation_info = {}
|
||||
generation.generation_info["pl_request_id"] = pl_request_id
|
||||
return generated_responses
|
||||
|
||||
237
langchain/llms/sagemaker_endpoint.py
Normal file
237
langchain/llms/sagemaker_endpoint.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Wrapper around Sagemaker InvokeEndpoint API."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
|
||||
class ContentHandlerBase(ABC):
|
||||
"""A handler class to transform input from LLM to a
|
||||
format that SageMaker endpoint expects. Similarily,
|
||||
the class also handles transforming output from the
|
||||
SageMaker endpoint to a format that LLM class expects.
|
||||
"""
|
||||
|
||||
"""
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class ContentHandler(ContentHandlerBase):
|
||||
content_type = "application/json"
|
||||
accepts = "application/json"
|
||||
|
||||
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
input_str = json.dumps({prompt: prompt, **model_kwargs})
|
||||
return input_str.encode('utf-8')
|
||||
|
||||
def transform_output(self, output: bytes) -> str:
|
||||
response_json = json.loads(output.read().decode("utf-8"))
|
||||
return response_json[0]["generated_text"]
|
||||
"""
|
||||
|
||||
content_type: Optional[str] = "text/plain"
|
||||
"""The MIME type of the input data passed to endpoint"""
|
||||
|
||||
accepts: Optional[str] = "text/plain"
|
||||
"""The MIME type of the response data returned from endpoint"""
|
||||
|
||||
@abstractmethod
|
||||
def transform_input(
|
||||
self, prompt: Union[str, List[str]], model_kwargs: Dict
|
||||
) -> bytes:
|
||||
"""Transforms the input to a format that model can accept
|
||||
as the request Body. Should return bytes or seekable file
|
||||
like object in the format specified in the content_type
|
||||
request header.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def transform_output(self, output: bytes) -> Any:
|
||||
"""Transforms the output from the model to string that
|
||||
the LLM class expects.
|
||||
"""
|
||||
|
||||
|
||||
class SagemakerEndpoint(LLM, BaseModel):
|
||||
"""Wrapper around custom Sagemaker Inference Endpoints.
|
||||
|
||||
To use, you must supply the endpoint name from your deployed
|
||||
Sagemaker model & the region where it is deployed.
|
||||
|
||||
To authenticate, the AWS client uses the following methods to
|
||||
automatically load credentials:
|
||||
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
|
||||
If a specific credential profile should be used, you must pass
|
||||
the name of the profile from the ~/.aws/credentials file that is to be used.
|
||||
|
||||
Make sure the credentials / roles used have the required policies to
|
||||
access the Sagemaker endpoint.
|
||||
See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
|
||||
"""
|
||||
|
||||
"""
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import SagemakerEndpoint
|
||||
endpoint_name = (
|
||||
"my-endpoint-name"
|
||||
)
|
||||
region_name = (
|
||||
"us-west-2"
|
||||
)
|
||||
credentials_profile_name = (
|
||||
"default"
|
||||
)
|
||||
se = SagemakerEndpoint(
|
||||
endpoint_name=endpoint_name,
|
||||
region_name=region_name,
|
||||
credentials_profile_name=credentials_profile_name
|
||||
)
|
||||
"""
|
||||
client: Any #: :meta private:
|
||||
|
||||
endpoint_name: str = ""
|
||||
"""The name of the endpoint from the deployed Sagemaker model.
|
||||
Must be unique within an AWS Region."""
|
||||
|
||||
region_name: str = ""
|
||||
"""The aws region where the Sagemaker model is deployed, eg. `us-west-2`."""
|
||||
|
||||
credentials_profile_name: Optional[str] = None
|
||||
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
|
||||
has either access keys or role information specified.
|
||||
If not specified, the default credential profile or, if on an EC2 instance,
|
||||
credentials from IMDS will be used.
|
||||
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
"""
|
||||
|
||||
content_handler: ContentHandlerBase
|
||||
"""The content handler class that provides an input and
|
||||
output transform functions to handle formats between LLM
|
||||
and the endpoint.
|
||||
"""
|
||||
|
||||
"""
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class ContentHandler(ContentHandlerBase):
|
||||
content_type = "application/json"
|
||||
accepts = "application/json"
|
||||
|
||||
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
input_str = json.dumps({prompt: prompt, **model_kwargs})
|
||||
return input_str.encode('utf-8')
|
||||
|
||||
def transform_output(self, output: bytes) -> str:
|
||||
response_json = json.loads(output.read().decode("utf-8"))
|
||||
return response_json[0]["generated_text"]
|
||||
"""
|
||||
|
||||
model_kwargs: Optional[Dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
endpoint_kwargs: Optional[Dict] = None
|
||||
"""Optional attributes passed to the invoke_endpoint
|
||||
function. See `boto3`_. docs for more info.
|
||||
.. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that AWS credentials to and python package exists in environment."""
|
||||
try:
|
||||
import boto3
|
||||
|
||||
try:
|
||||
if values["credentials_profile_name"] is not None:
|
||||
session = boto3.Session(
|
||||
profile_name=values["credentials_profile_name"]
|
||||
)
|
||||
else:
|
||||
# use default credentials
|
||||
session = boto3.Session()
|
||||
|
||||
values["client"] = session.client(
|
||||
"sagemaker-runtime", region_name=values["region_name"]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not load credentials to authenticate with AWS client. "
|
||||
"Please check that credentials in the specified "
|
||||
"profile name are valid."
|
||||
) from e
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import boto3 python package. "
|
||||
"Please it install it with `pip install boto3`."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
**{"endpoint_name": self.endpoint_name},
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "sagemaker_endpoint"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
"""Call out to Sagemaker inference endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = se("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_endpoint_kwargs = self.endpoint_kwargs or {}
|
||||
|
||||
body = self.content_handler.transform_input(prompt, _model_kwargs)
|
||||
content_type = self.content_handler.content_type
|
||||
accepts = self.content_handler.accepts
|
||||
|
||||
# send request
|
||||
try:
|
||||
response = self.client.invoke_endpoint(
|
||||
EndpointName=self.endpoint_name,
|
||||
Body=body,
|
||||
ContentType=content_type,
|
||||
Accept=accepts,
|
||||
**_endpoint_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
|
||||
text = self.content_handler.transform_output(response["Body"])
|
||||
if stop is not None:
|
||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||
# stop tokens when making calls to the sagemaker endpoint.
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
return text
|
||||
@@ -11,6 +11,7 @@ from langchain.memory.readonly import ReadOnlySharedMemory
|
||||
from langchain.memory.simple import SimpleMemory
|
||||
from langchain.memory.summary import ConversationSummaryMemory
|
||||
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
|
||||
from langchain.memory.token_buffer import ConversationTokenBufferMemory
|
||||
|
||||
__all__ = [
|
||||
"CombinedMemory",
|
||||
@@ -24,4 +25,5 @@ __all__ = [
|
||||
"ChatMessageHistory",
|
||||
"ConversationStringBufferMemory",
|
||||
"ReadOnlySharedMemory",
|
||||
"ConversationTokenBufferMemory",
|
||||
]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user