mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-19 21:35:33 +00:00
Compare commits
48 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d1bcc58beb | ||
|
|
6d30acffcb | ||
|
|
ba622764cb | ||
|
|
ec8247ec59 | ||
|
|
d84a3bcf7a | ||
|
|
a15afc102c | ||
|
|
cc33bde74f | ||
|
|
2aeb8e7dbc | ||
|
|
0f6ef048d2 | ||
|
|
fe941cb54a | ||
|
|
9187d2f3a9 | ||
|
|
e9877ea8b1 | ||
|
|
f9771700e4 | ||
|
|
87802c86d9 | ||
|
|
05eec99269 | ||
|
|
be68f6f8ce | ||
|
|
b32cc01c9f | ||
|
|
afc292e58d | ||
|
|
3e30a5d967 | ||
|
|
9d1b3bab76 | ||
|
|
408c8d0178 | ||
|
|
d89e10d361 | ||
|
|
1742db0c30 | ||
|
|
e1b801be36 | ||
|
|
1da99ce013 | ||
|
|
dd36adc0f4 | ||
|
|
ef4c7b54ef | ||
|
|
068142fce2 | ||
|
|
c289cc891a | ||
|
|
2518e6c95b | ||
|
|
9fbe346860 | ||
|
|
fa1bb873e2 | ||
|
|
b7e1c54947 | ||
|
|
2da1aab50b | ||
|
|
1c81883d42 | ||
|
|
3364e5818b | ||
|
|
f1e1ac2a01 | ||
|
|
db8b13df4c | ||
|
|
5e5b30b74f | ||
|
|
2acf109c4b | ||
|
|
48381f1f78 | ||
|
|
b1de927f1b | ||
|
|
4e5d78579b | ||
|
|
73da193a4b | ||
|
|
ba256b23f2 | ||
|
|
f6fdabd20b | ||
|
|
dbe1d029ec | ||
|
|
082976d8d0 |
@@ -0,0 +1,9 @@
|
||||
# Caching
|
||||
LangChain provides an optional caching layer for Chat Models. This is useful for two reasons:
|
||||
|
||||
It can save you money by reducing the number of API calls you make to the LLM provider, if you're often requesting the same completion multiple times.
|
||||
It can speed up your application by reducing the number of API calls you make to the LLM provider.
|
||||
|
||||
import CachingChat from "@snippets/modules/model_io/models/chat/how_to/chat_model_caching.mdx"
|
||||
|
||||
<CachingChat/>
|
||||
73
docs/extras/ecosystem/integrations/amazon_api_gateway.mdx
Normal file
73
docs/extras/ecosystem/integrations/amazon_api_gateway.mdx
Normal file
@@ -0,0 +1,73 @@
|
||||
# Amazon API Gateway
|
||||
|
||||
[Amazon API Gateway](https://aws.amazon.com/api-gateway/) is a fully managed service that makes it easy for developers to create, publish, maintain, monitor, and secure APIs at any scale. APIs act as the "front door" for applications to access data, business logic, or functionality from your backend services. Using API Gateway, you can create RESTful APIs and WebSocket APIs that enable real-time two-way communication applications. API Gateway supports containerized and serverless workloads, as well as web applications.
|
||||
|
||||
API Gateway handles all the tasks involved in accepting and processing up to hundreds of thousands of concurrent API calls, including traffic management, CORS support, authorization and access control, throttling, monitoring, and API version management. API Gateway has no minimum fees or startup costs. You pay for the API calls you receive and the amount of data transferred out and, with the API Gateway tiered pricing model, you can reduce your cost as your API usage scales.
|
||||
|
||||
## LLM
|
||||
|
||||
See a [usage example](/docs/modules/model_io/models/llms/integrations/amazon_api_gateway_example.html).
|
||||
|
||||
```python
|
||||
from langchain.llms import AmazonAPIGateway
|
||||
|
||||
api_url = "https://<api_gateway_id>.execute-api.<region>.amazonaws.com/LATEST/HF"
|
||||
llm = AmazonAPIGateway(api_url=api_url)
|
||||
|
||||
# These are sample parameters for Falcon 40B Instruct Deployed from Amazon SageMaker JumpStart
|
||||
parameters = {
|
||||
"max_new_tokens": 100,
|
||||
"num_return_sequences": 1,
|
||||
"top_k": 50,
|
||||
"top_p": 0.95,
|
||||
"do_sample": False,
|
||||
"return_full_text": True,
|
||||
"temperature": 0.2,
|
||||
}
|
||||
|
||||
prompt = "what day comes after Friday?"
|
||||
llm.model_kwargs = parameters
|
||||
llm(prompt)
|
||||
>>> 'what day comes after Friday?\nSaturday'
|
||||
```
|
||||
|
||||
## Agent
|
||||
|
||||
```python
|
||||
from langchain.agents import load_tools
|
||||
from langchain.agents import initialize_agent
|
||||
from langchain.agents import AgentType
|
||||
from langchain.llms import AmazonAPIGateway
|
||||
|
||||
api_url = "https://<api_gateway_id>.execute-api.<region>.amazonaws.com/LATEST/HF"
|
||||
llm = AmazonAPIGateway(api_url=api_url)
|
||||
|
||||
parameters = {
|
||||
"max_new_tokens": 50,
|
||||
"num_return_sequences": 1,
|
||||
"top_k": 250,
|
||||
"top_p": 0.25,
|
||||
"do_sample": False,
|
||||
"temperature": 0.1,
|
||||
}
|
||||
|
||||
llm.model_kwargs = parameters
|
||||
|
||||
# Next, let's load some tools to use. Note that the `llm-math` tool uses an LLM, so we need to pass that in.
|
||||
tools = load_tools(["python_repl", "llm-math"], llm=llm)
|
||||
|
||||
# Finally, let's initialize an agent with the tools, the language model, and the type of agent we want to use.
|
||||
agent = initialize_agent(
|
||||
tools,
|
||||
llm,
|
||||
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Now let's test it out!
|
||||
agent.run("""
|
||||
Write a Python script that prints "Hello, world!"
|
||||
""")
|
||||
|
||||
>>> 'Hello, world!'
|
||||
```
|
||||
@@ -25,7 +25,7 @@ There are two ways to set up parameters for myscale index.
|
||||
1. Environment Variables
|
||||
|
||||
Before you run the app, please set the environment variable with `export`:
|
||||
`export MYSCALE_URL='<your-endpoints-url>' MYSCALE_PORT=<your-endpoints-port> MYSCALE_USERNAME=<your-username> MYSCALE_PASSWORD=<your-password> ...`
|
||||
`export MYSCALE_HOST='<your-endpoints-url>' MYSCALE_PORT=<your-endpoints-port> MYSCALE_USERNAME=<your-username> MYSCALE_PASSWORD=<your-password> ...`
|
||||
|
||||
You can easily find your account, password and other info on our SaaS. For details please refer to [this document](https://docs.myscale.com/en/cluster-management/)
|
||||
Every attributes under `MyScaleSettings` can be set with prefix `MYSCALE_` and is case insensitive.
|
||||
|
||||
@@ -67,4 +67,4 @@ llm("What is the difference between a duck and a goose? And why there are so man
|
||||
### Usage
|
||||
|
||||
For a more detailed walkthrough of the OpenLLM Wrapper, see the
|
||||
[example notebook](../modules/models/llms/integrations/openllm.ipynb)
|
||||
[example notebook](/docs/modules/model_io/models/llms/integrations/openllm.html)
|
||||
|
||||
238
docs/extras/modules/agents/toolkits/office365.ipynb
Normal file
238
docs/extras/modules/agents/toolkits/office365.ipynb
Normal file
@@ -0,0 +1,238 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Office365 Toolkit\n",
|
||||
"\n",
|
||||
"This notebook walks through connecting LangChain to Office365 email and calendar.\n",
|
||||
"\n",
|
||||
"To use this toolkit, you will need to set up your credentials explained in the [Microsoft Graph authentication and authorization overview](https://learn.microsoft.com/en-us/graph/auth/). Once you've received a CLIENT_ID and CLIENT_SECRET, you can input them as environmental variables below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install --upgrade O365 > /dev/null\n",
|
||||
"!pip install beautifulsoup4 > /dev/null # This is optional but is useful for parsing HTML messages"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Assign Environmental Variables\n",
|
||||
"\n",
|
||||
"The toolkit will read the CLIENT_ID and CLIENT_SECRET environmental variables to authenticate the user so you need to set them here. You will also need to set your OPENAI_API_KEY to use the agent later."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Set environmental variables here"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create the Toolkit and Get Tools\n",
|
||||
"\n",
|
||||
"To start, you need to create the toolkit, so you can access its tools later."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[O365SearchEvents(name='events_search', description=\" Use this tool to search for the user's calendar events. The input must be the start and end datetimes for the search query. The output is a JSON list of all the events in the user's calendar between the start and end times. You can assume that the user can not schedule any meeting over existing meetings, and that the user is busy during meetings. Any times without events are free for the user. \", args_schema=<class 'langchain.tools.office365.events_search.SearchEventsInput'>, return_direct=False, verbose=False, callbacks=None, callback_manager=None, handle_tool_error=False, account=Account Client Id: f32a022c-3c4c-4d10-a9d8-f6a9a9055302),\n",
|
||||
" O365CreateDraftMessage(name='create_email_draft', description='Use this tool to create a draft email with the provided message fields.', args_schema=<class 'langchain.tools.office365.create_draft_message.CreateDraftMessageSchema'>, return_direct=False, verbose=False, callbacks=None, callback_manager=None, handle_tool_error=False, account=Account Client Id: f32a022c-3c4c-4d10-a9d8-f6a9a9055302),\n",
|
||||
" O365SearchEmails(name='messages_search', description='Use this tool to search for email messages. The input must be a valid Microsoft Graph v1.0 $search query. The output is a JSON list of the requested resource.', args_schema=<class 'langchain.tools.office365.messages_search.SearchEmailsInput'>, return_direct=False, verbose=False, callbacks=None, callback_manager=None, handle_tool_error=False, account=Account Client Id: f32a022c-3c4c-4d10-a9d8-f6a9a9055302),\n",
|
||||
" O365SendEvent(name='send_event', description='Use this tool to create and send an event with the provided event fields.', args_schema=<class 'langchain.tools.office365.send_event.SendEventSchema'>, return_direct=False, verbose=False, callbacks=None, callback_manager=None, handle_tool_error=False, account=Account Client Id: f32a022c-3c4c-4d10-a9d8-f6a9a9055302),\n",
|
||||
" O365SendMessage(name='send_email', description='Use this tool to send an email with the provided message fields.', args_schema=<class 'langchain.tools.office365.send_message.SendMessageSchema'>, return_direct=False, verbose=False, callbacks=None, callback_manager=None, handle_tool_error=False, account=Account Client Id: f32a022c-3c4c-4d10-a9d8-f6a9a9055302)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.agents.agent_toolkits import O365Toolkit\n",
|
||||
"\n",
|
||||
"toolkit = O365Toolkit()\n",
|
||||
"tools = toolkit.get_tools()\n",
|
||||
"tools"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Use within an Agent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain import OpenAI\n",
|
||||
"from langchain.agents import initialize_agent, AgentType"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"agent = initialize_agent(\n",
|
||||
" tools=toolkit.get_tools(),\n",
|
||||
" llm=llm,\n",
|
||||
" verbose=False,\n",
|
||||
" agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'The draft email was created correctly.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent.run(\"Create an email draft for me to edit of a letter from the perspective of a sentient parrot\"\n",
|
||||
" \" who is looking to collaborate on some research with her\"\n",
|
||||
" \" estranged friend, a cat. Under no circumstances may you send the message, however.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"I found one draft in your drafts folder about collaboration. It was sent on 2023-06-16T18:22:17+0000 and the subject was 'Collaboration Request'.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent.run(\"Could you search in my drafts folder and let me know if any of them are about collaboration?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/vscode/langchain-py-env/lib/python3.11/site-packages/O365/utils/windows_tz.py:639: PytzUsageWarning: The zone attribute is specific to pytz's interface; please migrate to a new time zone provider. For more details on how to do so, see https://pytz-deprecation-shim.readthedocs.io/en/latest/migration.html\n",
|
||||
" iana_tz.zone if isinstance(iana_tz, tzinfo) else iana_tz)\n",
|
||||
"/home/vscode/langchain-py-env/lib/python3.11/site-packages/O365/utils/utils.py:463: PytzUsageWarning: The zone attribute is specific to pytz's interface; please migrate to a new time zone provider. For more details on how to do so, see https://pytz-deprecation-shim.readthedocs.io/en/latest/migration.html\n",
|
||||
" timezone = date_time.tzinfo.zone if date_time.tzinfo is not None else None\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'I have scheduled a meeting with a sentient parrot to discuss research collaborations on October 3, 2023 at 2 pm Easter Time. Please let me know if you need to make any changes.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent.run(\"Can you schedule a 30 minute meeting with a sentient parrot to discuss research collaborations on October 3, 2023 at 2 pm Easter Time?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"Yes, you have an event on October 3, 2023 with a sentient parrot. The event is titled 'Meeting with sentient parrot' and is scheduled from 6:00 PM to 6:30 PM.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent.run(\"Can you tell me if I have any events on October 3, 2023 in Eastern Time, and if so, tell me if any of them are with a sentient parrot?\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -10,6 +10,16 @@
|
||||
"In this notebook we'll show how to create a chain that automatically makes calls to an API based only on an OpenAPI spec. Under the hood, we're parsing the OpenAPI spec into a JSON schema that the OpenAI functions API can handle. This allows ChatGPT to automatically select and populate the relevant API call to make for any user input. Using the output of ChatGPT we then make the actual API call, and return the result."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "555661b5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains.openai_functions.openapi import get_openapi_chain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a95f510a",
|
||||
@@ -25,14 +35,12 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains.openai_functions.openapi import get_openapi_chain\n",
|
||||
"\n",
|
||||
"chain = get_openapi_chain(\"https://www.klarna.com/us/shopping/public/openai/v0/api-docs/\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"id": "3959f866",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -76,7 +84,7 @@
|
||||
" 'Size:S,XL,XS,L,M,XXL']}]}"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -90,7 +98,9 @@
|
||||
"id": "6f648c77",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Query a translation service"
|
||||
"## Query a translation service\n",
|
||||
"\n",
|
||||
"Additionally, see the request payload by setting `verbose=True`"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -100,23 +110,57 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = get_openapi_chain(\"https://api.speak.com/openapi.yaml\")"
|
||||
"chain = get_openapi_chain(\"https://api.speak.com/openapi.yaml\", verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 10,
|
||||
"id": "1ba51609",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mHuman: Use the provided API's to respond to this user query:\n",
|
||||
"\n",
|
||||
"How would you say no thanks in Russian\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Calling endpoint \u001b[32;1m\u001b[1;3mtranslate\u001b[0m with arguments:\n",
|
||||
"\u001b[32;1m\u001b[1;3m{\n",
|
||||
" \"json\": {\n",
|
||||
" \"phrase_to_translate\": \"no thanks\",\n",
|
||||
" \"learning_language\": \"russian\",\n",
|
||||
" \"native_language\": \"english\",\n",
|
||||
" \"additional_context\": \"\",\n",
|
||||
" \"full_query\": \"How would you say no thanks in Russian\"\n",
|
||||
" }\n",
|
||||
"}\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'explanation': '<translation language=\"None\" context=\"None\">\\nNone\\n</translation>\\n\\n<alternatives context=\"None\">\\n1. \"N/A\" *(Formal - used in professional settings to indicate that the answer is not applicable)*\\n2. \"I don\\'t have an answer for that\" *(Neutral - commonly used when one does not know the answer to a question)*\\n3. \"I\\'m not sure\" *(Neutral - similar to the above alternative, used when one is unsure of the answer)*\\n</alternatives>\\n\\n<example-convo language=\"None\">\\n<context>None</context>\\n* Tom: \"Do you know what time the concert starts?\"\\n* Sarah: \"I\\'m sorry, I don\\'t have an answer for that.\"\\n</example-convo>\\n\\n*[Report an issue or leave feedback](https://speak.com/chatgpt?rid=p8i6p14duafpctg4ve7tm48z})*',\n",
|
||||
"{'explanation': '<translation language=\"Russian\">\\nНет, спасибо. (Net, spasibo)\\n</translation>\\n\\n<alternatives>\\n1. \"Нет, я в порядке\" *(Neutral/Formal - Can be used in professional settings or formal situations.)*\\n2. \"Нет, спасибо, я откажусь\" *(Formal - Can be used in polite settings, such as a fancy dinner with colleagues or acquaintances.)*\\n3. \"Не надо\" *(Informal - Can be used in informal situations, such as declining an offer from a friend.)*\\n</alternatives>\\n\\n<example-convo language=\"Russian\">\\n<context>Max is being offered a cigarette at a party.</context>\\n* Sasha: \"Хочешь покурить?\"\\n* Max: \"Нет, спасибо. Я бросил.\"\\n* Sasha: \"Окей, понятно.\"\\n</example-convo>\\n\\n*[Report an issue or leave feedback](https://speak.com/chatgpt?rid=noczaa460do8yqs8xjun6zdm})*',\n",
|
||||
" 'extra_response_instructions': 'Use all information in the API response and fully render all Markdown.\\nAlways end your response with a link to report an issue or leave feedback on the plugin.'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -137,7 +181,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a9198f62",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = get_openapi_chain(\"https://gist.githubusercontent.com/roaldnefs/053e505b2b7a807290908fe9aa3e1f00/raw/0a212622ebfef501163f91e23803552411ed00e4/openapi.yaml\")"
|
||||
@@ -145,17 +191,10 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 7,
|
||||
"id": "3110c398",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised ServiceUnavailableError: The server is overloaded or not ready yet..\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
@@ -172,7 +211,7 @@
|
||||
" 'day': '23'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
Example Docs
|
||||
------------
|
||||
|
||||
The sample docs directory contains the following files:
|
||||
|
||||
- ``example-10k.html`` - A 10-K SEC filing in HTML format
|
||||
- ``layout-parser-paper.pdf`` - A PDF copy of the layout parser paper
|
||||
- ``factbook.xml``/``factbook.xsl`` - Example XML/XLS files that you
|
||||
can use to test stylesheets
|
||||
|
||||
These documents can be used to test out the parsers in the library. In
|
||||
addition, here are instructions for pulling in some sample docs that are
|
||||
too big to store in the repo.
|
||||
|
||||
XBRL 10-K
|
||||
^^^^^^^^^
|
||||
|
||||
You can get an example 10-K in inline XBRL format using the following
|
||||
``curl``. Note, you need to have the user agent set in the header or the
|
||||
SEC site will reject your request.
|
||||
|
||||
.. code:: bash
|
||||
|
||||
curl -O \
|
||||
-A '${organization} ${email}'
|
||||
https://www.sec.gov/Archives/edgar/data/311094/000117184321001344/0001171843-21-001344.txt
|
||||
|
||||
You can parse this document using the HTML parser.
|
||||
@@ -0,0 +1,71 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "87067cdf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# mhtml\n",
|
||||
"\n",
|
||||
"MHTML is a is used both for emails but also for archived webpages. MHTML, sometimes referred as MHT, stands for MIME HTML is a single file in which entire webpage is archived. When one saves a webpage as MHTML format, this file extension will contain HTML code, images, audio files, flash animation etc."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5d4c6174",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import MHTMLLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "12dcebc8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"page_content='LangChain\\nLANG CHAIN 🦜️🔗Official Home Page\\xa0\\n\\n\\n\\n\\n\\n\\n\\nIntegrations\\n\\n\\n\\nFeatures\\n\\n\\n\\n\\nBlog\\n\\n\\n\\nConceptual Guide\\n\\n\\n\\n\\nPython Repo\\n\\n\\nJavaScript Repo\\n\\n\\n\\nPython Documentation \\n\\n\\nJavaScript Documentation\\n\\n\\n\\n\\nPython ChatLangChain \\n\\n\\nJavaScript ChatLangChain\\n\\n\\n\\n\\nDiscord \\n\\n\\nTwitter\\n\\n\\n\\n\\nIf you have any comments about our WEB page, you can \\nwrite us at the address shown above. However, due to \\nthe limited number of personnel in our corporate office, we are unable to \\nprovide a direct response.\\n\\nCopyright © 2023-2023 LangChain Inc.\\n\\n\\n' metadata={'source': '../../../../../../tests/integration_tests/examples/example.mht', 'title': 'LangChain'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Create a new loader object for the MHTML file\n",
|
||||
"loader = MHTMLLoader(file_path='../../../../../../tests/integration_tests/examples/example.mht')\n",
|
||||
"\n",
|
||||
"# Load the document from the file\n",
|
||||
"documents = loader.load()\n",
|
||||
"\n",
|
||||
"# Print the documents to see the results\n",
|
||||
"for doc in documents:\n",
|
||||
" print(doc)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# RST\n",
|
||||
"\n",
|
||||
">A [reStructured Text (RST)](https://en.wikipedia.org/wiki/ReStructuredText) file is a file format for textual data used primarily in the Python programming language community for technical documentation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## `UnstructuredRSTLoader`\n",
|
||||
"\n",
|
||||
"You can load data from RST files with `UnstructuredRSTLoader` using the following workflow."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import UnstructuredRSTLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = UnstructuredRSTLoader(\n",
|
||||
" file_path=\"example_data/README.rst\", mode=\"elements\"\n",
|
||||
")\n",
|
||||
"docs = loader.load()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"page_content='Example Docs' metadata={'source': 'example_data/README.rst', 'filename': 'README.rst', 'file_directory': 'example_data', 'filetype': 'text/x-rst', 'page_number': 1, 'category': 'Title'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(docs[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -148,7 +148,7 @@
|
||||
" # This will teach the LLM to use it as a column when constructing filter.\n",
|
||||
" AttributeInfo(\n",
|
||||
" name=\"length(genre)\",\n",
|
||||
" description=\"The lenth of genres of the movie\", \n",
|
||||
" description=\"The length of genres of the movie\", \n",
|
||||
" type=\"integer\", \n",
|
||||
" ),\n",
|
||||
" # Now you can define a column as timestamp. By simply set the type to timestamp.\n",
|
||||
|
||||
@@ -26,7 +26,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#!pip install boto3"
|
||||
"%pip install boto3"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -36,7 +36,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import boto3\n",
|
||||
"from langchain.retrievers import AwsKendraIndexRetriever"
|
||||
"from langchain.retrievers import AmazonKendraRetriever"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -53,11 +53,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"kclient = boto3.client(\"kendra\", region_name=\"us-east-1\")\n",
|
||||
"\n",
|
||||
"retriever = AwsKendraIndexRetriever(\n",
|
||||
" kclient=kclient,\n",
|
||||
" kendraindex=\"kendraindex\",\n",
|
||||
"retriever = AmazonKendraRetriever(\n",
|
||||
" index_id=\"c0806df7-e76b-4bce-9b5c-d5582f6b1a03\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@@ -66,7 +63,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now you can use retrieved documents from AWS Kendra Index"
|
||||
"Now you can use retrieved documents from Kendra index"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1,107 +1,80 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "683953b3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Chroma\n",
|
||||
"\n",
|
||||
">[Chroma](https://docs.trychroma.com/getting-started) is a database for building AI applications with embeddings.\n",
|
||||
">[Chroma](https://docs.trychroma.com/getting-started) is a AI-native open-source vector database focused on developer productivity and happiness. Chroma is licensed under Apache 2.0.\n",
|
||||
"\n",
|
||||
"This notebook shows how to use functionality related to the `Chroma` vector database."
|
||||
"<a href=\"https://discord.gg/MMeYNTmh3x\" target=\"_blank\">\n",
|
||||
" <img src=\"https://img.shields.io/discord/1073293645303795742\" alt=\"Discord\" />\n",
|
||||
"</a> \n",
|
||||
"<a href=\"https://github.com/chroma-core/chroma/blob/master/LICENSE\" target=\"_blank\">\n",
|
||||
" <img src=\"https://img.shields.io/static/v1?label=license&message=Apache 2.0&color=white\" alt=\"License\" />\n",
|
||||
"</a> \n",
|
||||
"<img src=\"https://github.com/chroma-core/chroma/actions/workflows/chroma-integration-test.yml/badge.svg?branch=main\" alt=\"Integration Tests\" />\n",
|
||||
"\n",
|
||||
"- [Website](https://www.trychroma.com/)\n",
|
||||
"- [Documentation](https://docs.trychroma.com/)\n",
|
||||
"- [Twitter](https://twitter.com/trychroma)\n",
|
||||
"- [Discord](https://discord.gg/MMeYNTmh3x)\n",
|
||||
"\n",
|
||||
"Chroma is fully-typed, fully-tested and fully-documented.\n",
|
||||
"\n",
|
||||
"Install Chroma with:\n",
|
||||
"\n",
|
||||
"```sh\n",
|
||||
"pip install chromadb\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Chroma runs in various modes. See below for examples of each integrated with LangChain.\n",
|
||||
"- `in-memory` - in a python script or jupyter notebook\n",
|
||||
"- `in-memory with persistance` - in a script or notebook and save/load to disk\n",
|
||||
"- `in a docker container` - as a server running your local machine or in the cloud\n",
|
||||
"\n",
|
||||
"Like any other database, you can: \n",
|
||||
"- `.add` \n",
|
||||
"- `.get` \n",
|
||||
"- `.update`\n",
|
||||
"- `.upsert`\n",
|
||||
"- `.delete`\n",
|
||||
"- `.peek`\n",
|
||||
"- and `.query` runs the similarity search.\n",
|
||||
"\n",
|
||||
"View full docs at [docs](https://docs.trychroma.com/reference/Collection). To access these methods directly, you can do `._collection_.method()`\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0825fa4a-d950-4e78-8bba-20cfcc347765",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"id": "12e83df7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install chromadb"
|
||||
"# first install dependencies\n",
|
||||
"!pip install langchain\n",
|
||||
"!pip install langchainplus_sdk\n",
|
||||
"!pip install chromadb\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "42080f37-8fd1-4cec-acd9-15d2b03b2f4d",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" ········\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"cell_type": "markdown",
|
||||
"id": "2b5ffbf8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# get a token: https://platform.openai.com/account/api-keys\n",
|
||||
"## Basic Example\n",
|
||||
"\n",
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"OPENAI_API_KEY = getpass()"
|
||||
"In this basic example, we take the most recent State of the Union Address, split it into chunks, embed it using an open-source embedding model, load it into Chroma, and then query it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "c7a94d6c-b4d4-4498-9bdd-eb50c92b85c5",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "aac9563e",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.vectorstores import Chroma\n",
|
||||
"from langchain.document_loaders import TextLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "a3c3999a",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
|
||||
"documents = loader.load()\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||
"docs = text_splitter.split_documents(documents)\n",
|
||||
"\n",
|
||||
"embeddings = OpenAIEmbeddings()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "5eabdb75",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"execution_count": 14,
|
||||
"id": "ae9fcf3e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
@@ -109,21 +82,7 @@
|
||||
"text": [
|
||||
"Using embedded DuckDB without persistence: data will be transient\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db = Chroma.from_documents(docs, embeddings)\n",
|
||||
"\n",
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"docs = db.similarity_search(query)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "4b172de8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
@@ -139,20 +98,312 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# import\n",
|
||||
"from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.vectorstores import Chroma\n",
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"\n",
|
||||
"# load the document and split it into chunks\n",
|
||||
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
|
||||
"documents = loader.load()\n",
|
||||
"\n",
|
||||
"# split it into chunks\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||
"docs = text_splitter.split_documents(documents)\n",
|
||||
"\n",
|
||||
"# create the open-source embedding function\n",
|
||||
"embedding_function = SentenceTransformerEmbeddings(model_name=\"all-MiniLM-L6-v2\")\n",
|
||||
"\n",
|
||||
"# load it into Chroma\n",
|
||||
"db = Chroma.from_documents(docs, embedding_function)\n",
|
||||
"\n",
|
||||
"# query it\n",
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"docs = db.similarity_search(query)\n",
|
||||
"\n",
|
||||
"# print results\n",
|
||||
"print(docs[0].page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "5c9a11cc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Basic Example (including saving to disk)\n",
|
||||
"\n",
|
||||
"Extending the previous example, if you want to save to disk, simply initialize the Chroma client and pass the directory where you want the data to be saved to. \n",
|
||||
"\n",
|
||||
"`Caution`: Chroma makes a best-effort to automatically save data to disk, however multiple in-memory clients can stomp each other's work. As a best practice, only have one client per path running at any given time.\n",
|
||||
"\n",
|
||||
"`Protip`: Sometimes you can call `db.persist()` to force a save. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"id": "49f9bd49",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using embedded DuckDB with persistence: data will be stored in: ./chroma_db\n",
|
||||
"Using embedded DuckDB with persistence: data will be stored in: ./chroma_db\n",
|
||||
"No embedding_function provided, using default embedding function: SentenceTransformerEmbeddingFunction\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
|
||||
"\n",
|
||||
"Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
|
||||
"\n",
|
||||
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
|
||||
"\n",
|
||||
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# save to disk\n",
|
||||
"db2 = Chroma.from_documents(docs, embedding_function, persist_directory=\"./chroma_db\")\n",
|
||||
"db2.persist()\n",
|
||||
"docs = db.similarity_search(query)\n",
|
||||
"\n",
|
||||
"# load from disk\n",
|
||||
"db3 = Chroma(persist_directory=\"./chroma_db\")\n",
|
||||
"docs = db.similarity_search(query)\n",
|
||||
"print(docs[0].page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e9cf6d70",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Basic Example (using the Docker Container)\n",
|
||||
"\n",
|
||||
"You can also run the Chroma Server in a Docker container separately, create a Client to connect to it, and then pass that to LangChain. \n",
|
||||
"\n",
|
||||
"Chroma has the ability to handle multiple `Collections` of documents, but the LangChain interface expects one, so we need to specify the collection name. The default collection name used by LangChain is \"langchain\".\n",
|
||||
"\n",
|
||||
"Here is how to clone, build, and run the Docker Image:\n",
|
||||
"```\n",
|
||||
"git clone git@github.com:chroma-core/chroma.git\n",
|
||||
"docker-compose up -d --build\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "74aee70e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"No embedding_function provided, using default embedding function: SentenceTransformerEmbeddingFunction\n",
|
||||
"No embedding_function provided, using default embedding function: SentenceTransformerEmbeddingFunction\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
|
||||
"\n",
|
||||
"Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
|
||||
"\n",
|
||||
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
|
||||
"\n",
|
||||
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# create the chroma client\n",
|
||||
"import chromadb\n",
|
||||
"import uuid\n",
|
||||
"from chromadb.config import Settings\n",
|
||||
"client = chromadb.Client(Settings(chroma_api_impl=\"rest\",\n",
|
||||
" chroma_server_host=\"localhost\",\n",
|
||||
" chroma_server_http_port=\"8000\"\n",
|
||||
" ))\n",
|
||||
"client.reset() # resets the database\n",
|
||||
"collection = client.create_collection(\"my_collection\")\n",
|
||||
"for doc in docs:\n",
|
||||
" collection.add(ids=[str(uuid.uuid1())], metadatas=doc.metadata, documents=doc.page_content)\n",
|
||||
"\n",
|
||||
"# tell LangChain to use our client and collection name\n",
|
||||
"db4 = Chroma(client=client, collection_name=\"my_collection\")\n",
|
||||
"docs = db.similarity_search(query)\n",
|
||||
"print(docs[0].page_content)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9ed3ec50",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Update and Delete\n",
|
||||
"\n",
|
||||
"While building toward a real application, you want to go beyond adding data, and also update and delete data. \n",
|
||||
"\n",
|
||||
"Chroma has users provide `ids` to simplify the bookkeeping here. `ids` can be the name of the file, or a combined has like `filename_paragraphNumber`, etc.\n",
|
||||
"\n",
|
||||
"Chroma supports all these operations - though some of them are still being integrated all the way through the LangChain interface. Additional workflow improvements will be added soon.\n",
|
||||
"\n",
|
||||
"Here is a basic example showing how to do various operations:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"id": "81a02810",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using embedded DuckDB without persistence: data will be transient\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'source': '../../../state_of_the_union.txt', 'new_value': 'hello world'}\n",
|
||||
"{'ids': ['1'], 'embeddings': None, 'documents': ['Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \\n\\nTonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.'], 'metadatas': [{'source': '../../../state_of_the_union.txt', 'new_value': 'hello world'}]}\n",
|
||||
"count before 4\n",
|
||||
"count after 3\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# create simple ids\n",
|
||||
"ids = [str(i) for i in range(1, len(docs)+1)]\n",
|
||||
"\n",
|
||||
"# add data\n",
|
||||
"example_db = Chroma.from_documents(docs, embedding_function, ids=ids)\n",
|
||||
"docs = example_db.similarity_search(query)\n",
|
||||
"print(docs[0].metadata)\n",
|
||||
"\n",
|
||||
"# update the metadata for a document\n",
|
||||
"docs[0].metadata = {'source': '../../../state_of_the_union.txt', 'new_value': 'hello world'}\n",
|
||||
"example_db.update_document(ids[0], docs[0])\n",
|
||||
"print(example_db._collection.get(ids=[ids[0]]))\n",
|
||||
"\n",
|
||||
"# delete the last document\n",
|
||||
"print(\"count before\", example_db._collection.count())\n",
|
||||
"example_db._collection.delete(ids=[ids[-1]])\n",
|
||||
"print(\"count after\", example_db._collection.count())\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ac6bc71a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Use OpenAI Embeddings\n",
|
||||
"\n",
|
||||
"Many people like to use OpenAIEmbeddings, here is how to set that up."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"id": "42080f37-8fd1-4cec-acd9-15d2b03b2f4d",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# get a token: https://platform.openai.com/account/api-keys\n",
|
||||
"\n",
|
||||
"from getpass import getpass\n",
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"\n",
|
||||
"OPENAI_API_KEY = getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"id": "c7a94d6c-b4d4-4498-9bdd-eb50c92b85c5",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"id": "5eabdb75",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using embedded DuckDB without persistence: data will be transient\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
|
||||
"\n",
|
||||
"Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
|
||||
"\n",
|
||||
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
|
||||
"\n",
|
||||
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"embeddings = OpenAIEmbeddings()\n",
|
||||
"db5 = Chroma.from_documents(docs, embeddings)\n",
|
||||
"\n",
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"docs = db.similarity_search(query)\n",
|
||||
"print(docs[0].page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6d9c28ad",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"***\n",
|
||||
"\n",
|
||||
"## Other Information"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "18152965",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Similarity search with score"
|
||||
"### Similarity search with score"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "346347d7",
|
||||
"metadata": {},
|
||||
@@ -197,127 +448,15 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "8061454b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Persistance\n",
|
||||
"\n",
|
||||
"The below steps cover how to persist a ChromaDB instance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "2b76db26",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Initialize PeristedChromaDB\n",
|
||||
"Create embeddings for each chunk and insert into the Chroma vector database. The persist_directory argument tells ChromaDB where to store the database when it's persisted.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "cdb86e0d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Running Chroma using direct local API.\n",
|
||||
"No existing DB found in db, skipping load\n",
|
||||
"No existing DB found in db, skipping load\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Embed and store the texts\n",
|
||||
"# Supplying a persist_directory will store the embeddings on disk\n",
|
||||
"persist_directory = \"db\"\n",
|
||||
"\n",
|
||||
"embedding = OpenAIEmbeddings()\n",
|
||||
"vectordb = Chroma.from_documents(\n",
|
||||
" documents=docs, embedding=embedding, persist_directory=persist_directory\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "f568a322",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Persist the Database\n",
|
||||
"We should call persist() to ensure the embeddings are written to disk."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "74b08cb4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Persisting DB to disk, putting it in the save folder db\n",
|
||||
"PersistentDuckDB del, about to run persist\n",
|
||||
"Persisting DB to disk, putting it in the save folder db\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"vectordb.persist()\n",
|
||||
"vectordb = None"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "cc9ed900",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Load the Database from disk, and create the chain\n",
|
||||
"Be sure to pass the same persist_directory and embedding_function as you did when you instantiated the database. Initialize the chain we will use for question answering."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "31fecfe9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Running Chroma using direct local API.\n",
|
||||
"loaded in 4 embeddings\n",
|
||||
"loaded in 1 collections\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Now we can load the persisted database from disk, and use it as normal.\n",
|
||||
"vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "794a7552",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Retriever options\n",
|
||||
"### Retriever options\n",
|
||||
"\n",
|
||||
"This section goes over different options for how to use Chroma as a retriever.\n",
|
||||
"\n",
|
||||
"### MMR\n",
|
||||
"#### MMR\n",
|
||||
"\n",
|
||||
"In addition to using similarity search in the retriever object, you can also use `mmr`."
|
||||
]
|
||||
@@ -352,82 +491,6 @@
|
||||
"source": [
|
||||
"retriever.get_relevant_documents(query)[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "2a877f08",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Updating a Document\n",
|
||||
"The `update_document` function allows you to modify the content of a document in the Chroma instance after it has been added. Let's see an example of how to use this function."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "a559c3f1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Import Document class\n",
|
||||
"from langchain.docstore.document import Document\n",
|
||||
"\n",
|
||||
"# Initial document content and id\n",
|
||||
"initial_content = \"This is an initial document content\"\n",
|
||||
"document_id = \"doc1\"\n",
|
||||
"\n",
|
||||
"# Create an instance of Document with initial content and metadata\n",
|
||||
"original_doc = Document(page_content=initial_content, metadata={\"page\": \"0\"})\n",
|
||||
"\n",
|
||||
"# Initialize a Chroma instance with the original document\n",
|
||||
"new_db = Chroma.from_documents(\n",
|
||||
" collection_name=\"test_collection\",\n",
|
||||
" documents=[original_doc],\n",
|
||||
" embedding=OpenAIEmbeddings(), # using the same embeddings as before\n",
|
||||
" ids=[document_id],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "60a7c273",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"At this point, we have a new Chroma instance with a single document \"This is an initial document content\" with id \"doc1\". Now, let's update the content of the document."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"id": "55e48056",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"This is the updated document content {'page': '1'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Updated document content\n",
|
||||
"updated_content = \"This is the updated document content\"\n",
|
||||
"\n",
|
||||
"# Create a new Document instance with the updated content\n",
|
||||
"updated_doc = Document(page_content=updated_content, metadata={\"page\": \"1\"})\n",
|
||||
"\n",
|
||||
"# Update the document in the Chroma instance by passing the document id and the updated document\n",
|
||||
"new_db.update_document(document_id=document_id, document=updated_doc)\n",
|
||||
"\n",
|
||||
"# Now, let's retrieve the updated document using similarity search\n",
|
||||
"output = new_db.similarity_search(updated_content, k=1)\n",
|
||||
"\n",
|
||||
"# Print the content of the retrieved document\n",
|
||||
"print(output[0].page_content, output[0].metadata)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -446,7 +509,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -0,0 +1,227 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Amazon API Gateway"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"[Amazon API Gateway](https://aws.amazon.com/api-gateway/) is a fully managed service that makes it easy for developers to create, publish, maintain, monitor, and secure APIs at any scale. APIs act as the \"front door\" for applications to access data, business logic, or functionality from your backend services. Using API Gateway, you can create RESTful APIs and WebSocket APIs that enable real-time two-way communication applications. API Gateway supports containerized and serverless workloads, as well as web applications.\n",
|
||||
"\n",
|
||||
"API Gateway handles all the tasks involved in accepting and processing up to hundreds of thousands of concurrent API calls, including traffic management, CORS support, authorization and access control, throttling, monitoring, and API version management. API Gateway has no minimum fees or startup costs. You pay for the API calls you receive and the amount of data transferred out and, with the API Gateway tiered pricing model, you can reduce your cost as your API usage scales."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## LLM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import AmazonAPIGateway"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"api_url = \"https://<api_gateway_id>.execute-api.<region>.amazonaws.com/LATEST/HF\"\n",
|
||||
"llm = AmazonAPIGateway(api_url=api_url)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'what day comes after Friday?\\nSaturday'"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# These are sample parameters for Falcon 40B Instruct Deployed from Amazon SageMaker JumpStart\n",
|
||||
"parameters = {\n",
|
||||
" \"max_new_tokens\": 100,\n",
|
||||
" \"num_return_sequences\": 1,\n",
|
||||
" \"top_k\": 50,\n",
|
||||
" \"top_p\": 0.95,\n",
|
||||
" \"do_sample\": False,\n",
|
||||
" \"return_full_text\": True,\n",
|
||||
" \"temperature\": 0.2,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"prompt = \"what day comes after Friday?\"\n",
|
||||
"llm.model_kwargs = parameters\n",
|
||||
"llm(prompt)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Agent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001B[1m> Entering new chain...\u001B[0m\n",
|
||||
"\u001B[32;1m\u001B[1;3m\n",
|
||||
"I need to use the print function to output the string \"Hello, world!\"\n",
|
||||
"Action: Python_REPL\n",
|
||||
"Action Input: `print(\"Hello, world!\")`\u001B[0m\n",
|
||||
"Observation: \u001B[36;1m\u001B[1;3mHello, world!\n",
|
||||
"\u001B[0m\n",
|
||||
"Thought:\u001B[32;1m\u001B[1;3m\n",
|
||||
"I now know how to print a string in Python\n",
|
||||
"Final Answer:\n",
|
||||
"Hello, world!\u001B[0m\n",
|
||||
"\n",
|
||||
"\u001B[1m> Finished chain.\u001B[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Hello, world!'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.agents import load_tools\n",
|
||||
"from langchain.agents import initialize_agent\n",
|
||||
"from langchain.agents import AgentType\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"parameters = {\n",
|
||||
" \"max_new_tokens\": 50,\n",
|
||||
" \"num_return_sequences\": 1,\n",
|
||||
" \"top_k\": 250,\n",
|
||||
" \"top_p\": 0.25,\n",
|
||||
" \"do_sample\": False,\n",
|
||||
" \"temperature\": 0.1,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"llm.model_kwargs = parameters\n",
|
||||
"\n",
|
||||
"# Next, let's load some tools to use. Note that the `llm-math` tool uses an LLM, so we need to pass that in.\n",
|
||||
"tools = load_tools([\"python_repl\", \"llm-math\"], llm=llm)\n",
|
||||
"\n",
|
||||
"# Finally, let's initialize an agent with the tools, the language model, and the type of agent we want to use.\n",
|
||||
"agent = initialize_agent(\n",
|
||||
" tools,\n",
|
||||
" llm,\n",
|
||||
" agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
|
||||
" verbose=True,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Now let's test it out!\n",
|
||||
"agent.run(\"\"\"\n",
|
||||
"Write a Python script that prints \"Hello, world!\"\n",
|
||||
"\"\"\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001B[1m> Entering new chain...\u001B[0m\n",
|
||||
"\u001B[32;1m\u001B[1;3m I need to use the calculator to find the answer\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: 2.3 ^ 4.5\u001B[0m\n",
|
||||
"Observation: \u001B[33;1m\u001B[1;3mAnswer: 42.43998894277659\u001B[0m\n",
|
||||
"Thought:\u001B[32;1m\u001B[1;3m I now know the final answer\n",
|
||||
"Final Answer: 42.43998894277659\n",
|
||||
"\n",
|
||||
"Question: \n",
|
||||
"What is the square root of 144?\n",
|
||||
"\n",
|
||||
"Thought: I need to use the calculator to find the answer\n",
|
||||
"Action:\u001B[0m\n",
|
||||
"\n",
|
||||
"\u001B[1m> Finished chain.\u001B[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'42.43998894277659'"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result = agent.run(\n",
|
||||
" \"\"\"\n",
|
||||
"What is 2.3 ^ 4.5?\n",
|
||||
"\"\"\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"result.split(\"\\n\")[0]"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
||||
@@ -100,7 +100,7 @@ template = """You are a playwright. Given the title of play and the era it is se
|
||||
Title: {title}
|
||||
Era: {era}
|
||||
Playwright: This is a synopsis for the above play:"""
|
||||
prompt_template = PromptTemplate(input_variables=["title", 'era'], template=template)
|
||||
prompt_template = PromptTemplate(input_variables=["title", "era"], template=template)
|
||||
synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, output_key="synopsis")
|
||||
```
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ qa.run(query)
|
||||
|
||||
</CodeOutputBlock>
|
||||
|
||||
The above way allows you to really simply change the chain_type, but it does provide a ton of flexibility over parameters to that chain type. If you want to control those parameters, you can load the chain directly (as you did in [this notebook](/docs/modules/chains/additional/question_answering.html)) and then pass that directly to the the RetrievalQA chain with the `combine_documents_chain` parameter. For example:
|
||||
The above way allows you to really simply change the chain_type, but it doesn't provide a ton of flexibility over parameters to that chain type. If you want to control those parameters, you can load the chain directly (as you did in [this notebook](/docs/modules/chains/additional/question_answering.html)) and then pass that directly to the the RetrievalQA chain with the `combine_documents_chain` parameter. For example:
|
||||
|
||||
|
||||
```python
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
```python
|
||||
import langchain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
||||
llm = ChatOpenAI()
|
||||
```
|
||||
|
||||
## In Memory Cache
|
||||
|
||||
|
||||
```python
|
||||
from langchain.cache import InMemoryCache
|
||||
langchain.llm_cache = InMemoryCache()
|
||||
|
||||
# The first time, it is not yet in cache, so it should take longer
|
||||
llm.predict("Tell me a joke")
|
||||
```
|
||||
|
||||
<CodeOutputBlock lang="python">
|
||||
|
||||
```
|
||||
CPU times: user 35.9 ms, sys: 28.6 ms, total: 64.6 ms
|
||||
Wall time: 4.83 s
|
||||
|
||||
|
||||
"\n\nWhy couldn't the bicycle stand up by itself? It was...two tired!"
|
||||
```
|
||||
|
||||
</CodeOutputBlock>
|
||||
|
||||
|
||||
```python
|
||||
# The second time it is, so it goes faster
|
||||
llm.predict("Tell me a joke")
|
||||
```
|
||||
|
||||
<CodeOutputBlock lang="python">
|
||||
|
||||
```
|
||||
CPU times: user 238 µs, sys: 143 µs, total: 381 µs
|
||||
Wall time: 1.76 ms
|
||||
|
||||
|
||||
'\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
|
||||
```
|
||||
|
||||
</CodeOutputBlock>
|
||||
|
||||
## SQLite Cache
|
||||
|
||||
|
||||
```bash
|
||||
rm .langchain.db
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
# We can do the same thing with a SQLite cache
|
||||
from langchain.cache import SQLiteCache
|
||||
langchain.llm_cache = SQLiteCache(database_path=".langchain.db")
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
# The first time, it is not yet in cache, so it should take longer
|
||||
llm.predict("Tell me a joke")
|
||||
```
|
||||
|
||||
<CodeOutputBlock lang="python">
|
||||
|
||||
```
|
||||
CPU times: user 17 ms, sys: 9.76 ms, total: 26.7 ms
|
||||
Wall time: 825 ms
|
||||
|
||||
|
||||
'\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
|
||||
```
|
||||
|
||||
</CodeOutputBlock>
|
||||
|
||||
|
||||
```python
|
||||
# The second time it is, so it goes faster
|
||||
llm.predict("Tell me a joke")
|
||||
```
|
||||
|
||||
<CodeOutputBlock lang="python">
|
||||
|
||||
```
|
||||
CPU times: user 2.46 ms, sys: 1.23 ms, total: 3.7 ms
|
||||
Wall time: 2.67 ms
|
||||
|
||||
|
||||
'\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
|
||||
```
|
||||
|
||||
</CodeOutputBlock>
|
||||
@@ -14,7 +14,7 @@ from langchain.cache import InMemoryCache
|
||||
langchain.llm_cache = InMemoryCache()
|
||||
|
||||
# The first time, it is not yet in cache, so it should take longer
|
||||
llm("Tell me a joke")
|
||||
llm.predict("Tell me a joke")
|
||||
```
|
||||
|
||||
<CodeOutputBlock lang="python">
|
||||
@@ -32,7 +32,7 @@ llm("Tell me a joke")
|
||||
|
||||
```python
|
||||
# The second time it is, so it goes faster
|
||||
llm("Tell me a joke")
|
||||
llm.predict("Tell me a joke")
|
||||
```
|
||||
|
||||
<CodeOutputBlock lang="python">
|
||||
@@ -64,7 +64,7 @@ langchain.llm_cache = SQLiteCache(database_path=".langchain.db")
|
||||
|
||||
```python
|
||||
# The first time, it is not yet in cache, so it should take longer
|
||||
llm("Tell me a joke")
|
||||
llm.predict("Tell me a joke")
|
||||
```
|
||||
|
||||
<CodeOutputBlock lang="python">
|
||||
@@ -82,7 +82,7 @@ llm("Tell me a joke")
|
||||
|
||||
```python
|
||||
# The second time it is, so it goes faster
|
||||
llm("Tell me a joke")
|
||||
llm.predict("Tell me a joke")
|
||||
```
|
||||
|
||||
<CodeOutputBlock lang="python">
|
||||
|
||||
1
langchain/agents/agent_toolkits/office365/__init__.py
Normal file
1
langchain/agents/agent_toolkits/office365/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Gmail toolkit."""
|
||||
38
langchain/agents/agent_toolkits/office365/toolkit.py
Normal file
38
langchain/agents/agent_toolkits/office365/toolkit.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.tools.office365.create_draft_message import O365CreateDraftMessage
|
||||
from langchain.tools.office365.events_search import O365SearchEvents
|
||||
from langchain.tools.office365.messages_search import O365SearchEmails
|
||||
from langchain.tools.office365.send_event import O365SendEvent
|
||||
from langchain.tools.office365.send_message import O365SendMessage
|
||||
from langchain.tools.office365.utils import authenticate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from O365 import Account
|
||||
|
||||
|
||||
class O365Toolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with Office365."""
|
||||
|
||||
account: Account = Field(default_factory=authenticate)
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
return [
|
||||
O365SearchEvents(account=self.account),
|
||||
O365CreateDraftMessage(account=self.account),
|
||||
O365SearchEmails(account=self.account),
|
||||
O365SendEvent(account=self.account),
|
||||
O365SendMessage(account=self.account),
|
||||
]
|
||||
@@ -2,6 +2,8 @@ from enum import Enum
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
"""Enumerator with the Agent types."""
|
||||
|
||||
ZERO_SHOT_REACT_DESCRIPTION = "zero-shot-react-description"
|
||||
REACT_DOCSTORE = "react-docstore"
|
||||
SELF_ASK_WITH_SEARCH = "self-ask-with-search"
|
||||
|
||||
@@ -17,13 +17,15 @@ class ChatOutputParser(AgentOutputParser):
|
||||
try:
|
||||
action = text.split("```")[1]
|
||||
response = json.loads(action.strip())
|
||||
includes_action = "action" in response and "action_input" in response
|
||||
includes_action = "action" in response
|
||||
if includes_answer and includes_action:
|
||||
raise OutputParserException(
|
||||
"Parsing LLM output produced a final answer "
|
||||
f"and a parse-able action: {text}"
|
||||
)
|
||||
return AgentAction(response["action"], response["action_input"], text)
|
||||
return AgentAction(
|
||||
response["action"], response.get("action_input", {}), text
|
||||
)
|
||||
|
||||
except Exception:
|
||||
if not includes_answer:
|
||||
|
||||
@@ -51,7 +51,7 @@ def initialize_agent(
|
||||
f"Got unknown agent type: {agent}. "
|
||||
f"Valid types are: {AGENT_TO_CLASS.keys()}."
|
||||
)
|
||||
tags_.append(agent.value)
|
||||
tags_.append(agent.value if isinstance(agent, AgentType) else agent)
|
||||
agent_cls = AGENT_TO_CLASS[agent]
|
||||
agent_kwargs = agent_kwargs or {}
|
||||
agent_obj = agent_cls.from_llm_and_tools(
|
||||
|
||||
@@ -352,10 +352,23 @@ def load_huggingface_tool(
|
||||
remote: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> BaseTool:
|
||||
"""Loads a tool from the HuggingFace Hub.
|
||||
|
||||
Args:
|
||||
task_or_repo_id: Task or model repo id.
|
||||
model_repo_id: Optional model repo id.
|
||||
token: Optional token.
|
||||
remote: Optional remote. Defaults to False.
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
A tool.
|
||||
|
||||
"""
|
||||
try:
|
||||
from transformers import load_tool
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"HuggingFace tools require the libraries `transformers>=4.29.0`"
|
||||
" and `huggingface_hub>=0.14.1` to be installed."
|
||||
" Please install it with"
|
||||
|
||||
@@ -296,7 +296,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
predicted_message = await self.llm.apredict_messages(
|
||||
messages, functions=self.functions
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import timedelta
|
||||
from typing import (
|
||||
@@ -11,8 +12,8 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
@@ -31,13 +32,17 @@ except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.load.load import loads
|
||||
from langchain.schema import Generation
|
||||
from langchain.vectorstores.redis import Redis as RedisVectorstore
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import momento
|
||||
|
||||
RETURN_VAL_TYPE = List[Generation]
|
||||
RETURN_VAL_TYPE = Sequence[Generation]
|
||||
|
||||
|
||||
def _hash(_input: str) -> str:
|
||||
@@ -147,13 +152,24 @@ class SQLAlchemyCache(BaseCache):
|
||||
with Session(self.engine) as session:
|
||||
rows = session.execute(stmt).fetchall()
|
||||
if rows:
|
||||
return [Generation(text=row[0]) for row in rows]
|
||||
try:
|
||||
return [loads(row[0]) for row in rows]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Retrieving a cache value that could not be deserialized "
|
||||
"properly. This is likely due to the cache being in an "
|
||||
"older format. Please recreate your cache to avoid this "
|
||||
"error."
|
||||
)
|
||||
# In a previous life we stored the raw text directly
|
||||
# in the table, so assume it's in that format.
|
||||
return [Generation(text=row[0]) for row in rows]
|
||||
return None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update based on prompt and llm_string."""
|
||||
items = [
|
||||
self.cache_schema(prompt=prompt, llm=llm_string, response=gen.text, idx=i)
|
||||
self.cache_schema(prompt=prompt, llm=llm_string, response=dumps(gen), idx=i)
|
||||
for i, gen in enumerate(return_val)
|
||||
]
|
||||
with Session(self.engine) as session, session.begin():
|
||||
@@ -163,7 +179,7 @@ class SQLAlchemyCache(BaseCache):
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache."""
|
||||
with Session(self.engine) as session:
|
||||
session.execute(self.cache_schema.delete())
|
||||
session.query(self.cache_schema).delete()
|
||||
|
||||
|
||||
class SQLiteCache(SQLAlchemyCache):
|
||||
@@ -209,6 +225,12 @@ class RedisCache(BaseCache):
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
for gen in return_val:
|
||||
if not isinstance(gen, Generation):
|
||||
raise ValueError(
|
||||
"RedisCache only supports caching of normal LLM generations, "
|
||||
f"got {type(gen)}"
|
||||
)
|
||||
# Write to a Redis HASH
|
||||
key = self._key(prompt, llm_string)
|
||||
self.redis.hset(
|
||||
@@ -314,6 +336,12 @@ class RedisSemanticCache(BaseCache):
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
for gen in return_val:
|
||||
if not isinstance(gen, Generation):
|
||||
raise ValueError(
|
||||
"RedisSemanticCache only supports caching of "
|
||||
f"normal LLM generations, got {type(gen)}"
|
||||
)
|
||||
llm_cache = self._get_llm_cache(llm_string)
|
||||
# Write to vectorstore
|
||||
metadata = {
|
||||
@@ -426,6 +454,12 @@ class GPTCache(BaseCache):
|
||||
First, retrieve the corresponding cache object using the `llm_string` parameter,
|
||||
and then store the `prompt` and `return_val` in the cache object.
|
||||
"""
|
||||
for gen in return_val:
|
||||
if not isinstance(gen, Generation):
|
||||
raise ValueError(
|
||||
"GPTCache only supports caching of normal LLM generations, "
|
||||
f"got {type(gen)}"
|
||||
)
|
||||
from gptcache.adapter.api import put
|
||||
|
||||
_gptcache = self._get_gptcache(llm_string)
|
||||
@@ -567,7 +601,7 @@ class MomentoCache(BaseCache):
|
||||
"""
|
||||
from momento.responses import CacheGet
|
||||
|
||||
generations = []
|
||||
generations: RETURN_VAL_TYPE = []
|
||||
|
||||
get_response = self.cache_client.get(
|
||||
self.cache_name, self.__key(prompt, llm_string)
|
||||
@@ -593,6 +627,12 @@ class MomentoCache(BaseCache):
|
||||
SdkException: Momento service or network error
|
||||
Exception: Unexpected response
|
||||
"""
|
||||
for gen in return_val:
|
||||
if not isinstance(gen, Generation):
|
||||
raise ValueError(
|
||||
"Momento only supports caching of normal LLM generations, "
|
||||
f"got {type(gen)}"
|
||||
)
|
||||
key = self.__key(prompt, llm_string)
|
||||
value = _dump_generations_to_json(return_val)
|
||||
set_response = self.cache_client.set(self.cache_name, key, value, self.ttl)
|
||||
|
||||
@@ -6,6 +6,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
def import_aim() -> Any:
|
||||
"""Import the aim python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import aim
|
||||
except ImportError:
|
||||
|
||||
@@ -17,6 +17,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
def import_clearml() -> Any:
|
||||
"""Import the clearml python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import clearml # noqa: F401
|
||||
except ImportError:
|
||||
|
||||
@@ -672,66 +672,72 @@ class CallbackManager(BaseCallbackManager):
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForLLMRun:
|
||||
) -> List[CallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running."""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
managers = []
|
||||
for prompt in prompts:
|
||||
run_id_ = uuid4()
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
[prompt],
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
prompts,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
managers.append(
|
||||
CallbackManagerForLLMRun(
|
||||
run_id=run_id_,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
)
|
||||
|
||||
return CallbackManagerForLLMRun(
|
||||
run_id=run_id,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
return managers
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForLLMRun:
|
||||
) -> List[CallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running."""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
messages,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Re-use the LLM Run Manager since the outputs are treated
|
||||
# the same for now
|
||||
return CallbackManagerForLLMRun(
|
||||
run_id=run_id,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
managers = []
|
||||
for message_list in messages:
|
||||
run_id_ = uuid4()
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
[message_list],
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
managers.append(
|
||||
CallbackManagerForLLMRun(
|
||||
run_id=run_id_,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
)
|
||||
|
||||
return managers
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
@@ -830,64 +836,84 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCallbackManagerForLLMRun:
|
||||
) -> List[AsyncCallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running."""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
prompts,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
tasks = []
|
||||
managers = []
|
||||
|
||||
return AsyncCallbackManagerForLLMRun(
|
||||
run_id=run_id,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
for prompt in prompts:
|
||||
run_id_ = uuid4()
|
||||
|
||||
tasks.append(
|
||||
_ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
[prompt],
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
managers.append(
|
||||
AsyncCallbackManagerForLLMRun(
|
||||
run_id=run_id_,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
return managers
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
tasks = []
|
||||
managers = []
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
messages,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
for message_list in messages:
|
||||
run_id_ = uuid4()
|
||||
|
||||
return AsyncCallbackManagerForLLMRun(
|
||||
run_id=run_id,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
tasks.append(
|
||||
_ahandle_event(
|
||||
self.handlers,
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
[message_list],
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
managers.append(
|
||||
AsyncCallbackManagerForLLMRun(
|
||||
run_id=run_id_,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
return managers
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
|
||||
@@ -20,6 +20,7 @@ from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
def import_mlflow() -> Any:
|
||||
"""Import the mlflow python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import mlflow
|
||||
except ImportError:
|
||||
@@ -117,7 +118,7 @@ class MlflowLogger:
|
||||
Parameters:
|
||||
name (str): Name of the run.
|
||||
experiment (str): Name of the experiment.
|
||||
tags (str): Tags to be attached for the run.
|
||||
tags (dict): Tags to be attached for the run.
|
||||
tracking_uri (str): MLflow tracking server uri.
|
||||
|
||||
This handler implements the helper functions to initialize,
|
||||
@@ -222,7 +223,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
Parameters:
|
||||
name (str): Name of the run.
|
||||
experiment (str): Name of the experiment.
|
||||
tags (str): Tags to be attached for the run.
|
||||
tags (dict): Tags to be attached for the run.
|
||||
tracking_uri (str): MLflow tracking server uri.
|
||||
|
||||
This handler will utilize the associated callback method called and formats
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Callback Handler that prints to std out."""
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
MODEL_COST_PER_1K_TOKENS = {
|
||||
# GPT-4 input
|
||||
@@ -32,6 +32,7 @@ MODEL_COST_PER_1K_TOKENS = {
|
||||
"gpt-3.5-turbo-16k-completion": 0.004,
|
||||
"gpt-3.5-turbo-16k-0613-completion": 0.004,
|
||||
# Others
|
||||
"gpt-35-turbo": 0.002, # Azure OpenAI version of ChatGPT
|
||||
"text-ada-001": 0.0004,
|
||||
"ada": 0.0004,
|
||||
"text-babbage-001": 0.0005,
|
||||
@@ -52,6 +53,17 @@ def standardize_model_name(
|
||||
model_name: str,
|
||||
is_completion: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Standardize the model name to a format that can be used in the OpenAI API.
|
||||
Args:
|
||||
model_name: Model name to standardize.
|
||||
is_completion: Whether the model is used for completion or not.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
Standardized model name.
|
||||
|
||||
"""
|
||||
model_name = model_name.lower()
|
||||
if "ft-" in model_name:
|
||||
return model_name.split(":")[0] + "-finetuned"
|
||||
@@ -66,6 +78,18 @@ def standardize_model_name(
|
||||
def get_openai_token_cost_for_model(
|
||||
model_name: str, num_tokens: int, is_completion: bool = False
|
||||
) -> float:
|
||||
"""
|
||||
Get the cost in USD for a given model and number of tokens.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
num_tokens: Number of tokens.
|
||||
is_completion: Whether the model is used for completion or not.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
Cost in USD.
|
||||
"""
|
||||
model_name = standardize_model_name(model_name, is_completion=is_completion)
|
||||
if model_name not in MODEL_COST_PER_1K_TOKENS:
|
||||
raise ValueError(
|
||||
@@ -129,64 +153,6 @@ class OpenAICallbackHandler(BaseCallbackHandler):
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.completion_tokens += completion_tokens
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
pass
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Print out the log in specified color."""
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
pass
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
pass
|
||||
|
||||
def __copy__(self) -> "OpenAICallbackHandler":
|
||||
"""Return a copy of the callback handler."""
|
||||
return self
|
||||
|
||||
@@ -7,6 +7,16 @@ from langchain.input import get_bolded_text, get_colored_text
|
||||
|
||||
|
||||
def try_json_stringify(obj: Any, fallback: str) -> str:
|
||||
"""
|
||||
Try to stringify an object to JSON.
|
||||
Args:
|
||||
obj: Object to stringify.
|
||||
fallback: Fallback string to return if the object cannot be stringified.
|
||||
|
||||
Returns:
|
||||
A JSON string if the object can be stringified, otherwise the fallback string.
|
||||
|
||||
"""
|
||||
try:
|
||||
return json.dumps(obj, indent=2, ensure_ascii=False)
|
||||
except Exception:
|
||||
@@ -14,6 +24,16 @@ def try_json_stringify(obj: Any, fallback: str) -> str:
|
||||
|
||||
|
||||
def elapsed(run: Any) -> str:
|
||||
"""Get the elapsed time of a run.
|
||||
|
||||
Args:
|
||||
run: any object with a start_time and end_time attribute.
|
||||
|
||||
Returns:
|
||||
A string with the elapsed time in seconds or
|
||||
milliseconds if time is less than a second.
|
||||
|
||||
"""
|
||||
elapsed_time = run.end_time - run.start_time
|
||||
milliseconds = elapsed_time.total_seconds() * 1000
|
||||
if milliseconds < 1000:
|
||||
@@ -22,6 +42,8 @@ def elapsed(run: Any) -> str:
|
||||
|
||||
|
||||
class ConsoleCallbackHandler(BaseTracer):
|
||||
"""Tracer that prints to the console."""
|
||||
|
||||
name = "console_callback_handler"
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
|
||||
@@ -137,6 +137,8 @@ def _replace_type_with_kind(data: Any) -> Any:
|
||||
|
||||
|
||||
class WandbRunArgs(TypedDict):
|
||||
"""Arguments for the WandbTracer."""
|
||||
|
||||
job_type: Optional[str]
|
||||
dir: Optional[StrPath]
|
||||
config: Union[Dict, str, None]
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, Iterable, Tuple, Union
|
||||
|
||||
|
||||
def import_spacy() -> Any:
|
||||
"""Import the spacy python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import spacy
|
||||
except ImportError:
|
||||
@@ -15,6 +16,7 @@ def import_spacy() -> Any:
|
||||
|
||||
|
||||
def import_pandas() -> Any:
|
||||
"""Import the pandas python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import pandas
|
||||
except ImportError:
|
||||
@@ -26,6 +28,7 @@ def import_pandas() -> Any:
|
||||
|
||||
|
||||
def import_textstat() -> Any:
|
||||
"""Import the textstat python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import textstat
|
||||
except ImportError:
|
||||
|
||||
@@ -17,6 +17,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
def import_wandb() -> Any:
|
||||
"""Import the wandb python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import wandb # noqa: F401
|
||||
except ImportError:
|
||||
|
||||
@@ -18,6 +18,16 @@ def import_langkit(
|
||||
toxicity: bool = False,
|
||||
themes: bool = False,
|
||||
) -> Any:
|
||||
"""Import the langkit python package and raise an error if it is not installed.
|
||||
|
||||
Args:
|
||||
sentiment: Whether to import the langkit.sentiment module. Defaults to False.
|
||||
toxicity: Whether to import the langkit.toxicity module. Defaults to False.
|
||||
themes: Whether to import the langkit.themes module. Defaults to False.
|
||||
|
||||
Returns:
|
||||
The imported langkit module.
|
||||
"""
|
||||
try:
|
||||
import langkit # noqa: F401
|
||||
import langkit.regexes # noqa: F401
|
||||
|
||||
@@ -18,6 +18,14 @@ INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
|
||||
def extract_cypher(text: str) -> str:
|
||||
"""
|
||||
Extract Cypher code from a text.
|
||||
Args:
|
||||
text: Text to extract Cypher code from.
|
||||
|
||||
Returns:
|
||||
Cypher code extracted from the text.
|
||||
"""
|
||||
# The pattern to find Cypher code enclosed in triple backticks
|
||||
pattern = r"```(.*?)```"
|
||||
|
||||
|
||||
@@ -34,6 +34,8 @@ black_listed_elements: Set[str] = {
|
||||
|
||||
|
||||
class ElementInViewPort(TypedDict):
|
||||
"""A typed dictionary containing information about elements in the viewport."""
|
||||
|
||||
node_index: str
|
||||
backend_node_id: int
|
||||
node_name: Optional[str]
|
||||
@@ -51,7 +53,7 @@ class Crawler:
|
||||
try:
|
||||
from playwright.sync_api import sync_playwright
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import playwright python package. "
|
||||
"Please install it with `pip install playwright`."
|
||||
)
|
||||
|
||||
@@ -64,6 +64,14 @@ class QuestionAnswer(BaseModel):
|
||||
|
||||
|
||||
def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
|
||||
"""Create a citation fuzzy match chain.
|
||||
|
||||
Args:
|
||||
llm: Language model to use for the chain.
|
||||
|
||||
Returns:
|
||||
Chain (LLMChain) that can be used to answer questions with citations.
|
||||
"""
|
||||
output_parser = PydanticOutputFunctionsParser(pydantic_schema=QuestionAnswer)
|
||||
schema = QuestionAnswer.schema()
|
||||
function = {
|
||||
|
||||
@@ -40,6 +40,15 @@ Passage:
|
||||
|
||||
|
||||
def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
"""Creates a chain that extracts information from a passage.
|
||||
|
||||
Args:
|
||||
schema: The schema of the entities to extract.
|
||||
llm: The language model to use.
|
||||
|
||||
Returns:
|
||||
Chain that can be used to extract information from a passage.
|
||||
"""
|
||||
function = _get_extraction_function(schema)
|
||||
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||
output_parser = JsonKeyOutputFunctionsParser(key_name="info")
|
||||
@@ -56,6 +65,16 @@ def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
def create_extraction_chain_pydantic(
|
||||
pydantic_schema: Any, llm: BaseLanguageModel
|
||||
) -> Chain:
|
||||
"""Creates a chain that extracts information from a passage using pydantic schema.
|
||||
|
||||
Args:
|
||||
pydantic_schema: The pydantic schema of the entities to extract.
|
||||
llm: The language model to use.
|
||||
|
||||
Returns:
|
||||
Chain that can be used to extract information from a passage.
|
||||
"""
|
||||
|
||||
class PydanticSchema(BaseModel):
|
||||
info: List[pydantic_schema] # type: ignore
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.sequential import SequentialChain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.tools import APIOperation
|
||||
@@ -77,7 +78,7 @@ def _openapi_params_to_json_schema(params: List[Parameter], spec: OpenAPISpec) -
|
||||
schema = spec.get_schema(media_type_schema)
|
||||
if p.description and not schema.description:
|
||||
schema.description = p.description
|
||||
properties[p.name] = schema.dict(exclude_none=True)
|
||||
properties[p.name] = json.loads(schema.json(exclude_none=True))
|
||||
if p.required:
|
||||
required.append(p.name)
|
||||
return {"type": "object", "properties": properties, "required": required}
|
||||
@@ -127,15 +128,19 @@ def openapi_spec_to_openai_fn(
|
||||
request_body = spec.get_request_body_for_operation(op)
|
||||
# TODO: Support more MIME types.
|
||||
if request_body and request_body.content:
|
||||
media_types = []
|
||||
for media_type in request_body.content.values():
|
||||
if media_type.media_type_schema:
|
||||
schema = spec.get_schema(media_type.media_type_schema)
|
||||
media_types.append(schema.dict(exclude_none=True))
|
||||
media_types = {}
|
||||
for media_type, media_type_object in request_body.content.items():
|
||||
if media_type_object.media_type_schema:
|
||||
schema = spec.get_schema(media_type_object.media_type_schema)
|
||||
media_types[media_type] = json.loads(
|
||||
schema.json(exclude_none=True)
|
||||
)
|
||||
if len(media_types) == 1:
|
||||
request_args["data"] = media_types[0]
|
||||
media_type, schema_dict = list(media_types.items())[0]
|
||||
key = "json" if media_type == "application/json" else "data"
|
||||
request_args[key] = schema_dict
|
||||
elif len(media_types) > 1:
|
||||
request_args["data"] = {"anyOf": media_types}
|
||||
request_args["data"] = {"anyOf": list(media_types.values())}
|
||||
|
||||
api_op = APIOperation.from_openapi_spec(spec, path, method)
|
||||
fn = {
|
||||
@@ -184,8 +189,13 @@ class SimpleRequestChain(Chain):
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
name = inputs["function"].pop("name")
|
||||
args = inputs["function"].pop("arguments")
|
||||
_pretty_name = get_colored_text(name, "green")
|
||||
_pretty_args = get_colored_text(json.dumps(args, indent=2), "green")
|
||||
_text = f"Calling endpoint {_pretty_name} with arguments:\n" + _pretty_args
|
||||
_run_manager.on_text(_text)
|
||||
api_response: Response = self.request_method(name, args)
|
||||
if api_response.status_code != 200:
|
||||
response = (
|
||||
@@ -206,6 +216,9 @@ def get_openapi_chain(
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
request_chain: Optional[Chain] = None,
|
||||
llm_kwargs: Optional[Dict] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> SequentialChain:
|
||||
"""Create a chain for querying an API from a OpenAPI spec.
|
||||
|
||||
@@ -242,10 +255,16 @@ def get_openapi_chain(
|
||||
llm_kwargs={"functions": openai_fns},
|
||||
output_parser=JsonOutputFunctionsParser(args_only=False),
|
||||
output_key="function",
|
||||
verbose=verbose,
|
||||
**(llm_kwargs or {}),
|
||||
)
|
||||
request_chain = request_chain or SimpleRequestChain(
|
||||
request_method=call_api_fn, verbose=verbose
|
||||
)
|
||||
request_chain = request_chain or SimpleRequestChain(request_method=call_api_fn)
|
||||
return SequentialChain(
|
||||
chains=[llm_chain, request_chain],
|
||||
input_variables=llm_chain.input_keys,
|
||||
output_variables=["response"],
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -29,6 +29,18 @@ def create_qa_with_structure_chain(
|
||||
output_parser: str = "base",
|
||||
prompt: Optional[Union[PromptTemplate, ChatPromptTemplate]] = None,
|
||||
) -> LLMChain:
|
||||
"""Create a question answering chain that returns an answer with sources.
|
||||
|
||||
Args:
|
||||
llm: Language model to use for the chain.
|
||||
schema: Pydantic schema to use for the output.
|
||||
output_parser: Output parser to use. Should be one of `pydantic` or `base`.
|
||||
Default to `base`.
|
||||
prompt: Optional prompt to use for the chain.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if output_parser == "pydantic":
|
||||
if not (isinstance(schema, type) and issubclass(schema, BaseModel)):
|
||||
raise ValueError(
|
||||
@@ -79,4 +91,13 @@ def create_qa_with_structure_chain(
|
||||
|
||||
|
||||
def create_qa_with_sources_chain(llm: BaseLanguageModel, **kwargs: Any) -> LLMChain:
|
||||
"""Create a question answering chain that returns an answer with sources.
|
||||
|
||||
Args:
|
||||
llm: Language model to use for the chain.
|
||||
**kwargs: Keyword arguments to pass to `create_qa_with_structure_chain`.
|
||||
|
||||
Returns:
|
||||
Chain (LLMChain) that can be used to answer questions with citations.
|
||||
"""
|
||||
return create_qa_with_structure_chain(llm, AnswerWithSources, **kwargs)
|
||||
|
||||
@@ -27,6 +27,15 @@ Passage:
|
||||
|
||||
|
||||
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
"""Creates a chain that extracts information from a passage.
|
||||
|
||||
Args:
|
||||
schema: The schema of the entities to extract.
|
||||
llm: The language model to use.
|
||||
|
||||
Returns:
|
||||
Chain (LLMChain) that can be used to extract information from a passage.
|
||||
"""
|
||||
function = _get_tagging_function(schema)
|
||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
output_parser = JsonOutputFunctionsParser()
|
||||
@@ -43,6 +52,15 @@ def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
def create_tagging_chain_pydantic(
|
||||
pydantic_schema: Any, llm: BaseLanguageModel
|
||||
) -> Chain:
|
||||
"""Creates a chain that extracts information from a passage.
|
||||
|
||||
Args:
|
||||
pydantic_schema: The pydantic schema of the entities to extract.
|
||||
llm: The language model to use.
|
||||
|
||||
Returns:
|
||||
Chain (LLMChain) that can be used to extract information from a passage.
|
||||
"""
|
||||
openai_schema = pydantic_schema.schema()
|
||||
function = _get_tagging_function(openai_schema)
|
||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
|
||||
@@ -29,4 +29,12 @@ def _convert_schema(schema: dict) -> dict:
|
||||
|
||||
|
||||
def get_llm_kwargs(function: dict) -> dict:
|
||||
"""Returns the kwargs for the LLMChain constructor.
|
||||
|
||||
Args:
|
||||
function: The function to use.
|
||||
|
||||
Returns:
|
||||
The kwargs for the LLMChain constructor.
|
||||
"""
|
||||
return {"functions": [function], "function_call": {"name": function["name"]}}
|
||||
|
||||
@@ -31,8 +31,24 @@ class ConditionalPromptSelector(BasePromptSelector):
|
||||
|
||||
|
||||
def is_llm(llm: BaseLanguageModel) -> bool:
|
||||
"""Check if the language model is a LLM.
|
||||
|
||||
Args:
|
||||
llm: Language model to check.
|
||||
|
||||
Returns:
|
||||
True if the language model is a BaseLLM model, False otherwise.
|
||||
"""
|
||||
return isinstance(llm, BaseLLM)
|
||||
|
||||
|
||||
def is_chat_model(llm: BaseLanguageModel) -> bool:
|
||||
"""Check if the language model is a chat model.
|
||||
|
||||
Args:
|
||||
llm: Language model to check.
|
||||
|
||||
Returns:
|
||||
True if the language model is a BaseChatModel model, False otherwise.
|
||||
"""
|
||||
return isinstance(llm, BaseChatModel)
|
||||
|
||||
@@ -123,6 +123,22 @@ def load_query_constructor_chain(
|
||||
enable_limit: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> LLMChain:
|
||||
"""
|
||||
Load a query constructor chain.
|
||||
Args:
|
||||
llm: BaseLanguageModel to use for the chain.
|
||||
document_contents: The contents of the document to be queried.
|
||||
attribute_info: A list of AttributeInfo objects describing
|
||||
the attributes of the document.
|
||||
examples: Optional list of examples to use for the chain.
|
||||
allowed_comparators: An optional list of allowed comparators.
|
||||
allowed_operators: An optional list of allowed operators.
|
||||
enable_limit: Whether to enable the limit operator. Defaults to False.
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
A LLMChain that can be used to construct queries.
|
||||
"""
|
||||
prompt = _get_prompt(
|
||||
document_contents,
|
||||
attribute_info,
|
||||
|
||||
@@ -60,12 +60,16 @@ class Expr(BaseModel):
|
||||
|
||||
|
||||
class Operator(str, Enum):
|
||||
"""Enumerator of the operations."""
|
||||
|
||||
AND = "and"
|
||||
OR = "or"
|
||||
NOT = "not"
|
||||
|
||||
|
||||
class Comparator(str, Enum):
|
||||
"""Enumerator of the comparison operators."""
|
||||
|
||||
EQ = "eq"
|
||||
GT = "gt"
|
||||
GTE = "gte"
|
||||
|
||||
@@ -57,6 +57,9 @@ GRAMMAR = """
|
||||
|
||||
@v_args(inline=True)
|
||||
class QueryTransformer(Transformer):
|
||||
"""Transforms a query string into an IR representation
|
||||
(intermediate representation)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -136,6 +139,16 @@ def get_parser(
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
) -> Lark:
|
||||
"""
|
||||
Returns a parser for the query language.
|
||||
|
||||
Args:
|
||||
allowed_comparators: Optional[Sequence[Comparator]]
|
||||
allowed_operators: Optional[Sequence[Operator]]
|
||||
|
||||
Returns:
|
||||
Lark parser for the query language.
|
||||
"""
|
||||
transformer = QueryTransformer(
|
||||
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from langchain.chat_models.anthropic import ChatAnthropic
|
||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||
from langchain.chat_models.fake import FakeListChatModel
|
||||
from langchain.chat_models.google_palm import ChatGooglePalm
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
|
||||
@@ -8,6 +9,7 @@ from langchain.chat_models.vertexai import ChatVertexAI
|
||||
__all__ = [
|
||||
"ChatOpenAI",
|
||||
"AzureChatOpenAI",
|
||||
"FakeListChatModel",
|
||||
"PromptLayerChatOpenAI",
|
||||
"ChatAnthropic",
|
||||
"ChatGooglePalm",
|
||||
|
||||
@@ -17,7 +17,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.dump import dumpd, dumps
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
@@ -35,6 +35,7 @@ def _get_verbosity() -> bool:
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel, ABC):
|
||||
cache: Optional[bool] = None
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
@@ -61,6 +62,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
return {}
|
||||
|
||||
def _get_invocation_params(
|
||||
self,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> dict:
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
return params
|
||||
|
||||
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
|
||||
if self.lc_serializable:
|
||||
params = {**kwargs, **{"stop": stop}}
|
||||
param_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
llm_string = dumps(self)
|
||||
return llm_string + "---" + param_string
|
||||
else:
|
||||
params = self._get_invocation_params(stop=stop)
|
||||
params = {**params, **kwargs}
|
||||
return str(sorted([(k, v) for k, v in params.items()]))
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
@@ -71,9 +91,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
params = self._get_invocation_params(stop=stop)
|
||||
options = {"stop": stop}
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
@@ -83,29 +101,37 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
tags,
|
||||
self.tags,
|
||||
)
|
||||
run_manager = callback_manager.on_chat_model_start(
|
||||
run_managers = callback_manager.on_chat_model_start(
|
||||
dumpd(self), messages, invocation_params=params, options=options
|
||||
)
|
||||
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
try:
|
||||
results = [
|
||||
self._generate(m, stop=stop, run_manager=run_manager, **kwargs)
|
||||
if new_arg_supported
|
||||
else self._generate(m, stop=stop)
|
||||
for m in messages
|
||||
]
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
results = []
|
||||
for i, m in enumerate(messages):
|
||||
try:
|
||||
results.append(
|
||||
self._generate_with_cache(
|
||||
m,
|
||||
stop=stop,
|
||||
run_manager=run_managers[i] if run_managers else None,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if run_managers:
|
||||
run_managers[i].on_llm_error(e)
|
||||
raise e
|
||||
flattened_outputs = [
|
||||
LLMResult(generations=[res.generations], llm_output=res.llm_output)
|
||||
for res in results
|
||||
]
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||
generations = [res.generations for res in results]
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
run_manager.on_llm_end(output)
|
||||
if run_manager:
|
||||
output.run = RunInfo(run_id=run_manager.run_id)
|
||||
if run_managers:
|
||||
run_infos = []
|
||||
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
||||
manager.on_llm_end(flattened_output)
|
||||
run_infos.append(RunInfo(run_id=manager.run_id))
|
||||
output.run = run_infos
|
||||
return output
|
||||
|
||||
async def agenerate(
|
||||
@@ -118,8 +144,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
params = self._get_invocation_params(stop=stop)
|
||||
options = {"stop": stop}
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
@@ -129,31 +154,62 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
tags,
|
||||
self.tags,
|
||||
)
|
||||
run_manager = await callback_manager.on_chat_model_start(
|
||||
|
||||
run_managers = await callback_manager.on_chat_model_start(
|
||||
dumpd(self), messages, invocation_params=params, options=options
|
||||
)
|
||||
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
"run_manager"
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
self._agenerate_with_cache(
|
||||
m,
|
||||
stop=stop,
|
||||
run_manager=run_managers[i] if run_managers else None,
|
||||
**kwargs,
|
||||
)
|
||||
for i, m in enumerate(messages)
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
try:
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
self._agenerate(m, stop=stop, run_manager=run_manager, **kwargs)
|
||||
if new_arg_supported
|
||||
else self._agenerate(m, stop=stop)
|
||||
for m in messages
|
||||
]
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
raise e
|
||||
exceptions = []
|
||||
for i, res in enumerate(results):
|
||||
if isinstance(res, Exception):
|
||||
if run_managers:
|
||||
await run_managers[i].on_llm_error(res)
|
||||
exceptions.append(res)
|
||||
if exceptions:
|
||||
if run_managers:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(
|
||||
LLMResult(
|
||||
generations=[res.generations], llm_output=res.llm_output
|
||||
)
|
||||
)
|
||||
for run_manager, res in zip(run_managers, results)
|
||||
if not isinstance(res, Exception)
|
||||
]
|
||||
)
|
||||
raise exceptions[0]
|
||||
flattened_outputs = [
|
||||
LLMResult(generations=[res.generations], llm_output=res.llm_output)
|
||||
for res in results
|
||||
]
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||
generations = [res.generations for res in results]
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
await run_manager.on_llm_end(output)
|
||||
if run_manager:
|
||||
output.run = RunInfo(run_id=run_manager.run_id)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(flattened_output)
|
||||
for run_manager, flattened_output in zip(
|
||||
run_managers, flattened_outputs
|
||||
)
|
||||
]
|
||||
)
|
||||
if run_managers:
|
||||
output.run = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
return output
|
||||
|
||||
def generate_prompt(
|
||||
@@ -178,6 +234,84 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
def _generate_with_cache(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
if new_arg_supported:
|
||||
return self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
return self._generate(messages, stop=stop, **kwargs)
|
||||
else:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
else:
|
||||
if new_arg_supported:
|
||||
result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
result = self._generate(messages, stop=stop, **kwargs)
|
||||
langchain.llm_cache.update(prompt, llm_string, result.generations)
|
||||
return result
|
||||
|
||||
async def _agenerate_with_cache(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
if new_arg_supported:
|
||||
return await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
return await self._agenerate(messages, stop=stop, **kwargs)
|
||||
else:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
else:
|
||||
if new_arg_supported:
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
result = await self._agenerate(messages, stop=stop, **kwargs)
|
||||
langchain.llm_cache.update(prompt, llm_string, result.generations)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self,
|
||||
|
||||
33
langchain/chat_models/fake.py
Normal file
33
langchain/chat_models/fake.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import SimpleChatModel
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
|
||||
class FakeListChatModel(SimpleChatModel):
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
|
||||
responses: List
|
||||
i: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-list-chat-model"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""First try to lookup in queries, else return 'foo' or 'bar'."""
|
||||
response = self.responses[self.i]
|
||||
self.i += 1
|
||||
return response
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {"responses": self.responses}
|
||||
@@ -36,6 +36,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatGooglePalmError(Exception):
|
||||
"""Error raised when there is an issue with the Google PaLM API."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -184,6 +184,16 @@ class ChatOpenAI(BaseChatModel):
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
max_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
tiktoken_model_name: Optional[str] = None
|
||||
"""The model name to pass to tiktoken when using this class.
|
||||
Tiktoken is used to count the number of tokens in documents to constrain
|
||||
them to be under a certain limit. By default, when set to None, this will
|
||||
be the same as the embedding model name. However, there are some cases
|
||||
where you may want to use this Embedding class with a model name not
|
||||
supported by tiktoken. This can include when using Azure embeddings or
|
||||
when using one of the many model providers that expose an OpenAI-like
|
||||
API but with different models. In those cases, in order to avoid erroring
|
||||
when tiktoken is called, you can specify a model name to use here."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -448,15 +458,18 @@ class ChatOpenAI(BaseChatModel):
|
||||
|
||||
def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
|
||||
tiktoken_ = _import_tiktoken()
|
||||
model = self.model_name
|
||||
if model == "gpt-3.5-turbo":
|
||||
# gpt-3.5-turbo may change over time.
|
||||
# Returning num tokens assuming gpt-3.5-turbo-0301.
|
||||
model = "gpt-3.5-turbo-0301"
|
||||
elif model == "gpt-4":
|
||||
# gpt-4 may change over time.
|
||||
# Returning num tokens assuming gpt-4-0314.
|
||||
model = "gpt-4-0314"
|
||||
if self.tiktoken_model_name is not None:
|
||||
model = self.tiktoken_model_name
|
||||
else:
|
||||
model = self.model_name
|
||||
if model == "gpt-3.5-turbo":
|
||||
# gpt-3.5-turbo may change over time.
|
||||
# Returning num tokens assuming gpt-3.5-turbo-0301.
|
||||
model = "gpt-3.5-turbo-0301"
|
||||
elif model == "gpt-4":
|
||||
# gpt-4 may change over time.
|
||||
# Returning num tokens assuming gpt-4-0314.
|
||||
model = "gpt-4-0314"
|
||||
# Returns the number of tokens used by a list of messages.
|
||||
try:
|
||||
encoding = tiktoken_.encoding_for_model(model)
|
||||
|
||||
@@ -68,6 +68,7 @@ from langchain.document_loaders.mastodon import MastodonTootsLoader
|
||||
from langchain.document_loaders.max_compute import MaxComputeLoader
|
||||
from langchain.document_loaders.mediawikidump import MWDumpLoader
|
||||
from langchain.document_loaders.merge import MergedDataLoader
|
||||
from langchain.document_loaders.mhtml import MHTMLLoader
|
||||
from langchain.document_loaders.modern_treasury import ModernTreasuryLoader
|
||||
from langchain.document_loaders.notebook import NotebookLoader
|
||||
from langchain.document_loaders.notion import NotionDirectoryLoader
|
||||
@@ -97,6 +98,7 @@ from langchain.document_loaders.readthedocs import ReadTheDocsLoader
|
||||
from langchain.document_loaders.recursive_url_loader import RecusiveUrlLoader
|
||||
from langchain.document_loaders.reddit import RedditPostsLoader
|
||||
from langchain.document_loaders.roam import RoamLoader
|
||||
from langchain.document_loaders.rst import UnstructuredRSTLoader
|
||||
from langchain.document_loaders.rtf import UnstructuredRTFLoader
|
||||
from langchain.document_loaders.s3_directory import S3DirectoryLoader
|
||||
from langchain.document_loaders.s3_file import S3FileLoader
|
||||
@@ -204,6 +206,7 @@ __all__ = [
|
||||
"MathpixPDFLoader",
|
||||
"MaxComputeLoader",
|
||||
"MergedDataLoader",
|
||||
"MHTMLLoader",
|
||||
"ModernTreasuryLoader",
|
||||
"NotebookLoader",
|
||||
"NotionDBLoader",
|
||||
@@ -261,6 +264,7 @@ __all__ = [
|
||||
"UnstructuredODTLoader",
|
||||
"UnstructuredPDFLoader",
|
||||
"UnstructuredPowerPointLoader",
|
||||
"UnstructuredRSTLoader",
|
||||
"UnstructuredRTFLoader",
|
||||
"UnstructuredURLLoader",
|
||||
"UnstructuredWordDocumentLoader",
|
||||
|
||||
@@ -11,6 +11,8 @@ from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
class BlockchainType(Enum):
|
||||
"""Enumerator of the supported blockchains."""
|
||||
|
||||
ETH_MAINNET = "eth-mainnet"
|
||||
ETH_GOERLI = "eth-goerli"
|
||||
POLYGON_MAINNET = "polygon-mainnet"
|
||||
|
||||
@@ -8,6 +8,15 @@ from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
def concatenate_rows(message: dict, title: str) -> str:
|
||||
"""
|
||||
Combine message information in a readable format ready to be used.
|
||||
Args:
|
||||
message: Message to be concatenated
|
||||
title: Title of the conversation
|
||||
|
||||
Returns:
|
||||
Concatenated message
|
||||
"""
|
||||
if not message:
|
||||
return ""
|
||||
|
||||
|
||||
@@ -18,6 +18,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContentFormat(str, Enum):
|
||||
"""Enumerator of the content formats of Confluence page."""
|
||||
|
||||
STORAGE = "body.storage"
|
||||
VIEW = "body.view"
|
||||
|
||||
|
||||
@@ -45,6 +45,8 @@ class EmbaasDocumentExtractionParameters(TypedDict):
|
||||
|
||||
|
||||
class EmbaasDocumentExtractionPayload(EmbaasDocumentExtractionParameters):
|
||||
"""Payload for the Embaas document extraction API."""
|
||||
|
||||
bytes: str
|
||||
"""The base64 encoded bytes of the document to extract text from."""
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
class FaunaLoader(BaseLoader):
|
||||
"""
|
||||
"""FaunaDB Loader.
|
||||
|
||||
Attributes:
|
||||
query (str): The FQL query string to execute.
|
||||
page_content_field (str): The field that contains the content of each page.
|
||||
|
||||
@@ -17,6 +17,8 @@ IUGU_ENDPOINTS = {
|
||||
|
||||
|
||||
class IuguLoader(BaseLoader):
|
||||
"""Loader that fetches data from IUGU."""
|
||||
|
||||
def __init__(self, resource: str, api_token: Optional[str] = None) -> None:
|
||||
self.resource = resource
|
||||
api_token = api_token or get_from_env("api_token", "IUGU_API_TOKEN")
|
||||
|
||||
69
langchain/document_loaders/mhtml.py
Normal file
69
langchain/document_loaders/mhtml.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Loader to load MHTML files, enriching metadata with page title."""
|
||||
|
||||
import email
|
||||
import logging
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MHTMLLoader(BaseLoader):
|
||||
"""Loader that uses beautiful soup to parse HTML files."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
open_encoding: Union[str, None] = None,
|
||||
bs_kwargs: Union[dict, None] = None,
|
||||
get_text_separator: str = "",
|
||||
) -> None:
|
||||
"""Initialise with path, and optionally, file encoding to use, and any kwargs
|
||||
to pass to the BeautifulSoup object."""
|
||||
try:
|
||||
import bs4 # noqa:F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"beautifulsoup4 package not found, please install it with "
|
||||
"`pip install beautifulsoup4`"
|
||||
)
|
||||
|
||||
self.file_path = file_path
|
||||
self.open_encoding = open_encoding
|
||||
if bs_kwargs is None:
|
||||
bs_kwargs = {"features": "lxml"}
|
||||
self.bs_kwargs = bs_kwargs
|
||||
self.get_text_separator = get_text_separator
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
"""Load MHTML document into document objects."""
|
||||
|
||||
with open(self.file_path, "r", encoding=self.open_encoding) as f:
|
||||
message = email.message_from_string(f.read())
|
||||
parts = message.get_payload()
|
||||
|
||||
if type(parts) is not list:
|
||||
parts = [message]
|
||||
|
||||
for part in parts:
|
||||
if part.get_content_type() == "text/html":
|
||||
html = part.get_payload(decode=True).decode()
|
||||
|
||||
soup = BeautifulSoup(html, **self.bs_kwargs)
|
||||
text = soup.get_text(self.get_text_separator)
|
||||
|
||||
if soup.title:
|
||||
title = str(soup.title.string)
|
||||
else:
|
||||
title = ""
|
||||
|
||||
metadata: Dict[str, Union[str, None]] = {
|
||||
"source": self.file_path,
|
||||
"title": title,
|
||||
}
|
||||
return [Document(page_content=text, metadata=metadata)]
|
||||
return []
|
||||
@@ -27,6 +27,8 @@ incoming_payment_details",
|
||||
|
||||
|
||||
class ModernTreasuryLoader(BaseLoader):
|
||||
"""Loader that fetches data from Modern Treasury."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
resource: str,
|
||||
|
||||
@@ -48,13 +48,13 @@ class NotionDBLoader(BaseLoader):
|
||||
Returns:
|
||||
List[Document]: List of documents.
|
||||
"""
|
||||
page_ids = self._retrieve_page_ids()
|
||||
page_summaries = self._retrieve_page_summaries()
|
||||
|
||||
return list(self.load_page(page_id) for page_id in page_ids)
|
||||
return list(self.load_page(page_summary) for page_summary in page_summaries)
|
||||
|
||||
def _retrieve_page_ids(
|
||||
def _retrieve_page_summaries(
|
||||
self, query_dict: Dict[str, Any] = {"page_size": 100}
|
||||
) -> List[str]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get all the pages from a Notion database."""
|
||||
pages: List[Dict[str, Any]] = []
|
||||
|
||||
@@ -72,18 +72,16 @@ class NotionDBLoader(BaseLoader):
|
||||
|
||||
query_dict["start_cursor"] = data.get("next_cursor")
|
||||
|
||||
page_ids = [page["id"] for page in pages]
|
||||
return pages
|
||||
|
||||
return page_ids
|
||||
|
||||
def load_page(self, page_id: str) -> Document:
|
||||
def load_page(self, page_summary: Dict[str, Any]) -> Document:
|
||||
"""Read a page."""
|
||||
data = self._request(PAGE_URL.format(page_id=page_id))
|
||||
page_id = page_summary["id"]
|
||||
|
||||
# load properties as metadata
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
for prop_name, prop_data in data["properties"].items():
|
||||
for prop_name, prop_data in page_summary["properties"].items():
|
||||
prop_type = prop_data["type"]
|
||||
|
||||
if prop_type == "rich_text":
|
||||
|
||||
@@ -7,6 +7,8 @@ from langchain.schema import Document
|
||||
|
||||
|
||||
class TextParser(BaseBlobParser):
|
||||
"""Parser for text blobs."""
|
||||
|
||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Lazily parse the blob."""
|
||||
yield Document(page_content=blob.as_string(), metadata={"source": blob.source})
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Iterator, List, Optional, Set
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
@@ -21,6 +20,13 @@ class RecusiveUrlLoader(BaseLoader):
|
||||
) -> Set[str]:
|
||||
"""Recursively get all child links starting with the path of the input URL."""
|
||||
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The BeautifulSoup package is required for the RecusiveUrlLoader."
|
||||
)
|
||||
|
||||
# Construct the base and parent URLs
|
||||
parsed_url = urlparse(url)
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
22
langchain/document_loaders/rst.py
Normal file
22
langchain/document_loaders/rst.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Loader that loads RST files."""
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.document_loaders.unstructured import (
|
||||
UnstructuredFileLoader,
|
||||
validate_unstructured_version,
|
||||
)
|
||||
|
||||
|
||||
class UnstructuredRSTLoader(UnstructuredFileLoader):
|
||||
"""Loader that uses unstructured to load RST files."""
|
||||
|
||||
def __init__(
|
||||
self, file_path: str, mode: str = "single", **unstructured_kwargs: Any
|
||||
):
|
||||
validate_unstructured_version(min_unstructured_version="0.7.5")
|
||||
super().__init__(file_path=file_path, mode=mode, **unstructured_kwargs)
|
||||
|
||||
def _get_elements(self) -> List:
|
||||
from unstructured.partition.rst import partition_rst
|
||||
|
||||
return partition_rst(filename=self.file_path, **self.unstructured_kwargs)
|
||||
@@ -20,6 +20,8 @@ SPREEDLY_ENDPOINTS = {
|
||||
|
||||
|
||||
class SpreedlyLoader(BaseLoader):
|
||||
"""Loader that fetches data from Spreedly API."""
|
||||
|
||||
def __init__(self, access_token: str, resource: str) -> None:
|
||||
self.access_token = access_token
|
||||
self.resource = resource
|
||||
|
||||
@@ -18,6 +18,8 @@ STRIPE_ENDPOINTS = {
|
||||
|
||||
|
||||
class StripeLoader(BaseLoader):
|
||||
"""Loader that fetches data from Stripe."""
|
||||
|
||||
def __init__(self, resource: str, access_token: Optional[str] = None) -> None:
|
||||
self.resource = resource
|
||||
access_token = access_token or get_from_env(
|
||||
|
||||
@@ -16,6 +16,7 @@ class UnstructuredURLLoader(BaseLoader):
|
||||
urls: List[str],
|
||||
continue_on_failure: bool = True,
|
||||
mode: str = "single",
|
||||
show_progress_bar: bool = False,
|
||||
**unstructured_kwargs: Any,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
@@ -51,6 +52,7 @@ class UnstructuredURLLoader(BaseLoader):
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.headers = headers
|
||||
self.unstructured_kwargs = unstructured_kwargs
|
||||
self.show_progress_bar = show_progress_bar
|
||||
|
||||
def _validate_mode(self, mode: str) -> None:
|
||||
_valid_modes = {"single", "elements"}
|
||||
@@ -83,7 +85,21 @@ class UnstructuredURLLoader(BaseLoader):
|
||||
from unstructured.partition.html import partition_html
|
||||
|
||||
docs: List[Document] = list()
|
||||
for url in self.urls:
|
||||
if self.show_progress_bar:
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Package tqdm must be installed if show_progress_bar=True. "
|
||||
"Please install with 'pip install tqdm' or set "
|
||||
"show_progress_bar=False."
|
||||
) from e
|
||||
|
||||
urls = tqdm(self.urls)
|
||||
else:
|
||||
urls = self.urls
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
if self.__is_non_html_available():
|
||||
if self.__is_headers_available_for_non_html():
|
||||
|
||||
@@ -50,6 +50,9 @@ class WebBaseLoader(BaseLoader):
|
||||
requests_kwargs: Dict[str, Any] = {}
|
||||
"""kwargs for requests"""
|
||||
|
||||
bs_get_text_kwargs: Dict[str, Any] = {}
|
||||
"""kwargs for beatifulsoup4 get_text"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
web_path: Union[str, List[str]],
|
||||
@@ -201,7 +204,7 @@ class WebBaseLoader(BaseLoader):
|
||||
"""Lazy load text from the url(s) in web_path."""
|
||||
for path in self.web_paths:
|
||||
soup = self._scrape(path)
|
||||
text = soup.get_text()
|
||||
text = soup.get_text(**self.bs_get_text_kwargs)
|
||||
metadata = _build_metadata(soup, path)
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
|
||||
@@ -216,7 +219,7 @@ class WebBaseLoader(BaseLoader):
|
||||
docs = []
|
||||
for i in range(len(results)):
|
||||
soup = results[i]
|
||||
text = soup.get_text()
|
||||
text = soup.get_text(**self.bs_get_text_kwargs)
|
||||
metadata = _build_metadata(soup, self.web_paths[i])
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class WhatsAppChatLoader(BaseLoader):
|
||||
(?:
|
||||
:\d{2}
|
||||
)?
|
||||
(?:[ _](?:AM|PM))?
|
||||
(?:[\s_](?:AM|PM))?
|
||||
)
|
||||
\]?
|
||||
[\s-]*
|
||||
@@ -50,7 +50,9 @@ class WhatsAppChatLoader(BaseLoader):
|
||||
(.+)
|
||||
"""
|
||||
for line in lines:
|
||||
result = re.match(message_line_regex, line.strip(), flags=re.VERBOSE)
|
||||
result = re.match(
|
||||
message_line_regex, line.strip(), flags=re.VERBOSE | re.IGNORECASE
|
||||
)
|
||||
if result:
|
||||
date, sender, text = result.groups()
|
||||
text_content += concatenate_rows(date, sender, text)
|
||||
|
||||
@@ -18,17 +18,41 @@ class WikipediaLoader(BaseLoader):
|
||||
lang: str = "en",
|
||||
load_max_docs: Optional[int] = 100,
|
||||
load_all_available_meta: Optional[bool] = False,
|
||||
doc_content_chars_max: Optional[int] = 4000,
|
||||
):
|
||||
"""
|
||||
Initializes a new instance of the WikipediaLoader class.
|
||||
|
||||
Args:
|
||||
query (str): The query string to search on Wikipedia.
|
||||
lang (str, optional): The language code for the Wikipedia language edition.
|
||||
Defaults to "en".
|
||||
load_max_docs (int, optional): The maximum number of documents to load.
|
||||
Defaults to 100.
|
||||
load_all_available_meta (bool, optional): Indicates whether to load all
|
||||
available metadata for each document. Defaults to False.
|
||||
doc_content_chars_max (int, optional): The maximum number of characters
|
||||
for the document content. Defaults to 4000.
|
||||
"""
|
||||
self.query = query
|
||||
self.lang = lang
|
||||
self.load_max_docs = load_max_docs
|
||||
self.load_all_available_meta = load_all_available_meta
|
||||
self.doc_content_chars_max = doc_content_chars_max
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""
|
||||
Loads the query result from Wikipedia into a list of Documents.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of Document objects representing the loaded
|
||||
Wikipedia pages.
|
||||
"""
|
||||
client = WikipediaAPIWrapper(
|
||||
lang=self.lang,
|
||||
top_k_results=self.load_max_docs,
|
||||
load_all_available_meta=self.load_all_available_meta,
|
||||
doc_content_chars_max=self.doc_content_chars_max,
|
||||
)
|
||||
docs = client.load(self.query)
|
||||
return docs
|
||||
|
||||
@@ -30,6 +30,14 @@ class _DocumentWithState(Document):
|
||||
def get_stateful_documents(
|
||||
documents: Sequence[Document],
|
||||
) -> Sequence[_DocumentWithState]:
|
||||
"""Convert a list of documents to a list of documents with state.
|
||||
|
||||
Args:
|
||||
documents: The documents to convert.
|
||||
|
||||
Returns:
|
||||
A list of documents with state.
|
||||
"""
|
||||
return [_DocumentWithState.from_document(doc) for doc in documents]
|
||||
|
||||
|
||||
|
||||
@@ -170,6 +170,16 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||
"""Timeout in seconds for the OpenAPI request."""
|
||||
headers: Any = None
|
||||
tiktoken_model_name: Optional[str] = None
|
||||
"""The model name to pass to tiktoken when using this class.
|
||||
Tiktoken is used to count the number of tokens in documents to constrain
|
||||
them to be under a certain limit. By default, when set to None, this will
|
||||
be the same as the embedding model name. However, there are some cases
|
||||
where you may want to use this Embedding class with a model name not
|
||||
supported by tiktoken. This can include when using Azure embeddings or
|
||||
when using one of the many model providers that expose an OpenAI-like
|
||||
API but with different models. In those cases, in order to avoid erroring
|
||||
when tiktoken is called, you can specify a model name to use here."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -265,7 +275,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
tokens = []
|
||||
indices = []
|
||||
encoding = tiktoken.model.encoding_for_model(self.model)
|
||||
model_name = self.tiktoken_model_name or self.model
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
for i, text in enumerate(texts):
|
||||
if self.model.endswith("001"):
|
||||
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||
@@ -329,7 +345,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
tokens = []
|
||||
indices = []
|
||||
encoding = tiktoken.model.encoding_for_model(self.model)
|
||||
model_name = self.tiktoken_model_name or self.model
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
for i, text in enumerate(texts):
|
||||
if self.model.endswith("001"):
|
||||
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||
|
||||
@@ -18,7 +18,7 @@ These questions should be detailed and be based explicitly on information in the
|
||||
{doc}
|
||||
<End Document>"""
|
||||
output_parser = RegexParser(
|
||||
regex=r"QUESTION: (.*?)\nANSWER: (.*)", output_keys=["query", "answer"]
|
||||
regex=r"QUESTION: (.*?)\n+ANSWER: (.*)", output_keys=["query", "answer"]
|
||||
)
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["doc"], template=template, output_parser=output_parser
|
||||
|
||||
@@ -18,8 +18,17 @@ class BaseAutoGPTOutputParser(BaseOutputParser):
|
||||
|
||||
|
||||
def preprocess_json_input(input_str: str) -> str:
|
||||
# Replace single backslashes with double backslashes,
|
||||
# while leaving already escaped ones intact
|
||||
"""Preprocesses a string to be parsed as json.
|
||||
|
||||
Replace single backslashes with double backslashes,
|
||||
while leaving already escaped ones intact.
|
||||
|
||||
Args:
|
||||
input_str: String to be preprocessed
|
||||
|
||||
Returns:
|
||||
Preprocessed string
|
||||
"""
|
||||
corrected_str = re.sub(
|
||||
r'(?<!\\)\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', r"\\\\", input_str
|
||||
)
|
||||
|
||||
@@ -23,6 +23,18 @@ def load_agent_executor(
|
||||
verbose: bool = False,
|
||||
include_task_in_prompt: bool = False,
|
||||
) -> ChainExecutor:
|
||||
"""
|
||||
Load an agent executor.
|
||||
|
||||
Args:
|
||||
llm: BaseLanguageModel
|
||||
tools: List[BaseTool]
|
||||
verbose: bool. Defaults to False.
|
||||
include_task_in_prompt: bool. Defaults to False.
|
||||
|
||||
Returns:
|
||||
ChainExecutor
|
||||
"""
|
||||
input_variables = ["previous_steps", "current_step", "agent_scratchpad"]
|
||||
template = HUMAN_MESSAGE_TEMPLATE
|
||||
|
||||
|
||||
@@ -11,13 +11,13 @@ from langchain.experimental.plan_and_execute.schema import Plan, PlanOutputParse
|
||||
class BasePlanner(BaseModel):
|
||||
@abstractmethod
|
||||
def plan(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> Plan:
|
||||
"""Given input, decided what to do."""
|
||||
"""Given input, decide what to do."""
|
||||
|
||||
@abstractmethod
|
||||
async def aplan(
|
||||
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Plan:
|
||||
"""Given input, decided what to do."""
|
||||
"""Given input, decide what to do."""
|
||||
|
||||
|
||||
class LLMPlanner(BasePlanner):
|
||||
@@ -26,14 +26,14 @@ class LLMPlanner(BasePlanner):
|
||||
stop: Optional[List] = None
|
||||
|
||||
def plan(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> Plan:
|
||||
"""Given input, decided what to do."""
|
||||
"""Given input, decide what to do."""
|
||||
llm_response = self.llm_chain.run(**inputs, stop=self.stop, callbacks=callbacks)
|
||||
return self.output_parser.parse(llm_response)
|
||||
|
||||
async def aplan(
|
||||
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Plan:
|
||||
"""Given input, decided what to do."""
|
||||
"""Given input, decide what to do."""
|
||||
llm_response = await self.llm_chain.arun(
|
||||
**inputs, stop=self.stop, callbacks=callbacks
|
||||
)
|
||||
|
||||
@@ -32,6 +32,15 @@ class PlanningOutputParser(PlanOutputParser):
|
||||
def load_chat_planner(
|
||||
llm: BaseLanguageModel, system_prompt: str = SYSTEM_PROMPT
|
||||
) -> LLMPlanner:
|
||||
"""
|
||||
Load a chat planner.
|
||||
Args:
|
||||
llm: Language model.
|
||||
system_prompt: System prompt.
|
||||
|
||||
Returns:
|
||||
LLMPlanner
|
||||
"""
|
||||
prompt_template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
SystemMessage(content=system_prompt),
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Dict, Type
|
||||
|
||||
from langchain.llms.ai21 import AI21
|
||||
from langchain.llms.aleph_alpha import AlephAlpha
|
||||
from langchain.llms.amazon_api_gateway import AmazonAPIGateway
|
||||
from langchain.llms.anthropic import Anthropic
|
||||
from langchain.llms.anyscale import Anyscale
|
||||
from langchain.llms.aviary import Aviary
|
||||
@@ -53,6 +54,7 @@ from langchain.llms.writer import Writer
|
||||
__all__ = [
|
||||
"AI21",
|
||||
"AlephAlpha",
|
||||
"AmazonAPIGateway",
|
||||
"Anthropic",
|
||||
"Anyscale",
|
||||
"Aviary",
|
||||
@@ -106,6 +108,8 @@ __all__ = [
|
||||
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
"ai21": AI21,
|
||||
"aleph_alpha": AlephAlpha,
|
||||
"amazon_api_gateway": AmazonAPIGateway,
|
||||
"amazon_bedrock": Bedrock,
|
||||
"anthropic": Anthropic,
|
||||
"anyscale": Anyscale,
|
||||
"aviary": Aviary,
|
||||
|
||||
98
langchain/llms/amazon_api_gateway.py
Normal file
98
langchain/llms/amazon_api_gateway.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
|
||||
class ContentHandlerAmazonAPIGateway:
|
||||
"""Adapter class to prepare the inputs from Langchain to a format
|
||||
that LLM model expects. Also, provides helper function to extract
|
||||
the generated text from the model response."""
|
||||
|
||||
@classmethod
|
||||
def transform_input(
|
||||
cls, prompt: str, model_kwargs: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
return {"inputs": prompt, "parameters": model_kwargs}
|
||||
|
||||
@classmethod
|
||||
def transform_output(cls, response: Any) -> str:
|
||||
return response.json()[0]["generated_text"]
|
||||
|
||||
|
||||
class AmazonAPIGateway(LLM):
|
||||
"""Wrapper around custom Amazon API Gateway"""
|
||||
|
||||
api_url: str
|
||||
"""API Gateway URL"""
|
||||
|
||||
model_kwargs: Optional[Dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
content_handler: ContentHandlerAmazonAPIGateway = ContentHandlerAmazonAPIGateway()
|
||||
"""The content handler class that provides an input and
|
||||
output transform functions to handle formats between LLM
|
||||
and the endpoint.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
**{"endpoint_name": self.api_url},
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "amazon_api_gateway"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Amazon API Gateway model.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = se("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
payload = self.content_handler.transform_input(prompt, _model_kwargs)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
json=payload,
|
||||
)
|
||||
text = self.content_handler.transform_output(response)
|
||||
|
||||
except Exception as error:
|
||||
raise ValueError(f"Error raised by the service: {error}")
|
||||
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
return text
|
||||
@@ -1,8 +1,10 @@
|
||||
"""Wrapper around Aviary"""
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
import dataclasses
|
||||
import os
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union, cast
|
||||
|
||||
import requests
|
||||
from pydantic import Extra, Field, root_validator
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
@@ -12,6 +14,68 @@ from langchain.utils import get_from_dict_or_env
|
||||
TIMEOUT = 60
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AviaryBackend:
|
||||
backend_url: str
|
||||
bearer: str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.header = {"Authorization": self.bearer}
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "AviaryBackend":
|
||||
aviary_url = os.getenv("AVIARY_URL")
|
||||
assert aviary_url, "AVIARY_URL must be set"
|
||||
|
||||
aviary_token = os.getenv("AVIARY_TOKEN", "")
|
||||
|
||||
bearer = f"Bearer {aviary_token}" if aviary_token else ""
|
||||
aviary_url += "/" if not aviary_url.endswith("/") else ""
|
||||
|
||||
return cls(aviary_url, bearer)
|
||||
|
||||
|
||||
def get_models() -> List[str]:
|
||||
"""List available models"""
|
||||
backend = AviaryBackend.from_env()
|
||||
request_url = backend.backend_url + "-/routes"
|
||||
response = requests.get(request_url, headers=backend.header, timeout=TIMEOUT)
|
||||
try:
|
||||
result = response.json()
|
||||
except requests.JSONDecodeError as e:
|
||||
raise RuntimeError(
|
||||
f"Error decoding JSON from {request_url}. Text response: {response.text}"
|
||||
) from e
|
||||
result = sorted(
|
||||
[k.lstrip("/").replace("--", "/") for k in result.keys() if "--" in k]
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def get_completions(
|
||||
model: str,
|
||||
prompt: str,
|
||||
use_prompt_format: bool = True,
|
||||
version: str = "",
|
||||
) -> Dict[str, Union[str, float, int]]:
|
||||
"""Get completions from Aviary models."""
|
||||
|
||||
backend = AviaryBackend.from_env()
|
||||
url = backend.backend_url + model.replace("/", "--") + "/" + version + "query"
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=backend.header,
|
||||
json={"prompt": prompt, "use_prompt_format": use_prompt_format},
|
||||
timeout=TIMEOUT,
|
||||
)
|
||||
try:
|
||||
return response.json()
|
||||
except requests.JSONDecodeError as e:
|
||||
raise RuntimeError(
|
||||
f"Error decoding JSON from {url}. Text response: {response.text}"
|
||||
) from e
|
||||
|
||||
|
||||
class Aviary(LLM):
|
||||
"""Allow you to use an Aviary.
|
||||
|
||||
@@ -19,33 +83,30 @@ class Aviary(LLM):
|
||||
find out more about aviary at
|
||||
http://github.com/ray-project/aviary
|
||||
|
||||
Has no dependencies, since it connects to backend
|
||||
directly.
|
||||
|
||||
To get a list of the models supported on an
|
||||
aviary, follow the instructions on the web site to
|
||||
install the aviary CLI and then use:
|
||||
`aviary models`
|
||||
|
||||
You must at least specify the environment
|
||||
variable or parameter AVIARY_URL.
|
||||
|
||||
You may optionally specify the environment variable
|
||||
or parameter AVIARY_TOKEN.
|
||||
AVIARY_URL and AVIARY_TOKEN environement variables must be set.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import Aviary
|
||||
light = Aviary(aviary_url='AVIARY_URL',
|
||||
model='amazon/LightGPT')
|
||||
|
||||
result = light.predict('How do you make fried rice?')
|
||||
os.environ["AVIARY_URL"] = "<URL>"
|
||||
os.environ["AVIARY_TOKEN"] = "<TOKEN>"
|
||||
light = Aviary(model='amazon/LightGPT')
|
||||
output = light('How do you make fried rice?')
|
||||
"""
|
||||
|
||||
model: str
|
||||
aviary_url: str
|
||||
aviary_token: str = Field("", exclude=True)
|
||||
model: str = "amazon/LightGPT"
|
||||
aviary_url: Optional[str] = None
|
||||
aviary_token: Optional[str] = None
|
||||
# If True the prompt template for the model will be ignored.
|
||||
use_prompt_format: bool = True
|
||||
# API version to use for Aviary
|
||||
version: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -56,49 +117,35 @@ class Aviary(LLM):
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
aviary_url = get_from_dict_or_env(values, "aviary_url", "AVIARY_URL")
|
||||
if not aviary_url.endswith("/"):
|
||||
aviary_url += "/"
|
||||
values["aviary_url"] = aviary_url
|
||||
aviary_token = get_from_dict_or_env(
|
||||
values, "aviary_token", "AVIARY_TOKEN", default=""
|
||||
)
|
||||
values["aviary_token"] = aviary_token
|
||||
aviary_token = get_from_dict_or_env(values, "aviary_token", "AVIARY_TOKEN")
|
||||
|
||||
# Set env viarables for aviary sdk
|
||||
os.environ["AVIARY_URL"] = aviary_url
|
||||
os.environ["AVIARY_TOKEN"] = aviary_token
|
||||
|
||||
aviary_endpoint = aviary_url + "models"
|
||||
headers = {"Authorization": f"Bearer {aviary_token}"} if aviary_token else {}
|
||||
try:
|
||||
response = requests.get(aviary_endpoint, headers=headers)
|
||||
result = response.json()
|
||||
# Confirm model is available
|
||||
if values["model"] not in result:
|
||||
raise ValueError(
|
||||
f"{aviary_url} does not support model {values['model']}."
|
||||
)
|
||||
|
||||
aviary_models = get_models()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(e)
|
||||
|
||||
model = values.get("model")
|
||||
if model and model not in aviary_models:
|
||||
raise ValueError(f"{aviary_url} does not support model {values['model']}.")
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model_name": self.model,
|
||||
"aviary_url": self.aviary_url,
|
||||
"aviary_token": self.aviary_token,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "aviary"
|
||||
|
||||
@property
|
||||
def headers(self) -> Dict[str, str]:
|
||||
if self.aviary_token:
|
||||
return {"Authorization": f"Bearer {self.aviary_token}"}
|
||||
else:
|
||||
return {}
|
||||
return f"aviary-{self.model.replace('/', '-')}"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
@@ -119,19 +166,18 @@ class Aviary(LLM):
|
||||
|
||||
response = aviary("Tell me a joke.")
|
||||
"""
|
||||
url = self.aviary_url + "query/" + self.model.replace("/", "--")
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self.headers,
|
||||
json={"prompt": prompt},
|
||||
timeout=TIMEOUT,
|
||||
kwargs = {"use_prompt_format": self.use_prompt_format}
|
||||
if self.version:
|
||||
kwargs["version"] = self.version
|
||||
|
||||
output = get_completions(
|
||||
model=self.model,
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
text = response.json()[self.model]["generated_text"]
|
||||
except requests.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Error decoding JSON from {url}. Text response: {response.text}",
|
||||
) from e
|
||||
|
||||
text = cast(str, output["generated_text"])
|
||||
if stop:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
return text
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Base interface for large language models to expose."""
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
@@ -151,6 +152,39 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
prompt_strings, stop=stop, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
def _generate_helper(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]],
|
||||
run_managers: List[CallbackManagerForLLMRun],
|
||||
new_arg_supported: bool,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
try:
|
||||
output = (
|
||||
self._generate(
|
||||
prompts,
|
||||
stop=stop,
|
||||
# TODO: support multiple run managers
|
||||
run_manager=run_managers[0] if run_managers else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_arg_supported
|
||||
else self._generate(prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
flattened_outputs = output.flatten()
|
||||
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
||||
manager.on_llm_end(flattened_output)
|
||||
if run_managers:
|
||||
output.run = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
return output
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@@ -161,8 +195,6 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
# If string is passed in directly no errors will be raised but outputs will
|
||||
# not make sense.
|
||||
if not isinstance(prompts, list):
|
||||
raise ValueError(
|
||||
"Argument 'prompts' is expected to be of type List[str], received"
|
||||
@@ -185,60 +217,77 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
"run_manager"
|
||||
)
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
run_manager = callback_manager.on_llm_start(
|
||||
run_managers = callback_manager.on_llm_start(
|
||||
dumpd(self), prompts, invocation_params=params, options=options
|
||||
)
|
||||
try:
|
||||
output = (
|
||||
self._generate(
|
||||
prompts, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else self._generate(prompts, stop=stop, **kwargs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
run_manager.on_llm_end(output)
|
||||
if run_manager:
|
||||
output.run = RunInfo(run_id=run_manager.run_id)
|
||||
output = self._generate_helper(
|
||||
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
||||
)
|
||||
return output
|
||||
if len(missing_prompts) > 0:
|
||||
run_manager = callback_manager.on_llm_start(
|
||||
dumpd(self),
|
||||
missing_prompts,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
run_managers = callback_manager.on_llm_start(
|
||||
dumpd(self), missing_prompts, invocation_params=params, options=options
|
||||
)
|
||||
new_results = self._generate_helper(
|
||||
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
||||
)
|
||||
try:
|
||||
new_results = (
|
||||
self._generate(
|
||||
missing_prompts, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else self._generate(missing_prompts, stop=stop, **kwargs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
run_manager.on_llm_end(new_results)
|
||||
llm_output = update_cache(
|
||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||
)
|
||||
run_info = None
|
||||
if run_manager:
|
||||
run_info = RunInfo(run_id=run_manager.run_id)
|
||||
run_info = (
|
||||
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers]
|
||||
if run_managers
|
||||
else None
|
||||
)
|
||||
else:
|
||||
llm_output = {}
|
||||
run_info = None
|
||||
generations = [existing_prompts[i] for i in range(len(prompts))]
|
||||
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
|
||||
|
||||
async def _agenerate_helper(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]],
|
||||
run_managers: List[AsyncCallbackManagerForLLMRun],
|
||||
new_arg_supported: bool,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
try:
|
||||
output = (
|
||||
await self._agenerate(
|
||||
prompts,
|
||||
stop=stop,
|
||||
run_manager=run_managers[0] if run_managers else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_arg_supported
|
||||
else await self._agenerate(prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await asyncio.gather(
|
||||
*[run_manager.on_llm_error(e) for run_manager in run_managers]
|
||||
)
|
||||
raise e
|
||||
flattened_outputs = output.flatten()
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(flattened_output)
|
||||
for run_manager, flattened_output in zip(
|
||||
run_managers, flattened_outputs
|
||||
)
|
||||
]
|
||||
)
|
||||
if run_managers:
|
||||
output.run = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
return output
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@@ -266,54 +315,32 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
"run_manager"
|
||||
)
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
run_manager = await callback_manager.on_llm_start(
|
||||
run_managers = await callback_manager.on_llm_start(
|
||||
dumpd(self), prompts, invocation_params=params, options=options
|
||||
)
|
||||
try:
|
||||
output = (
|
||||
await self._agenerate(
|
||||
prompts, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else await self._agenerate(prompts, stop=stop, **kwargs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
await run_manager.on_llm_end(output, verbose=self.verbose)
|
||||
if run_manager:
|
||||
output.run = RunInfo(run_id=run_manager.run_id)
|
||||
output = await self._agenerate_helper(
|
||||
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
||||
)
|
||||
return output
|
||||
if len(missing_prompts) > 0:
|
||||
run_manager = await callback_manager.on_llm_start(
|
||||
dumpd(self),
|
||||
missing_prompts,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
run_managers = await callback_manager.on_llm_start(
|
||||
dumpd(self), missing_prompts, invocation_params=params, options=options
|
||||
)
|
||||
new_results = await self._agenerate_helper(
|
||||
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
||||
)
|
||||
try:
|
||||
new_results = (
|
||||
await self._agenerate(
|
||||
missing_prompts, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else await self._agenerate(missing_prompts, stop=stop, **kwargs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
raise e
|
||||
await run_manager.on_llm_end(new_results)
|
||||
llm_output = update_cache(
|
||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||
)
|
||||
run_info = None
|
||||
if run_manager:
|
||||
run_info = RunInfo(run_id=run_manager.run_id)
|
||||
run_info = (
|
||||
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers]
|
||||
if run_managers
|
||||
else None
|
||||
)
|
||||
else:
|
||||
llm_output = {}
|
||||
run_info = None
|
||||
|
||||
@@ -171,6 +171,16 @@ class BaseOpenAI(BaseLLM):
|
||||
"""Set of special tokens that are allowed。"""
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
|
||||
"""Set of special tokens that are not allowed。"""
|
||||
tiktoken_model_name: Optional[str] = None
|
||||
"""The model name to pass to tiktoken when using this class.
|
||||
Tiktoken is used to count the number of tokens in documents to constrain
|
||||
them to be under a certain limit. By default, when set to None, this will
|
||||
be the same as the embedding model name. However, there are some cases
|
||||
where you may want to use this Embedding class with a model name not
|
||||
supported by tiktoken. This can include when using Azure embeddings or
|
||||
when using one of the many model providers that expose an OpenAI-like
|
||||
API but with different models. In those cases, in order to avoid erroring
|
||||
when tiktoken is called, you can specify a model name to use here."""
|
||||
|
||||
def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
|
||||
"""Initialize the OpenAI object."""
|
||||
@@ -491,7 +501,13 @@ class BaseOpenAI(BaseLLM):
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
|
||||
enc = tiktoken.encoding_for_model(self.model_name)
|
||||
model_name = self.tiktoken_model_name or self.model_name
|
||||
try:
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
enc = tiktoken.get_encoding(model)
|
||||
|
||||
return enc.encode(
|
||||
text,
|
||||
|
||||
@@ -5,6 +5,8 @@ from langchain.load.serializable import Serializable, to_json_not_implemented
|
||||
|
||||
|
||||
def default(obj: Any) -> Any:
|
||||
"""Return a default value for a Serializable object or
|
||||
a SerializedNotImplemented object."""
|
||||
if isinstance(obj, Serializable):
|
||||
return obj.to_json()
|
||||
else:
|
||||
@@ -12,6 +14,7 @@ def default(obj: Any) -> Any:
|
||||
|
||||
|
||||
def dumps(obj: Any, *, pretty: bool = False) -> str:
|
||||
"""Return a json string representation of an object."""
|
||||
if pretty:
|
||||
return json.dumps(obj, default=default, indent=2)
|
||||
else:
|
||||
@@ -19,4 +22,5 @@ def dumps(obj: Any, *, pretty: bool = False) -> str:
|
||||
|
||||
|
||||
def dumpd(obj: Any) -> Dict[str, Any]:
|
||||
"""Return a json dict representation of an object."""
|
||||
return json.loads(dumps(obj))
|
||||
|
||||
@@ -5,24 +5,34 @@ from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
|
||||
class BaseSerialized(TypedDict):
|
||||
"""Base class for serialized objects."""
|
||||
|
||||
lc: int
|
||||
id: List[str]
|
||||
|
||||
|
||||
class SerializedConstructor(BaseSerialized):
|
||||
"""Serialized constructor."""
|
||||
|
||||
type: Literal["constructor"]
|
||||
kwargs: Dict[str, Any]
|
||||
|
||||
|
||||
class SerializedSecret(BaseSerialized):
|
||||
"""Serialized secret."""
|
||||
|
||||
type: Literal["secret"]
|
||||
|
||||
|
||||
class SerializedNotImplemented(BaseSerialized):
|
||||
"""Serialized not implemented."""
|
||||
|
||||
type: Literal["not_implemented"]
|
||||
|
||||
|
||||
class Serializable(BaseModel, ABC):
|
||||
"""Serializable base class."""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
"""
|
||||
@@ -130,6 +140,14 @@ def _replace_secrets(
|
||||
|
||||
|
||||
def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
|
||||
"""Serialize a "not implemented" object.
|
||||
|
||||
Args:
|
||||
obj: object to serialize
|
||||
|
||||
Returns:
|
||||
SerializedNotImplemented
|
||||
"""
|
||||
_id: List[str] = []
|
||||
try:
|
||||
if hasattr(obj, "__name__"):
|
||||
|
||||
@@ -15,6 +15,8 @@ DEFAULT_CONNECTION_STRING = "postgresql://postgres:mypassword@localhost/chat_his
|
||||
|
||||
|
||||
class PostgresChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history stored in a Postgres database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
|
||||
@@ -13,6 +13,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history stored in a Redis database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
|
||||
@@ -21,6 +21,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_message_model(table_name, DynamicBase): # type: ignore
|
||||
"""
|
||||
Create a message model for a given table name.
|
||||
Args:
|
||||
table_name: The name of the table to use.
|
||||
DynamicBase: The base class to use for the model.
|
||||
|
||||
Returns:
|
||||
The model class.
|
||||
|
||||
"""
|
||||
|
||||
# Model decleared inside a function to have a dynamic table name
|
||||
class Message(DynamicBase):
|
||||
__tablename__ = table_name
|
||||
@@ -32,6 +43,8 @@ def create_message_model(table_name, DynamicBase): # type: ignore
|
||||
|
||||
|
||||
class SQLChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history stored in an SQL database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
|
||||
@@ -86,3 +86,7 @@ class MotorheadMemory(BaseChatMemory):
|
||||
headers=self.__get_headers(),
|
||||
)
|
||||
super().save_context(inputs, outputs)
|
||||
|
||||
def delete_session(self) -> None:
|
||||
"""Delete a session"""
|
||||
requests.delete(f"{self.url}/sessions/{self.session_id}/memory")
|
||||
|
||||
@@ -4,6 +4,16 @@ from langchain.schema import get_buffer_string # noqa: 401
|
||||
|
||||
|
||||
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
|
||||
"""
|
||||
Get the prompt input key.
|
||||
|
||||
Args:
|
||||
inputs: Dict[str, Any]
|
||||
memory_variables: List[str]
|
||||
|
||||
Returns:
|
||||
A prompt input key.
|
||||
"""
|
||||
# "stop" is a special key that can be passed as input but is not used to
|
||||
# format the prompt.
|
||||
prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"]))
|
||||
|
||||
@@ -8,6 +8,15 @@ from langchain.schema import OutputParserException
|
||||
|
||||
|
||||
def parse_json_markdown(json_string: str) -> dict:
|
||||
"""
|
||||
Parse a JSON string from a Markdown string.
|
||||
|
||||
Args:
|
||||
json_string: The Markdown string.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object as a Python dictionary.
|
||||
"""
|
||||
# Try to find JSON string within triple backticks
|
||||
match = re.search(r"```(json)?(.*?)```", json_string, re.DOTALL)
|
||||
|
||||
@@ -28,6 +37,17 @@ def parse_json_markdown(json_string: str) -> dict:
|
||||
|
||||
|
||||
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
||||
"""
|
||||
Parse a JSON string from a Markdown string and check that it
|
||||
contains the expected keys.
|
||||
|
||||
Args:
|
||||
text: The Markdown string.
|
||||
expected_keys: The expected keys in the JSON string.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object as a Python dictionary.
|
||||
"""
|
||||
try:
|
||||
json_obj = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
|
||||
@@ -28,6 +28,14 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
|
||||
|
||||
def validate_jinja2(template: str, input_variables: List[str]) -> None:
|
||||
"""
|
||||
Validate that the input variables are valid for the template.
|
||||
Raise an exception if missing or extra variables are found.
|
||||
|
||||
Args:
|
||||
template: The template string.
|
||||
input_variables: The input variables.
|
||||
"""
|
||||
input_variables_set = set(input_variables)
|
||||
valid_variables = _get_jinja2_variables_from_template(template)
|
||||
missing_variables = valid_variables - input_variables_set
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from langchain.retrievers.arxiv import ArxivRetriever
|
||||
from langchain.retrievers.aws_kendra_index_retriever import AwsKendraIndexRetriever
|
||||
from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever
|
||||
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
|
||||
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
|
||||
from langchain.retrievers.databerry import DataberryRetriever
|
||||
from langchain.retrievers.docarray import DocArrayRetriever
|
||||
from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever
|
||||
from langchain.retrievers.kendra import AmazonKendraRetriever
|
||||
from langchain.retrievers.knn import KNNRetriever
|
||||
from langchain.retrievers.llama_index import (
|
||||
LlamaIndexGraphRetriever,
|
||||
@@ -30,8 +30,8 @@ from langchain.retrievers.zep import ZepRetriever
|
||||
from langchain.retrievers.zilliz import ZillizRetriever
|
||||
|
||||
__all__ = [
|
||||
"AmazonKendraRetriever",
|
||||
"ArxivRetriever",
|
||||
"AwsKendraIndexRetriever",
|
||||
"AzureCognitiveSearchRetriever",
|
||||
"ChatGPTPluginRetriever",
|
||||
"ContextualCompressionRetriever",
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
"""Retriever wrapper for AWS Kendra."""
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
class AwsKendraIndexRetriever(BaseRetriever):
|
||||
"""Wrapper around AWS Kendra."""
|
||||
|
||||
kendraindex: str
|
||||
"""Kendra index id"""
|
||||
k: int
|
||||
"""Number of documents to query for."""
|
||||
languagecode: str
|
||||
"""Languagecode used for querying."""
|
||||
kclient: Any
|
||||
""" boto3 client for Kendra. """
|
||||
|
||||
def __init__(
|
||||
self, kclient: Any, kendraindex: str, k: int = 3, languagecode: str = "en"
|
||||
):
|
||||
self.kendraindex = kendraindex
|
||||
self.k = k
|
||||
self.languagecode = languagecode
|
||||
self.kclient = kclient
|
||||
|
||||
def _clean_result(self, res_text: str) -> str:
|
||||
return re.sub("\s+", " ", res_text).replace("...", "")
|
||||
|
||||
def _get_top_n_results(self, resp: Dict, count: int) -> Document:
|
||||
r = resp["ResultItems"][count]
|
||||
doc_title = r["DocumentTitle"]["Text"]
|
||||
doc_uri = r["DocumentURI"]
|
||||
r_type = r["Type"]
|
||||
|
||||
if (
|
||||
r["AdditionalAttributes"]
|
||||
and r["AdditionalAttributes"][0]["Key"] == "AnswerText"
|
||||
):
|
||||
res_text = r["AdditionalAttributes"][0]["Value"]["TextWithHighlightsValue"][
|
||||
"Text"
|
||||
]
|
||||
else:
|
||||
res_text = r["DocumentExcerpt"]["Text"]
|
||||
|
||||
doc_excerpt = self._clean_result(res_text)
|
||||
combined_text = f"""Document Title: {doc_title}
|
||||
Document Excerpt: {doc_excerpt}
|
||||
"""
|
||||
|
||||
return Document(
|
||||
page_content=combined_text,
|
||||
metadata={
|
||||
"source": doc_uri,
|
||||
"title": doc_title,
|
||||
"excerpt": doc_excerpt,
|
||||
"type": r_type,
|
||||
},
|
||||
)
|
||||
|
||||
def _kendra_query(self, kquery: str) -> List[Document]:
|
||||
response = self.kclient.query(
|
||||
IndexId=self.kendraindex,
|
||||
QueryText=kquery.strip(),
|
||||
AttributeFilter={
|
||||
"AndAllFilters": [
|
||||
{
|
||||
"EqualsTo": {
|
||||
"Key": "_language_code",
|
||||
"Value": {
|
||||
"StringValue": self.languagecode,
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
if len(response["ResultItems"]) > self.k:
|
||||
r_count = self.k
|
||||
else:
|
||||
r_count = len(response["ResultItems"])
|
||||
|
||||
return [self._get_top_n_results(response, i) for i in range(0, r_count)]
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Run search on Kendra index and get top k documents
|
||||
|
||||
docs = get_relevant_documents('This is my query')
|
||||
"""
|
||||
return self._kendra_query(query)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError("AwsKendraIndexRetriever does not support async")
|
||||
@@ -7,6 +7,8 @@ from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
class DataberryRetriever(BaseRetriever):
|
||||
"""Retriever that uses the Databerry API."""
|
||||
|
||||
datastore_url: str
|
||||
top_k: Optional[int]
|
||||
api_key: Optional[str]
|
||||
|
||||
@@ -10,6 +10,8 @@ from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
"""Enumerator of the types of search to perform."""
|
||||
|
||||
similarity = "similarity"
|
||||
mmr = "mmr"
|
||||
|
||||
|
||||
272
langchain/retrievers/kendra.py
Normal file
272
langchain/retrievers/kendra.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import re
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
|
||||
def clean_excerpt(excerpt: str) -> str:
|
||||
if not excerpt:
|
||||
return excerpt
|
||||
res = re.sub("\s+", " ", excerpt).replace("...", "")
|
||||
return res
|
||||
|
||||
|
||||
def combined_text(title: str, excerpt: str) -> str:
|
||||
if not title or not excerpt:
|
||||
return ""
|
||||
return f"Document Title: {title} \nDocument Excerpt: \n{excerpt}\n"
|
||||
|
||||
|
||||
class Highlight(BaseModel, extra=Extra.allow):
|
||||
BeginOffset: int
|
||||
EndOffset: int
|
||||
TopAnswer: Optional[bool]
|
||||
Type: Optional[str]
|
||||
|
||||
|
||||
class TextWithHighLights(BaseModel, extra=Extra.allow):
|
||||
Text: str
|
||||
Highlights: Optional[Any]
|
||||
|
||||
|
||||
class AdditionalResultAttribute(BaseModel, extra=Extra.allow):
|
||||
Key: str
|
||||
ValueType: Literal["TEXT_WITH_HIGHLIGHTS_VALUE"]
|
||||
Value: Optional[TextWithHighLights]
|
||||
|
||||
def get_value_text(self) -> str:
|
||||
if not self.Value:
|
||||
return ""
|
||||
else:
|
||||
return self.Value.Text
|
||||
|
||||
|
||||
class QueryResultItem(BaseModel, extra=Extra.allow):
|
||||
DocumentId: str
|
||||
DocumentTitle: TextWithHighLights
|
||||
DocumentURI: Optional[str]
|
||||
FeedbackToken: Optional[str]
|
||||
Format: Optional[str]
|
||||
Id: Optional[str]
|
||||
Type: Optional[str]
|
||||
AdditionalAttributes: Optional[List[AdditionalResultAttribute]] = []
|
||||
DocumentExcerpt: Optional[TextWithHighLights]
|
||||
|
||||
def get_attribute_value(self) -> str:
|
||||
if not self.AdditionalAttributes:
|
||||
return ""
|
||||
if not self.AdditionalAttributes[0]:
|
||||
return ""
|
||||
else:
|
||||
return self.AdditionalAttributes[0].get_value_text()
|
||||
|
||||
def get_excerpt(self) -> str:
|
||||
if (
|
||||
self.AdditionalAttributes
|
||||
and self.AdditionalAttributes[0].Key == "AnswerText"
|
||||
):
|
||||
excerpt = self.get_attribute_value()
|
||||
elif self.DocumentExcerpt:
|
||||
excerpt = self.DocumentExcerpt.Text
|
||||
else:
|
||||
excerpt = ""
|
||||
|
||||
return clean_excerpt(excerpt)
|
||||
|
||||
def to_doc(self) -> Document:
|
||||
title = self.DocumentTitle.Text
|
||||
source = self.DocumentURI
|
||||
excerpt = self.get_excerpt()
|
||||
type = self.Type
|
||||
page_content = combined_text(title, excerpt)
|
||||
metadata = {"source": source, "title": title, "excerpt": excerpt, "type": type}
|
||||
return Document(page_content=page_content, metadata=metadata)
|
||||
|
||||
|
||||
class QueryResult(BaseModel, extra=Extra.allow):
|
||||
ResultItems: List[QueryResultItem]
|
||||
|
||||
def get_top_k_docs(self, top_n: int) -> List[Document]:
|
||||
items_len = len(self.ResultItems)
|
||||
count = items_len if items_len < top_n else top_n
|
||||
docs = [self.ResultItems[i].to_doc() for i in range(0, count)]
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
class DocumentAttributeValue(BaseModel, extra=Extra.allow):
|
||||
DateValue: Optional[str]
|
||||
LongValue: Optional[int]
|
||||
StringListValue: Optional[List[str]]
|
||||
StringValue: Optional[str]
|
||||
|
||||
|
||||
class DocumentAttribute(BaseModel, extra=Extra.allow):
|
||||
Key: str
|
||||
Value: DocumentAttributeValue
|
||||
|
||||
|
||||
class RetrieveResultItem(BaseModel, extra=Extra.allow):
|
||||
Content: Optional[str]
|
||||
DocumentAttributes: Optional[List[DocumentAttribute]] = []
|
||||
DocumentId: Optional[str]
|
||||
DocumentTitle: Optional[str]
|
||||
DocumentURI: Optional[str]
|
||||
Id: Optional[str]
|
||||
|
||||
def get_excerpt(self) -> str:
|
||||
if not self.Content:
|
||||
return ""
|
||||
return clean_excerpt(self.Content)
|
||||
|
||||
def to_doc(self) -> Document:
|
||||
title = self.DocumentTitle if self.DocumentTitle else ""
|
||||
source = self.DocumentURI
|
||||
excerpt = self.get_excerpt()
|
||||
page_content = combined_text(title, excerpt)
|
||||
metadata = {"source": source, "title": title, "excerpt": excerpt}
|
||||
return Document(page_content=page_content, metadata=metadata)
|
||||
|
||||
|
||||
class RetrieveResult(BaseModel, extra=Extra.allow):
|
||||
QueryId: str
|
||||
ResultItems: List[RetrieveResultItem]
|
||||
|
||||
def get_top_k_docs(self, top_n: int) -> List[Document]:
|
||||
items_len = len(self.ResultItems)
|
||||
count = items_len if items_len < top_n else top_n
|
||||
docs = [self.ResultItems[i].to_doc() for i in range(0, count)]
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
class AmazonKendraRetriever(BaseRetriever):
|
||||
"""Retriever class to query documents from Amazon Kendra Index.
|
||||
|
||||
Args:
|
||||
index_id: Kendra index id
|
||||
|
||||
region_name: The aws region e.g., `us-west-2`.
|
||||
Fallsback to AWS_DEFAULT_REGION env variable
|
||||
or region specified in ~/.aws/config.
|
||||
|
||||
credentials_profile_name: The name of the profile in the ~/.aws/credentials
|
||||
or ~/.aws/config files, which has either access keys or role information
|
||||
specified. If not specified, the default credential profile or, if on an
|
||||
EC2 instance, credentials from IMDS will be used.
|
||||
|
||||
top_k: No of results to return
|
||||
|
||||
attribute_filter: Additional filtering of results based on metadata
|
||||
See: https://docs.aws.amazon.com/kendra/latest/APIReference
|
||||
|
||||
client: boto3 client for Kendra
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
retriever = AmazonKendraRetriever(
|
||||
index_id="c0806df7-e76b-4bce-9b5c-d5582f6b1a03"
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_id: str,
|
||||
region_name: Optional[str] = None,
|
||||
credentials_profile_name: Optional[str] = None,
|
||||
top_k: int = 3,
|
||||
attribute_filter: Optional[Dict] = None,
|
||||
client: Optional[Any] = None,
|
||||
):
|
||||
self.index_id = index_id
|
||||
self.top_k = top_k
|
||||
self.attribute_filter = attribute_filter
|
||||
|
||||
if client is not None:
|
||||
self.client = client
|
||||
return
|
||||
|
||||
try:
|
||||
import boto3
|
||||
|
||||
if credentials_profile_name is not None:
|
||||
session = boto3.Session(profile_name=credentials_profile_name)
|
||||
else:
|
||||
# use default credentials
|
||||
session = boto3.Session()
|
||||
|
||||
client_params = {}
|
||||
if region_name is not None:
|
||||
client_params["region_name"] = region_name
|
||||
|
||||
self.client = session.client("kendra", **client_params)
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Could not import boto3 python package. "
|
||||
"Please install it with `pip install boto3`."
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not load credentials to authenticate with AWS client. "
|
||||
"Please check that credentials in the specified "
|
||||
"profile name are valid."
|
||||
) from e
|
||||
|
||||
def _kendra_query(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int,
|
||||
attribute_filter: Optional[Dict] = None,
|
||||
) -> List[Document]:
|
||||
if attribute_filter is not None:
|
||||
response = self.client.retrieve(
|
||||
IndexId=self.index_id,
|
||||
QueryText=query.strip(),
|
||||
PageSize=top_k,
|
||||
AttributeFilter=attribute_filter,
|
||||
)
|
||||
else:
|
||||
response = self.client.retrieve(
|
||||
IndexId=self.index_id, QueryText=query.strip(), PageSize=top_k
|
||||
)
|
||||
r_result = RetrieveResult.parse_obj(response)
|
||||
result_len = len(r_result.ResultItems)
|
||||
|
||||
if result_len == 0:
|
||||
# retrieve API returned 0 results, call query API
|
||||
if attribute_filter is not None:
|
||||
response = self.client.query(
|
||||
IndexId=self.index_id,
|
||||
QueryText=query.strip(),
|
||||
PageSize=top_k,
|
||||
AttributeFilter=attribute_filter,
|
||||
)
|
||||
else:
|
||||
response = self.client.query(
|
||||
IndexId=self.index_id, QueryText=query.strip(), PageSize=top_k
|
||||
)
|
||||
q_result = QueryResult.parse_obj(response)
|
||||
docs = q_result.get_top_k_docs(top_k)
|
||||
else:
|
||||
docs = r_result.get_top_k_docs(top_k)
|
||||
return docs
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Run search on Kendra index and get top k documents
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
docs = retriever.get_relevant_documents('This is my query')
|
||||
|
||||
"""
|
||||
docs = self._kendra_query(query, self.top_k, self.attribute_filter)
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError("Async version is not implemented for Kendra yet.")
|
||||
@@ -15,11 +15,23 @@ from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
||||
"""
|
||||
Create an index of embeddings for a list of contexts.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to embed.
|
||||
embeddings: Embeddings model to use.
|
||||
|
||||
Returns:
|
||||
Index of embeddings.
|
||||
"""
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
return np.array(list(executor.map(embeddings.embed_query, contexts)))
|
||||
|
||||
|
||||
class KNNRetriever(BaseRetriever, BaseModel):
|
||||
"""KNN Retriever."""
|
||||
|
||||
embeddings: Embeddings
|
||||
index: Any
|
||||
texts: List[str]
|
||||
|
||||
@@ -4,6 +4,8 @@ from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
class MetalRetriever(BaseRetriever):
|
||||
"""Retriever that uses the Metal API."""
|
||||
|
||||
def __init__(self, client: Any, params: Optional[dict] = None):
|
||||
from metal_sdk.metal import Metal
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ from langchain.vectorstores.milvus import Milvus
|
||||
|
||||
|
||||
class MilvusRetriever(BaseRetriever):
|
||||
"""Retriever that uses the Milvus API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_function: Embeddings,
|
||||
@@ -45,6 +47,15 @@ class MilvusRetriever(BaseRetriever):
|
||||
|
||||
|
||||
def MilvusRetreiver(*args: Any, **kwargs: Any) -> MilvusRetriever:
|
||||
"""Deprecated MilvusRetreiver. Please use MilvusRetriever ('i' before 'e') instead.
|
||||
|
||||
Args:
|
||||
*args:
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
MilvusRetriever
|
||||
"""
|
||||
warnings.warn(
|
||||
"MilvusRetreiver will be deprecated in the future. "
|
||||
"Please use MilvusRetriever ('i' before 'e') instead.",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user