Compare commits

..

48 Commits

Author SHA1 Message Date
Harrison Chase
d1bcc58beb bump version to 216 (#6770) 2023-06-26 09:46:19 -07:00
Zander Chase
6d30acffcb Fix breaking tags (#6765)
Fix tags change that broke old way of initializing agent

Closes #6756
2023-06-26 09:28:11 -07:00
James Croft
ba622764cb Improve performance when retrieving Notion DB pages (#6710) 2023-06-26 05:46:09 -07:00
Richy Wang
ec8247ec59 Fixed bug in AnalyticDB Vector Store caused by upgrade SQLAlchemy version (#6736) 2023-06-26 05:35:25 -07:00
Santiago Delgado
d84a3bcf7a Office365 Tool (#6306)
#### Background
With the development of [structured
tools](https://blog.langchain.dev/structured-tools/), the LangChain team
expanded the platform's functionality to meet the needs of new
applications. The GMail tool, empowered by structured tools, now
supports multiple arguments and powerful search capabilities,
demonstrating LangChain's ability to interact with dynamic data sources
like email servers.

#### Challenge
The current GMail tool only supports GMail, while users often utilize
other email services like Outlook in Office365. Additionally, the
proposed calendar tool in PR
https://github.com/hwchase17/langchain/pull/652 only works with Google
Calendar, not Outlook.

#### Changes
This PR implements an Office365 integration for LangChain, enabling
seamless email and calendar functionality with a single authentication
process.

#### Future Work
With the core Office365 integration complete, future work could include
integrating other Office365 tools such as Tasks and Address Book.

#### Who can review?
@hwchase17 or @vowelparrot can review this PR

#### Appendix
@janscas, I utilized your [O365](https://github.com/O365/python-o365)
library extensively. Given the rising popularity of LangChain and
similar AI frameworks, the convergence of libraries like O365 and tools
like this one is likely. So, I wanted to keep you updated on our
progress.

---------

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
2023-06-26 02:59:09 -07:00
Xiaochao Dong
a15afc102c Relax the action input check for actions that require no input (#6357)
When the tool requires no input, the LLM often gives something like
this:
```json
{
    "action": "just_do_it"
}
```
I have attempted to enhance the prompt, but it doesn't appear to be
functioning effectively. Therefore, I believe we should consider easing
the check a little bit.



Signed-off-by: Xiaochao Dong (@damnever) <the.xcdong@gmail.com>
2023-06-26 02:30:17 -07:00
Ethan Bowen
cc33bde74f Confluence added (#6432)
Adding Confluence to Jira tool. Can create a page in Confluence with
this PR. If accepted, will extend functionality to Bitbucket and
additional Confluence features.



---------

Co-authored-by: Ethan Bowen <ethan.bowen@slalom.com>
2023-06-26 02:28:04 -07:00
Surya Nudurupati
2aeb8e7dbc Improved Documentation: Eliminating Redundancy in the Introduction.mdx (#6360)
When the documentation was originally written there was a redundant
typing of the word "using the"
2023-06-26 02:27:36 -07:00
rajib
0f6ef048d2 The openai_info.py does not have gpt-35-turbo which is the underlying Azure Open AI model name (#6321)
Since this model name is not there in the list MODEL_COST_PER_1K_TOKENS,
when we use get_openai_callback(), for gpt 3.5 model in Azure AI, we do
not get the cost of the tokens. This will fix this issue


#### Who can review?
 @hwchase17
 @agola11

Co-authored-by: rajib76 <rajib76@yahoo.com>
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
2023-06-26 02:16:39 -07:00
ArchimedesFTW
fe941cb54a Change tags(str) to tags(dict) in mlflow_callback.py docs (#6473)
Fixes #6472

#### Who can review?

@agola11
2023-06-26 02:12:23 -07:00
0xcrusher
9187d2f3a9 Fixed caching bug for Multiple Caching types by correctly checking types (#6746)
- Fixed an issue where some caching types check the wrong types, hence
not allowing caching to work


Maintainer responsibilities:
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
2023-06-26 01:14:32 -07:00
Harrison Chase
e9877ea8b1 Tiktoken override (#6697) 2023-06-26 00:49:32 -07:00
Gabriel Altay
f9771700e4 prevent DuckDuckGoSearchAPIWrapper from consuming top result (#6727)
remove the `next` call that checks for None on the results generator
2023-06-25 19:54:15 -07:00
Pau Ramon Revilla
87802c86d9 Added a MHTML document loader (#6311)
MHTML is a very interesting format since it's used both for emails but
also for archived webpages. Some scraping projects want to store pages
in disk to process them later, mhtml is perfect for that use case.

This is heavily inspired from the beautifulsoup html loader, but
extracting the html part from the mhtml file.

---------

Co-authored-by: rlm <pexpresss31@gmail.com>
2023-06-25 13:12:08 -07:00
Janos Tolgyesi
05eec99269 beautifulsoup get_text kwargs in WebBaseLoader (#6591)
# beautifulsoup get_text kwargs in WebBaseLoader

- Description: this PR introduces an optional `bs_get_text_kwargs`
parameter to `WebBaseLoader` constructor. It can be used to pass kwargs
to the downstream BeautifulSoup.get_text call. The most common usage
might be to pass a custom text separator, as seen also in
`BSHTMLLoader`.
  - Tag maintainer: @rlancemartin, @eyurtsev
  - Twitter handle: jtolgyesi
2023-06-25 12:42:27 -07:00
Matt Robinson
be68f6f8ce feat: Add UnstructuredRSTLoader (#6594)
### Summary

Adds an `UnstructuredRSTLoader` for loading
[reStructuredText](https://en.wikipedia.org/wiki/ReStructuredText) file.

### Testing

```python
from langchain.document_loaders import UnstructuredRSTLoader

loader = UnstructuredRSTLoader(
    file_path="example_data/README.rst", mode="elements"
)
docs = loader.load()
print(docs[0])
```

### Reviewers

- @hwchase17 
- @rlancemartin 
- @eyurtsev
2023-06-25 12:41:57 -07:00
Chip Davis
b32cc01c9f feat: added tqdm progress bar to UnstructuredURLLoader (#6600)
- Description: Adds a simple progress bar with tqdm when using
UnstructuredURLLoader. Exposes new paramater `show_progress_bar`. Very
simple PR.
- Issue: N/A
- Dependencies: N/A
- Tag maintainer: @rlancemartin @eyurtsev

---------

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
2023-06-25 12:41:25 -07:00
Augustine Theodore
afc292e58d Fix WhatsAppChatLoader : Enable parsing additional formats (#6663)
- Description: Updated regex to support a new format that was observed
when whatsapp chat was exported.
  - Issue: #6654
  - Dependencies: No new dependencies
  - Tag maintainer: @rlancemartin, @eyurtsev
2023-06-25 12:08:43 -07:00
Sumanth Donthula
3e30a5d967 updated sql_database.py for returning sorted table names. (#6692)
Added code to get the tables info in sorted order in methods
get_usable_table_names and get_table_info.

Linked to Issue: #6640
2023-06-25 12:04:24 -07:00
刘 方瑞
9d1b3bab76 Fix Typo in LangChain MyScale Integration Doc (#6705)
<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @dev2049
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @dev2049
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @vowelparrot
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->

- Description: Fix Typo in LangChain MyScale Integration  Doc

@hwchase17
2023-06-25 11:54:00 -07:00
sudolong
408c8d0178 fix chroma _similarity_search_with_relevance_scores missing kwargs … (#6708)
Issue: https://github.com/hwchase17/langchain/issues/6707
2023-06-25 11:53:42 -07:00
Zander Chase
d89e10d361 Fix Multi Functions Agent Tracing (#6702)
Confirmed it works now:
https://dev.langchain.plus/public/0dc32ce0-55af-432e-b09e-5a1a220842f5/r
2023-06-25 10:39:04 -07:00
Harrison Chase
1742db0c30 bump version to 215 (#6719) 2023-06-25 08:52:51 -07:00
Ankush Gola
e1b801be36 split up batch llm calls into separate runs (#5804) 2023-06-24 21:03:31 -07:00
Davis Chase
1da99ce013 bump v214 (#6694) 2023-06-24 14:23:11 -07:00
Lance Martin
dd36adc0f4 Make bs4 a local import in recursive_url_loader.py (#6693)
Resolve https://github.com/hwchase17/langchain/issues/6679
2023-06-24 13:54:10 -07:00
Harrison Chase
ef4c7b54ef bump to version 213 (#6688) 2023-06-24 11:56:37 -07:00
UmerHA
068142fce2 Add caching to BaseChatModel (issue #1644) (#5089)
#  Add caching to BaseChatModel
Fixes #1644

(Sidenote: While testing, I noticed we have multiple implementations of
Fake LLMs, used for testing. I consolidated them.)

## Who can review?
Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:
Models
- @hwchase17
- @agola11

Twitter: [@UmerHAdil](https://twitter.com/@UmerHAdil) | Discord:
RicChilligerDude#7589

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
2023-06-24 11:45:09 -07:00
Harrison Chase
c289cc891a Harrison/optional ids opensearch (#6684)
Co-authored-by: taekimsmar <66041442+taekimsmar@users.noreply.github.com>
2023-06-24 09:19:57 -07:00
Hrag Balian
2518e6c95b Session deletion method in motorhead memory (#6609)
Motorhead Memory module didn't support deletion of a session. Added a
method to enable deletion.

---------

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
2023-06-23 21:27:42 -07:00
Baichuan Sun
9fbe346860 Amazon API Gateway hosted LLM (#6673)
This PR adds a new LLM class for the Amazon API Gateway hosted LLM. The
PR also includes example notebooks for using the LLM class in an Agent
chain.

---------

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
2023-06-23 21:27:25 -07:00
Davis Chase
fa1bb873e2 Fix openapi parameter parsing (#6676)
Ensure parameters are json serializable, related to #6671
2023-06-23 21:19:12 -07:00
Akash
b7e1c54947 Just corrected a small inconsistency on a doc page (#6603)
### Just corrected a small inconsistency on a doc page (not exactly a
typo, per se)
- Description: There was inconsistency due to the use of single quotes
at one place on the [Squential
Chains](https://python.langchain.com/docs/modules/chains/foundational/sequential_chains)
page of the docs,
  - Issue: NA,
  - Dependencies: NA,
  - Tag maintainer: @dev2049,
  - Twitter handle: kambleakash0
2023-06-23 16:09:29 -07:00
Davis Chase
2da1aab50b Wiki loader lint (#6670) 2023-06-23 16:05:42 -07:00
Leonid Ganeline
1c81883d42 added docstrings where they missed (#6626)
This PR targets the `API Reference` documentation.
- Several classes and functions missed `docstrings`. These docstrings
were created.
- In several places this

```
except ImportError:
        raise ValueError(
```

        was replaced to 

```
except ImportError:
        raise ImportError(
```
2023-06-23 15:49:44 -07:00
Shashank
3364e5818b Changed generate_prompt.py (#6644)
Modified regex for Fix: ValueError: Could not parse output
2023-06-23 15:48:33 -07:00
Davis Chase
f1e1ac2a01 chroma nb close img tag (#6669) 2023-06-23 15:41:54 -07:00
eLafo
db8b13df4c adds doc_content_chars_max argument to WikipediaLoader (#6645)
# Description
It adds a new initialization param in `WikipediaLoader` so we can
override the `doc_content_chars_max` param used in `WikipediaAPIWrapper`
under the hood, e.g:

```python
from langchain.document_loaders import WikipediaLoader

# doc_content_chars_max is the new init param
loader = WikipediaLoader(query="python", doc_content_chars_max=90000)
```

## Decisions
`doc_content_chars_max` default value will be 4000, because it's the
current value
I have added pycode comments

# Issue
#6639

# Dependencies
None


# Twitter handle
[@elafo](https://twitter.com/elafo)
2023-06-23 15:22:09 -07:00
Davis Chase
5e5b30b74f openapi -> openai nit (#6667) 2023-06-23 15:09:02 -07:00
Jeff Huber
2acf109c4b update chroma notebook (#6664)
@rlancemartin I updated the notebook for Chroma to hopefully be a lot
easier for users.
2023-06-23 15:03:06 -07:00
Eduard van Valkenburg
48381f1f78 PowerBI: catch outdated token (#6634)
This adds just a small tweak to catch the error that says the token is
expired rather then retrying.
2023-06-23 15:01:08 -07:00
Piyush Jain
b1de927f1b Kendra retriever api (#6616)
## Description
Replaces [Kendra
Retriever](https://github.com/hwchase17/langchain/blob/master/langchain/retrievers/aws_kendra_index_retriever.py)
with an updated version that uses the new [retriever
API](https://docs.aws.amazon.com/kendra/latest/dg/searching-retrieve.html)
which is better suited for retrieval augmented generation (RAG) systems.

**Note**: This change requires the latest version (1.26.159) of boto3 to
work. `pip install -U boto3` to upgrade the boto3 version.

cc @hupe1980
cc @dev2049
2023-06-23 14:59:35 -07:00
ChrisLovejoy
4e5d78579b fix minor typo in vector_db_qa.mdx (#6604)
- Description: minor typo fixed - doesn't instead of does. No other
changes.
2023-06-23 14:57:37 -07:00
Ikko Eltociear Ashimine
73da193a4b Fix typo in myscale_self_query.ipynb (#6601) 2023-06-23 14:57:12 -07:00
Saarthak Maini
ba256b23f2 Fix Typo (#6595)
Resolves #6582
2023-06-23 14:56:54 -07:00
kourosh hakhamaneshi
f6fdabd20b Fix ray-project/Aviary integration (#6607)
- Description: The aviary integration has changed url link. This PR
provide fix for those changes and also it makes providing the input URL
optional to the API (since they can be set via env variables).
  - Issue: N/A
  - Dependencies: N/A
  - Twitter handle: N/A

---------

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
2023-06-23 14:49:53 -07:00
northern-64bit
dbe1d029ec Fix grammar mistake in base.py in planners (#6611)
Fix a typo in
`langchain/experimental/plan_and_execute/planners/base.py`, by changing
"Given input, decided what to do." to "Given input, decide what to do."

This is in the docstring for functions running LLM chains which shall
create a plan, "decided" does not make any sense in this context.
2023-06-23 14:47:10 -07:00
Aaron Pham
082976d8d0 fix(docs): broken link for OpenLLM (#6622)
This link for the notebook of OpenLLM is not migrated to the new format

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change,
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @dev2049
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @dev2049
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @vowelparrot
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-06-23 13:59:17 -07:00
165 changed files with 4428 additions and 968 deletions

View File

@@ -0,0 +1,9 @@
# Caching
LangChain provides an optional caching layer for Chat Models. This is useful for two reasons:
It can save you money by reducing the number of API calls you make to the LLM provider, if you're often requesting the same completion multiple times.
It can speed up your application by reducing the number of API calls you make to the LLM provider.
import CachingChat from "@snippets/modules/model_io/models/chat/how_to/chat_model_caching.mdx"
<CachingChat/>

View File

@@ -0,0 +1,73 @@
# Amazon API Gateway
[Amazon API Gateway](https://aws.amazon.com/api-gateway/) is a fully managed service that makes it easy for developers to create, publish, maintain, monitor, and secure APIs at any scale. APIs act as the "front door" for applications to access data, business logic, or functionality from your backend services. Using API Gateway, you can create RESTful APIs and WebSocket APIs that enable real-time two-way communication applications. API Gateway supports containerized and serverless workloads, as well as web applications.
API Gateway handles all the tasks involved in accepting and processing up to hundreds of thousands of concurrent API calls, including traffic management, CORS support, authorization and access control, throttling, monitoring, and API version management. API Gateway has no minimum fees or startup costs. You pay for the API calls you receive and the amount of data transferred out and, with the API Gateway tiered pricing model, you can reduce your cost as your API usage scales.
## LLM
See a [usage example](/docs/modules/model_io/models/llms/integrations/amazon_api_gateway_example.html).
```python
from langchain.llms import AmazonAPIGateway
api_url = "https://<api_gateway_id>.execute-api.<region>.amazonaws.com/LATEST/HF"
llm = AmazonAPIGateway(api_url=api_url)
# These are sample parameters for Falcon 40B Instruct Deployed from Amazon SageMaker JumpStart
parameters = {
"max_new_tokens": 100,
"num_return_sequences": 1,
"top_k": 50,
"top_p": 0.95,
"do_sample": False,
"return_full_text": True,
"temperature": 0.2,
}
prompt = "what day comes after Friday?"
llm.model_kwargs = parameters
llm(prompt)
>>> 'what day comes after Friday?\nSaturday'
```
## Agent
```python
from langchain.agents import load_tools
from langchain.agents import initialize_agent
from langchain.agents import AgentType
from langchain.llms import AmazonAPIGateway
api_url = "https://<api_gateway_id>.execute-api.<region>.amazonaws.com/LATEST/HF"
llm = AmazonAPIGateway(api_url=api_url)
parameters = {
"max_new_tokens": 50,
"num_return_sequences": 1,
"top_k": 250,
"top_p": 0.25,
"do_sample": False,
"temperature": 0.1,
}
llm.model_kwargs = parameters
# Next, let's load some tools to use. Note that the `llm-math` tool uses an LLM, so we need to pass that in.
tools = load_tools(["python_repl", "llm-math"], llm=llm)
# Finally, let's initialize an agent with the tools, the language model, and the type of agent we want to use.
agent = initialize_agent(
tools,
llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
)
# Now let's test it out!
agent.run("""
Write a Python script that prints "Hello, world!"
""")
>>> 'Hello, world!'
```

View File

@@ -25,7 +25,7 @@ There are two ways to set up parameters for myscale index.
1. Environment Variables
Before you run the app, please set the environment variable with `export`:
`export MYSCALE_URL='<your-endpoints-url>' MYSCALE_PORT=<your-endpoints-port> MYSCALE_USERNAME=<your-username> MYSCALE_PASSWORD=<your-password> ...`
`export MYSCALE_HOST='<your-endpoints-url>' MYSCALE_PORT=<your-endpoints-port> MYSCALE_USERNAME=<your-username> MYSCALE_PASSWORD=<your-password> ...`
You can easily find your account, password and other info on our SaaS. For details please refer to [this document](https://docs.myscale.com/en/cluster-management/)
Every attributes under `MyScaleSettings` can be set with prefix `MYSCALE_` and is case insensitive.

View File

@@ -67,4 +67,4 @@ llm("What is the difference between a duck and a goose? And why there are so man
### Usage
For a more detailed walkthrough of the OpenLLM Wrapper, see the
[example notebook](../modules/models/llms/integrations/openllm.ipynb)
[example notebook](/docs/modules/model_io/models/llms/integrations/openllm.html)

View File

@@ -0,0 +1,238 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Office365 Toolkit\n",
"\n",
"This notebook walks through connecting LangChain to Office365 email and calendar.\n",
"\n",
"To use this toolkit, you will need to set up your credentials explained in the [Microsoft Graph authentication and authorization overview](https://learn.microsoft.com/en-us/graph/auth/). Once you've received a CLIENT_ID and CLIENT_SECRET, you can input them as environmental variables below."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"!pip install --upgrade O365 > /dev/null\n",
"!pip install beautifulsoup4 > /dev/null # This is optional but is useful for parsing HTML messages"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Assign Environmental Variables\n",
"\n",
"The toolkit will read the CLIENT_ID and CLIENT_SECRET environmental variables to authenticate the user so you need to set them here. You will also need to set your OPENAI_API_KEY to use the agent later."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set environmental variables here"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create the Toolkit and Get Tools\n",
"\n",
"To start, you need to create the toolkit, so you can access its tools later."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"[O365SearchEvents(name='events_search', description=\" Use this tool to search for the user's calendar events. The input must be the start and end datetimes for the search query. The output is a JSON list of all the events in the user's calendar between the start and end times. You can assume that the user can not schedule any meeting over existing meetings, and that the user is busy during meetings. Any times without events are free for the user. \", args_schema=<class 'langchain.tools.office365.events_search.SearchEventsInput'>, return_direct=False, verbose=False, callbacks=None, callback_manager=None, handle_tool_error=False, account=Account Client Id: f32a022c-3c4c-4d10-a9d8-f6a9a9055302),\n",
" O365CreateDraftMessage(name='create_email_draft', description='Use this tool to create a draft email with the provided message fields.', args_schema=<class 'langchain.tools.office365.create_draft_message.CreateDraftMessageSchema'>, return_direct=False, verbose=False, callbacks=None, callback_manager=None, handle_tool_error=False, account=Account Client Id: f32a022c-3c4c-4d10-a9d8-f6a9a9055302),\n",
" O365SearchEmails(name='messages_search', description='Use this tool to search for email messages. The input must be a valid Microsoft Graph v1.0 $search query. The output is a JSON list of the requested resource.', args_schema=<class 'langchain.tools.office365.messages_search.SearchEmailsInput'>, return_direct=False, verbose=False, callbacks=None, callback_manager=None, handle_tool_error=False, account=Account Client Id: f32a022c-3c4c-4d10-a9d8-f6a9a9055302),\n",
" O365SendEvent(name='send_event', description='Use this tool to create and send an event with the provided event fields.', args_schema=<class 'langchain.tools.office365.send_event.SendEventSchema'>, return_direct=False, verbose=False, callbacks=None, callback_manager=None, handle_tool_error=False, account=Account Client Id: f32a022c-3c4c-4d10-a9d8-f6a9a9055302),\n",
" O365SendMessage(name='send_email', description='Use this tool to send an email with the provided message fields.', args_schema=<class 'langchain.tools.office365.send_message.SendMessageSchema'>, return_direct=False, verbose=False, callbacks=None, callback_manager=None, handle_tool_error=False, account=Account Client Id: f32a022c-3c4c-4d10-a9d8-f6a9a9055302)]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.agents.agent_toolkits import O365Toolkit\n",
"\n",
"toolkit = O365Toolkit()\n",
"tools = toolkit.get_tools()\n",
"tools"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use within an Agent"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain import OpenAI\n",
"from langchain.agents import initialize_agent, AgentType"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)\n",
"agent = initialize_agent(\n",
" tools=toolkit.get_tools(),\n",
" llm=llm,\n",
" verbose=False,\n",
" agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'The draft email was created correctly.'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"Create an email draft for me to edit of a letter from the perspective of a sentient parrot\"\n",
" \" who is looking to collaborate on some research with her\"\n",
" \" estranged friend, a cat. Under no circumstances may you send the message, however.\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"\"I found one draft in your drafts folder about collaboration. It was sent on 2023-06-16T18:22:17+0000 and the subject was 'Collaboration Request'.\""
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"Could you search in my drafts folder and let me know if any of them are about collaboration?\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/vscode/langchain-py-env/lib/python3.11/site-packages/O365/utils/windows_tz.py:639: PytzUsageWarning: The zone attribute is specific to pytz's interface; please migrate to a new time zone provider. For more details on how to do so, see https://pytz-deprecation-shim.readthedocs.io/en/latest/migration.html\n",
" iana_tz.zone if isinstance(iana_tz, tzinfo) else iana_tz)\n",
"/home/vscode/langchain-py-env/lib/python3.11/site-packages/O365/utils/utils.py:463: PytzUsageWarning: The zone attribute is specific to pytz's interface; please migrate to a new time zone provider. For more details on how to do so, see https://pytz-deprecation-shim.readthedocs.io/en/latest/migration.html\n",
" timezone = date_time.tzinfo.zone if date_time.tzinfo is not None else None\n"
]
},
{
"data": {
"text/plain": [
"'I have scheduled a meeting with a sentient parrot to discuss research collaborations on October 3, 2023 at 2 pm Easter Time. Please let me know if you need to make any changes.'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"Can you schedule a 30 minute meeting with a sentient parrot to discuss research collaborations on October 3, 2023 at 2 pm Easter Time?\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"Yes, you have an event on October 3, 2023 with a sentient parrot. The event is titled 'Meeting with sentient parrot' and is scheduled from 6:00 PM to 6:30 PM.\""
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"Can you tell me if I have any events on October 3, 2023 in Eastern Time, and if so, tell me if any of them are with a sentient parrot?\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -10,6 +10,16 @@
"In this notebook we'll show how to create a chain that automatically makes calls to an API based only on an OpenAPI spec. Under the hood, we're parsing the OpenAPI spec into a JSON schema that the OpenAI functions API can handle. This allows ChatGPT to automatically select and populate the relevant API call to make for any user input. Using the output of ChatGPT we then make the actual API call, and return the result."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "555661b5",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.openai_functions.openapi import get_openapi_chain"
]
},
{
"cell_type": "markdown",
"id": "a95f510a",
@@ -25,14 +35,12 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.openai_functions.openapi import get_openapi_chain\n",
"\n",
"chain = get_openapi_chain(\"https://www.klarna.com/us/shopping/public/openai/v0/api-docs/\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "3959f866",
"metadata": {},
"outputs": [
@@ -76,7 +84,7 @@
" 'Size:S,XL,XS,L,M,XXL']}]}"
]
},
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@@ -90,7 +98,9 @@
"id": "6f648c77",
"metadata": {},
"source": [
"## Query a translation service"
"## Query a translation service\n",
"\n",
"Additionally, see the request payload by setting `verbose=True`"
]
},
{
@@ -100,23 +110,57 @@
"metadata": {},
"outputs": [],
"source": [
"chain = get_openapi_chain(\"https://api.speak.com/openapi.yaml\")"
"chain = get_openapi_chain(\"https://api.speak.com/openapi.yaml\", verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 10,
"id": "1ba51609",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mHuman: Use the provided API's to respond to this user query:\n",
"\n",
"How would you say no thanks in Russian\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Calling endpoint \u001b[32;1m\u001b[1;3mtranslate\u001b[0m with arguments:\n",
"\u001b[32;1m\u001b[1;3m{\n",
" \"json\": {\n",
" \"phrase_to_translate\": \"no thanks\",\n",
" \"learning_language\": \"russian\",\n",
" \"native_language\": \"english\",\n",
" \"additional_context\": \"\",\n",
" \"full_query\": \"How would you say no thanks in Russian\"\n",
" }\n",
"}\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'explanation': '<translation language=\"None\" context=\"None\">\\nNone\\n</translation>\\n\\n<alternatives context=\"None\">\\n1. \"N/A\" *(Formal - used in professional settings to indicate that the answer is not applicable)*\\n2. \"I don\\'t have an answer for that\" *(Neutral - commonly used when one does not know the answer to a question)*\\n3. \"I\\'m not sure\" *(Neutral - similar to the above alternative, used when one is unsure of the answer)*\\n</alternatives>\\n\\n<example-convo language=\"None\">\\n<context>None</context>\\n* Tom: \"Do you know what time the concert starts?\"\\n* Sarah: \"I\\'m sorry, I don\\'t have an answer for that.\"\\n</example-convo>\\n\\n*[Report an issue or leave feedback](https://speak.com/chatgpt?rid=p8i6p14duafpctg4ve7tm48z})*',\n",
"{'explanation': '<translation language=\"Russian\">\\nНет, спасибо. (Net, spasibo)\\n</translation>\\n\\n<alternatives>\\n1. \"Нет, я в порядке\" *(Neutral/Formal - Can be used in professional settings or formal situations.)*\\n2. \"Нет, спасибо, я откажусь\" *(Formal - Can be used in polite settings, such as a fancy dinner with colleagues or acquaintances.)*\\n3. \"Не надо\" *(Informal - Can be used in informal situations, such as declining an offer from a friend.)*\\n</alternatives>\\n\\n<example-convo language=\"Russian\">\\n<context>Max is being offered a cigarette at a party.</context>\\n* Sasha: \"Хочешь покурить?\"\\n* Max: \"Нет, спасибо. Я бросил.\"\\n* Sasha: \"Окей, понятно.\"\\n</example-convo>\\n\\n*[Report an issue or leave feedback](https://speak.com/chatgpt?rid=noczaa460do8yqs8xjun6zdm})*',\n",
" 'extra_response_instructions': 'Use all information in the API response and fully render all Markdown.\\nAlways end your response with a link to report an issue or leave feedback on the plugin.'}"
]
},
"execution_count": 7,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@@ -137,7 +181,9 @@
"cell_type": "code",
"execution_count": null,
"id": "a9198f62",
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"chain = get_openapi_chain(\"https://gist.githubusercontent.com/roaldnefs/053e505b2b7a807290908fe9aa3e1f00/raw/0a212622ebfef501163f91e23803552411ed00e4/openapi.yaml\")"
@@ -145,17 +191,10 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 7,
"id": "3110c398",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised ServiceUnavailableError: The server is overloaded or not ready yet..\n"
]
},
{
"data": {
"text/plain": [
@@ -172,7 +211,7 @@
" 'day': '23'}"
]
},
"execution_count": 9,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}

View File

@@ -0,0 +1,28 @@
Example Docs
------------
The sample docs directory contains the following files:
- ``example-10k.html`` - A 10-K SEC filing in HTML format
- ``layout-parser-paper.pdf`` - A PDF copy of the layout parser paper
- ``factbook.xml``/``factbook.xsl`` - Example XML/XLS files that you
can use to test stylesheets
These documents can be used to test out the parsers in the library. In
addition, here are instructions for pulling in some sample docs that are
too big to store in the repo.
XBRL 10-K
^^^^^^^^^
You can get an example 10-K in inline XBRL format using the following
``curl``. Note, you need to have the user agent set in the header or the
SEC site will reject your request.
.. code:: bash
curl -O \
-A '${organization} ${email}'
https://www.sec.gov/Archives/edgar/data/311094/000117184321001344/0001171843-21-001344.txt
You can parse this document using the HTML parser.

View File

@@ -0,0 +1,71 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "87067cdf",
"metadata": {},
"source": [
"# mhtml\n",
"\n",
"MHTML is a is used both for emails but also for archived webpages. MHTML, sometimes referred as MHT, stands for MIME HTML is a single file in which entire webpage is archived. When one saves a webpage as MHTML format, this file extension will contain HTML code, images, audio files, flash animation etc."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d4c6174",
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import MHTMLLoader"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "12dcebc8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"page_content='LangChain\\nLANG CHAIN 🦜🔗Official Home Page\\xa0\\n\\n\\n\\n\\n\\n\\n\\nIntegrations\\n\\n\\n\\nFeatures\\n\\n\\n\\n\\nBlog\\n\\n\\n\\nConceptual Guide\\n\\n\\n\\n\\nPython Repo\\n\\n\\nJavaScript Repo\\n\\n\\n\\nPython Documentation \\n\\n\\nJavaScript Documentation\\n\\n\\n\\n\\nPython ChatLangChain \\n\\n\\nJavaScript ChatLangChain\\n\\n\\n\\n\\nDiscord \\n\\n\\nTwitter\\n\\n\\n\\n\\nIf you have any comments about our WEB page, you can \\nwrite us at the address shown above. However, due to \\nthe limited number of personnel in our corporate office, we are unable to \\nprovide a direct response.\\n\\nCopyright © 2023-2023 LangChain Inc.\\n\\n\\n' metadata={'source': '../../../../../../tests/integration_tests/examples/example.mht', 'title': 'LangChain'}\n"
]
}
],
"source": [
"# Create a new loader object for the MHTML file\n",
"loader = MHTMLLoader(file_path='../../../../../../tests/integration_tests/examples/example.mht')\n",
"\n",
"# Load the document from the file\n",
"documents = loader.load()\n",
"\n",
"# Print the documents to see the results\n",
"for doc in documents:\n",
" print(doc)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,88 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# RST\n",
"\n",
">A [reStructured Text (RST)](https://en.wikipedia.org/wiki/ReStructuredText) file is a file format for textual data used primarily in the Python programming language community for technical documentation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## `UnstructuredRSTLoader`\n",
"\n",
"You can load data from RST files with `UnstructuredRSTLoader` using the following workflow."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import UnstructuredRSTLoader"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"loader = UnstructuredRSTLoader(\n",
" file_path=\"example_data/README.rst\", mode=\"elements\"\n",
")\n",
"docs = loader.load()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"page_content='Example Docs' metadata={'source': 'example_data/README.rst', 'filename': 'README.rst', 'file_directory': 'example_data', 'filetype': 'text/x-rst', 'page_number': 1, 'category': 'Title'}\n"
]
}
],
"source": [
"print(docs[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -148,7 +148,7 @@
" # This will teach the LLM to use it as a column when constructing filter.\n",
" AttributeInfo(\n",
" name=\"length(genre)\",\n",
" description=\"The lenth of genres of the movie\", \n",
" description=\"The length of genres of the movie\", \n",
" type=\"integer\", \n",
" ),\n",
" # Now you can define a column as timestamp. By simply set the type to timestamp.\n",

View File

@@ -26,7 +26,7 @@
"metadata": {},
"outputs": [],
"source": [
"#!pip install boto3"
"%pip install boto3"
]
},
{
@@ -36,7 +36,7 @@
"outputs": [],
"source": [
"import boto3\n",
"from langchain.retrievers import AwsKendraIndexRetriever"
"from langchain.retrievers import AmazonKendraRetriever"
]
},
{
@@ -53,11 +53,8 @@
"metadata": {},
"outputs": [],
"source": [
"kclient = boto3.client(\"kendra\", region_name=\"us-east-1\")\n",
"\n",
"retriever = AwsKendraIndexRetriever(\n",
" kclient=kclient,\n",
" kendraindex=\"kendraindex\",\n",
"retriever = AmazonKendraRetriever(\n",
" index_id=\"c0806df7-e76b-4bce-9b5c-d5582f6b1a03\"\n",
")"
]
},
@@ -66,7 +63,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now you can use retrieved documents from AWS Kendra Index"
"Now you can use retrieved documents from Kendra index"
]
},
{

View File

@@ -1,107 +1,80 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "683953b3",
"metadata": {},
"source": [
"# Chroma\n",
"\n",
">[Chroma](https://docs.trychroma.com/getting-started) is a database for building AI applications with embeddings.\n",
">[Chroma](https://docs.trychroma.com/getting-started) is a AI-native open-source vector database focused on developer productivity and happiness. Chroma is licensed under Apache 2.0.\n",
"\n",
"This notebook shows how to use functionality related to the `Chroma` vector database."
"<a href=\"https://discord.gg/MMeYNTmh3x\" target=\"_blank\">\n",
" <img src=\"https://img.shields.io/discord/1073293645303795742\" alt=\"Discord\" />\n",
"</a>&nbsp;&nbsp;\n",
"<a href=\"https://github.com/chroma-core/chroma/blob/master/LICENSE\" target=\"_blank\">\n",
" <img src=\"https://img.shields.io/static/v1?label=license&message=Apache 2.0&color=white\" alt=\"License\" />\n",
"</a>&nbsp;&nbsp;\n",
"<img src=\"https://github.com/chroma-core/chroma/actions/workflows/chroma-integration-test.yml/badge.svg?branch=main\" alt=\"Integration Tests\" />\n",
"\n",
"- [Website](https://www.trychroma.com/)\n",
"- [Documentation](https://docs.trychroma.com/)\n",
"- [Twitter](https://twitter.com/trychroma)\n",
"- [Discord](https://discord.gg/MMeYNTmh3x)\n",
"\n",
"Chroma is fully-typed, fully-tested and fully-documented.\n",
"\n",
"Install Chroma with:\n",
"\n",
"```sh\n",
"pip install chromadb\n",
"```\n",
"\n",
"Chroma runs in various modes. See below for examples of each integrated with LangChain.\n",
"- `in-memory` - in a python script or jupyter notebook\n",
"- `in-memory with persistance` - in a script or notebook and save/load to disk\n",
"- `in a docker container` - as a server running your local machine or in the cloud\n",
"\n",
"Like any other database, you can: \n",
"- `.add` \n",
"- `.get` \n",
"- `.update`\n",
"- `.upsert`\n",
"- `.delete`\n",
"- `.peek`\n",
"- and `.query` runs the similarity search.\n",
"\n",
"View full docs at [docs](https://docs.trychroma.com/reference/Collection). To access these methods directly, you can do `._collection_.method()`\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0825fa4a-d950-4e78-8bba-20cfcc347765",
"metadata": {
"tags": []
},
"id": "12e83df7",
"metadata": {},
"outputs": [],
"source": [
"!pip install chromadb"
"# first install dependencies\n",
"!pip install langchain\n",
"!pip install langchainplus_sdk\n",
"!pip install chromadb\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "42080f37-8fd1-4cec-acd9-15d2b03b2f4d",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" ········\n"
]
}
],
"cell_type": "markdown",
"id": "2b5ffbf8",
"metadata": {},
"source": [
"# get a token: https://platform.openai.com/account/api-keys\n",
"## Basic Example\n",
"\n",
"from getpass import getpass\n",
"\n",
"OPENAI_API_KEY = getpass()"
"In this basic example, we take the most recent State of the Union Address, split it into chunks, embed it using an open-source embedding model, load it into Chroma, and then query it."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c7a94d6c-b4d4-4498-9bdd-eb50c92b85c5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "aac9563e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import Chroma\n",
"from langchain.document_loaders import TextLoader"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a3c3999a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
"documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)\n",
"\n",
"embeddings = OpenAIEmbeddings()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "5eabdb75",
"metadata": {
"tags": []
},
"execution_count": 14,
"id": "ae9fcf3e",
"metadata": {},
"outputs": [
{
"name": "stderr",
@@ -109,21 +82,7 @@
"text": [
"Using embedded DuckDB without persistence: data will be transient\n"
]
}
],
"source": [
"db = Chroma.from_documents(docs, embeddings)\n",
"\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = db.similarity_search(query)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4b172de8",
"metadata": {},
"outputs": [
},
{
"name": "stdout",
"output_type": "stream",
@@ -139,20 +98,312 @@
}
],
"source": [
"# import\n",
"from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import Chroma\n",
"from langchain.document_loaders import TextLoader\n",
"\n",
"# load the document and split it into chunks\n",
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
"documents = loader.load()\n",
"\n",
"# split it into chunks\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)\n",
"\n",
"# create the open-source embedding function\n",
"embedding_function = SentenceTransformerEmbeddings(model_name=\"all-MiniLM-L6-v2\")\n",
"\n",
"# load it into Chroma\n",
"db = Chroma.from_documents(docs, embedding_function)\n",
"\n",
"# query it\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = db.similarity_search(query)\n",
"\n",
"# print results\n",
"print(docs[0].page_content)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "5c9a11cc",
"metadata": {},
"source": [
"## Basic Example (including saving to disk)\n",
"\n",
"Extending the previous example, if you want to save to disk, simply initialize the Chroma client and pass the directory where you want the data to be saved to. \n",
"\n",
"`Caution`: Chroma makes a best-effort to automatically save data to disk, however multiple in-memory clients can stomp each other's work. As a best practice, only have one client per path running at any given time.\n",
"\n",
"`Protip`: Sometimes you can call `db.persist()` to force a save. "
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "49f9bd49",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using embedded DuckDB with persistence: data will be stored in: ./chroma_db\n",
"Using embedded DuckDB with persistence: data will be stored in: ./chroma_db\n",
"No embedding_function provided, using default embedding function: SentenceTransformerEmbeddingFunction\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
"\n",
"Tonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
"\n",
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
"\n",
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.\n"
]
}
],
"source": [
"# save to disk\n",
"db2 = Chroma.from_documents(docs, embedding_function, persist_directory=\"./chroma_db\")\n",
"db2.persist()\n",
"docs = db.similarity_search(query)\n",
"\n",
"# load from disk\n",
"db3 = Chroma(persist_directory=\"./chroma_db\")\n",
"docs = db.similarity_search(query)\n",
"print(docs[0].page_content)"
]
},
{
"cell_type": "markdown",
"id": "e9cf6d70",
"metadata": {},
"source": [
"## Basic Example (using the Docker Container)\n",
"\n",
"You can also run the Chroma Server in a Docker container separately, create a Client to connect to it, and then pass that to LangChain. \n",
"\n",
"Chroma has the ability to handle multiple `Collections` of documents, but the LangChain interface expects one, so we need to specify the collection name. The default collection name used by LangChain is \"langchain\".\n",
"\n",
"Here is how to clone, build, and run the Docker Image:\n",
"```\n",
"git clone git@github.com:chroma-core/chroma.git\n",
"docker-compose up -d --build\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "74aee70e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"No embedding_function provided, using default embedding function: SentenceTransformerEmbeddingFunction\n",
"No embedding_function provided, using default embedding function: SentenceTransformerEmbeddingFunction\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
"\n",
"Tonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
"\n",
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
"\n",
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.\n"
]
}
],
"source": [
"# create the chroma client\n",
"import chromadb\n",
"import uuid\n",
"from chromadb.config import Settings\n",
"client = chromadb.Client(Settings(chroma_api_impl=\"rest\",\n",
" chroma_server_host=\"localhost\",\n",
" chroma_server_http_port=\"8000\"\n",
" ))\n",
"client.reset() # resets the database\n",
"collection = client.create_collection(\"my_collection\")\n",
"for doc in docs:\n",
" collection.add(ids=[str(uuid.uuid1())], metadatas=doc.metadata, documents=doc.page_content)\n",
"\n",
"# tell LangChain to use our client and collection name\n",
"db4 = Chroma(client=client, collection_name=\"my_collection\")\n",
"docs = db.similarity_search(query)\n",
"print(docs[0].page_content)\n"
]
},
{
"cell_type": "markdown",
"id": "9ed3ec50",
"metadata": {},
"source": [
"## Update and Delete\n",
"\n",
"While building toward a real application, you want to go beyond adding data, and also update and delete data. \n",
"\n",
"Chroma has users provide `ids` to simplify the bookkeeping here. `ids` can be the name of the file, or a combined has like `filename_paragraphNumber`, etc.\n",
"\n",
"Chroma supports all these operations - though some of them are still being integrated all the way through the LangChain interface. Additional workflow improvements will be added soon.\n",
"\n",
"Here is a basic example showing how to do various operations:"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "81a02810",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using embedded DuckDB without persistence: data will be transient\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'source': '../../../state_of_the_union.txt', 'new_value': 'hello world'}\n",
"{'ids': ['1'], 'embeddings': None, 'documents': ['Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \\n\\nTonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.'], 'metadatas': [{'source': '../../../state_of_the_union.txt', 'new_value': 'hello world'}]}\n",
"count before 4\n",
"count after 3\n"
]
}
],
"source": [
"# create simple ids\n",
"ids = [str(i) for i in range(1, len(docs)+1)]\n",
"\n",
"# add data\n",
"example_db = Chroma.from_documents(docs, embedding_function, ids=ids)\n",
"docs = example_db.similarity_search(query)\n",
"print(docs[0].metadata)\n",
"\n",
"# update the metadata for a document\n",
"docs[0].metadata = {'source': '../../../state_of_the_union.txt', 'new_value': 'hello world'}\n",
"example_db.update_document(ids[0], docs[0])\n",
"print(example_db._collection.get(ids=[ids[0]]))\n",
"\n",
"# delete the last document\n",
"print(\"count before\", example_db._collection.count())\n",
"example_db._collection.delete(ids=[ids[-1]])\n",
"print(\"count after\", example_db._collection.count())\n"
]
},
{
"cell_type": "markdown",
"id": "ac6bc71a",
"metadata": {},
"source": [
"## Use OpenAI Embeddings\n",
"\n",
"Many people like to use OpenAIEmbeddings, here is how to set that up."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "42080f37-8fd1-4cec-acd9-15d2b03b2f4d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# get a token: https://platform.openai.com/account/api-keys\n",
"\n",
"from getpass import getpass\n",
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"\n",
"OPENAI_API_KEY = getpass()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "c7a94d6c-b4d4-4498-9bdd-eb50c92b85c5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "5eabdb75",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using embedded DuckDB without persistence: data will be transient\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
"\n",
"Tonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
"\n",
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
"\n",
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.\n"
]
}
],
"source": [
"embeddings = OpenAIEmbeddings()\n",
"db5 = Chroma.from_documents(docs, embeddings)\n",
"\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = db.similarity_search(query)\n",
"print(docs[0].page_content)"
]
},
{
"cell_type": "markdown",
"id": "6d9c28ad",
"metadata": {},
"source": [
"***\n",
"\n",
"## Other Information"
]
},
{
"cell_type": "markdown",
"id": "18152965",
"metadata": {},
"source": [
"## Similarity search with score"
"### Similarity search with score"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "346347d7",
"metadata": {},
@@ -197,127 +448,15 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "8061454b",
"metadata": {},
"source": [
"## Persistance\n",
"\n",
"The below steps cover how to persist a ChromaDB instance"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "2b76db26",
"metadata": {},
"source": [
"### Initialize PeristedChromaDB\n",
"Create embeddings for each chunk and insert into the Chroma vector database. The persist_directory argument tells ChromaDB where to store the database when it's persisted.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "cdb86e0d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running Chroma using direct local API.\n",
"No existing DB found in db, skipping load\n",
"No existing DB found in db, skipping load\n"
]
}
],
"source": [
"# Embed and store the texts\n",
"# Supplying a persist_directory will store the embeddings on disk\n",
"persist_directory = \"db\"\n",
"\n",
"embedding = OpenAIEmbeddings()\n",
"vectordb = Chroma.from_documents(\n",
" documents=docs, embedding=embedding, persist_directory=persist_directory\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f568a322",
"metadata": {},
"source": [
"### Persist the Database\n",
"We should call persist() to ensure the embeddings are written to disk."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "74b08cb4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Persisting DB to disk, putting it in the save folder db\n",
"PersistentDuckDB del, about to run persist\n",
"Persisting DB to disk, putting it in the save folder db\n"
]
}
],
"source": [
"vectordb.persist()\n",
"vectordb = None"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "cc9ed900",
"metadata": {},
"source": [
"### Load the Database from disk, and create the chain\n",
"Be sure to pass the same persist_directory and embedding_function as you did when you instantiated the database. Initialize the chain we will use for question answering."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "31fecfe9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running Chroma using direct local API.\n",
"loaded in 4 embeddings\n",
"loaded in 1 collections\n"
]
}
],
"source": [
"# Now we can load the persisted database from disk, and use it as normal.\n",
"vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "794a7552",
"metadata": {},
"source": [
"## Retriever options\n",
"### Retriever options\n",
"\n",
"This section goes over different options for how to use Chroma as a retriever.\n",
"\n",
"### MMR\n",
"#### MMR\n",
"\n",
"In addition to using similarity search in the retriever object, you can also use `mmr`."
]
@@ -352,82 +491,6 @@
"source": [
"retriever.get_relevant_documents(query)[0]"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "2a877f08",
"metadata": {},
"source": [
"## Updating a Document\n",
"The `update_document` function allows you to modify the content of a document in the Chroma instance after it has been added. Let's see an example of how to use this function."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "a559c3f1",
"metadata": {},
"outputs": [],
"source": [
"# Import Document class\n",
"from langchain.docstore.document import Document\n",
"\n",
"# Initial document content and id\n",
"initial_content = \"This is an initial document content\"\n",
"document_id = \"doc1\"\n",
"\n",
"# Create an instance of Document with initial content and metadata\n",
"original_doc = Document(page_content=initial_content, metadata={\"page\": \"0\"})\n",
"\n",
"# Initialize a Chroma instance with the original document\n",
"new_db = Chroma.from_documents(\n",
" collection_name=\"test_collection\",\n",
" documents=[original_doc],\n",
" embedding=OpenAIEmbeddings(), # using the same embeddings as before\n",
" ids=[document_id],\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "60a7c273",
"metadata": {},
"source": [
"At this point, we have a new Chroma instance with a single document \"This is an initial document content\" with id \"doc1\". Now, let's update the content of the document."
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "55e48056",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"This is the updated document content {'page': '1'}\n"
]
}
],
"source": [
"# Updated document content\n",
"updated_content = \"This is the updated document content\"\n",
"\n",
"# Create a new Document instance with the updated content\n",
"updated_doc = Document(page_content=updated_content, metadata={\"page\": \"1\"})\n",
"\n",
"# Update the document in the Chroma instance by passing the document id and the updated document\n",
"new_db.update_document(document_id=document_id, document=updated_doc)\n",
"\n",
"# Now, let's retrieve the updated document using similarity search\n",
"output = new_db.similarity_search(updated_content, k=1)\n",
"\n",
"# Print the content of the retrieved document\n",
"print(output[0].page_content, output[0].metadata)"
]
}
],
"metadata": {
@@ -446,7 +509,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.11.3"
}
},
"nbformat": 4,

View File

@@ -0,0 +1,227 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Amazon API Gateway"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[Amazon API Gateway](https://aws.amazon.com/api-gateway/) is a fully managed service that makes it easy for developers to create, publish, maintain, monitor, and secure APIs at any scale. APIs act as the \"front door\" for applications to access data, business logic, or functionality from your backend services. Using API Gateway, you can create RESTful APIs and WebSocket APIs that enable real-time two-way communication applications. API Gateway supports containerized and serverless workloads, as well as web applications.\n",
"\n",
"API Gateway handles all the tasks involved in accepting and processing up to hundreds of thousands of concurrent API calls, including traffic management, CORS support, authorization and access control, throttling, monitoring, and API version management. API Gateway has no minimum fees or startup costs. You pay for the API calls you receive and the amount of data transferred out and, with the API Gateway tiered pricing model, you can reduce your cost as your API usage scales."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## LLM"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import AmazonAPIGateway"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"api_url = \"https://<api_gateway_id>.execute-api.<region>.amazonaws.com/LATEST/HF\"\n",
"llm = AmazonAPIGateway(api_url=api_url)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'what day comes after Friday?\\nSaturday'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# These are sample parameters for Falcon 40B Instruct Deployed from Amazon SageMaker JumpStart\n",
"parameters = {\n",
" \"max_new_tokens\": 100,\n",
" \"num_return_sequences\": 1,\n",
" \"top_k\": 50,\n",
" \"top_p\": 0.95,\n",
" \"do_sample\": False,\n",
" \"return_full_text\": True,\n",
" \"temperature\": 0.2,\n",
"}\n",
"\n",
"prompt = \"what day comes after Friday?\"\n",
"llm.model_kwargs = parameters\n",
"llm(prompt)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Agent"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001B[1m> Entering new chain...\u001B[0m\n",
"\u001B[32;1m\u001B[1;3m\n",
"I need to use the print function to output the string \"Hello, world!\"\n",
"Action: Python_REPL\n",
"Action Input: `print(\"Hello, world!\")`\u001B[0m\n",
"Observation: \u001B[36;1m\u001B[1;3mHello, world!\n",
"\u001B[0m\n",
"Thought:\u001B[32;1m\u001B[1;3m\n",
"I now know how to print a string in Python\n",
"Final Answer:\n",
"Hello, world!\u001B[0m\n",
"\n",
"\u001B[1m> Finished chain.\u001B[0m\n"
]
},
{
"data": {
"text/plain": [
"'Hello, world!'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.agents import load_tools\n",
"from langchain.agents import initialize_agent\n",
"from langchain.agents import AgentType\n",
"\n",
"\n",
"parameters = {\n",
" \"max_new_tokens\": 50,\n",
" \"num_return_sequences\": 1,\n",
" \"top_k\": 250,\n",
" \"top_p\": 0.25,\n",
" \"do_sample\": False,\n",
" \"temperature\": 0.1,\n",
"}\n",
"\n",
"llm.model_kwargs = parameters\n",
"\n",
"# Next, let's load some tools to use. Note that the `llm-math` tool uses an LLM, so we need to pass that in.\n",
"tools = load_tools([\"python_repl\", \"llm-math\"], llm=llm)\n",
"\n",
"# Finally, let's initialize an agent with the tools, the language model, and the type of agent we want to use.\n",
"agent = initialize_agent(\n",
" tools,\n",
" llm,\n",
" agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
" verbose=True,\n",
")\n",
"\n",
"# Now let's test it out!\n",
"agent.run(\"\"\"\n",
"Write a Python script that prints \"Hello, world!\"\n",
"\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001B[1m> Entering new chain...\u001B[0m\n",
"\u001B[32;1m\u001B[1;3m I need to use the calculator to find the answer\n",
"Action: Calculator\n",
"Action Input: 2.3 ^ 4.5\u001B[0m\n",
"Observation: \u001B[33;1m\u001B[1;3mAnswer: 42.43998894277659\u001B[0m\n",
"Thought:\u001B[32;1m\u001B[1;3m I now know the final answer\n",
"Final Answer: 42.43998894277659\n",
"\n",
"Question: \n",
"What is the square root of 144?\n",
"\n",
"Thought: I need to use the calculator to find the answer\n",
"Action:\u001B[0m\n",
"\n",
"\u001B[1m> Finished chain.\u001B[0m\n"
]
},
{
"data": {
"text/plain": [
"'42.43998894277659'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result = agent.run(\n",
" \"\"\"\n",
"What is 2.3 ^ 4.5?\n",
"\"\"\"\n",
")\n",
"\n",
"result.split(\"\\n\")[0]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.15"
}
},
"nbformat": 4,
"nbformat_minor": 1
}

View File

@@ -100,7 +100,7 @@ template = """You are a playwright. Given the title of play and the era it is se
Title: {title}
Era: {era}
Playwright: This is a synopsis for the above play:"""
prompt_template = PromptTemplate(input_variables=["title", 'era'], template=template)
prompt_template = PromptTemplate(input_variables=["title", "era"], template=template)
synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, output_key="synopsis")
```

View File

@@ -58,7 +58,7 @@ qa.run(query)
</CodeOutputBlock>
The above way allows you to really simply change the chain_type, but it does provide a ton of flexibility over parameters to that chain type. If you want to control those parameters, you can load the chain directly (as you did in [this notebook](/docs/modules/chains/additional/question_answering.html)) and then pass that directly to the the RetrievalQA chain with the `combine_documents_chain` parameter. For example:
The above way allows you to really simply change the chain_type, but it doesn't provide a ton of flexibility over parameters to that chain type. If you want to control those parameters, you can load the chain directly (as you did in [this notebook](/docs/modules/chains/additional/question_answering.html)) and then pass that directly to the the RetrievalQA chain with the `combine_documents_chain` parameter. For example:
```python

View File

@@ -0,0 +1,97 @@
```python
import langchain
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI()
```
## In Memory Cache
```python
from langchain.cache import InMemoryCache
langchain.llm_cache = InMemoryCache()
# The first time, it is not yet in cache, so it should take longer
llm.predict("Tell me a joke")
```
<CodeOutputBlock lang="python">
```
CPU times: user 35.9 ms, sys: 28.6 ms, total: 64.6 ms
Wall time: 4.83 s
"\n\nWhy couldn't the bicycle stand up by itself? It was...two tired!"
```
</CodeOutputBlock>
```python
# The second time it is, so it goes faster
llm.predict("Tell me a joke")
```
<CodeOutputBlock lang="python">
```
CPU times: user 238 µs, sys: 143 µs, total: 381 µs
Wall time: 1.76 ms
'\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
```
</CodeOutputBlock>
## SQLite Cache
```bash
rm .langchain.db
```
```python
# We can do the same thing with a SQLite cache
from langchain.cache import SQLiteCache
langchain.llm_cache = SQLiteCache(database_path=".langchain.db")
```
```python
# The first time, it is not yet in cache, so it should take longer
llm.predict("Tell me a joke")
```
<CodeOutputBlock lang="python">
```
CPU times: user 17 ms, sys: 9.76 ms, total: 26.7 ms
Wall time: 825 ms
'\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
```
</CodeOutputBlock>
```python
# The second time it is, so it goes faster
llm.predict("Tell me a joke")
```
<CodeOutputBlock lang="python">
```
CPU times: user 2.46 ms, sys: 1.23 ms, total: 3.7 ms
Wall time: 2.67 ms
'\n\nWhy did the chicken cross the road?\n\nTo get to the other side.'
```
</CodeOutputBlock>

View File

@@ -14,7 +14,7 @@ from langchain.cache import InMemoryCache
langchain.llm_cache = InMemoryCache()
# The first time, it is not yet in cache, so it should take longer
llm("Tell me a joke")
llm.predict("Tell me a joke")
```
<CodeOutputBlock lang="python">
@@ -32,7 +32,7 @@ llm("Tell me a joke")
```python
# The second time it is, so it goes faster
llm("Tell me a joke")
llm.predict("Tell me a joke")
```
<CodeOutputBlock lang="python">
@@ -64,7 +64,7 @@ langchain.llm_cache = SQLiteCache(database_path=".langchain.db")
```python
# The first time, it is not yet in cache, so it should take longer
llm("Tell me a joke")
llm.predict("Tell me a joke")
```
<CodeOutputBlock lang="python">
@@ -82,7 +82,7 @@ llm("Tell me a joke")
```python
# The second time it is, so it goes faster
llm("Tell me a joke")
llm.predict("Tell me a joke")
```
<CodeOutputBlock lang="python">

View File

@@ -0,0 +1 @@
"""Gmail toolkit."""

View File

@@ -0,0 +1,38 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List
from pydantic import Field
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
from langchain.tools.office365.create_draft_message import O365CreateDraftMessage
from langchain.tools.office365.events_search import O365SearchEvents
from langchain.tools.office365.messages_search import O365SearchEmails
from langchain.tools.office365.send_event import O365SendEvent
from langchain.tools.office365.send_message import O365SendMessage
from langchain.tools.office365.utils import authenticate
if TYPE_CHECKING:
from O365 import Account
class O365Toolkit(BaseToolkit):
"""Toolkit for interacting with Office365."""
account: Account = Field(default_factory=authenticate)
class Config:
"""Pydantic config."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
O365SearchEvents(account=self.account),
O365CreateDraftMessage(account=self.account),
O365SearchEmails(account=self.account),
O365SendEvent(account=self.account),
O365SendMessage(account=self.account),
]

View File

@@ -2,6 +2,8 @@ from enum import Enum
class AgentType(str, Enum):
"""Enumerator with the Agent types."""
ZERO_SHOT_REACT_DESCRIPTION = "zero-shot-react-description"
REACT_DOCSTORE = "react-docstore"
SELF_ASK_WITH_SEARCH = "self-ask-with-search"

View File

@@ -17,13 +17,15 @@ class ChatOutputParser(AgentOutputParser):
try:
action = text.split("```")[1]
response = json.loads(action.strip())
includes_action = "action" in response and "action_input" in response
includes_action = "action" in response
if includes_answer and includes_action:
raise OutputParserException(
"Parsing LLM output produced a final answer "
f"and a parse-able action: {text}"
)
return AgentAction(response["action"], response["action_input"], text)
return AgentAction(
response["action"], response.get("action_input", {}), text
)
except Exception:
if not includes_answer:

View File

@@ -51,7 +51,7 @@ def initialize_agent(
f"Got unknown agent type: {agent}. "
f"Valid types are: {AGENT_TO_CLASS.keys()}."
)
tags_.append(agent.value)
tags_.append(agent.value if isinstance(agent, AgentType) else agent)
agent_cls = AGENT_TO_CLASS[agent]
agent_kwargs = agent_kwargs or {}
agent_obj = agent_cls.from_llm_and_tools(

View File

@@ -352,10 +352,23 @@ def load_huggingface_tool(
remote: bool = False,
**kwargs: Any,
) -> BaseTool:
"""Loads a tool from the HuggingFace Hub.
Args:
task_or_repo_id: Task or model repo id.
model_repo_id: Optional model repo id.
token: Optional token.
remote: Optional remote. Defaults to False.
**kwargs:
Returns:
A tool.
"""
try:
from transformers import load_tool
except ImportError:
raise ValueError(
raise ImportError(
"HuggingFace tools require the libraries `transformers>=4.29.0`"
" and `huggingface_hub>=0.14.1` to be installed."
" Please install it with"

View File

@@ -296,7 +296,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
predicted_message = await self.llm.apredict_messages(
messages, functions=self.functions
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import hashlib
import inspect
import json
import logging
from abc import ABC, abstractmethod
from datetime import timedelta
from typing import (
@@ -11,8 +12,8 @@ from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
@@ -31,13 +32,17 @@ except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain.embeddings.base import Embeddings
from langchain.load.dump import dumps
from langchain.load.load import loads
from langchain.schema import Generation
from langchain.vectorstores.redis import Redis as RedisVectorstore
logger = logging.getLogger(__file__)
if TYPE_CHECKING:
import momento
RETURN_VAL_TYPE = List[Generation]
RETURN_VAL_TYPE = Sequence[Generation]
def _hash(_input: str) -> str:
@@ -147,13 +152,24 @@ class SQLAlchemyCache(BaseCache):
with Session(self.engine) as session:
rows = session.execute(stmt).fetchall()
if rows:
return [Generation(text=row[0]) for row in rows]
try:
return [loads(row[0]) for row in rows]
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "
"properly. This is likely due to the cache being in an "
"older format. Please recreate your cache to avoid this "
"error."
)
# In a previous life we stored the raw text directly
# in the table, so assume it's in that format.
return [Generation(text=row[0]) for row in rows]
return None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update based on prompt and llm_string."""
items = [
self.cache_schema(prompt=prompt, llm=llm_string, response=gen.text, idx=i)
self.cache_schema(prompt=prompt, llm=llm_string, response=dumps(gen), idx=i)
for i, gen in enumerate(return_val)
]
with Session(self.engine) as session, session.begin():
@@ -163,7 +179,7 @@ class SQLAlchemyCache(BaseCache):
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
with Session(self.engine) as session:
session.execute(self.cache_schema.delete())
session.query(self.cache_schema).delete()
class SQLiteCache(SQLAlchemyCache):
@@ -209,6 +225,12 @@ class RedisCache(BaseCache):
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"RedisCache only supports caching of normal LLM generations, "
f"got {type(gen)}"
)
# Write to a Redis HASH
key = self._key(prompt, llm_string)
self.redis.hset(
@@ -314,6 +336,12 @@ class RedisSemanticCache(BaseCache):
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"RedisSemanticCache only supports caching of "
f"normal LLM generations, got {type(gen)}"
)
llm_cache = self._get_llm_cache(llm_string)
# Write to vectorstore
metadata = {
@@ -426,6 +454,12 @@ class GPTCache(BaseCache):
First, retrieve the corresponding cache object using the `llm_string` parameter,
and then store the `prompt` and `return_val` in the cache object.
"""
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"GPTCache only supports caching of normal LLM generations, "
f"got {type(gen)}"
)
from gptcache.adapter.api import put
_gptcache = self._get_gptcache(llm_string)
@@ -567,7 +601,7 @@ class MomentoCache(BaseCache):
"""
from momento.responses import CacheGet
generations = []
generations: RETURN_VAL_TYPE = []
get_response = self.cache_client.get(
self.cache_name, self.__key(prompt, llm_string)
@@ -593,6 +627,12 @@ class MomentoCache(BaseCache):
SdkException: Momento service or network error
Exception: Unexpected response
"""
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"Momento only supports caching of normal LLM generations, "
f"got {type(gen)}"
)
key = self.__key(prompt, llm_string)
value = _dump_generations_to_json(return_val)
set_response = self.cache_client.set(self.cache_name, key, value, self.ttl)

View File

@@ -6,6 +6,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult
def import_aim() -> Any:
"""Import the aim python package and raise an error if it is not installed."""
try:
import aim
except ImportError:

View File

@@ -17,6 +17,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult
def import_clearml() -> Any:
"""Import the clearml python package and raise an error if it is not installed."""
try:
import clearml # noqa: F401
except ImportError:

View File

@@ -672,66 +672,72 @@ class CallbackManager(BaseCallbackManager):
self,
serialized: Dict[str, Any],
prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForLLMRun:
) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
managers = []
for prompt in prompts:
run_id_ = uuid4()
_handle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
[prompt],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
_handle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
prompts,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
managers.append(
CallbackManagerForLLMRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
)
return CallbackManagerForLLMRun(
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
return managers
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForLLMRun:
) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
_handle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
# Re-use the LLM Run Manager since the outputs are treated
# the same for now
return CallbackManagerForLLMRun(
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
managers = []
for message_list in messages:
run_id_ = uuid4()
_handle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
[message_list],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
managers.append(
CallbackManagerForLLMRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
)
return managers
def on_chain_start(
self,
@@ -830,64 +836,84 @@ class AsyncCallbackManager(BaseCallbackManager):
self,
serialized: Dict[str, Any],
prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForLLMRun:
) -> List[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
await _ahandle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
prompts,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
tasks = []
managers = []
return AsyncCallbackManagerForLLMRun(
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
for prompt in prompts:
run_id_ = uuid4()
tasks.append(
_ahandle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
[prompt],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
)
managers.append(
AsyncCallbackManagerForLLMRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
)
await asyncio.gather(*tasks)
return managers
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if run_id is None:
run_id = uuid4()
tasks = []
managers = []
await _ahandle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
for message_list in messages:
run_id_ = uuid4()
return AsyncCallbackManagerForLLMRun(
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
tasks.append(
_ahandle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
[message_list],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
)
managers.append(
AsyncCallbackManagerForLLMRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
)
await asyncio.gather(*tasks)
return managers
async def on_chain_start(
self,

View File

@@ -20,6 +20,7 @@ from langchain.utils import get_from_dict_or_env
def import_mlflow() -> Any:
"""Import the mlflow python package and raise an error if it is not installed."""
try:
import mlflow
except ImportError:
@@ -117,7 +118,7 @@ class MlflowLogger:
Parameters:
name (str): Name of the run.
experiment (str): Name of the experiment.
tags (str): Tags to be attached for the run.
tags (dict): Tags to be attached for the run.
tracking_uri (str): MLflow tracking server uri.
This handler implements the helper functions to initialize,
@@ -222,7 +223,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
Parameters:
name (str): Name of the run.
experiment (str): Name of the experiment.
tags (str): Tags to be attached for the run.
tags (dict): Tags to be attached for the run.
tracking_uri (str): MLflow tracking server uri.
This handler will utilize the associated callback method called and formats

View File

@@ -1,8 +1,8 @@
"""Callback Handler that prints to std out."""
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import LLMResult
MODEL_COST_PER_1K_TOKENS = {
# GPT-4 input
@@ -32,6 +32,7 @@ MODEL_COST_PER_1K_TOKENS = {
"gpt-3.5-turbo-16k-completion": 0.004,
"gpt-3.5-turbo-16k-0613-completion": 0.004,
# Others
"gpt-35-turbo": 0.002, # Azure OpenAI version of ChatGPT
"text-ada-001": 0.0004,
"ada": 0.0004,
"text-babbage-001": 0.0005,
@@ -52,6 +53,17 @@ def standardize_model_name(
model_name: str,
is_completion: bool = False,
) -> str:
"""
Standardize the model name to a format that can be used in the OpenAI API.
Args:
model_name: Model name to standardize.
is_completion: Whether the model is used for completion or not.
Defaults to False.
Returns:
Standardized model name.
"""
model_name = model_name.lower()
if "ft-" in model_name:
return model_name.split(":")[0] + "-finetuned"
@@ -66,6 +78,18 @@ def standardize_model_name(
def get_openai_token_cost_for_model(
model_name: str, num_tokens: int, is_completion: bool = False
) -> float:
"""
Get the cost in USD for a given model and number of tokens.
Args:
model_name: Name of the model
num_tokens: Number of tokens.
is_completion: Whether the model is used for completion or not.
Defaults to False.
Returns:
Cost in USD.
"""
model_name = standardize_model_name(model_name, is_completion=is_completion)
if model_name not in MODEL_COST_PER_1K_TOKENS:
raise ValueError(
@@ -129,64 +153,6 @@ class OpenAICallbackHandler(BaseCallbackHandler):
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
pass
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
pass
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
pass
def __copy__(self) -> "OpenAICallbackHandler":
"""Return a copy of the callback handler."""
return self

View File

@@ -7,6 +7,16 @@ from langchain.input import get_bolded_text, get_colored_text
def try_json_stringify(obj: Any, fallback: str) -> str:
"""
Try to stringify an object to JSON.
Args:
obj: Object to stringify.
fallback: Fallback string to return if the object cannot be stringified.
Returns:
A JSON string if the object can be stringified, otherwise the fallback string.
"""
try:
return json.dumps(obj, indent=2, ensure_ascii=False)
except Exception:
@@ -14,6 +24,16 @@ def try_json_stringify(obj: Any, fallback: str) -> str:
def elapsed(run: Any) -> str:
"""Get the elapsed time of a run.
Args:
run: any object with a start_time and end_time attribute.
Returns:
A string with the elapsed time in seconds or
milliseconds if time is less than a second.
"""
elapsed_time = run.end_time - run.start_time
milliseconds = elapsed_time.total_seconds() * 1000
if milliseconds < 1000:
@@ -22,6 +42,8 @@ def elapsed(run: Any) -> str:
class ConsoleCallbackHandler(BaseTracer):
"""Tracer that prints to the console."""
name = "console_callback_handler"
def _persist_run(self, run: Run) -> None:

View File

@@ -137,6 +137,8 @@ def _replace_type_with_kind(data: Any) -> Any:
class WandbRunArgs(TypedDict):
"""Arguments for the WandbTracer."""
job_type: Optional[str]
dir: Optional[StrPath]
config: Union[Dict, str, None]

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, Iterable, Tuple, Union
def import_spacy() -> Any:
"""Import the spacy python package and raise an error if it is not installed."""
try:
import spacy
except ImportError:
@@ -15,6 +16,7 @@ def import_spacy() -> Any:
def import_pandas() -> Any:
"""Import the pandas python package and raise an error if it is not installed."""
try:
import pandas
except ImportError:
@@ -26,6 +28,7 @@ def import_pandas() -> Any:
def import_textstat() -> Any:
"""Import the textstat python package and raise an error if it is not installed."""
try:
import textstat
except ImportError:

View File

@@ -17,6 +17,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult
def import_wandb() -> Any:
"""Import the wandb python package and raise an error if it is not installed."""
try:
import wandb # noqa: F401
except ImportError:

View File

@@ -18,6 +18,16 @@ def import_langkit(
toxicity: bool = False,
themes: bool = False,
) -> Any:
"""Import the langkit python package and raise an error if it is not installed.
Args:
sentiment: Whether to import the langkit.sentiment module. Defaults to False.
toxicity: Whether to import the langkit.toxicity module. Defaults to False.
themes: Whether to import the langkit.themes module. Defaults to False.
Returns:
The imported langkit module.
"""
try:
import langkit # noqa: F401
import langkit.regexes # noqa: F401

View File

@@ -18,6 +18,14 @@ INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def extract_cypher(text: str) -> str:
"""
Extract Cypher code from a text.
Args:
text: Text to extract Cypher code from.
Returns:
Cypher code extracted from the text.
"""
# The pattern to find Cypher code enclosed in triple backticks
pattern = r"```(.*?)```"

View File

@@ -34,6 +34,8 @@ black_listed_elements: Set[str] = {
class ElementInViewPort(TypedDict):
"""A typed dictionary containing information about elements in the viewport."""
node_index: str
backend_node_id: int
node_name: Optional[str]
@@ -51,7 +53,7 @@ class Crawler:
try:
from playwright.sync_api import sync_playwright
except ImportError:
raise ValueError(
raise ImportError(
"Could not import playwright python package. "
"Please install it with `pip install playwright`."
)

View File

@@ -64,6 +64,14 @@ class QuestionAnswer(BaseModel):
def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
"""Create a citation fuzzy match chain.
Args:
llm: Language model to use for the chain.
Returns:
Chain (LLMChain) that can be used to answer questions with citations.
"""
output_parser = PydanticOutputFunctionsParser(pydantic_schema=QuestionAnswer)
schema = QuestionAnswer.schema()
function = {

View File

@@ -40,6 +40,15 @@ Passage:
def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
"""Creates a chain that extracts information from a passage.
Args:
schema: The schema of the entities to extract.
llm: The language model to use.
Returns:
Chain that can be used to extract information from a passage.
"""
function = _get_extraction_function(schema)
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
output_parser = JsonKeyOutputFunctionsParser(key_name="info")
@@ -56,6 +65,16 @@ def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
def create_extraction_chain_pydantic(
pydantic_schema: Any, llm: BaseLanguageModel
) -> Chain:
"""Creates a chain that extracts information from a passage using pydantic schema.
Args:
pydantic_schema: The pydantic schema of the entities to extract.
llm: The language model to use.
Returns:
Chain that can be used to extract information from a passage.
"""
class PydanticSchema(BaseModel):
info: List[pydantic_schema] # type: ignore

View File

@@ -13,6 +13,7 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.sequential import SequentialChain
from langchain.chat_models import ChatOpenAI
from langchain.input import get_colored_text
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.prompts import ChatPromptTemplate
from langchain.tools import APIOperation
@@ -77,7 +78,7 @@ def _openapi_params_to_json_schema(params: List[Parameter], spec: OpenAPISpec) -
schema = spec.get_schema(media_type_schema)
if p.description and not schema.description:
schema.description = p.description
properties[p.name] = schema.dict(exclude_none=True)
properties[p.name] = json.loads(schema.json(exclude_none=True))
if p.required:
required.append(p.name)
return {"type": "object", "properties": properties, "required": required}
@@ -127,15 +128,19 @@ def openapi_spec_to_openai_fn(
request_body = spec.get_request_body_for_operation(op)
# TODO: Support more MIME types.
if request_body and request_body.content:
media_types = []
for media_type in request_body.content.values():
if media_type.media_type_schema:
schema = spec.get_schema(media_type.media_type_schema)
media_types.append(schema.dict(exclude_none=True))
media_types = {}
for media_type, media_type_object in request_body.content.items():
if media_type_object.media_type_schema:
schema = spec.get_schema(media_type_object.media_type_schema)
media_types[media_type] = json.loads(
schema.json(exclude_none=True)
)
if len(media_types) == 1:
request_args["data"] = media_types[0]
media_type, schema_dict = list(media_types.items())[0]
key = "json" if media_type == "application/json" else "data"
request_args[key] = schema_dict
elif len(media_types) > 1:
request_args["data"] = {"anyOf": media_types}
request_args["data"] = {"anyOf": list(media_types.values())}
api_op = APIOperation.from_openapi_spec(spec, path, method)
fn = {
@@ -184,8 +189,13 @@ class SimpleRequestChain(Chain):
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
name = inputs["function"].pop("name")
args = inputs["function"].pop("arguments")
_pretty_name = get_colored_text(name, "green")
_pretty_args = get_colored_text(json.dumps(args, indent=2), "green")
_text = f"Calling endpoint {_pretty_name} with arguments:\n" + _pretty_args
_run_manager.on_text(_text)
api_response: Response = self.request_method(name, args)
if api_response.status_code != 200:
response = (
@@ -206,6 +216,9 @@ def get_openapi_chain(
llm: Optional[BaseLanguageModel] = None,
prompt: Optional[BasePromptTemplate] = None,
request_chain: Optional[Chain] = None,
llm_kwargs: Optional[Dict] = None,
verbose: bool = False,
**kwargs: Any,
) -> SequentialChain:
"""Create a chain for querying an API from a OpenAPI spec.
@@ -242,10 +255,16 @@ def get_openapi_chain(
llm_kwargs={"functions": openai_fns},
output_parser=JsonOutputFunctionsParser(args_only=False),
output_key="function",
verbose=verbose,
**(llm_kwargs or {}),
)
request_chain = request_chain or SimpleRequestChain(
request_method=call_api_fn, verbose=verbose
)
request_chain = request_chain or SimpleRequestChain(request_method=call_api_fn)
return SequentialChain(
chains=[llm_chain, request_chain],
input_variables=llm_chain.input_keys,
output_variables=["response"],
verbose=verbose,
**kwargs,
)

View File

@@ -29,6 +29,18 @@ def create_qa_with_structure_chain(
output_parser: str = "base",
prompt: Optional[Union[PromptTemplate, ChatPromptTemplate]] = None,
) -> LLMChain:
"""Create a question answering chain that returns an answer with sources.
Args:
llm: Language model to use for the chain.
schema: Pydantic schema to use for the output.
output_parser: Output parser to use. Should be one of `pydantic` or `base`.
Default to `base`.
prompt: Optional prompt to use for the chain.
Returns:
"""
if output_parser == "pydantic":
if not (isinstance(schema, type) and issubclass(schema, BaseModel)):
raise ValueError(
@@ -79,4 +91,13 @@ def create_qa_with_structure_chain(
def create_qa_with_sources_chain(llm: BaseLanguageModel, **kwargs: Any) -> LLMChain:
"""Create a question answering chain that returns an answer with sources.
Args:
llm: Language model to use for the chain.
**kwargs: Keyword arguments to pass to `create_qa_with_structure_chain`.
Returns:
Chain (LLMChain) that can be used to answer questions with citations.
"""
return create_qa_with_structure_chain(llm, AnswerWithSources, **kwargs)

View File

@@ -27,6 +27,15 @@ Passage:
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
"""Creates a chain that extracts information from a passage.
Args:
schema: The schema of the entities to extract.
llm: The language model to use.
Returns:
Chain (LLMChain) that can be used to extract information from a passage.
"""
function = _get_tagging_function(schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
output_parser = JsonOutputFunctionsParser()
@@ -43,6 +52,15 @@ def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
def create_tagging_chain_pydantic(
pydantic_schema: Any, llm: BaseLanguageModel
) -> Chain:
"""Creates a chain that extracts information from a passage.
Args:
pydantic_schema: The pydantic schema of the entities to extract.
llm: The language model to use.
Returns:
Chain (LLMChain) that can be used to extract information from a passage.
"""
openai_schema = pydantic_schema.schema()
function = _get_tagging_function(openai_schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)

View File

@@ -29,4 +29,12 @@ def _convert_schema(schema: dict) -> dict:
def get_llm_kwargs(function: dict) -> dict:
"""Returns the kwargs for the LLMChain constructor.
Args:
function: The function to use.
Returns:
The kwargs for the LLMChain constructor.
"""
return {"functions": [function], "function_call": {"name": function["name"]}}

View File

@@ -31,8 +31,24 @@ class ConditionalPromptSelector(BasePromptSelector):
def is_llm(llm: BaseLanguageModel) -> bool:
"""Check if the language model is a LLM.
Args:
llm: Language model to check.
Returns:
True if the language model is a BaseLLM model, False otherwise.
"""
return isinstance(llm, BaseLLM)
def is_chat_model(llm: BaseLanguageModel) -> bool:
"""Check if the language model is a chat model.
Args:
llm: Language model to check.
Returns:
True if the language model is a BaseChatModel model, False otherwise.
"""
return isinstance(llm, BaseChatModel)

View File

@@ -123,6 +123,22 @@ def load_query_constructor_chain(
enable_limit: bool = False,
**kwargs: Any,
) -> LLMChain:
"""
Load a query constructor chain.
Args:
llm: BaseLanguageModel to use for the chain.
document_contents: The contents of the document to be queried.
attribute_info: A list of AttributeInfo objects describing
the attributes of the document.
examples: Optional list of examples to use for the chain.
allowed_comparators: An optional list of allowed comparators.
allowed_operators: An optional list of allowed operators.
enable_limit: Whether to enable the limit operator. Defaults to False.
**kwargs:
Returns:
A LLMChain that can be used to construct queries.
"""
prompt = _get_prompt(
document_contents,
attribute_info,

View File

@@ -60,12 +60,16 @@ class Expr(BaseModel):
class Operator(str, Enum):
"""Enumerator of the operations."""
AND = "and"
OR = "or"
NOT = "not"
class Comparator(str, Enum):
"""Enumerator of the comparison operators."""
EQ = "eq"
GT = "gt"
GTE = "gte"

View File

@@ -57,6 +57,9 @@ GRAMMAR = """
@v_args(inline=True)
class QueryTransformer(Transformer):
"""Transforms a query string into an IR representation
(intermediate representation)."""
def __init__(
self,
*args: Any,
@@ -136,6 +139,16 @@ def get_parser(
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
) -> Lark:
"""
Returns a parser for the query language.
Args:
allowed_comparators: Optional[Sequence[Comparator]]
allowed_operators: Optional[Sequence[Operator]]
Returns:
Lark parser for the query language.
"""
transformer = QueryTransformer(
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
)

View File

@@ -1,5 +1,6 @@
from langchain.chat_models.anthropic import ChatAnthropic
from langchain.chat_models.azure_openai import AzureChatOpenAI
from langchain.chat_models.fake import FakeListChatModel
from langchain.chat_models.google_palm import ChatGooglePalm
from langchain.chat_models.openai import ChatOpenAI
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
@@ -8,6 +9,7 @@ from langchain.chat_models.vertexai import ChatVertexAI
__all__ = [
"ChatOpenAI",
"AzureChatOpenAI",
"FakeListChatModel",
"PromptLayerChatOpenAI",
"ChatAnthropic",
"ChatGooglePalm",

View File

@@ -17,7 +17,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.load.dump import dumpd
from langchain.load.dump import dumpd, dumps
from langchain.schema import (
AIMessage,
BaseMessage,
@@ -35,6 +35,7 @@ def _get_verbosity() -> bool:
class BaseChatModel(BaseLanguageModel, ABC):
cache: Optional[bool] = None
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
@@ -61,6 +62,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return {}
def _get_invocation_params(
self,
stop: Optional[List[str]] = None,
) -> dict:
params = self.dict()
params["stop"] = stop
return params
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
if self.lc_serializable:
params = {**kwargs, **{"stop": stop}}
param_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = dumps(self)
return llm_string + "---" + param_string
else:
params = self._get_invocation_params(stop=stop)
params = {**params, **kwargs}
return str(sorted([(k, v) for k, v in params.items()]))
def generate(
self,
messages: List[List[BaseMessage]],
@@ -71,9 +91,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
params = self.dict()
params["stop"] = stop
params = self._get_invocation_params(stop=stop)
options = {"stop": stop}
callback_manager = CallbackManager.configure(
@@ -83,29 +101,37 @@ class BaseChatModel(BaseLanguageModel, ABC):
tags,
self.tags,
)
run_manager = callback_manager.on_chat_model_start(
run_managers = callback_manager.on_chat_model_start(
dumpd(self), messages, invocation_params=params, options=options
)
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
try:
results = [
self._generate(m, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported
else self._generate(m, stop=stop)
for m in messages
]
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
raise e
results = []
for i, m in enumerate(messages):
try:
results.append(
self._generate_with_cache(
m,
stop=stop,
run_manager=run_managers[i] if run_managers else None,
**kwargs,
)
)
except (KeyboardInterrupt, Exception) as e:
if run_managers:
run_managers[i].on_llm_error(e)
raise e
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
if run_managers:
run_infos = []
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
run_infos.append(RunInfo(run_id=manager.run_id))
output.run = run_infos
return output
async def agenerate(
@@ -118,8 +144,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
params = self.dict()
params["stop"] = stop
params = self._get_invocation_params(stop=stop)
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure(
@@ -129,31 +154,62 @@ class BaseChatModel(BaseLanguageModel, ABC):
tags,
self.tags,
)
run_manager = await callback_manager.on_chat_model_start(
run_managers = await callback_manager.on_chat_model_start(
dumpd(self), messages, invocation_params=params, options=options
)
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
results = await asyncio.gather(
*[
self._agenerate_with_cache(
m,
stop=stop,
run_manager=run_managers[i] if run_managers else None,
**kwargs,
)
for i, m in enumerate(messages)
],
return_exceptions=True,
)
try:
results = await asyncio.gather(
*[
self._agenerate(m, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported
else self._agenerate(m, stop=stop)
for m in messages
]
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
raise e
exceptions = []
for i, res in enumerate(results):
if isinstance(res, Exception):
if run_managers:
await run_managers[i].on_llm_error(res)
exceptions.append(res)
if exceptions:
if run_managers:
await asyncio.gather(
*[
run_manager.on_llm_end(
LLMResult(
generations=[res.generations], llm_output=res.llm_output
)
)
for run_manager, res in zip(run_managers, results)
if not isinstance(res, Exception)
]
)
raise exceptions[0]
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
await run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
)
]
)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
def generate_prompt(
@@ -178,6 +234,84 @@ class BaseChatModel(BaseLanguageModel, ABC):
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
)
def _generate_with_cache(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
if new_arg_supported:
return self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
return self._generate(messages, stop=stop, **kwargs)
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
if new_arg_supported:
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
result = self._generate(messages, stop=stop, **kwargs)
langchain.llm_cache.update(prompt, llm_string, result.generations)
return result
async def _agenerate_with_cache(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
if new_arg_supported:
return await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
return await self._agenerate(messages, stop=stop, **kwargs)
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
if new_arg_supported:
result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
result = await self._agenerate(messages, stop=stop, **kwargs)
langchain.llm_cache.update(prompt, llm_string, result.generations)
return result
@abstractmethod
def _generate(
self,

View File

@@ -0,0 +1,33 @@
"""Fake ChatModel for testing purposes."""
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.schema import BaseMessage
class FakeListChatModel(SimpleChatModel):
"""Fake ChatModel for testing purposes."""
responses: List
i: int = 0
@property
def _llm_type(self) -> str:
return "fake-list-chat-model"
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
response = self.responses[self.i]
self.i += 1
return response
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"responses": self.responses}

View File

@@ -36,6 +36,8 @@ logger = logging.getLogger(__name__)
class ChatGooglePalmError(Exception):
"""Error raised when there is an issue with the Google PaLM API."""
pass

View File

@@ -184,6 +184,16 @@ class ChatOpenAI(BaseChatModel):
"""Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain
them to be under a certain limit. By default, when set to None, this will
be the same as the embedding model name. However, there are some cases
where you may want to use this Embedding class with a model name not
supported by tiktoken. This can include when using Azure embeddings or
when using one of the many model providers that expose an OpenAI-like
API but with different models. In those cases, in order to avoid erroring
when tiktoken is called, you can specify a model name to use here."""
class Config:
"""Configuration for this pydantic object."""
@@ -448,15 +458,18 @@ class ChatOpenAI(BaseChatModel):
def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
tiktoken_ = _import_tiktoken()
model = self.model_name
if model == "gpt-3.5-turbo":
# gpt-3.5-turbo may change over time.
# Returning num tokens assuming gpt-3.5-turbo-0301.
model = "gpt-3.5-turbo-0301"
elif model == "gpt-4":
# gpt-4 may change over time.
# Returning num tokens assuming gpt-4-0314.
model = "gpt-4-0314"
if self.tiktoken_model_name is not None:
model = self.tiktoken_model_name
else:
model = self.model_name
if model == "gpt-3.5-turbo":
# gpt-3.5-turbo may change over time.
# Returning num tokens assuming gpt-3.5-turbo-0301.
model = "gpt-3.5-turbo-0301"
elif model == "gpt-4":
# gpt-4 may change over time.
# Returning num tokens assuming gpt-4-0314.
model = "gpt-4-0314"
# Returns the number of tokens used by a list of messages.
try:
encoding = tiktoken_.encoding_for_model(model)

View File

@@ -68,6 +68,7 @@ from langchain.document_loaders.mastodon import MastodonTootsLoader
from langchain.document_loaders.max_compute import MaxComputeLoader
from langchain.document_loaders.mediawikidump import MWDumpLoader
from langchain.document_loaders.merge import MergedDataLoader
from langchain.document_loaders.mhtml import MHTMLLoader
from langchain.document_loaders.modern_treasury import ModernTreasuryLoader
from langchain.document_loaders.notebook import NotebookLoader
from langchain.document_loaders.notion import NotionDirectoryLoader
@@ -97,6 +98,7 @@ from langchain.document_loaders.readthedocs import ReadTheDocsLoader
from langchain.document_loaders.recursive_url_loader import RecusiveUrlLoader
from langchain.document_loaders.reddit import RedditPostsLoader
from langchain.document_loaders.roam import RoamLoader
from langchain.document_loaders.rst import UnstructuredRSTLoader
from langchain.document_loaders.rtf import UnstructuredRTFLoader
from langchain.document_loaders.s3_directory import S3DirectoryLoader
from langchain.document_loaders.s3_file import S3FileLoader
@@ -204,6 +206,7 @@ __all__ = [
"MathpixPDFLoader",
"MaxComputeLoader",
"MergedDataLoader",
"MHTMLLoader",
"ModernTreasuryLoader",
"NotebookLoader",
"NotionDBLoader",
@@ -261,6 +264,7 @@ __all__ = [
"UnstructuredODTLoader",
"UnstructuredPDFLoader",
"UnstructuredPowerPointLoader",
"UnstructuredRSTLoader",
"UnstructuredRTFLoader",
"UnstructuredURLLoader",
"UnstructuredWordDocumentLoader",

View File

@@ -11,6 +11,8 @@ from langchain.document_loaders.base import BaseLoader
class BlockchainType(Enum):
"""Enumerator of the supported blockchains."""
ETH_MAINNET = "eth-mainnet"
ETH_GOERLI = "eth-goerli"
POLYGON_MAINNET = "polygon-mainnet"

View File

@@ -8,6 +8,15 @@ from langchain.document_loaders.base import BaseLoader
def concatenate_rows(message: dict, title: str) -> str:
"""
Combine message information in a readable format ready to be used.
Args:
message: Message to be concatenated
title: Title of the conversation
Returns:
Concatenated message
"""
if not message:
return ""

View File

@@ -18,6 +18,8 @@ logger = logging.getLogger(__name__)
class ContentFormat(str, Enum):
"""Enumerator of the content formats of Confluence page."""
STORAGE = "body.storage"
VIEW = "body.view"

View File

@@ -45,6 +45,8 @@ class EmbaasDocumentExtractionParameters(TypedDict):
class EmbaasDocumentExtractionPayload(EmbaasDocumentExtractionParameters):
"""Payload for the Embaas document extraction API."""
bytes: str
"""The base64 encoded bytes of the document to extract text from."""

View File

@@ -5,7 +5,8 @@ from langchain.document_loaders.base import BaseLoader
class FaunaLoader(BaseLoader):
"""
"""FaunaDB Loader.
Attributes:
query (str): The FQL query string to execute.
page_content_field (str): The field that contains the content of each page.

View File

@@ -17,6 +17,8 @@ IUGU_ENDPOINTS = {
class IuguLoader(BaseLoader):
"""Loader that fetches data from IUGU."""
def __init__(self, resource: str, api_token: Optional[str] = None) -> None:
self.resource = resource
api_token = api_token or get_from_env("api_token", "IUGU_API_TOKEN")

View File

@@ -0,0 +1,69 @@
"""Loader to load MHTML files, enriching metadata with page title."""
import email
import logging
from typing import Dict, List, Union
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
logger = logging.getLogger(__name__)
class MHTMLLoader(BaseLoader):
"""Loader that uses beautiful soup to parse HTML files."""
def __init__(
self,
file_path: str,
open_encoding: Union[str, None] = None,
bs_kwargs: Union[dict, None] = None,
get_text_separator: str = "",
) -> None:
"""Initialise with path, and optionally, file encoding to use, and any kwargs
to pass to the BeautifulSoup object."""
try:
import bs4 # noqa:F401
except ImportError:
raise ValueError(
"beautifulsoup4 package not found, please install it with "
"`pip install beautifulsoup4`"
)
self.file_path = file_path
self.open_encoding = open_encoding
if bs_kwargs is None:
bs_kwargs = {"features": "lxml"}
self.bs_kwargs = bs_kwargs
self.get_text_separator = get_text_separator
def load(self) -> List[Document]:
from bs4 import BeautifulSoup
"""Load MHTML document into document objects."""
with open(self.file_path, "r", encoding=self.open_encoding) as f:
message = email.message_from_string(f.read())
parts = message.get_payload()
if type(parts) is not list:
parts = [message]
for part in parts:
if part.get_content_type() == "text/html":
html = part.get_payload(decode=True).decode()
soup = BeautifulSoup(html, **self.bs_kwargs)
text = soup.get_text(self.get_text_separator)
if soup.title:
title = str(soup.title.string)
else:
title = ""
metadata: Dict[str, Union[str, None]] = {
"source": self.file_path,
"title": title,
}
return [Document(page_content=text, metadata=metadata)]
return []

View File

@@ -27,6 +27,8 @@ incoming_payment_details",
class ModernTreasuryLoader(BaseLoader):
"""Loader that fetches data from Modern Treasury."""
def __init__(
self,
resource: str,

View File

@@ -48,13 +48,13 @@ class NotionDBLoader(BaseLoader):
Returns:
List[Document]: List of documents.
"""
page_ids = self._retrieve_page_ids()
page_summaries = self._retrieve_page_summaries()
return list(self.load_page(page_id) for page_id in page_ids)
return list(self.load_page(page_summary) for page_summary in page_summaries)
def _retrieve_page_ids(
def _retrieve_page_summaries(
self, query_dict: Dict[str, Any] = {"page_size": 100}
) -> List[str]:
) -> List[Dict[str, Any]]:
"""Get all the pages from a Notion database."""
pages: List[Dict[str, Any]] = []
@@ -72,18 +72,16 @@ class NotionDBLoader(BaseLoader):
query_dict["start_cursor"] = data.get("next_cursor")
page_ids = [page["id"] for page in pages]
return pages
return page_ids
def load_page(self, page_id: str) -> Document:
def load_page(self, page_summary: Dict[str, Any]) -> Document:
"""Read a page."""
data = self._request(PAGE_URL.format(page_id=page_id))
page_id = page_summary["id"]
# load properties as metadata
metadata: Dict[str, Any] = {}
for prop_name, prop_data in data["properties"].items():
for prop_name, prop_data in page_summary["properties"].items():
prop_type = prop_data["type"]
if prop_type == "rich_text":

View File

@@ -7,6 +7,8 @@ from langchain.schema import Document
class TextParser(BaseBlobParser):
"""Parser for text blobs."""
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
yield Document(page_content=blob.as_string(), metadata={"source": blob.source})

View File

@@ -2,7 +2,6 @@ from typing import Iterator, List, Optional, Set
from urllib.parse import urlparse
import requests
from bs4 import BeautifulSoup
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
@@ -21,6 +20,13 @@ class RecusiveUrlLoader(BaseLoader):
) -> Set[str]:
"""Recursively get all child links starting with the path of the input URL."""
try:
from bs4 import BeautifulSoup
except ImportError:
raise ImportError(
"The BeautifulSoup package is required for the RecusiveUrlLoader."
)
# Construct the base and parent URLs
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"

View File

@@ -0,0 +1,22 @@
"""Loader that loads RST files."""
from typing import Any, List
from langchain.document_loaders.unstructured import (
UnstructuredFileLoader,
validate_unstructured_version,
)
class UnstructuredRSTLoader(UnstructuredFileLoader):
"""Loader that uses unstructured to load RST files."""
def __init__(
self, file_path: str, mode: str = "single", **unstructured_kwargs: Any
):
validate_unstructured_version(min_unstructured_version="0.7.5")
super().__init__(file_path=file_path, mode=mode, **unstructured_kwargs)
def _get_elements(self) -> List:
from unstructured.partition.rst import partition_rst
return partition_rst(filename=self.file_path, **self.unstructured_kwargs)

View File

@@ -20,6 +20,8 @@ SPREEDLY_ENDPOINTS = {
class SpreedlyLoader(BaseLoader):
"""Loader that fetches data from Spreedly API."""
def __init__(self, access_token: str, resource: str) -> None:
self.access_token = access_token
self.resource = resource

View File

@@ -18,6 +18,8 @@ STRIPE_ENDPOINTS = {
class StripeLoader(BaseLoader):
"""Loader that fetches data from Stripe."""
def __init__(self, resource: str, access_token: Optional[str] = None) -> None:
self.resource = resource
access_token = access_token or get_from_env(

View File

@@ -16,6 +16,7 @@ class UnstructuredURLLoader(BaseLoader):
urls: List[str],
continue_on_failure: bool = True,
mode: str = "single",
show_progress_bar: bool = False,
**unstructured_kwargs: Any,
):
"""Initialize with file path."""
@@ -51,6 +52,7 @@ class UnstructuredURLLoader(BaseLoader):
self.continue_on_failure = continue_on_failure
self.headers = headers
self.unstructured_kwargs = unstructured_kwargs
self.show_progress_bar = show_progress_bar
def _validate_mode(self, mode: str) -> None:
_valid_modes = {"single", "elements"}
@@ -83,7 +85,21 @@ class UnstructuredURLLoader(BaseLoader):
from unstructured.partition.html import partition_html
docs: List[Document] = list()
for url in self.urls:
if self.show_progress_bar:
try:
from tqdm import tqdm
except ImportError as e:
raise ImportError(
"Package tqdm must be installed if show_progress_bar=True. "
"Please install with 'pip install tqdm' or set "
"show_progress_bar=False."
) from e
urls = tqdm(self.urls)
else:
urls = self.urls
for url in urls:
try:
if self.__is_non_html_available():
if self.__is_headers_available_for_non_html():

View File

@@ -50,6 +50,9 @@ class WebBaseLoader(BaseLoader):
requests_kwargs: Dict[str, Any] = {}
"""kwargs for requests"""
bs_get_text_kwargs: Dict[str, Any] = {}
"""kwargs for beatifulsoup4 get_text"""
def __init__(
self,
web_path: Union[str, List[str]],
@@ -201,7 +204,7 @@ class WebBaseLoader(BaseLoader):
"""Lazy load text from the url(s) in web_path."""
for path in self.web_paths:
soup = self._scrape(path)
text = soup.get_text()
text = soup.get_text(**self.bs_get_text_kwargs)
metadata = _build_metadata(soup, path)
yield Document(page_content=text, metadata=metadata)
@@ -216,7 +219,7 @@ class WebBaseLoader(BaseLoader):
docs = []
for i in range(len(results)):
soup = results[i]
text = soup.get_text()
text = soup.get_text(**self.bs_get_text_kwargs)
metadata = _build_metadata(soup, self.web_paths[i])
docs.append(Document(page_content=text, metadata=metadata))

View File

@@ -40,7 +40,7 @@ class WhatsAppChatLoader(BaseLoader):
(?:
:\d{2}
)?
(?:[ _](?:AM|PM))?
(?:[\s_](?:AM|PM))?
)
\]?
[\s-]*
@@ -50,7 +50,9 @@ class WhatsAppChatLoader(BaseLoader):
(.+)
"""
for line in lines:
result = re.match(message_line_regex, line.strip(), flags=re.VERBOSE)
result = re.match(
message_line_regex, line.strip(), flags=re.VERBOSE | re.IGNORECASE
)
if result:
date, sender, text = result.groups()
text_content += concatenate_rows(date, sender, text)

View File

@@ -18,17 +18,41 @@ class WikipediaLoader(BaseLoader):
lang: str = "en",
load_max_docs: Optional[int] = 100,
load_all_available_meta: Optional[bool] = False,
doc_content_chars_max: Optional[int] = 4000,
):
"""
Initializes a new instance of the WikipediaLoader class.
Args:
query (str): The query string to search on Wikipedia.
lang (str, optional): The language code for the Wikipedia language edition.
Defaults to "en".
load_max_docs (int, optional): The maximum number of documents to load.
Defaults to 100.
load_all_available_meta (bool, optional): Indicates whether to load all
available metadata for each document. Defaults to False.
doc_content_chars_max (int, optional): The maximum number of characters
for the document content. Defaults to 4000.
"""
self.query = query
self.lang = lang
self.load_max_docs = load_max_docs
self.load_all_available_meta = load_all_available_meta
self.doc_content_chars_max = doc_content_chars_max
def load(self) -> List[Document]:
"""
Loads the query result from Wikipedia into a list of Documents.
Returns:
List[Document]: A list of Document objects representing the loaded
Wikipedia pages.
"""
client = WikipediaAPIWrapper(
lang=self.lang,
top_k_results=self.load_max_docs,
load_all_available_meta=self.load_all_available_meta,
doc_content_chars_max=self.doc_content_chars_max,
)
docs = client.load(self.query)
return docs

View File

@@ -30,6 +30,14 @@ class _DocumentWithState(Document):
def get_stateful_documents(
documents: Sequence[Document],
) -> Sequence[_DocumentWithState]:
"""Convert a list of documents to a list of documents with state.
Args:
documents: The documents to convert.
Returns:
A list of documents with state.
"""
return [_DocumentWithState.from_document(doc) for doc in documents]

View File

@@ -170,6 +170,16 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout in seconds for the OpenAPI request."""
headers: Any = None
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain
them to be under a certain limit. By default, when set to None, this will
be the same as the embedding model name. However, there are some cases
where you may want to use this Embedding class with a model name not
supported by tiktoken. This can include when using Azure embeddings or
when using one of the many model providers that expose an OpenAI-like
API but with different models. In those cases, in order to avoid erroring
when tiktoken is called, you can specify a model name to use here."""
class Config:
"""Configuration for this pydantic object."""
@@ -265,7 +275,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
tokens = []
indices = []
encoding = tiktoken.model.encoding_for_model(self.model)
model_name = self.tiktoken_model_name or self.model
try:
encoding = tiktoken.encoding_for_model(model_name)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
for i, text in enumerate(texts):
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
@@ -329,7 +345,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
tokens = []
indices = []
encoding = tiktoken.model.encoding_for_model(self.model)
model_name = self.tiktoken_model_name or self.model
try:
encoding = tiktoken.encoding_for_model(model_name)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
for i, text in enumerate(texts):
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500

View File

@@ -18,7 +18,7 @@ These questions should be detailed and be based explicitly on information in the
{doc}
<End Document>"""
output_parser = RegexParser(
regex=r"QUESTION: (.*?)\nANSWER: (.*)", output_keys=["query", "answer"]
regex=r"QUESTION: (.*?)\n+ANSWER: (.*)", output_keys=["query", "answer"]
)
PROMPT = PromptTemplate(
input_variables=["doc"], template=template, output_parser=output_parser

View File

@@ -18,8 +18,17 @@ class BaseAutoGPTOutputParser(BaseOutputParser):
def preprocess_json_input(input_str: str) -> str:
# Replace single backslashes with double backslashes,
# while leaving already escaped ones intact
"""Preprocesses a string to be parsed as json.
Replace single backslashes with double backslashes,
while leaving already escaped ones intact.
Args:
input_str: String to be preprocessed
Returns:
Preprocessed string
"""
corrected_str = re.sub(
r'(?<!\\)\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', r"\\\\", input_str
)

View File

@@ -23,6 +23,18 @@ def load_agent_executor(
verbose: bool = False,
include_task_in_prompt: bool = False,
) -> ChainExecutor:
"""
Load an agent executor.
Args:
llm: BaseLanguageModel
tools: List[BaseTool]
verbose: bool. Defaults to False.
include_task_in_prompt: bool. Defaults to False.
Returns:
ChainExecutor
"""
input_variables = ["previous_steps", "current_step", "agent_scratchpad"]
template = HUMAN_MESSAGE_TEMPLATE

View File

@@ -11,13 +11,13 @@ from langchain.experimental.plan_and_execute.schema import Plan, PlanOutputParse
class BasePlanner(BaseModel):
@abstractmethod
def plan(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> Plan:
"""Given input, decided what to do."""
"""Given input, decide what to do."""
@abstractmethod
async def aplan(
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
) -> Plan:
"""Given input, decided what to do."""
"""Given input, decide what to do."""
class LLMPlanner(BasePlanner):
@@ -26,14 +26,14 @@ class LLMPlanner(BasePlanner):
stop: Optional[List] = None
def plan(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> Plan:
"""Given input, decided what to do."""
"""Given input, decide what to do."""
llm_response = self.llm_chain.run(**inputs, stop=self.stop, callbacks=callbacks)
return self.output_parser.parse(llm_response)
async def aplan(
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
) -> Plan:
"""Given input, decided what to do."""
"""Given input, decide what to do."""
llm_response = await self.llm_chain.arun(
**inputs, stop=self.stop, callbacks=callbacks
)

View File

@@ -32,6 +32,15 @@ class PlanningOutputParser(PlanOutputParser):
def load_chat_planner(
llm: BaseLanguageModel, system_prompt: str = SYSTEM_PROMPT
) -> LLMPlanner:
"""
Load a chat planner.
Args:
llm: Language model.
system_prompt: System prompt.
Returns:
LLMPlanner
"""
prompt_template = ChatPromptTemplate.from_messages(
[
SystemMessage(content=system_prompt),

View File

@@ -3,6 +3,7 @@ from typing import Dict, Type
from langchain.llms.ai21 import AI21
from langchain.llms.aleph_alpha import AlephAlpha
from langchain.llms.amazon_api_gateway import AmazonAPIGateway
from langchain.llms.anthropic import Anthropic
from langchain.llms.anyscale import Anyscale
from langchain.llms.aviary import Aviary
@@ -53,6 +54,7 @@ from langchain.llms.writer import Writer
__all__ = [
"AI21",
"AlephAlpha",
"AmazonAPIGateway",
"Anthropic",
"Anyscale",
"Aviary",
@@ -106,6 +108,8 @@ __all__ = [
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"ai21": AI21,
"aleph_alpha": AlephAlpha,
"amazon_api_gateway": AmazonAPIGateway,
"amazon_bedrock": Bedrock,
"anthropic": Anthropic,
"anyscale": Anyscale,
"aviary": Aviary,

View File

@@ -0,0 +1,98 @@
from typing import Any, Dict, List, Mapping, Optional
import requests
from pydantic import Extra
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
class ContentHandlerAmazonAPIGateway:
"""Adapter class to prepare the inputs from Langchain to a format
that LLM model expects. Also, provides helper function to extract
the generated text from the model response."""
@classmethod
def transform_input(
cls, prompt: str, model_kwargs: Dict[str, Any]
) -> Dict[str, Any]:
return {"inputs": prompt, "parameters": model_kwargs}
@classmethod
def transform_output(cls, response: Any) -> str:
return response.json()[0]["generated_text"]
class AmazonAPIGateway(LLM):
"""Wrapper around custom Amazon API Gateway"""
api_url: str
"""API Gateway URL"""
model_kwargs: Optional[Dict] = None
"""Key word arguments to pass to the model."""
content_handler: ContentHandlerAmazonAPIGateway = ContentHandlerAmazonAPIGateway()
"""The content handler class that provides an input and
output transform functions to handle formats between LLM
and the endpoint.
"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{"endpoint_name": self.api_url},
**{"model_kwargs": _model_kwargs},
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "amazon_api_gateway"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Amazon API Gateway model.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = se("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
payload = self.content_handler.transform_input(prompt, _model_kwargs)
try:
response = requests.post(
self.api_url,
json=payload,
)
text = self.content_handler.transform_output(response)
except Exception as error:
raise ValueError(f"Error raised by the service: {error}")
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text

View File

@@ -1,8 +1,10 @@
"""Wrapper around Aviary"""
from typing import Any, Dict, List, Mapping, Optional
import dataclasses
import os
from typing import Any, Dict, List, Mapping, Optional, Union, cast
import requests
from pydantic import Extra, Field, root_validator
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
@@ -12,6 +14,68 @@ from langchain.utils import get_from_dict_or_env
TIMEOUT = 60
@dataclasses.dataclass
class AviaryBackend:
backend_url: str
bearer: str
def __post_init__(self) -> None:
self.header = {"Authorization": self.bearer}
@classmethod
def from_env(cls) -> "AviaryBackend":
aviary_url = os.getenv("AVIARY_URL")
assert aviary_url, "AVIARY_URL must be set"
aviary_token = os.getenv("AVIARY_TOKEN", "")
bearer = f"Bearer {aviary_token}" if aviary_token else ""
aviary_url += "/" if not aviary_url.endswith("/") else ""
return cls(aviary_url, bearer)
def get_models() -> List[str]:
"""List available models"""
backend = AviaryBackend.from_env()
request_url = backend.backend_url + "-/routes"
response = requests.get(request_url, headers=backend.header, timeout=TIMEOUT)
try:
result = response.json()
except requests.JSONDecodeError as e:
raise RuntimeError(
f"Error decoding JSON from {request_url}. Text response: {response.text}"
) from e
result = sorted(
[k.lstrip("/").replace("--", "/") for k in result.keys() if "--" in k]
)
return result
def get_completions(
model: str,
prompt: str,
use_prompt_format: bool = True,
version: str = "",
) -> Dict[str, Union[str, float, int]]:
"""Get completions from Aviary models."""
backend = AviaryBackend.from_env()
url = backend.backend_url + model.replace("/", "--") + "/" + version + "query"
response = requests.post(
url,
headers=backend.header,
json={"prompt": prompt, "use_prompt_format": use_prompt_format},
timeout=TIMEOUT,
)
try:
return response.json()
except requests.JSONDecodeError as e:
raise RuntimeError(
f"Error decoding JSON from {url}. Text response: {response.text}"
) from e
class Aviary(LLM):
"""Allow you to use an Aviary.
@@ -19,33 +83,30 @@ class Aviary(LLM):
find out more about aviary at
http://github.com/ray-project/aviary
Has no dependencies, since it connects to backend
directly.
To get a list of the models supported on an
aviary, follow the instructions on the web site to
install the aviary CLI and then use:
`aviary models`
You must at least specify the environment
variable or parameter AVIARY_URL.
You may optionally specify the environment variable
or parameter AVIARY_TOKEN.
AVIARY_URL and AVIARY_TOKEN environement variables must be set.
Example:
.. code-block:: python
from langchain.llms import Aviary
light = Aviary(aviary_url='AVIARY_URL',
model='amazon/LightGPT')
result = light.predict('How do you make fried rice?')
os.environ["AVIARY_URL"] = "<URL>"
os.environ["AVIARY_TOKEN"] = "<TOKEN>"
light = Aviary(model='amazon/LightGPT')
output = light('How do you make fried rice?')
"""
model: str
aviary_url: str
aviary_token: str = Field("", exclude=True)
model: str = "amazon/LightGPT"
aviary_url: Optional[str] = None
aviary_token: Optional[str] = None
# If True the prompt template for the model will be ignored.
use_prompt_format: bool = True
# API version to use for Aviary
version: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
@@ -56,49 +117,35 @@ class Aviary(LLM):
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
aviary_url = get_from_dict_or_env(values, "aviary_url", "AVIARY_URL")
if not aviary_url.endswith("/"):
aviary_url += "/"
values["aviary_url"] = aviary_url
aviary_token = get_from_dict_or_env(
values, "aviary_token", "AVIARY_TOKEN", default=""
)
values["aviary_token"] = aviary_token
aviary_token = get_from_dict_or_env(values, "aviary_token", "AVIARY_TOKEN")
# Set env viarables for aviary sdk
os.environ["AVIARY_URL"] = aviary_url
os.environ["AVIARY_TOKEN"] = aviary_token
aviary_endpoint = aviary_url + "models"
headers = {"Authorization": f"Bearer {aviary_token}"} if aviary_token else {}
try:
response = requests.get(aviary_endpoint, headers=headers)
result = response.json()
# Confirm model is available
if values["model"] not in result:
raise ValueError(
f"{aviary_url} does not support model {values['model']}."
)
aviary_models = get_models()
except requests.exceptions.RequestException as e:
raise ValueError(e)
model = values.get("model")
if model and model not in aviary_models:
raise ValueError(f"{aviary_url} does not support model {values['model']}.")
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"model_name": self.model,
"aviary_url": self.aviary_url,
"aviary_token": self.aviary_token,
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "aviary"
@property
def headers(self) -> Dict[str, str]:
if self.aviary_token:
return {"Authorization": f"Bearer {self.aviary_token}"}
else:
return {}
return f"aviary-{self.model.replace('/', '-')}"
def _call(
self,
@@ -119,19 +166,18 @@ class Aviary(LLM):
response = aviary("Tell me a joke.")
"""
url = self.aviary_url + "query/" + self.model.replace("/", "--")
response = requests.post(
url,
headers=self.headers,
json={"prompt": prompt},
timeout=TIMEOUT,
kwargs = {"use_prompt_format": self.use_prompt_format}
if self.version:
kwargs["version"] = self.version
output = get_completions(
model=self.model,
prompt=prompt,
**kwargs,
)
try:
text = response.json()[self.model]["generated_text"]
except requests.JSONDecodeError as e:
raise ValueError(
f"Error decoding JSON from {url}. Text response: {response.text}",
) from e
text = cast(str, output["generated_text"])
if stop:
text = enforce_stop_tokens(text, stop)
return text

View File

@@ -1,4 +1,5 @@
"""Base interface for large language models to expose."""
import asyncio
import inspect
import json
import warnings
@@ -151,6 +152,39 @@ class BaseLLM(BaseLanguageModel, ABC):
prompt_strings, stop=stop, callbacks=callbacks, **kwargs
)
def _generate_helper(
self,
prompts: List[str],
stop: Optional[List[str]],
run_managers: List[CallbackManagerForLLMRun],
new_arg_supported: bool,
**kwargs: Any,
) -> LLMResult:
try:
output = (
self._generate(
prompts,
stop=stop,
# TODO: support multiple run managers
run_manager=run_managers[0] if run_managers else None,
**kwargs,
)
if new_arg_supported
else self._generate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
for run_manager in run_managers:
run_manager.on_llm_error(e)
raise e
flattened_outputs = output.flatten()
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
def generate(
self,
prompts: List[str],
@@ -161,8 +195,6 @@ class BaseLLM(BaseLanguageModel, ABC):
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
# If string is passed in directly no errors will be raised but outputs will
# not make sense.
if not isinstance(prompts, list):
raise ValueError(
"Argument 'prompts' is expected to be of type List[str], received"
@@ -185,60 +217,77 @@ class BaseLLM(BaseLanguageModel, ABC):
"run_manager"
)
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
run_manager = callback_manager.on_llm_start(
run_managers = callback_manager.on_llm_start(
dumpd(self), prompts, invocation_params=params, options=options
)
try:
output = (
self._generate(
prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported
else self._generate(prompts, stop=stop, **kwargs)
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
raise e
run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
output = self._generate_helper(
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
return output
if len(missing_prompts) > 0:
run_manager = callback_manager.on_llm_start(
dumpd(self),
missing_prompts,
invocation_params=params,
options=options,
run_managers = callback_manager.on_llm_start(
dumpd(self), missing_prompts, invocation_params=params, options=options
)
new_results = self._generate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
try:
new_results = (
self._generate(
missing_prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported
else self._generate(missing_prompts, stop=stop, **kwargs)
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
raise e
run_manager.on_llm_end(new_results)
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = None
if run_manager:
run_info = RunInfo(run_id=run_manager.run_id)
run_info = (
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers]
if run_managers
else None
)
else:
llm_output = {}
run_info = None
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
async def _agenerate_helper(
self,
prompts: List[str],
stop: Optional[List[str]],
run_managers: List[AsyncCallbackManagerForLLMRun],
new_arg_supported: bool,
**kwargs: Any,
) -> LLMResult:
try:
output = (
await self._agenerate(
prompts,
stop=stop,
run_manager=run_managers[0] if run_managers else None,
**kwargs,
)
if new_arg_supported
else await self._agenerate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
await asyncio.gather(
*[run_manager.on_llm_error(e) for run_manager in run_managers]
)
raise e
flattened_outputs = output.flatten()
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
)
]
)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
async def agenerate(
self,
prompts: List[str],
@@ -266,54 +315,32 @@ class BaseLLM(BaseLanguageModel, ABC):
"run_manager"
)
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
run_manager = await callback_manager.on_llm_start(
run_managers = await callback_manager.on_llm_start(
dumpd(self), prompts, invocation_params=params, options=options
)
try:
output = (
await self._agenerate(
prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported
else await self._agenerate(prompts, stop=stop, **kwargs)
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e, verbose=self.verbose)
raise e
await run_manager.on_llm_end(output, verbose=self.verbose)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
output = await self._agenerate_helper(
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
return output
if len(missing_prompts) > 0:
run_manager = await callback_manager.on_llm_start(
dumpd(self),
missing_prompts,
invocation_params=params,
options=options,
run_managers = await callback_manager.on_llm_start(
dumpd(self), missing_prompts, invocation_params=params, options=options
)
new_results = await self._agenerate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
try:
new_results = (
await self._agenerate(
missing_prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported
else await self._agenerate(missing_prompts, stop=stop, **kwargs)
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
raise e
await run_manager.on_llm_end(new_results)
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = None
if run_manager:
run_info = RunInfo(run_id=run_manager.run_id)
run_info = (
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers]
if run_managers
else None
)
else:
llm_output = {}
run_info = None

View File

@@ -171,6 +171,16 @@ class BaseOpenAI(BaseLLM):
"""Set of special tokens that are allowed。"""
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
"""Set of special tokens that are not allowed。"""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain
them to be under a certain limit. By default, when set to None, this will
be the same as the embedding model name. However, there are some cases
where you may want to use this Embedding class with a model name not
supported by tiktoken. This can include when using Azure embeddings or
when using one of the many model providers that expose an OpenAI-like
API but with different models. In those cases, in order to avoid erroring
when tiktoken is called, you can specify a model name to use here."""
def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
"""Initialize the OpenAI object."""
@@ -491,7 +501,13 @@ class BaseOpenAI(BaseLLM):
"Please install it with `pip install tiktoken`."
)
enc = tiktoken.encoding_for_model(self.model_name)
model_name = self.tiktoken_model_name or self.model_name
try:
enc = tiktoken.encoding_for_model(model_name)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
enc = tiktoken.get_encoding(model)
return enc.encode(
text,

View File

@@ -5,6 +5,8 @@ from langchain.load.serializable import Serializable, to_json_not_implemented
def default(obj: Any) -> Any:
"""Return a default value for a Serializable object or
a SerializedNotImplemented object."""
if isinstance(obj, Serializable):
return obj.to_json()
else:
@@ -12,6 +14,7 @@ def default(obj: Any) -> Any:
def dumps(obj: Any, *, pretty: bool = False) -> str:
"""Return a json string representation of an object."""
if pretty:
return json.dumps(obj, default=default, indent=2)
else:
@@ -19,4 +22,5 @@ def dumps(obj: Any, *, pretty: bool = False) -> str:
def dumpd(obj: Any) -> Dict[str, Any]:
"""Return a json dict representation of an object."""
return json.loads(dumps(obj))

View File

@@ -5,24 +5,34 @@ from pydantic import BaseModel, PrivateAttr
class BaseSerialized(TypedDict):
"""Base class for serialized objects."""
lc: int
id: List[str]
class SerializedConstructor(BaseSerialized):
"""Serialized constructor."""
type: Literal["constructor"]
kwargs: Dict[str, Any]
class SerializedSecret(BaseSerialized):
"""Serialized secret."""
type: Literal["secret"]
class SerializedNotImplemented(BaseSerialized):
"""Serialized not implemented."""
type: Literal["not_implemented"]
class Serializable(BaseModel, ABC):
"""Serializable base class."""
@property
def lc_serializable(self) -> bool:
"""
@@ -130,6 +140,14 @@ def _replace_secrets(
def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
"""Serialize a "not implemented" object.
Args:
obj: object to serialize
Returns:
SerializedNotImplemented
"""
_id: List[str] = []
try:
if hasattr(obj, "__name__"):

View File

@@ -15,6 +15,8 @@ DEFAULT_CONNECTION_STRING = "postgresql://postgres:mypassword@localhost/chat_his
class PostgresChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in a Postgres database."""
def __init__(
self,
session_id: str,

View File

@@ -13,6 +13,8 @@ logger = logging.getLogger(__name__)
class RedisChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in a Redis database."""
def __init__(
self,
session_id: str,

View File

@@ -21,6 +21,17 @@ logger = logging.getLogger(__name__)
def create_message_model(table_name, DynamicBase): # type: ignore
"""
Create a message model for a given table name.
Args:
table_name: The name of the table to use.
DynamicBase: The base class to use for the model.
Returns:
The model class.
"""
# Model decleared inside a function to have a dynamic table name
class Message(DynamicBase):
__tablename__ = table_name
@@ -32,6 +43,8 @@ def create_message_model(table_name, DynamicBase): # type: ignore
class SQLChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in an SQL database."""
def __init__(
self,
session_id: str,

View File

@@ -86,3 +86,7 @@ class MotorheadMemory(BaseChatMemory):
headers=self.__get_headers(),
)
super().save_context(inputs, outputs)
def delete_session(self) -> None:
"""Delete a session"""
requests.delete(f"{self.url}/sessions/{self.session_id}/memory")

View File

@@ -4,6 +4,16 @@ from langchain.schema import get_buffer_string # noqa: 401
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
"""
Get the prompt input key.
Args:
inputs: Dict[str, Any]
memory_variables: List[str]
Returns:
A prompt input key.
"""
# "stop" is a special key that can be passed as input but is not used to
# format the prompt.
prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"]))

View File

@@ -8,6 +8,15 @@ from langchain.schema import OutputParserException
def parse_json_markdown(json_string: str) -> dict:
"""
Parse a JSON string from a Markdown string.
Args:
json_string: The Markdown string.
Returns:
The parsed JSON object as a Python dictionary.
"""
# Try to find JSON string within triple backticks
match = re.search(r"```(json)?(.*?)```", json_string, re.DOTALL)
@@ -28,6 +37,17 @@ def parse_json_markdown(json_string: str) -> dict:
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
"""
Parse a JSON string from a Markdown string and check that it
contains the expected keys.
Args:
text: The Markdown string.
expected_keys: The expected keys in the JSON string.
Returns:
The parsed JSON object as a Python dictionary.
"""
try:
json_obj = parse_json_markdown(text)
except json.JSONDecodeError as e:

View File

@@ -28,6 +28,14 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
def validate_jinja2(template: str, input_variables: List[str]) -> None:
"""
Validate that the input variables are valid for the template.
Raise an exception if missing or extra variables are found.
Args:
template: The template string.
input_variables: The input variables.
"""
input_variables_set = set(input_variables)
valid_variables = _get_jinja2_variables_from_template(template)
missing_variables = valid_variables - input_variables_set

View File

@@ -1,11 +1,11 @@
from langchain.retrievers.arxiv import ArxivRetriever
from langchain.retrievers.aws_kendra_index_retriever import AwsKendraIndexRetriever
from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.retrievers.databerry import DataberryRetriever
from langchain.retrievers.docarray import DocArrayRetriever
from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever
from langchain.retrievers.kendra import AmazonKendraRetriever
from langchain.retrievers.knn import KNNRetriever
from langchain.retrievers.llama_index import (
LlamaIndexGraphRetriever,
@@ -30,8 +30,8 @@ from langchain.retrievers.zep import ZepRetriever
from langchain.retrievers.zilliz import ZillizRetriever
__all__ = [
"AmazonKendraRetriever",
"ArxivRetriever",
"AwsKendraIndexRetriever",
"AzureCognitiveSearchRetriever",
"ChatGPTPluginRetriever",
"ContextualCompressionRetriever",

View File

@@ -1,95 +0,0 @@
"""Retriever wrapper for AWS Kendra."""
import re
from typing import Any, Dict, List
from langchain.schema import BaseRetriever, Document
class AwsKendraIndexRetriever(BaseRetriever):
"""Wrapper around AWS Kendra."""
kendraindex: str
"""Kendra index id"""
k: int
"""Number of documents to query for."""
languagecode: str
"""Languagecode used for querying."""
kclient: Any
""" boto3 client for Kendra. """
def __init__(
self, kclient: Any, kendraindex: str, k: int = 3, languagecode: str = "en"
):
self.kendraindex = kendraindex
self.k = k
self.languagecode = languagecode
self.kclient = kclient
def _clean_result(self, res_text: str) -> str:
return re.sub("\s+", " ", res_text).replace("...", "")
def _get_top_n_results(self, resp: Dict, count: int) -> Document:
r = resp["ResultItems"][count]
doc_title = r["DocumentTitle"]["Text"]
doc_uri = r["DocumentURI"]
r_type = r["Type"]
if (
r["AdditionalAttributes"]
and r["AdditionalAttributes"][0]["Key"] == "AnswerText"
):
res_text = r["AdditionalAttributes"][0]["Value"]["TextWithHighlightsValue"][
"Text"
]
else:
res_text = r["DocumentExcerpt"]["Text"]
doc_excerpt = self._clean_result(res_text)
combined_text = f"""Document Title: {doc_title}
Document Excerpt: {doc_excerpt}
"""
return Document(
page_content=combined_text,
metadata={
"source": doc_uri,
"title": doc_title,
"excerpt": doc_excerpt,
"type": r_type,
},
)
def _kendra_query(self, kquery: str) -> List[Document]:
response = self.kclient.query(
IndexId=self.kendraindex,
QueryText=kquery.strip(),
AttributeFilter={
"AndAllFilters": [
{
"EqualsTo": {
"Key": "_language_code",
"Value": {
"StringValue": self.languagecode,
},
}
}
]
},
)
if len(response["ResultItems"]) > self.k:
r_count = self.k
else:
r_count = len(response["ResultItems"])
return [self._get_top_n_results(response, i) for i in range(0, r_count)]
def get_relevant_documents(self, query: str) -> List[Document]:
"""Run search on Kendra index and get top k documents
docs = get_relevant_documents('This is my query')
"""
return self._kendra_query(query)
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError("AwsKendraIndexRetriever does not support async")

View File

@@ -7,6 +7,8 @@ from langchain.schema import BaseRetriever, Document
class DataberryRetriever(BaseRetriever):
"""Retriever that uses the Databerry API."""
datastore_url: str
top_k: Optional[int]
api_key: Optional[str]

View File

@@ -10,6 +10,8 @@ from langchain.vectorstores.utils import maximal_marginal_relevance
class SearchType(str, Enum):
"""Enumerator of the types of search to perform."""
similarity = "similarity"
mmr = "mmr"

View File

@@ -0,0 +1,272 @@
import re
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Extra
from langchain.docstore.document import Document
from langchain.schema import BaseRetriever
def clean_excerpt(excerpt: str) -> str:
if not excerpt:
return excerpt
res = re.sub("\s+", " ", excerpt).replace("...", "")
return res
def combined_text(title: str, excerpt: str) -> str:
if not title or not excerpt:
return ""
return f"Document Title: {title} \nDocument Excerpt: \n{excerpt}\n"
class Highlight(BaseModel, extra=Extra.allow):
BeginOffset: int
EndOffset: int
TopAnswer: Optional[bool]
Type: Optional[str]
class TextWithHighLights(BaseModel, extra=Extra.allow):
Text: str
Highlights: Optional[Any]
class AdditionalResultAttribute(BaseModel, extra=Extra.allow):
Key: str
ValueType: Literal["TEXT_WITH_HIGHLIGHTS_VALUE"]
Value: Optional[TextWithHighLights]
def get_value_text(self) -> str:
if not self.Value:
return ""
else:
return self.Value.Text
class QueryResultItem(BaseModel, extra=Extra.allow):
DocumentId: str
DocumentTitle: TextWithHighLights
DocumentURI: Optional[str]
FeedbackToken: Optional[str]
Format: Optional[str]
Id: Optional[str]
Type: Optional[str]
AdditionalAttributes: Optional[List[AdditionalResultAttribute]] = []
DocumentExcerpt: Optional[TextWithHighLights]
def get_attribute_value(self) -> str:
if not self.AdditionalAttributes:
return ""
if not self.AdditionalAttributes[0]:
return ""
else:
return self.AdditionalAttributes[0].get_value_text()
def get_excerpt(self) -> str:
if (
self.AdditionalAttributes
and self.AdditionalAttributes[0].Key == "AnswerText"
):
excerpt = self.get_attribute_value()
elif self.DocumentExcerpt:
excerpt = self.DocumentExcerpt.Text
else:
excerpt = ""
return clean_excerpt(excerpt)
def to_doc(self) -> Document:
title = self.DocumentTitle.Text
source = self.DocumentURI
excerpt = self.get_excerpt()
type = self.Type
page_content = combined_text(title, excerpt)
metadata = {"source": source, "title": title, "excerpt": excerpt, "type": type}
return Document(page_content=page_content, metadata=metadata)
class QueryResult(BaseModel, extra=Extra.allow):
ResultItems: List[QueryResultItem]
def get_top_k_docs(self, top_n: int) -> List[Document]:
items_len = len(self.ResultItems)
count = items_len if items_len < top_n else top_n
docs = [self.ResultItems[i].to_doc() for i in range(0, count)]
return docs
class DocumentAttributeValue(BaseModel, extra=Extra.allow):
DateValue: Optional[str]
LongValue: Optional[int]
StringListValue: Optional[List[str]]
StringValue: Optional[str]
class DocumentAttribute(BaseModel, extra=Extra.allow):
Key: str
Value: DocumentAttributeValue
class RetrieveResultItem(BaseModel, extra=Extra.allow):
Content: Optional[str]
DocumentAttributes: Optional[List[DocumentAttribute]] = []
DocumentId: Optional[str]
DocumentTitle: Optional[str]
DocumentURI: Optional[str]
Id: Optional[str]
def get_excerpt(self) -> str:
if not self.Content:
return ""
return clean_excerpt(self.Content)
def to_doc(self) -> Document:
title = self.DocumentTitle if self.DocumentTitle else ""
source = self.DocumentURI
excerpt = self.get_excerpt()
page_content = combined_text(title, excerpt)
metadata = {"source": source, "title": title, "excerpt": excerpt}
return Document(page_content=page_content, metadata=metadata)
class RetrieveResult(BaseModel, extra=Extra.allow):
QueryId: str
ResultItems: List[RetrieveResultItem]
def get_top_k_docs(self, top_n: int) -> List[Document]:
items_len = len(self.ResultItems)
count = items_len if items_len < top_n else top_n
docs = [self.ResultItems[i].to_doc() for i in range(0, count)]
return docs
class AmazonKendraRetriever(BaseRetriever):
"""Retriever class to query documents from Amazon Kendra Index.
Args:
index_id: Kendra index id
region_name: The aws region e.g., `us-west-2`.
Fallsback to AWS_DEFAULT_REGION env variable
or region specified in ~/.aws/config.
credentials_profile_name: The name of the profile in the ~/.aws/credentials
or ~/.aws/config files, which has either access keys or role information
specified. If not specified, the default credential profile or, if on an
EC2 instance, credentials from IMDS will be used.
top_k: No of results to return
attribute_filter: Additional filtering of results based on metadata
See: https://docs.aws.amazon.com/kendra/latest/APIReference
client: boto3 client for Kendra
Example:
.. code-block:: python
retriever = AmazonKendraRetriever(
index_id="c0806df7-e76b-4bce-9b5c-d5582f6b1a03"
)
"""
def __init__(
self,
index_id: str,
region_name: Optional[str] = None,
credentials_profile_name: Optional[str] = None,
top_k: int = 3,
attribute_filter: Optional[Dict] = None,
client: Optional[Any] = None,
):
self.index_id = index_id
self.top_k = top_k
self.attribute_filter = attribute_filter
if client is not None:
self.client = client
return
try:
import boto3
if credentials_profile_name is not None:
session = boto3.Session(profile_name=credentials_profile_name)
else:
# use default credentials
session = boto3.Session()
client_params = {}
if region_name is not None:
client_params["region_name"] = region_name
self.client = session.client("kendra", **client_params)
except ImportError:
raise ModuleNotFoundError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
def _kendra_query(
self,
query: str,
top_k: int,
attribute_filter: Optional[Dict] = None,
) -> List[Document]:
if attribute_filter is not None:
response = self.client.retrieve(
IndexId=self.index_id,
QueryText=query.strip(),
PageSize=top_k,
AttributeFilter=attribute_filter,
)
else:
response = self.client.retrieve(
IndexId=self.index_id, QueryText=query.strip(), PageSize=top_k
)
r_result = RetrieveResult.parse_obj(response)
result_len = len(r_result.ResultItems)
if result_len == 0:
# retrieve API returned 0 results, call query API
if attribute_filter is not None:
response = self.client.query(
IndexId=self.index_id,
QueryText=query.strip(),
PageSize=top_k,
AttributeFilter=attribute_filter,
)
else:
response = self.client.query(
IndexId=self.index_id, QueryText=query.strip(), PageSize=top_k
)
q_result = QueryResult.parse_obj(response)
docs = q_result.get_top_k_docs(top_k)
else:
docs = r_result.get_top_k_docs(top_k)
return docs
def get_relevant_documents(self, query: str) -> List[Document]:
"""Run search on Kendra index and get top k documents
Example:
.. code-block:: python
docs = retriever.get_relevant_documents('This is my query')
"""
docs = self._kendra_query(query, self.top_k, self.attribute_filter)
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError("Async version is not implemented for Kendra yet.")

View File

@@ -15,11 +15,23 @@ from langchain.schema import BaseRetriever, Document
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
"""
Create an index of embeddings for a list of contexts.
Args:
contexts: List of contexts to embed.
embeddings: Embeddings model to use.
Returns:
Index of embeddings.
"""
with concurrent.futures.ThreadPoolExecutor() as executor:
return np.array(list(executor.map(embeddings.embed_query, contexts)))
class KNNRetriever(BaseRetriever, BaseModel):
"""KNN Retriever."""
embeddings: Embeddings
index: Any
texts: List[str]

View File

@@ -4,6 +4,8 @@ from langchain.schema import BaseRetriever, Document
class MetalRetriever(BaseRetriever):
"""Retriever that uses the Metal API."""
def __init__(self, client: Any, params: Optional[dict] = None):
from metal_sdk.metal import Metal

View File

@@ -10,6 +10,8 @@ from langchain.vectorstores.milvus import Milvus
class MilvusRetriever(BaseRetriever):
"""Retriever that uses the Milvus API."""
def __init__(
self,
embedding_function: Embeddings,
@@ -45,6 +47,15 @@ class MilvusRetriever(BaseRetriever):
def MilvusRetreiver(*args: Any, **kwargs: Any) -> MilvusRetriever:
"""Deprecated MilvusRetreiver. Please use MilvusRetriever ('i' before 'e') instead.
Args:
*args:
**kwargs:
Returns:
MilvusRetriever
"""
warnings.warn(
"MilvusRetreiver will be deprecated in the future. "
"Please use MilvusRetriever ('i' before 'e') instead.",

Some files were not shown because too many files have changed in this diff Show More