mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-05 16:50:03 +00:00
Compare commits
29 Commits
fork/async
...
jacob/chat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f90e665413 | ||
|
|
fad076fa06 | ||
|
|
59ffccf27d | ||
|
|
ac85fca6f0 | ||
|
|
f29ad020a0 | ||
|
|
b67561890b | ||
|
|
b970bfe8da | ||
|
|
36e432672a | ||
|
|
38425c99d2 | ||
|
|
3c387bc12d | ||
|
|
481493dbce | ||
|
|
f01fb47597 | ||
|
|
508bde7f40 | ||
|
|
5e73603e8a | ||
|
|
3e87b67a3c | ||
|
|
c314137f5b | ||
|
|
27665e3546 | ||
|
|
5975bf39ec | ||
|
|
4915c3cd86 | ||
|
|
e86fd946c8 | ||
|
|
0bc397957b | ||
|
|
52ccae3fb1 | ||
|
|
570b4f8e66 | ||
|
|
4e189cd89a | ||
|
|
a936472512 | ||
|
|
6543e585a5 | ||
|
|
6a75ef74ca | ||
|
|
70ff54eace | ||
|
|
5b5115c408 |
2
.github/CONTRIBUTING.md
vendored
2
.github/CONTRIBUTING.md
vendored
@@ -13,7 +13,7 @@ There are many ways to contribute to LangChain. Here are some common ways people
|
||||
|
||||
- [**Documentation**](https://python.langchain.com/docs/contributing/documentation): Help improve our docs, including this one!
|
||||
- [**Code**](https://python.langchain.com/docs/contributing/code): Help us write code, fix bugs, or improve our infrastructure.
|
||||
- [**Integrations**](https://python.langchain.com/docs/contributing/integration): Help us integrate with your favorite vendors and tools.
|
||||
- [**Integrations**](https://python.langchain.com/docs/contributing/integrations): Help us integrate with your favorite vendors and tools.
|
||||
|
||||
### 🚩GitHub Issues
|
||||
|
||||
|
||||
13
.github/workflows/langchain_cli_release.yml
vendored
13
.github/workflows/langchain_cli_release.yml
vendored
@@ -1,13 +0,0 @@
|
||||
---
|
||||
name: libs/cli Release
|
||||
|
||||
on:
|
||||
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
|
||||
|
||||
jobs:
|
||||
release:
|
||||
uses:
|
||||
./.github/workflows/_release.yml
|
||||
with:
|
||||
working-directory: libs/cli
|
||||
secrets: inherit
|
||||
@@ -1,13 +0,0 @@
|
||||
---
|
||||
name: libs/community Release
|
||||
|
||||
on:
|
||||
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
|
||||
|
||||
jobs:
|
||||
release:
|
||||
uses:
|
||||
./.github/workflows/_release.yml
|
||||
with:
|
||||
working-directory: libs/community
|
||||
secrets: inherit
|
||||
13
.github/workflows/langchain_core_release.yml
vendored
13
.github/workflows/langchain_core_release.yml
vendored
@@ -1,13 +0,0 @@
|
||||
---
|
||||
name: libs/core Release
|
||||
|
||||
on:
|
||||
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
|
||||
|
||||
jobs:
|
||||
release:
|
||||
uses:
|
||||
./.github/workflows/_release.yml
|
||||
with:
|
||||
working-directory: libs/core
|
||||
secrets: inherit
|
||||
@@ -1,13 +0,0 @@
|
||||
---
|
||||
name: libs/experimental Release
|
||||
|
||||
on:
|
||||
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
|
||||
|
||||
jobs:
|
||||
release:
|
||||
uses:
|
||||
./.github/workflows/_release.yml
|
||||
with:
|
||||
working-directory: libs/experimental
|
||||
secrets: inherit
|
||||
@@ -1,13 +0,0 @@
|
||||
---
|
||||
name: Experimental Test Release
|
||||
|
||||
on:
|
||||
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
|
||||
|
||||
jobs:
|
||||
release:
|
||||
uses:
|
||||
./.github/workflows/_test_release.yml
|
||||
with:
|
||||
working-directory: libs/experimental
|
||||
secrets: inherit
|
||||
13
.github/workflows/langchain_openai_release.yml
vendored
13
.github/workflows/langchain_openai_release.yml
vendored
@@ -1,13 +0,0 @@
|
||||
---
|
||||
name: libs/core Release
|
||||
|
||||
on:
|
||||
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
|
||||
|
||||
jobs:
|
||||
release:
|
||||
uses:
|
||||
./.github/workflows/_release.yml
|
||||
with:
|
||||
working-directory: libs/core
|
||||
secrets: inherit
|
||||
27
.github/workflows/langchain_release.yml
vendored
27
.github/workflows/langchain_release.yml
vendored
@@ -1,27 +0,0 @@
|
||||
---
|
||||
name: libs/langchain Release
|
||||
|
||||
on:
|
||||
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
|
||||
|
||||
jobs:
|
||||
release:
|
||||
uses:
|
||||
./.github/workflows/_release.yml
|
||||
with:
|
||||
working-directory: libs/langchain
|
||||
secrets: inherit
|
||||
|
||||
# N.B.: It's possible that PyPI doesn't make the new release visible / available
|
||||
# immediately after publishing. If that happens, the docker build might not
|
||||
# create a new docker image for the new release, since it won't see it.
|
||||
#
|
||||
# If this ends up being a problem, add a check to the end of the `_release.yml`
|
||||
# workflow that prevents the workflow from finishing until the new release
|
||||
# is visible and installable on PyPI.
|
||||
release-docker:
|
||||
needs:
|
||||
- release
|
||||
uses:
|
||||
./.github/workflows/langchain_release_docker.yml
|
||||
secrets: inherit
|
||||
13
.github/workflows/langchain_test_release.yml
vendored
13
.github/workflows/langchain_test_release.yml
vendored
@@ -1,13 +0,0 @@
|
||||
---
|
||||
name: Test Release
|
||||
|
||||
on:
|
||||
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
|
||||
|
||||
jobs:
|
||||
release:
|
||||
uses:
|
||||
./.github/workflows/_test_release.yml
|
||||
with:
|
||||
working-directory: libs/langchain
|
||||
secrets: inherit
|
||||
13
docs/docs/integrations/providers/baichuan.mdx
Normal file
13
docs/docs/integrations/providers/baichuan.mdx
Normal file
@@ -0,0 +1,13 @@
|
||||
# Baichuan
|
||||
|
||||
>[Baichuan Inc.](https://www.baichuan-ai.com/) is a Chinese startup in the era of AGI, dedicated to addressing fundamental human needs: Efficiency, Health, and Happiness.
|
||||
|
||||
## Visit Us
|
||||
Visit us at https://www.baichuan-ai.com/.
|
||||
Register and get an API key if you are trying out our APIs.
|
||||
|
||||
## Baichuan Chat Model
|
||||
An example is available at [example](/docs/integrations/chat/baichuan).
|
||||
|
||||
## Baichuan Text Embedding Model
|
||||
An example is available at [example] (/docs/integrations/text_embedding/baichuan)
|
||||
@@ -1,45 +1,52 @@
|
||||
# DeepInfra
|
||||
|
||||
This page covers how to use the DeepInfra ecosystem within LangChain.
|
||||
>[DeepInfra](https://deepinfra.com/docs) allows us to run the
|
||||
> [latest machine learning models](https://deepinfra.com/models) with ease.
|
||||
> DeepInfra takes care of all the heavy lifting related to running, scaling and monitoring
|
||||
> the models. Users can focus on your application and integrate the models with simple REST API calls.
|
||||
|
||||
>DeepInfra provides [examples](https://deepinfra.com/docs/advanced/langchain) of integration with LangChain.
|
||||
|
||||
This page covers how to use the `DeepInfra` ecosystem within `LangChain`.
|
||||
It is broken into two parts: installation and setup, and then references to specific DeepInfra wrappers.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
- Get your DeepInfra api key from this link [here](https://deepinfra.com/).
|
||||
- Get an DeepInfra api key and set it as an environment variable (`DEEPINFRA_API_TOKEN`)
|
||||
|
||||
## Available Models
|
||||
|
||||
DeepInfra provides a range of Open Source LLMs ready for deployment.
|
||||
You can list supported models for
|
||||
|
||||
You can see supported models for
|
||||
[text-generation](https://deepinfra.com/models?type=text-generation) and
|
||||
[embeddings](https://deepinfra.com/models?type=embeddings).
|
||||
google/flan\* models can be viewed [here](https://deepinfra.com/models?type=text2text-generation).
|
||||
|
||||
You can view a [list of request and response parameters](https://deepinfra.com/meta-llama/Llama-2-70b-chat-hf/api).
|
||||
|
||||
Chat models [follow openai api](https://deepinfra.com/meta-llama/Llama-2-70b-chat-hf/api?example=openai-http)
|
||||
|
||||
## Wrappers
|
||||
|
||||
### LLM
|
||||
## LLM
|
||||
|
||||
There exists an DeepInfra LLM wrapper, which you can access with
|
||||
See a [usage example](/docs/integrations/llms/deepinfra).
|
||||
|
||||
```python
|
||||
from langchain_community.llms import DeepInfra
|
||||
```
|
||||
|
||||
### Embeddings
|
||||
## Embeddings
|
||||
|
||||
There is also an DeepInfra Embeddings wrapper, you can access with
|
||||
See a [usage example](/docs/integrations/text_embedding/deepinfra).
|
||||
|
||||
```python
|
||||
from langchain_community.embeddings import DeepInfraEmbeddings
|
||||
```
|
||||
|
||||
### Chat Models
|
||||
## Chat Models
|
||||
|
||||
There is a chat-oriented wrapper as well, accessible with
|
||||
See a [usage example](/docs/integrations/chat/deepinfra).
|
||||
|
||||
```python
|
||||
from langchain_community.chat_models import ChatDeepInfra
|
||||
|
||||
75
docs/docs/integrations/text_embedding/baichuan.ipynb
Normal file
75
docs/docs/integrations/text_embedding/baichuan.ipynb
Normal file
@@ -0,0 +1,75 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Baichuan Text Embeddings\n",
|
||||
"\n",
|
||||
"As of today (Jan 25th, 2024) BaichuanTextEmbeddings ranks #1 in C-MTEB (Chinese Multi-Task Embedding Benchmark) leaderboard.\n",
|
||||
"\n",
|
||||
"Leaderboard (Under Overall -> Chinese section): https://huggingface.co/spaces/mteb/leaderboard\n",
|
||||
"\n",
|
||||
"Official Website: https://platform.baichuan-ai.com/docs/text-Embedding\n",
|
||||
"An API-key is required to use this embedding model. You can get one by registering at https://platform.baichuan-ai.com/docs/text-Embedding.\n",
|
||||
"BaichuanTextEmbeddings support 512 token window and preduces vectors with 1024 dimensions. \n",
|
||||
"\n",
|
||||
"Please NOTE that BaichuanTextEmbeddings only supports Chinese text embedding. Multi-language support is coming soon.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "plaintext"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.embeddings import BaichuanTextEmbeddings\n",
|
||||
"\n",
|
||||
"# Place your Baichuan API-key here.\n",
|
||||
"embeddings = BaichuanTextEmbeddings(baichuan_api_key=\"sk-*\")\n",
|
||||
"\n",
|
||||
"text_1 = \"今天天气不错\"\n",
|
||||
"text_2 = \"今天阳光很好\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "plaintext"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query_result = embeddings.embed_query(text_1)\n",
|
||||
"query_result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "plaintext"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"doc_result = embeddings.embed_documents([text_1, text_2])\n",
|
||||
"doc_result"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
160
docs/docs/integrations/tools/ionic.ipynb
Normal file
160
docs/docs/integrations/tools/ionic.ipynb
Normal file
@@ -0,0 +1,160 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Ionic\n",
|
||||
">[Ionic](https://www.ioniccommerce.com/) stands at the forefront of commerce innovation, offering a suite of APIs that serve as the backbone for AI assistants and their developers. With Ionic, you unlock a new realm of possibility where convenience and intelligence converge, enabling users to navigate purchases with unprecedented ease. Experience the synergy of human desire and AI capability, all through Ionic's seamless integration.\n",
|
||||
"\n",
|
||||
"By including an `IonicTool` in the list of tools provided to an Agent, you are effortlessly adding e-commerce capabilities to your Agent. For more documetation on setting up your Agent with Ionic, see the [Ionic documentation](https://docs.ioniccommerce.com/guides/langchain).\n",
|
||||
"\n",
|
||||
"This Jupyter Notebook demonstrates how to use the `Ionic` tool with an Agent.\n",
|
||||
"\n",
|
||||
"First, let's install the `ionic-langchain` package.\n",
|
||||
"**The `ionic-langchain` package is maintained by the Ionic team, not the LangChain maintainers.**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "shellscript"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pip install ionic-langchain > /dev/null"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, let's create an `IonicTool` instance and initialize an Agent with the tool."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-01-24T17:33:11.755683Z",
|
||||
"start_time": "2024-01-24T17:33:11.174044Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from ionic_langchain.tool import Ionic, IonicTool\n",
|
||||
"from langchain import hub\n",
|
||||
"from langchain.agents import AgentExecutor, Tool, create_react_agent\n",
|
||||
"from langchain_openai import OpenAI\n",
|
||||
"\n",
|
||||
"open_ai_key = os.environ[\"OPENAI_API_KEY\"]\n",
|
||||
"\n",
|
||||
"llm = OpenAI(openai_api_key=open_ai_key, temperature=0.5)\n",
|
||||
"\n",
|
||||
"tools: list[Tool] = [IonicTool().tool()]\n",
|
||||
"\n",
|
||||
"prompt = hub.pull(\"hwchase17/react\") # the example prompt for create_react_agent\n",
|
||||
"\n",
|
||||
"agent = create_react_agent(\n",
|
||||
" llm,\n",
|
||||
" tools,\n",
|
||||
" prompt=prompt,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, we can use the Agent to shop for products and get product information from Ionic."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-01-24T17:34:31.257036Z",
|
||||
"start_time": "2024-01-24T17:33:45.849440Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n",
|
||||
"\u001B[32;1m\u001B[1;3m Since the user is looking for a specific product, we should use Ionic Commerce Shopping Tool to find and compare products.\n",
|
||||
"Action: Ionic Commerce Shopping Tool\n",
|
||||
"Action Input: 4K Monitor, 5, 100000, 1000000\u001B[0m\u001B[36;1m\u001B[1;3m[{'products': [{'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://goto.walmart.com/c/123456/568844/9383?veh=aff&sourceid=imp_000011112222333344&u=https%3A%2F%2Fwww.walmart.com%2Fip%2F118806626'}], 'merchant_name': 'Walmart', 'merchant_product_id': '118806626', 'name': 'ASUS ProArt Display PA32UCX-PK 32” 4K HDR Mini LED Monitor, 99% DCI-P3 99.5% Adobe RGB, DeltaE<1, 10-bit, IPS, Thunderbolt 3 USB-C HDMI DP, Calman Ready, Dolby Vision, 1200nits, w/ X-rite Calibrator', 'price': '$2299.00', 'status': 'available', 'thumbnail': 'https://i5.walmartimages.com/asr/5ddc6e4a-5197-4f08-b505-83551b541de3.fd51cbae2a4d88fb366f5880b41eef03.png?odnHeight=100&odnWidth=100&odnBg=ffffff', 'brand_name': 'ASUS', 'upc': '192876749388'}, {'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://www.amazon.com/dp/B0BHXNL922?tag=ioniccommer00-20&linkCode=osi&th=1&psc=1'}], 'merchant_name': 'Amazon', 'merchant_product_id': 'B0BHXNL922', 'name': 'LG Ultrafine™ OLED Monitor (27EQ850) – 27 inch 4K UHD (3840 x 2160) OLED Pro Display with Adobe RGB 99%, DCI-P3 99%, 1M:1 Contrast Ratio, Hardware Calibration, Multi-Interface, USB Type-C™ (PD 90W)', 'price': '$1796.99', 'status': 'available', 'thumbnail': 'https://m.media-amazon.com/images/I/41VEl4V2U4L._SL160_.jpg', 'brand_name': 'LG', 'upc': None}, {'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://www.amazon.com/dp/B0BZR81SQG?tag=ioniccommer00-20&linkCode=osi&th=1&psc=1'}], 'merchant_name': 'Amazon', 'merchant_product_id': 'B0BZR81SQG', 'name': 'ASUS ROG Swift 38” 4K HDMI 2.1 HDR DSC Gaming Monitor (PG38UQ) - UHD (3840 x 2160), 144Hz, 1ms, Fast IPS, G-SYNC Compatible, Speakers, FreeSync Premium Pro, DisplayPort, DisplayHDR600, 98% DCI-P3', 'price': '$1001.42', 'status': 'available', 'thumbnail': 'https://m.media-amazon.com/images/I/41ULH0sb1zL._SL160_.jpg', 'brand_name': 'ASUS', 'upc': None}, {'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://www.amazon.com/dp/B0BBSV1LK5?tag=ioniccommer00-20&linkCode=osi&th=1&psc=1'}], 'merchant_name': 'Amazon', 'merchant_product_id': 'B0BBSV1LK5', 'name': 'ASUS ROG Swift 41.5\" 4K OLED 138Hz 0.1ms Gaming Monitor PG42UQ', 'price': '$1367.09', 'status': 'available', 'thumbnail': 'https://m.media-amazon.com/images/I/51ZM41brvHL._SL160_.jpg', 'brand_name': 'ASUS', 'upc': None}, {'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://www.amazon.com/dp/B07K8877Y5?tag=ioniccommer00-20&linkCode=osi&th=1&psc=1'}], 'merchant_name': 'Amazon', 'merchant_product_id': 'B07K8877Y5', 'name': 'LG 32UL950-W 32\" Class Ultrafine 4K UHD LED Monitor with Thunderbolt 3 Connectivity Silver (31.5\" Display)', 'price': '$1149.33', 'status': 'available', 'thumbnail': 'https://m.media-amazon.com/images/I/41Q2OE2NnDL._SL160_.jpg', 'brand_name': 'LG', 'upc': None}], 'query': {'query': '4K Monitor', 'max_price': 1000000, 'min_price': 100000, 'num_results': 5}}]\u001B[0m\u001B[32;1m\u001B[1;3m Since the results are in cents, we should convert them back to dollars before displaying the results to the user.\n",
|
||||
"Action: Convert prices to dollars\n",
|
||||
"Action Input: [{'products': [{'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://goto.walmart.com/c/123456/568844/9383?veh=aff&sourceid=imp_000011112222333344&u=https%3A%2F%2Fwww.walmart.com%2Fip%2F118806626'}], 'merchant_name': 'Walmart', 'merchant_product_id': '118806626', 'name': 'ASUS ProArt Display PA32UCX-PK 32” 4K HDR Mini LED Monitor, 99% DCI-P3 99.5% Adobe RGB, DeltaE<1, 10-bit, IPS, Thunderbolt 3 USB-C HDMI DP, Calman Ready, Dolby Vision, 1200nits, w/ X-rite Calibrator', 'price': '$2299.00', 'status': 'available', 'thumbnail': 'https://i5.walmartimages.com/asr/5ddc6e4a-5197\u001B[0mConvert prices to dollars is not a valid tool, try one of [Ionic Commerce Shopping Tool].\u001B[32;1m\u001B[1;3m The results are in a list format, we should display them to the user in a more readable format.\n",
|
||||
"Action: Display results in readable format\n",
|
||||
"Action Input: [{'products': [{'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://goto.walmart.com/c/123456/568844/9383?veh=aff&sourceid=imp_000011112222333344&u=https%3A%2F%2Fwww.walmart.com%2Fip%2F118806626'}], 'merchant_name': 'Walmart', 'merchant_product_id': '118806626', 'name': 'ASUS ProArt Display PA32UCX-PK 32” 4K HDR Mini LED Monitor, 99% DCI-P3 99.5% Adobe RGB, DeltaE<1, 10-bit, IPS, Thunderbolt 3 USB-C HDMI DP, Calman Ready, Dolby Vision, 1200nits, w/ X-rite Calibrator', 'price': '$2299.00', 'status': 'available', 'thumbnail': 'https://i5.walmartimages.com/asr/5ddc6e4a-5197\u001B[0mDisplay results in readable format is not a valid tool, try one of [Ionic Commerce Shopping Tool].\u001B[32;1m\u001B[1;3m We should check if the user is satisfied with the results or if they have additional requirements.\n",
|
||||
"Action: Check user satisfaction\n",
|
||||
"Action Input: None\u001B[0mCheck user satisfaction is not a valid tool, try one of [Ionic Commerce Shopping Tool].\u001B[32;1m\u001B[1;3m I now know the final answer\n",
|
||||
"Final Answer: The final answer is [{'products': [{'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://goto.walmart.com/c/123456/568844/9383?veh=aff&sourceid=imp_000011112222333344&u=https%3A%2F%2Fwww.walmart.com%2Fip%2F118806626'}], 'merchant_name': 'Walmart', 'merchant_product_id': '118806626', 'name': 'ASUS ProArt Display PA32UCX-PK 32” 4K HDR Mini LED Monitor, 99% DCI-P3 99.5% Adobe RGB, DeltaE<1, 10-bit, IPS, Thunderbolt 3 USB-C HDMI DP, Calman Ready, Dolby Vision, 1200nits, w/ X-rite Calibrator', 'price': '$2299.00', 'status': 'available', 'thumbnail': 'https://i5.walmartimages.com/asr/5ddc6e4a-5197-4f08-b505-83551b541de3.fd51cbae2\u001B[0m\n",
|
||||
"\n",
|
||||
"\u001B[1m> Finished chain.\u001B[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "{'input': \"I'm looking for a new 4K Monitor with 1000R under $1000\",\n 'output': \"The final answer is [{'products': [{'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://goto.walmart.com/c/123456/568844/9383?veh=aff&sourceid=imp_000011112222333344&u=https%3A%2F%2Fwww.walmart.com%2Fip%2F118806626'}], 'merchant_name': 'Walmart', 'merchant_product_id': '118806626', 'name': 'ASUS ProArt Display PA32UCX-PK 32” 4K HDR Mini LED Monitor, 99% DCI-P3 99.5% Adobe RGB, DeltaE<1, 10-bit, IPS, Thunderbolt 3 USB-C HDMI DP, Calman Ready, Dolby Vision, 1200nits, w/ X-rite Calibrator', 'price': '$2299.00', 'status': 'available', 'thumbnail': 'https://i5.walmartimages.com/asr/5ddc6e4a-5197-4f08-b505-83551b541de3.fd51cbae2\"}"
|
||||
},
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"input = \"I'm looking for a new 4K Monitor under $1000\"\n",
|
||||
"\n",
|
||||
"agent_executor.invoke({\"input\": input})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "f85209c3c4c190dca7367d6a1e623da50a9a4392fd53313a7cf9d4bda9c4b85b"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -167,9 +167,9 @@
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"URL = 'https://www.conseil-constitutionnel.fr/node/3850/pdf'\n",
|
||||
"PDF = 'Déclaration_des_droits_de_l_homme_et_du_citoyen.pdf'\n",
|
||||
"open(PDF, 'wb').write(requests.get(URL).content)"
|
||||
"URL = \"https://www.conseil-constitutionnel.fr/node/3850/pdf\"\n",
|
||||
"PDF = \"Déclaration_des_droits_de_l_homme_et_du_citoyen.pdf\"\n",
|
||||
"open(PDF, \"wb\").write(requests.get(URL).content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -208,7 +208,7 @@
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"print('Read a PDF...')\n",
|
||||
"print(\"Read a PDF...\")\n",
|
||||
"loader = PyPDFLoader(PDF)\n",
|
||||
"pages = loader.load_and_split()\n",
|
||||
"len(pages)"
|
||||
@@ -252,12 +252,14 @@
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"print('Create a Vector Database from PDF text...')\n",
|
||||
"embeddings = OpenAIEmbeddings(model='text-embedding-ada-002')\n",
|
||||
"print(\"Create a Vector Database from PDF text...\")\n",
|
||||
"embeddings = OpenAIEmbeddings(model=\"text-embedding-ada-002\")\n",
|
||||
"texts = [p.page_content for p in pages]\n",
|
||||
"metadata = pd.DataFrame(index=list(range(len(texts))))\n",
|
||||
"metadata['tag'] = 'law'\n",
|
||||
"metadata['title'] = 'Déclaration des Droits de l\\'Homme et du Citoyen de 1789'.encode('utf-8')\n",
|
||||
"metadata[\"tag\"] = \"law\"\n",
|
||||
"metadata[\"title\"] = \"Déclaration des Droits de l'Homme et du Citoyen de 1789\".encode(\n",
|
||||
" \"utf-8\"\n",
|
||||
")\n",
|
||||
"vectordb = KDBAI(table, embeddings)\n",
|
||||
"vectordb.add_texts(texts=texts, metadatas=metadata)"
|
||||
]
|
||||
@@ -288,11 +290,13 @@
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"print('Create LangChain Pipeline...')\n",
|
||||
"qabot = RetrievalQA.from_chain_type(chain_type='stuff',\n",
|
||||
" llm=ChatOpenAI(model='gpt-3.5-turbo-16k', temperature=TEMP), \n",
|
||||
" retriever=vectordb.as_retriever(search_kwargs=dict(k=K)),\n",
|
||||
" return_source_documents=True)"
|
||||
"print(\"Create LangChain Pipeline...\")\n",
|
||||
"qabot = RetrievalQA.from_chain_type(\n",
|
||||
" chain_type=\"stuff\",\n",
|
||||
" llm=ChatOpenAI(model=\"gpt-3.5-turbo-16k\", temperature=TEMP),\n",
|
||||
" retriever=vectordb.as_retriever(search_kwargs=dict(k=K)),\n",
|
||||
" return_source_documents=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -325,9 +329,9 @@
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"Q = 'Summarize the document in English:'\n",
|
||||
"print(f'\\n\\n{Q}\\n')\n",
|
||||
"print(qabot.invoke(dict(query=Q))['result'])"
|
||||
"Q = \"Summarize the document in English:\"\n",
|
||||
"print(f\"\\n\\n{Q}\\n\")\n",
|
||||
"print(qabot.invoke(dict(query=Q))[\"result\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -362,9 +366,9 @@
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"Q = 'Is it a fair law and why ?'\n",
|
||||
"print(f'\\n\\n{Q}\\n')\n",
|
||||
"print(qabot.invoke(dict(query=Q))['result'])"
|
||||
"Q = \"Is it a fair law and why ?\"\n",
|
||||
"print(f\"\\n\\n{Q}\\n\")\n",
|
||||
"print(qabot.invoke(dict(query=Q))[\"result\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -414,9 +418,9 @@
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"Q = 'What are the rights and duties of the man, the citizen and the society ?'\n",
|
||||
"print(f'\\n\\n{Q}\\n')\n",
|
||||
"print(qabot.invoke(dict(query=Q))['result'])"
|
||||
"Q = \"What are the rights and duties of the man, the citizen and the society ?\"\n",
|
||||
"print(f\"\\n\\n{Q}\\n\")\n",
|
||||
"print(qabot.invoke(dict(query=Q))[\"result\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -441,9 +445,9 @@
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"Q = 'Is this law practical ?'\n",
|
||||
"print(f'\\n\\n{Q}\\n')\n",
|
||||
"print(qabot.invoke(dict(query=Q))['result'])"
|
||||
"Q = \"Is this law practical ?\"\n",
|
||||
"print(f\"\\n\\n{Q}\\n\")\n",
|
||||
"print(qabot.invoke(dict(query=Q))[\"result\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
"\n",
|
||||
"Newer OpenAI models have been fine-tuned to detect when **one or more** function(s) should be called and respond with the inputs that should be passed to the function(s). In an API call, you can describe functions and have the model intelligently choose to output a JSON object containing arguments to call these functions. The goal of the OpenAI tools APIs is to more reliably return valid and useful function calls than what can be done using a generic text completion or chat API.\n",
|
||||
"\n",
|
||||
"OpenAI termed the capability to invoke a **single** function as **functions**, and the capability to invoke **one or more** funcitons as **tools**.\n",
|
||||
"OpenAI termed the capability to invoke a **single** function as **functions**, and the capability to invoke **one or more** functions as **tools**.\n",
|
||||
"\n",
|
||||
":::tip\n",
|
||||
"\n",
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
"\n",
|
||||
"* Use with regular LLMs, not with chat models.\n",
|
||||
"* Use only with unstructured tools; i.e., tools that accept a single string input.\n",
|
||||
"* See [AgentTypes](/docs/moduels/agents/agent_types/) documentation for more agent types.\n",
|
||||
"* See [AgentTypes](/docs/modules/agents/agent_types/) documentation for more agent types.\n",
|
||||
":::"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -430,7 +430,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": null,
|
||||
"id": "64650362",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -452,8 +452,8 @@
|
||||
"from langchain.prompts import (\n",
|
||||
" PromptTemplate,\n",
|
||||
")\n",
|
||||
"from langchain_core.pydantic_v1 import BaseModel, Field, validator\n",
|
||||
"from langchain_openai import OpenAI\n",
|
||||
"from pydantic import BaseModel, Field, validator\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Person(BaseModel):\n",
|
||||
@@ -531,8 +531,8 @@
|
||||
"from langchain.prompts import (\n",
|
||||
" PromptTemplate,\n",
|
||||
")\n",
|
||||
"from langchain_core.pydantic_v1 import BaseModel, Field, validator\n",
|
||||
"from langchain_openai import OpenAI\n",
|
||||
"from pydantic import BaseModel, Field, validator\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Define your desired data structure.\n",
|
||||
|
||||
@@ -94,15 +94,19 @@ def analyze_text(
|
||||
files serialized to HTML string.
|
||||
"""
|
||||
resp: Dict[str, Any] = {}
|
||||
textstat = import_textstat()
|
||||
spacy = import_spacy()
|
||||
text_complexity_metrics = {
|
||||
key: getattr(textstat, key)(text) for key in get_text_complexity_metrics()
|
||||
}
|
||||
resp.update({"text_complexity_metrics": text_complexity_metrics})
|
||||
resp.update(text_complexity_metrics)
|
||||
try:
|
||||
textstat = import_textstat()
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
text_complexity_metrics = {
|
||||
key: getattr(textstat, key)(text) for key in get_text_complexity_metrics()
|
||||
}
|
||||
resp.update({"text_complexity_metrics": text_complexity_metrics})
|
||||
resp.update(text_complexity_metrics)
|
||||
|
||||
if nlp is not None:
|
||||
spacy = import_spacy()
|
||||
doc = nlp(text)
|
||||
|
||||
dep_out = spacy.displacy.render( # type: ignore
|
||||
@@ -279,9 +283,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
import_pandas()
|
||||
import_textstat()
|
||||
import_mlflow()
|
||||
spacy = import_spacy()
|
||||
super().__init__()
|
||||
|
||||
self.name = name
|
||||
@@ -303,14 +305,19 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
)
|
||||
|
||||
self.action_records: list = []
|
||||
self.nlp = None
|
||||
try:
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
except OSError:
|
||||
logger.warning(
|
||||
"Run `python -m spacy download en_core_web_sm` "
|
||||
"to download en_core_web_sm model for text visualization."
|
||||
)
|
||||
self.nlp = None
|
||||
spacy = import_spacy()
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
except OSError:
|
||||
logger.warning(
|
||||
"Run `python -m spacy download en_core_web_sm` "
|
||||
"to download en_core_web_sm model for text visualization."
|
||||
)
|
||||
|
||||
self.metrics = {key: 0 for key in mlflow_callback_metrics()}
|
||||
|
||||
|
||||
@@ -142,9 +142,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
stream_resp = self.client.completions.create(**params, stream=True)
|
||||
for data in stream_resp:
|
||||
delta = data.completion
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(delta)
|
||||
run_manager.on_llm_new_token(delta, chunk=chunk)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
@@ -161,9 +162,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
stream_resp = await self.async_client.completions.create(**params, stream=True)
|
||||
async for data in stream_resp:
|
||||
delta = data.completion
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(delta)
|
||||
await run_manager.on_llm_new_token(delta, chunk=chunk)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
|
||||
@@ -59,6 +59,7 @@ from langchain_community.document_loaders.blob_loaders import (
|
||||
from langchain_community.document_loaders.blockchain import BlockchainDocumentLoader
|
||||
from langchain_community.document_loaders.brave_search import BraveSearchLoader
|
||||
from langchain_community.document_loaders.browserless import BrowserlessLoader
|
||||
from langchain_community.document_loaders.cassandra import CassandraLoader
|
||||
from langchain_community.document_loaders.chatgpt import ChatGPTLoader
|
||||
from langchain_community.document_loaders.chromium import AsyncChromiumLoader
|
||||
from langchain_community.document_loaders.college_confidential import (
|
||||
@@ -267,6 +268,7 @@ __all__ = [
|
||||
"BlockchainDocumentLoader",
|
||||
"BraveSearchLoader",
|
||||
"BrowserlessLoader",
|
||||
"CassandraLoader",
|
||||
"CSVLoader",
|
||||
"ChatGPTLoader",
|
||||
"CoNLLULoader",
|
||||
|
||||
@@ -2,12 +2,24 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
from queue import Queue
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrapy.db import AstraDB, AsyncAstraDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -19,7 +31,8 @@ class AstraDBLoader(BaseLoader):
|
||||
collection_name: str,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
|
||||
astra_db_client: Optional["AstraDB"] = None,
|
||||
async_astra_db_client: Optional["AsyncAstraDB"] = None,
|
||||
namespace: Optional[str] = None,
|
||||
filter_criteria: Optional[Dict[str, Any]] = None,
|
||||
projection: Optional[Dict[str, Any]] = None,
|
||||
@@ -36,34 +49,60 @@ class AstraDBLoader(BaseLoader):
|
||||
)
|
||||
|
||||
# Conflicting-arg checks:
|
||||
if astra_db_client is not None:
|
||||
if astra_db_client is not None or async_astra_db_client is not None:
|
||||
if token is not None or api_endpoint is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass 'astra_db_client' to AstraDB if passing "
|
||||
"'token' and 'api_endpoint'."
|
||||
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
|
||||
"AstraDB if passing 'token' and 'api_endpoint'."
|
||||
)
|
||||
|
||||
self.collection_name = collection_name
|
||||
self.filter = filter_criteria
|
||||
self.projection = projection
|
||||
self.find_options = find_options or {}
|
||||
self.nb_prefetched = nb_prefetched
|
||||
self.extraction_function = extraction_function
|
||||
|
||||
if astra_db_client is not None:
|
||||
astra_db = astra_db_client
|
||||
else:
|
||||
astra_db = astra_db_client
|
||||
async_astra_db = async_astra_db_client
|
||||
|
||||
if token and api_endpoint:
|
||||
astra_db = AstraDB(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
namespace=namespace,
|
||||
)
|
||||
self.collection = astra_db.collection(collection_name)
|
||||
try:
|
||||
from astrapy.db import AsyncAstraDB
|
||||
|
||||
async_astra_db = AsyncAstraDB(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
namespace=namespace,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pass
|
||||
if not astra_db and not async_astra_db:
|
||||
raise ValueError(
|
||||
"Must provide 'astra_db_client' or 'async_astra_db_client' or 'token' "
|
||||
"and 'api_endpoint'"
|
||||
)
|
||||
self.collection = astra_db.collection(collection_name) if astra_db else None
|
||||
if async_astra_db:
|
||||
from astrapy.db import AsyncAstraDBCollection
|
||||
|
||||
self.async_collection = AsyncAstraDBCollection(
|
||||
astra_db=async_astra_db, collection_name=collection_name
|
||||
)
|
||||
else:
|
||||
self.async_collection = None
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Eagerly load the content."""
|
||||
return list(self.lazy_load())
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
if not self.collection:
|
||||
raise ValueError("Missing AstraDB client")
|
||||
queue = Queue(self.nb_prefetched)
|
||||
t = threading.Thread(target=self.fetch_results, args=(queue,))
|
||||
t.start()
|
||||
@@ -74,6 +113,29 @@ class AstraDBLoader(BaseLoader):
|
||||
yield doc
|
||||
t.join()
|
||||
|
||||
async def aload(self) -> List[Document]:
|
||||
"""Load data into Document objects."""
|
||||
return [doc async for doc in self.alazy_load()]
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
if not self.async_collection:
|
||||
raise ValueError("Missing AsyncAstraDB client")
|
||||
async for doc in self.async_collection.paginated_find(
|
||||
filter=self.filter,
|
||||
options=self.find_options,
|
||||
projection=self.projection,
|
||||
sort=None,
|
||||
prefetched=True,
|
||||
):
|
||||
yield Document(
|
||||
page_content=self.extraction_function(doc),
|
||||
metadata={
|
||||
"namespace": self.async_collection.astra_db.namespace,
|
||||
"api_endpoint": self.async_collection.astra_db.base_url,
|
||||
"collection": self.collection_name,
|
||||
},
|
||||
)
|
||||
|
||||
def fetch_results(self, queue: Queue):
|
||||
self.fetch_page_result(queue)
|
||||
while self.find_options.get("pageState"):
|
||||
|
||||
@@ -2,10 +2,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Optional
|
||||
from typing import TYPE_CHECKING, Iterator, List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
from langchain_community.document_loaders.blob_loaders import Blob
|
||||
|
||||
@@ -53,22 +52,14 @@ class BaseLoader(ABC):
|
||||
|
||||
# Attention: This method will be upgraded into an abstractmethod once it's
|
||||
# implemented in all the existing subclasses.
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
def lazy_load(
|
||||
self,
|
||||
) -> Iterator[Document]:
|
||||
"""A lazy loader for Documents."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement lazy_load()"
|
||||
)
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
"""A lazy loader for Documents."""
|
||||
iterator = await run_in_executor(None, self.lazy_load)
|
||||
done = object()
|
||||
while True:
|
||||
doc = await run_in_executor(None, next, iterator, done)
|
||||
if doc is done:
|
||||
break
|
||||
yield doc
|
||||
|
||||
|
||||
class BaseBlobParser(ABC):
|
||||
"""Abstract interface for blob parsers.
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -14,13 +13,6 @@ from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
def default_page_content_mapper(row: Any) -> str:
|
||||
if hasattr(row, "_asdict"):
|
||||
return json.dumps(row._asdict())
|
||||
return json.dumps(row)
|
||||
|
||||
|
||||
_NOT_SET = object()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -36,7 +28,7 @@ class CassandraLoader(BaseLoader):
|
||||
session: Optional["Session"] = None,
|
||||
keyspace: Optional[str] = None,
|
||||
query: Optional[Union[str, "Statement"]] = None,
|
||||
page_content_mapper: Callable[[Any], str] = default_page_content_mapper,
|
||||
page_content_mapper: Callable[[Any], str] = str,
|
||||
metadata_mapper: Callable[[Any], dict] = lambda _: {},
|
||||
*,
|
||||
query_parameters: Union[dict, Sequence] = None,
|
||||
@@ -61,6 +53,7 @@ class CassandraLoader(BaseLoader):
|
||||
query: The query used to load the data.
|
||||
(do not use together with the table parameter)
|
||||
page_content_mapper: a function to convert a row to string page content.
|
||||
Defaults to the str representation of the row.
|
||||
query_parameters: The query parameters used when calling session.execute .
|
||||
query_timeout: The query timeout used when calling session.execute .
|
||||
query_custom_payload: The query custom_payload used when calling
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import AsyncIterator, Iterator, List
|
||||
from typing import Iterator, List
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@@ -26,9 +26,3 @@ class MergedDataLoader(BaseLoader):
|
||||
def load(self) -> List[Document]:
|
||||
"""Load docs."""
|
||||
return list(self.lazy_load())
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
"""Lazy load docs from each individual loader."""
|
||||
for loader in self.loaders:
|
||||
async for document in loader.alazy_load():
|
||||
yield document
|
||||
|
||||
@@ -132,6 +132,7 @@ class WebBaseLoader(BaseLoader):
|
||||
url,
|
||||
headers=self.session.headers,
|
||||
ssl=None if self.session.verify else False,
|
||||
cookies=self.session.cookies.get_dict(),
|
||||
) as response:
|
||||
return await response.text()
|
||||
except aiohttp.ClientConnectionError as e:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
@@ -139,6 +140,11 @@ def _parse_video_id(url: str) -> Optional[str]:
|
||||
return video_id
|
||||
|
||||
|
||||
class TranscriptFormat(Enum):
|
||||
TEXT = "text"
|
||||
LINES = "lines"
|
||||
|
||||
|
||||
class YoutubeLoader(BaseLoader):
|
||||
"""Load `YouTube` transcripts."""
|
||||
|
||||
@@ -148,6 +154,7 @@ class YoutubeLoader(BaseLoader):
|
||||
add_video_info: bool = False,
|
||||
language: Union[str, Sequence[str]] = "en",
|
||||
translation: Optional[str] = None,
|
||||
transcript_format: TranscriptFormat = TranscriptFormat.TEXT,
|
||||
continue_on_failure: bool = False,
|
||||
):
|
||||
"""Initialize with YouTube video ID."""
|
||||
@@ -159,6 +166,7 @@ class YoutubeLoader(BaseLoader):
|
||||
else:
|
||||
self.language = language
|
||||
self.translation = translation
|
||||
self.transcript_format = transcript_format
|
||||
self.continue_on_failure = continue_on_failure
|
||||
|
||||
@staticmethod
|
||||
@@ -214,9 +222,19 @@ class YoutubeLoader(BaseLoader):
|
||||
|
||||
transcript_pieces = transcript.fetch()
|
||||
|
||||
transcript = " ".join([t["text"].strip(" ") for t in transcript_pieces])
|
||||
|
||||
return [Document(page_content=transcript, metadata=metadata)]
|
||||
if self.transcript_format == TranscriptFormat.TEXT:
|
||||
transcript = " ".join([t["text"].strip(" ") for t in transcript_pieces])
|
||||
return [Document(page_content=transcript, metadata=metadata)]
|
||||
elif self.transcript_format == TranscriptFormat.LINES:
|
||||
return [
|
||||
Document(
|
||||
page_content=t["text"].strip(" "),
|
||||
metadata=dict((key, t[key]) for key in t if key != "text"),
|
||||
)
|
||||
for t in transcript_pieces
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unknown transcript format.")
|
||||
|
||||
def _get_video_info(self) -> dict:
|
||||
"""Get important video information.
|
||||
|
||||
@@ -20,6 +20,7 @@ from langchain_community.embeddings.aleph_alpha import (
|
||||
)
|
||||
from langchain_community.embeddings.awa import AwaEmbeddings
|
||||
from langchain_community.embeddings.azure_openai import AzureOpenAIEmbeddings
|
||||
from langchain_community.embeddings.baichuan import BaichuanTextEmbeddings
|
||||
from langchain_community.embeddings.baidu_qianfan_endpoint import (
|
||||
QianfanEmbeddingsEndpoint,
|
||||
)
|
||||
@@ -92,6 +93,7 @@ logger = logging.getLogger(__name__)
|
||||
__all__ = [
|
||||
"OpenAIEmbeddings",
|
||||
"AzureOpenAIEmbeddings",
|
||||
"BaichuanTextEmbeddings",
|
||||
"ClarifaiEmbeddings",
|
||||
"CohereEmbeddings",
|
||||
"DatabricksEmbeddings",
|
||||
|
||||
113
libs/community/langchain_community/embeddings/baichuan.py
Normal file
113
libs/community/langchain_community/embeddings/baichuan.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
BAICHUAN_API_URL: str = "http://api.baichuan-ai.com/v1/embeddings"
|
||||
|
||||
# BaichuanTextEmbeddings is an embedding model provided by Baichuan Inc. (https://www.baichuan-ai.com/home).
|
||||
# As of today (Jan 25th, 2024) BaichuanTextEmbeddings ranks #1 in C-MTEB
|
||||
# (Chinese Multi-Task Embedding Benchmark) leaderboard.
|
||||
# Leaderboard (Under Overall -> Chinese section): https://huggingface.co/spaces/mteb/leaderboard
|
||||
|
||||
# Official Website: https://platform.baichuan-ai.com/docs/text-Embedding
|
||||
# An API-key is required to use this embedding model. You can get one by registering
|
||||
# at https://platform.baichuan-ai.com/docs/text-Embedding.
|
||||
# BaichuanTextEmbeddings support 512 token window and preduces vectors with
|
||||
# 1024 dimensions.
|
||||
|
||||
|
||||
# NOTE!! BaichuanTextEmbeddings only supports Chinese text embedding.
|
||||
# Multi-language support is coming soon.
|
||||
class BaichuanTextEmbeddings(BaseModel, Embeddings):
|
||||
"""Baichuan Text Embedding models."""
|
||||
|
||||
session: Any #: :meta private:
|
||||
model_name: str = "Baichuan-Text-Embedding"
|
||||
baichuan_api_key: Optional[SecretStr] = None
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that auth token exists in environment."""
|
||||
try:
|
||||
baichuan_api_key = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "baichuan_api_key", "BAICHUAN_API_KEY")
|
||||
)
|
||||
except ValueError as original_exc:
|
||||
try:
|
||||
baichuan_api_key = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values, "baichuan_auth_token", "BAICHUAN_AUTH_TOKEN"
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
raise original_exc
|
||||
session = requests.Session()
|
||||
session.headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {baichuan_api_key.get_secret_value()}",
|
||||
"Accept-Encoding": "identity",
|
||||
"Content-type": "application/json",
|
||||
}
|
||||
)
|
||||
values["session"] = session
|
||||
return values
|
||||
|
||||
def _embed(self, texts: List[str]) -> Optional[List[List[float]]]:
|
||||
"""Internal method to call Baichuan Embedding API and return embeddings.
|
||||
|
||||
Args:
|
||||
texts: A list of texts to embed.
|
||||
|
||||
Returns:
|
||||
A list of list of floats representing the embeddings, or None if an
|
||||
error occurs.
|
||||
"""
|
||||
try:
|
||||
response = self.session.post(
|
||||
BAICHUAN_API_URL, json={"input": texts, "model": self.model_name}
|
||||
)
|
||||
# Check if the response status code indicates success
|
||||
if response.status_code == 200:
|
||||
resp = response.json()
|
||||
embeddings = resp.get("data", [])
|
||||
# Sort resulting embeddings by index
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e.get("index", 0))
|
||||
# Return just the embeddings
|
||||
return [result.get("embedding", []) for result in sorted_embeddings]
|
||||
else:
|
||||
# Log error or handle unsuccessful response appropriately
|
||||
print(
|
||||
f"""Error: Received status code {response.status_code} from
|
||||
embedding API"""
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
# Log the exception or handle it as needed
|
||||
print(f"Exception occurred while trying to get embeddings: {str(e)}")
|
||||
return None
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]:
|
||||
"""Public method to get embeddings for a list of documents.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
A list of embeddings, one for each text, or None if an error occurs.
|
||||
"""
|
||||
return self._embed(texts)
|
||||
|
||||
def embed_query(self, text: str) -> Optional[List[float]]:
|
||||
"""Public method to get embedding for a single query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text, or None if an error occurs.
|
||||
"""
|
||||
result = self._embed([text])
|
||||
return result[0] if result is not None else None
|
||||
@@ -71,9 +71,9 @@ class SelfHostedHuggingFaceEmbeddings(SelfHostedEmbeddings):
|
||||
|
||||
from langchain_community.embeddings import SelfHostedHuggingFaceEmbeddings
|
||||
import runhouse as rh
|
||||
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||
model_id = "sentence-transformers/all-mpnet-base-v2"
|
||||
gpu = rh.cluster(name="rh-a10x", instance_type="A100:1")
|
||||
hf = SelfHostedHuggingFaceEmbeddings(model_name=model_name, hardware=gpu)
|
||||
hf = SelfHostedHuggingFaceEmbeddings(model_id=model_id, hardware=gpu)
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
|
||||
@@ -64,6 +64,10 @@ class _OllamaCommon(BaseLanguageModel):
|
||||
It is recommended to set this value to the number of physical
|
||||
CPU cores your system has (as opposed to the logical number of cores)."""
|
||||
|
||||
num_predict: Optional[int] = None
|
||||
"""Maximum number of tokens to predict when generating text.
|
||||
(Default: 128, -1 = infinite generation, -2 = fill context)"""
|
||||
|
||||
repeat_last_n: Optional[int] = None
|
||||
"""Sets how far back for the model to look back to prevent
|
||||
repetition. (Default: 64, 0 = disabled, -1 = num_ctx)"""
|
||||
@@ -126,6 +130,7 @@ class _OllamaCommon(BaseLanguageModel):
|
||||
"num_ctx": self.num_ctx,
|
||||
"num_gpu": self.num_gpu,
|
||||
"num_thread": self.num_thread,
|
||||
"num_predict": self.num_predict,
|
||||
"repeat_last_n": self.repeat_last_n,
|
||||
"repeat_penalty": self.repeat_penalty,
|
||||
"temperature": self.temperature,
|
||||
@@ -279,7 +284,10 @@ class _OllamaCommon(BaseLanguageModel):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url=api_url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**(self.headers if isinstance(self.headers, dict) else {}),
|
||||
},
|
||||
json=request_payload,
|
||||
timeout=self.timeout,
|
||||
) as response:
|
||||
|
||||
@@ -109,6 +109,12 @@ class Bagel(VectorStore):
|
||||
import bagel # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError("Please install bagel `pip install betabageldb`.")
|
||||
|
||||
if self._embedding_function and query_embeddings is None and query_texts:
|
||||
texts = list(query_texts)
|
||||
query_embeddings = self._embedding_function.embed_documents(texts)
|
||||
query_texts = None
|
||||
|
||||
return self._cluster.find(
|
||||
query_texts=query_texts,
|
||||
query_embeddings=query_embeddings,
|
||||
|
||||
@@ -13,11 +13,18 @@ Required to run this test:
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.document_loaders.astradb import AstraDBLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrapy.db import (
|
||||
AstraDBCollection,
|
||||
AsyncAstraDBCollection,
|
||||
)
|
||||
|
||||
ASTRA_DB_APPLICATION_TOKEN = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
|
||||
ASTRA_DB_API_ENDPOINT = os.getenv("ASTRA_DB_API_ENDPOINT")
|
||||
ASTRA_DB_KEYSPACE = os.getenv("ASTRA_DB_KEYSPACE")
|
||||
@@ -28,7 +35,7 @@ def _has_env_vars() -> bool:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def astra_db_collection():
|
||||
def astra_db_collection() -> "AstraDBCollection":
|
||||
from astrapy.db import AstraDB
|
||||
|
||||
astra_db = AstraDB(
|
||||
@@ -38,21 +45,41 @@ def astra_db_collection():
|
||||
)
|
||||
collection_name = f"lc_test_loader_{str(uuid.uuid4()).split('-')[0]}"
|
||||
collection = astra_db.create_collection(collection_name)
|
||||
collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20)
|
||||
collection.insert_many(
|
||||
[{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4
|
||||
)
|
||||
|
||||
yield collection
|
||||
|
||||
astra_db.delete_collection(collection_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_astra_db_collection() -> "AsyncAstraDBCollection":
|
||||
from astrapy.db import AsyncAstraDB
|
||||
|
||||
astra_db = AsyncAstraDB(
|
||||
token=ASTRA_DB_APPLICATION_TOKEN,
|
||||
api_endpoint=ASTRA_DB_API_ENDPOINT,
|
||||
namespace=ASTRA_DB_KEYSPACE,
|
||||
)
|
||||
collection_name = f"lc_test_loader_{str(uuid.uuid4()).split('-')[0]}"
|
||||
collection = await astra_db.create_collection(collection_name)
|
||||
await collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20)
|
||||
await collection.insert_many(
|
||||
[{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4
|
||||
)
|
||||
|
||||
yield collection
|
||||
|
||||
await astra_db.delete_collection(collection_name)
|
||||
|
||||
|
||||
@pytest.mark.requires("astrapy")
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
class TestAstraDB:
|
||||
def test_astradb_loader(self, astra_db_collection) -> None:
|
||||
astra_db_collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20)
|
||||
astra_db_collection.insert_many(
|
||||
[{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4
|
||||
)
|
||||
|
||||
def test_astradb_loader(self, astra_db_collection: "AstraDBCollection") -> None:
|
||||
loader = AstraDBLoader(
|
||||
astra_db_collection.collection_name,
|
||||
token=ASTRA_DB_APPLICATION_TOKEN,
|
||||
@@ -79,9 +106,9 @@ class TestAstraDB:
|
||||
"collection": astra_db_collection.collection_name,
|
||||
}
|
||||
|
||||
def test_extraction_function(self, astra_db_collection) -> None:
|
||||
astra_db_collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20)
|
||||
|
||||
def test_extraction_function(
|
||||
self, astra_db_collection: "AstraDBCollection"
|
||||
) -> None:
|
||||
loader = AstraDBLoader(
|
||||
astra_db_collection.collection_name,
|
||||
token=ASTRA_DB_APPLICATION_TOKEN,
|
||||
@@ -94,3 +121,51 @@ class TestAstraDB:
|
||||
doc = next(docs)
|
||||
|
||||
assert doc.page_content == "bar"
|
||||
|
||||
async def test_astradb_loader_async(
|
||||
self, async_astra_db_collection: "AsyncAstraDBCollection"
|
||||
) -> None:
|
||||
await async_astra_db_collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20)
|
||||
await async_astra_db_collection.insert_many(
|
||||
[{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4
|
||||
)
|
||||
|
||||
loader = AstraDBLoader(
|
||||
async_astra_db_collection.collection_name,
|
||||
token=ASTRA_DB_APPLICATION_TOKEN,
|
||||
api_endpoint=ASTRA_DB_API_ENDPOINT,
|
||||
namespace=ASTRA_DB_KEYSPACE,
|
||||
nb_prefetched=1,
|
||||
projection={"foo": 1},
|
||||
find_options={"limit": 22},
|
||||
filter_criteria={"foo": "bar"},
|
||||
)
|
||||
docs = await loader.aload()
|
||||
|
||||
assert len(docs) == 22
|
||||
ids = set()
|
||||
for doc in docs:
|
||||
content = json.loads(doc.page_content)
|
||||
assert content["foo"] == "bar"
|
||||
assert "baz" not in content
|
||||
assert content["_id"] not in ids
|
||||
ids.add(content["_id"])
|
||||
assert doc.metadata == {
|
||||
"namespace": async_astra_db_collection.astra_db.namespace,
|
||||
"api_endpoint": async_astra_db_collection.astra_db.base_url,
|
||||
"collection": async_astra_db_collection.collection_name,
|
||||
}
|
||||
|
||||
async def test_extraction_function_async(
|
||||
self, async_astra_db_collection: "AsyncAstraDBCollection"
|
||||
) -> None:
|
||||
loader = AstraDBLoader(
|
||||
async_astra_db_collection.collection_name,
|
||||
token=ASTRA_DB_APPLICATION_TOKEN,
|
||||
api_endpoint=ASTRA_DB_API_ENDPOINT,
|
||||
namespace=ASTRA_DB_KEYSPACE,
|
||||
find_options={"limit": 30},
|
||||
extraction_function=lambda x: x["foo"],
|
||||
)
|
||||
doc = await anext(loader.alazy_load())
|
||||
assert doc.page_content == "bar"
|
||||
|
||||
@@ -59,11 +59,11 @@ def test_loader_table(keyspace: str) -> None:
|
||||
loader = CassandraLoader(table=CASSANDRA_TABLE)
|
||||
assert loader.load() == [
|
||||
Document(
|
||||
page_content='{"row_id": "id1", "body_blob": "text1"}',
|
||||
page_content="Row(row_id='id1', body_blob='text1')",
|
||||
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
|
||||
),
|
||||
Document(
|
||||
page_content='{"row_id": "id2", "body_blob": "text2"}',
|
||||
page_content="Row(row_id='id2', body_blob='text2')",
|
||||
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
|
||||
),
|
||||
]
|
||||
@@ -74,8 +74,8 @@ def test_loader_query(keyspace: str) -> None:
|
||||
query=f"SELECT body_blob FROM {keyspace}.{CASSANDRA_TABLE}"
|
||||
)
|
||||
assert loader.load() == [
|
||||
Document(page_content='{"body_blob": "text1"}'),
|
||||
Document(page_content='{"body_blob": "text2"}'),
|
||||
Document(page_content="Row(body_blob='text1')"),
|
||||
Document(page_content="Row(body_blob='text2')"),
|
||||
]
|
||||
|
||||
|
||||
@@ -103,7 +103,7 @@ def test_loader_metadata_mapper(keyspace: str) -> None:
|
||||
loader = CassandraLoader(table=CASSANDRA_TABLE, metadata_mapper=mapper)
|
||||
assert loader.load() == [
|
||||
Document(
|
||||
page_content='{"row_id": "id1", "body_blob": "text1"}',
|
||||
page_content="Row(row_id='id1', body_blob='text1')",
|
||||
metadata={
|
||||
"table": CASSANDRA_TABLE,
|
||||
"keyspace": keyspace,
|
||||
@@ -111,7 +111,7 @@ def test_loader_metadata_mapper(keyspace: str) -> None:
|
||||
},
|
||||
),
|
||||
Document(
|
||||
page_content='{"row_id": "id2", "body_blob": "text2"}',
|
||||
page_content="Row(row_id='id2', body_blob='text2')",
|
||||
metadata={
|
||||
"table": CASSANDRA_TABLE,
|
||||
"keyspace": keyspace,
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Test Baichuan Text Embedding."""
|
||||
from langchain_community.embeddings.baichuan import BaichuanTextEmbeddings
|
||||
|
||||
|
||||
def test_baichuan_embedding_documents() -> None:
|
||||
"""Test Baichuan Text Embedding for documents."""
|
||||
documents = ["今天天气不错", "今天阳光灿烂"]
|
||||
embedding = BaichuanTextEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) == 1024
|
||||
|
||||
|
||||
def test_baichuan_embedding_query() -> None:
|
||||
"""Test Baichuan Text Embedding for query."""
|
||||
document = "所有的小学生都会学过只因兔同笼问题。"
|
||||
embedding = BaichuanTextEmbeddings()
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) == 1024
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Test Base Schema of documents."""
|
||||
from typing import Iterator, List
|
||||
from typing import Iterator
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.document_loaders.base import BaseBlobParser, BaseLoader
|
||||
from langchain_community.document_loaders.base import BaseBlobParser
|
||||
from langchain_community.document_loaders.blob_loaders import Blob
|
||||
|
||||
|
||||
@@ -27,25 +27,3 @@ def test_base_blob_parser() -> None:
|
||||
docs = parser.parse(Blob(data="who?"))
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "foo"
|
||||
|
||||
|
||||
async def test_default_aload() -> None:
|
||||
class FakeLoader(BaseLoader):
|
||||
def load(self) -> List[Document]:
|
||||
return list(self.lazy_load())
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
yield from [
|
||||
Document(page_content="foo"),
|
||||
Document(page_content="bar")
|
||||
]
|
||||
|
||||
loader = FakeLoader()
|
||||
docs = loader.load()
|
||||
assert docs == [Document(page_content="foo"), Document(page_content="bar")]
|
||||
|
||||
# Test that async lazy loading works
|
||||
docs = [doc async for doc in loader.alazy_load()]
|
||||
assert docs == [Document(page_content="foo"), Document(page_content="bar")]
|
||||
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ EXPECTED_ALL = [
|
||||
"BlockchainDocumentLoader",
|
||||
"BraveSearchLoader",
|
||||
"BrowserlessLoader",
|
||||
"CassandraLoader",
|
||||
"CSVLoader",
|
||||
"ChatGPTLoader",
|
||||
"CoNLLULoader",
|
||||
|
||||
@@ -3,6 +3,7 @@ from langchain_community.embeddings import __all__
|
||||
EXPECTED_ALL = [
|
||||
"OpenAIEmbeddings",
|
||||
"AzureOpenAIEmbeddings",
|
||||
"BaichuanTextEmbeddings",
|
||||
"ClarifaiEmbeddings",
|
||||
"CohereEmbeddings",
|
||||
"DatabricksEmbeddings",
|
||||
|
||||
@@ -88,6 +88,7 @@ def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None:
|
||||
"num_ctx": None,
|
||||
"num_gpu": None,
|
||||
"num_thread": None,
|
||||
"num_predict": None,
|
||||
"repeat_last_n": None,
|
||||
"repeat_penalty": None,
|
||||
"stop": [],
|
||||
@@ -133,6 +134,7 @@ def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None:
|
||||
"num_ctx": None,
|
||||
"num_gpu": None,
|
||||
"num_thread": None,
|
||||
"num_predict": None,
|
||||
"repeat_last_n": None,
|
||||
"repeat_penalty": None,
|
||||
"stop": [],
|
||||
|
||||
@@ -16,7 +16,12 @@ from typing import (
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string
|
||||
from langchain_core.messages import (
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
MessageLikeRepresentation,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.runnables import Runnable, RunnableSerializable
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
@@ -49,7 +54,7 @@ def _get_token_ids_default_method(text: str) -> List[int]:
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
LanguageModelInput = Union[PromptValue, str, Sequence[BaseMessage]]
|
||||
LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]]
|
||||
LanguageModelOutput = Union[BaseMessage, str]
|
||||
LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
|
||||
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)
|
||||
|
||||
@@ -34,6 +34,7 @@ from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
convert_to_messages,
|
||||
message_chunk_to_message,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
@@ -144,7 +145,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
elif isinstance(input, str):
|
||||
return StringPromptValue(text=input)
|
||||
elif isinstance(input, Sequence):
|
||||
return ChatPromptValue(messages=input)
|
||||
return ChatPromptValue(messages=convert_to_messages(input))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(input)}. "
|
||||
|
||||
@@ -48,7 +48,12 @@ from langchain_core.callbacks import (
|
||||
from langchain_core.globals import get_llm_cache
|
||||
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
convert_to_messages,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.pydantic_v1 import Field, root_validator, validator
|
||||
@@ -210,7 +215,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
elif isinstance(input, str):
|
||||
return StringPromptValue(text=input)
|
||||
elif isinstance(input, Sequence):
|
||||
return ChatPromptValue(messages=input)
|
||||
return ChatPromptValue(messages=convert_to_messages(input))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(input)}. "
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Sequence, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from langchain_core.messages.ai import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages.base import (
|
||||
@@ -117,6 +117,110 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
|
||||
)
|
||||
|
||||
|
||||
MessageLikeRepresentation = Union[BaseMessage, Tuple[str, str], str, Dict[str, Any]]
|
||||
|
||||
|
||||
def _create_message_from_message_type(
|
||||
message_type: str,
|
||||
content: str,
|
||||
name: Optional[str] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
**additional_kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""Create a message from a message type and content string.
|
||||
|
||||
Args:
|
||||
message_type: str the type of the message (e.g., "human", "ai", etc.)
|
||||
content: str the content string.
|
||||
|
||||
Returns:
|
||||
a message of the appropriate type.
|
||||
"""
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if name is not None:
|
||||
kwargs["name"] = name
|
||||
if tool_call_id is not None:
|
||||
kwargs["tool_call_id"] = tool_call_id
|
||||
if additional_kwargs:
|
||||
kwargs["additional_kwargs"] = additional_kwargs # type: ignore[assignment]
|
||||
if message_type in ("human", "user"):
|
||||
message: BaseMessage = HumanMessage(content=content, **kwargs)
|
||||
elif message_type in ("ai", "assistant"):
|
||||
message = AIMessage(content=content, **kwargs)
|
||||
elif message_type == "system":
|
||||
message = SystemMessage(content=content, **kwargs)
|
||||
elif message_type == "function":
|
||||
message = FunctionMessage(content=content, **kwargs)
|
||||
elif message_type == "tool":
|
||||
message = ToolMessage(content=content, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected message type: {message_type}. Use one of 'human',"
|
||||
f" 'user', 'ai', 'assistant', or 'system'."
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
def _convert_to_message(
|
||||
message: MessageLikeRepresentation,
|
||||
) -> BaseMessage:
|
||||
"""Instantiate a message from a variety of message formats.
|
||||
|
||||
The message format can be one of the following:
|
||||
|
||||
- BaseMessagePromptTemplate
|
||||
- BaseMessage
|
||||
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
|
||||
- dict: a message dict with role and content keys
|
||||
- string: shorthand for ("human", template); e.g., "{user_input}"
|
||||
|
||||
Args:
|
||||
message: a representation of a message in one of the supported formats
|
||||
|
||||
Returns:
|
||||
an instance of a message or a message template
|
||||
"""
|
||||
if isinstance(message, BaseMessage):
|
||||
_message = message
|
||||
elif isinstance(message, str):
|
||||
_message = _create_message_from_message_type("human", message)
|
||||
elif isinstance(message, tuple):
|
||||
if len(message) != 2:
|
||||
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
|
||||
message_type_str, template = message
|
||||
_message = _create_message_from_message_type(message_type_str, template)
|
||||
elif isinstance(message, dict):
|
||||
msg_kwargs = message.copy()
|
||||
try:
|
||||
msg_type = msg_kwargs.pop("role")
|
||||
msg_content = msg_kwargs.pop("content")
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Message dict must contain 'role' and 'content' keys, got {message}"
|
||||
)
|
||||
_message = _create_message_from_message_type(
|
||||
msg_type, msg_content, **msg_kwargs
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
return _message
|
||||
|
||||
|
||||
def convert_to_messages(
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
) -> List[BaseMessage]:
|
||||
"""Convert a sequence of messages to a list of messages.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages to convert.
|
||||
|
||||
Returns:
|
||||
List of messages (BaseMessages).
|
||||
"""
|
||||
return [_convert_to_message(m) for m in messages]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AIMessage",
|
||||
"AIMessageChunk",
|
||||
@@ -133,6 +237,7 @@ __all__ = [
|
||||
"SystemMessageChunk",
|
||||
"ToolMessage",
|
||||
"ToolMessageChunk",
|
||||
"convert_to_messages",
|
||||
"get_buffer_string",
|
||||
"message_chunk_to_message",
|
||||
"messages_from_dict",
|
||||
|
||||
@@ -3,6 +3,8 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Literal, Sequence
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import (
|
||||
AnyMessage,
|
||||
@@ -82,6 +84,30 @@ class ChatPromptValue(PromptValue):
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
|
||||
class ImageURL(TypedDict, total=False):
|
||||
detail: Literal["auto", "low", "high"]
|
||||
"""Specifies the detail level of the image."""
|
||||
|
||||
url: str
|
||||
"""Either a URL of the image or the base64 encoded image data."""
|
||||
|
||||
|
||||
class ImagePromptValue(PromptValue):
|
||||
"""Image prompt value."""
|
||||
|
||||
image_url: ImageURL
|
||||
"""Prompt image."""
|
||||
type: Literal["ImagePromptValue"] = "ImagePromptValue"
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt as string."""
|
||||
return self.image_url["url"]
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as messages."""
|
||||
return [HumanMessage(content=[self.image_url])]
|
||||
|
||||
|
||||
class ChatPromptValueConcrete(ChatPromptValue):
|
||||
"""Chat prompt value which explicitly lists out the message types it accepts.
|
||||
For use in external schemas."""
|
||||
|
||||
@@ -8,10 +8,12 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
@@ -30,7 +32,12 @@ if TYPE_CHECKING:
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||
FormatOutputType = TypeVar("FormatOutputType")
|
||||
|
||||
|
||||
class BasePromptTemplate(
|
||||
RunnableSerializable[Dict, PromptValue], Generic[FormatOutputType], ABC
|
||||
):
|
||||
"""Base class for all prompt templates, returning a prompt."""
|
||||
|
||||
input_variables: List[str]
|
||||
@@ -142,7 +149,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||
return {**partial_kwargs, **kwargs}
|
||||
|
||||
@abstractmethod
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
def format(self, **kwargs: Any) -> FormatOutputType:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
@@ -210,7 +217,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
|
||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
|
||||
"""Format a document into a string based on a prompt template.
|
||||
|
||||
First, this pulls information from the document from two sources:
|
||||
@@ -236,7 +243,7 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core import Document
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
doc = Document(page_content="This is a joke", metadata={"page": "1"})
|
||||
|
||||
@@ -13,8 +13,10 @@ from typing import (
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
@@ -27,12 +29,14 @@ from langchain_core.messages import (
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
convert_to_messages,
|
||||
)
|
||||
from langchain_core.messages.base import get_msg_title_repr
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue
|
||||
from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.image import ImagePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import StringPromptTemplate
|
||||
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils import get_colored_text
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
@@ -126,7 +130,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
f"variable {self.variable_name} should be a list of base messages, "
|
||||
f"got {value}"
|
||||
)
|
||||
for v in value:
|
||||
for v in convert_to_messages(value):
|
||||
if not isinstance(v, BaseMessage):
|
||||
raise ValueError(
|
||||
f"variable {self.variable_name} should be a list of base messages,"
|
||||
@@ -287,14 +291,153 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
)
|
||||
|
||||
|
||||
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
_StringImageMessagePromptTemplateT = TypeVar(
|
||||
"_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate"
|
||||
)
|
||||
|
||||
|
||||
class _TextTemplateParam(TypedDict, total=False):
|
||||
text: Union[str, Dict]
|
||||
|
||||
|
||||
class _ImageTemplateParam(TypedDict, total=False):
|
||||
image_url: Union[str, Dict]
|
||||
|
||||
|
||||
class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
"""Human message prompt template. This is a message sent from the user."""
|
||||
|
||||
prompt: Union[
|
||||
StringPromptTemplate, List[Union[StringPromptTemplate, ImagePromptTemplate]]
|
||||
]
|
||||
"""Prompt template."""
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
"""Additional keyword arguments to pass to the prompt template."""
|
||||
|
||||
_msg_class: Type[BaseMessage]
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls: Type[_StringImageMessagePromptTemplateT],
|
||||
template: Union[str, List[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
|
||||
template_format: str = "f-string",
|
||||
**kwargs: Any,
|
||||
) -> _StringImageMessagePromptTemplateT:
|
||||
"""Create a class from a string template.
|
||||
|
||||
Args:
|
||||
template: a template.
|
||||
template_format: format of the template.
|
||||
**kwargs: keyword arguments to pass to the constructor.
|
||||
|
||||
Returns:
|
||||
A new instance of this class.
|
||||
"""
|
||||
if isinstance(template, str):
|
||||
prompt: Union[StringPromptTemplate, List] = PromptTemplate.from_template(
|
||||
template, template_format=template_format
|
||||
)
|
||||
return cls(prompt=prompt, **kwargs)
|
||||
elif isinstance(template, list):
|
||||
prompt = []
|
||||
for tmpl in template:
|
||||
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
|
||||
if isinstance(tmpl, str):
|
||||
text: str = tmpl
|
||||
else:
|
||||
text = cast(_TextTemplateParam, tmpl)["text"] # type: ignore[assignment] # noqa: E501
|
||||
prompt.append(
|
||||
PromptTemplate.from_template(
|
||||
text, template_format=template_format
|
||||
)
|
||||
)
|
||||
elif isinstance(tmpl, dict) and "image_url" in tmpl:
|
||||
img_template = cast(_ImageTemplateParam, tmpl)["image_url"]
|
||||
if isinstance(img_template, str):
|
||||
vars = get_template_variables(img_template, "f-string")
|
||||
if vars:
|
||||
if len(vars) > 1:
|
||||
raise ValueError(
|
||||
"Only one format variable allowed per image"
|
||||
f" template.\nGot: {vars}"
|
||||
f"\nFrom: {tmpl}"
|
||||
)
|
||||
input_variables = [vars[0]]
|
||||
else:
|
||||
input_variables = None
|
||||
img_template = {"url": img_template}
|
||||
img_template_obj = ImagePromptTemplate(
|
||||
input_variables=input_variables, template=img_template
|
||||
)
|
||||
elif isinstance(img_template, dict):
|
||||
img_template = dict(img_template)
|
||||
if "url" in img_template:
|
||||
input_variables = get_template_variables(
|
||||
img_template["url"], "f-string"
|
||||
)
|
||||
else:
|
||||
input_variables = None
|
||||
img_template_obj = ImagePromptTemplate(
|
||||
input_variables=input_variables, template=img_template
|
||||
)
|
||||
else:
|
||||
raise ValueError()
|
||||
prompt.append(img_template_obj)
|
||||
else:
|
||||
raise ValueError()
|
||||
return cls(prompt=prompt, **kwargs)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
@classmethod
|
||||
def from_template_file(
|
||||
cls: Type[_StringImageMessagePromptTemplateT],
|
||||
template_file: Union[str, Path],
|
||||
input_variables: List[str],
|
||||
**kwargs: Any,
|
||||
) -> _StringImageMessagePromptTemplateT:
|
||||
"""Create a class from a template file.
|
||||
|
||||
Args:
|
||||
template_file: path to a template file. String or Path.
|
||||
input_variables: list of input variables.
|
||||
**kwargs: keyword arguments to pass to the constructor.
|
||||
|
||||
Returns:
|
||||
A new instance of this class.
|
||||
"""
|
||||
with open(str(template_file), "r") as f:
|
||||
template = f.read()
|
||||
return cls.from_template(template, input_variables=input_variables, **kwargs)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
List of BaseMessages.
|
||||
"""
|
||||
return [self.format(**kwargs)]
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""
|
||||
Input variables for this prompt template.
|
||||
|
||||
Returns:
|
||||
List of input variable names.
|
||||
"""
|
||||
prompts = self.prompt if isinstance(self.prompt, list) else [self.prompt]
|
||||
input_variables = [iv for prompt in prompts for iv in prompt.input_variables]
|
||||
return input_variables
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
@@ -304,53 +447,55 @@ class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
text = self.prompt.format(**kwargs)
|
||||
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
if isinstance(self.prompt, StringPromptTemplate):
|
||||
text = self.prompt.format(**kwargs)
|
||||
return self._msg_class(
|
||||
content=text, additional_kwargs=self.additional_kwargs
|
||||
)
|
||||
else:
|
||||
content = []
|
||||
for prompt in self.prompt:
|
||||
inputs = {var: kwargs[var] for var in prompt.input_variables}
|
||||
if isinstance(prompt, StringPromptTemplate):
|
||||
formatted: Union[str, ImageURL] = prompt.format(**inputs)
|
||||
content.append({"type": "text", "text": formatted})
|
||||
elif isinstance(prompt, ImagePromptTemplate):
|
||||
formatted = prompt.format(**inputs)
|
||||
content.append({"type": "image_url", "image_url": formatted})
|
||||
return self._msg_class(
|
||||
content=content, additional_kwargs=self.additional_kwargs
|
||||
)
|
||||
|
||||
|
||||
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
class HumanMessagePromptTemplate(_StringImageMessagePromptTemplate):
|
||||
"""Human message prompt template. This is a message sent from the user."""
|
||||
|
||||
_msg_class: Type[BaseMessage] = HumanMessage
|
||||
|
||||
|
||||
class AIMessagePromptTemplate(_StringImageMessagePromptTemplate):
|
||||
"""AI message prompt template. This is a message sent from the AI."""
|
||||
|
||||
_msg_class: Type[BaseMessage] = AIMessage
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
text = self.prompt.format(**kwargs)
|
||||
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
|
||||
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
class SystemMessagePromptTemplate(_StringImageMessagePromptTemplate):
|
||||
"""System message prompt template.
|
||||
This is a message that is not sent to the user.
|
||||
"""
|
||||
|
||||
_msg_class: Type[BaseMessage] = SystemMessage
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
text = self.prompt.format(**kwargs)
|
||||
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
|
||||
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""Base class for chat prompt templates."""
|
||||
@@ -404,8 +549,7 @@ MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTempla
|
||||
|
||||
MessageLikeRepresentation = Union[
|
||||
MessageLike,
|
||||
Tuple[str, str],
|
||||
Tuple[Type, str],
|
||||
Tuple[Union[str, Type], Union[str, List[dict], List[object]]],
|
||||
str,
|
||||
]
|
||||
|
||||
@@ -737,7 +881,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
|
||||
|
||||
def _create_template_from_message_type(
|
||||
message_type: str, template: str
|
||||
message_type: str, template: Union[str, list]
|
||||
) -> BaseMessagePromptTemplate:
|
||||
"""Create a message prompt template from a message type and template string.
|
||||
|
||||
@@ -753,9 +897,9 @@ def _create_template_from_message_type(
|
||||
template
|
||||
)
|
||||
elif message_type in ("ai", "assistant"):
|
||||
message = AIMessagePromptTemplate.from_template(template)
|
||||
message = AIMessagePromptTemplate.from_template(cast(str, template))
|
||||
elif message_type == "system":
|
||||
message = SystemMessagePromptTemplate.from_template(template)
|
||||
message = SystemMessagePromptTemplate.from_template(cast(str, template))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected message type: {message_type}. Use one of 'human',"
|
||||
@@ -798,7 +942,9 @@ def _convert_to_message(
|
||||
if isinstance(message_type_str, str):
|
||||
_message = _create_template_from_message_type(message_type_str, template)
|
||||
else:
|
||||
_message = message_type_str(prompt=PromptTemplate.from_template(template))
|
||||
_message = message_type_str(
|
||||
prompt=PromptTemplate.from_template(cast(str, template))
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
|
||||
76
libs/core/langchain_core/prompts/image.py
Normal file
76
libs/core/langchain_core/prompts/image.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.utils import image as image_utils
|
||||
|
||||
|
||||
class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
||||
"""An image prompt template for a multimodal model."""
|
||||
|
||||
template: dict = Field(default_factory=dict)
|
||||
"""Template for the prompt."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
if "input_variables" not in kwargs:
|
||||
kwargs["input_variables"] = []
|
||||
|
||||
overlap = set(kwargs["input_variables"]) & set(("url", "path", "detail"))
|
||||
if overlap:
|
||||
raise ValueError(
|
||||
"input_variables for the image template cannot contain"
|
||||
" any of 'url', 'path', or 'detail'."
|
||||
f" Found: {overlap}"
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
return "image-prompt"
|
||||
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""Create Chat Messages."""
|
||||
return ImagePromptValue(image_url=self.format(**kwargs))
|
||||
|
||||
def format(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> ImageURL:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt.format(variable1="foo")
|
||||
"""
|
||||
formatted = {}
|
||||
for k, v in self.template.items():
|
||||
if isinstance(v, str):
|
||||
formatted[k] = v.format(**kwargs)
|
||||
else:
|
||||
formatted[k] = v
|
||||
url = kwargs.get("url") or formatted.get("url")
|
||||
path = kwargs.get("path") or formatted.get("path")
|
||||
detail = kwargs.get("detail") or formatted.get("detail")
|
||||
if not url and not path:
|
||||
raise ValueError("Must provide either url or path.")
|
||||
if not url:
|
||||
if not isinstance(path, str):
|
||||
raise ValueError("path must be a string.")
|
||||
url = image_utils.image_to_data_url(path)
|
||||
if not isinstance(url, str):
|
||||
raise ValueError("url must be a string.")
|
||||
output: ImageURL = {"url": url}
|
||||
if detail:
|
||||
# Don't check literal values here: let the API check them
|
||||
output["detail"] = detail # type: ignore[typeddict-item]
|
||||
return output
|
||||
@@ -20,7 +20,7 @@ from typing import (
|
||||
from uuid import UUID
|
||||
|
||||
import jsonpatch # type: ignore[import]
|
||||
from anyio import create_memory_object_stream
|
||||
from anyio import BrokenResourceError, ClosedResourceError, create_memory_object_stream
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain_core.load import dumps
|
||||
@@ -223,6 +223,14 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
def __aiter__(self) -> AsyncIterator[RunLogPatch]:
|
||||
return self.receive_stream.__aiter__()
|
||||
|
||||
def send(self, *ops: Dict[str, Any]) -> bool:
|
||||
"""Send a patch to the stream, return False if the stream is closed."""
|
||||
try:
|
||||
self.send_stream.send_nowait(RunLogPatch(*ops))
|
||||
return True
|
||||
except (ClosedResourceError, BrokenResourceError):
|
||||
return False
|
||||
|
||||
async def tap_output_aiter(
|
||||
self, run_id: UUID, output: AsyncIterator[T]
|
||||
) -> AsyncIterator[T]:
|
||||
@@ -233,15 +241,14 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
# if we can't find the run silently ignore
|
||||
# eg. because this run wasn't included in the log
|
||||
if key := self._key_map_by_run_id.get(run_id):
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{key}/streamed_output/-",
|
||||
"value": chunk,
|
||||
}
|
||||
)
|
||||
)
|
||||
if not self.send(
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{key}/streamed_output/-",
|
||||
"value": chunk,
|
||||
}
|
||||
):
|
||||
break
|
||||
|
||||
yield chunk
|
||||
|
||||
@@ -285,22 +292,21 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
"""Start a run."""
|
||||
if self.root_id is None:
|
||||
self.root_id = run.id
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "",
|
||||
"value": RunState(
|
||||
id=str(run.id),
|
||||
streamed_output=[],
|
||||
final_output=None,
|
||||
logs={},
|
||||
name=run.name,
|
||||
type=run.run_type,
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
if not self.send(
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "",
|
||||
"value": RunState(
|
||||
id=str(run.id),
|
||||
streamed_output=[],
|
||||
final_output=None,
|
||||
logs={},
|
||||
name=run.name,
|
||||
type=run.run_type,
|
||||
),
|
||||
}
|
||||
):
|
||||
return
|
||||
|
||||
if not self.include_run(run):
|
||||
return
|
||||
@@ -331,14 +337,12 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
entry["inputs"] = _get_standardized_inputs(run, self._schema_format)
|
||||
|
||||
# Add the run to the stream
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{self._key_map_by_run_id[run.id]}",
|
||||
"value": entry,
|
||||
}
|
||||
)
|
||||
self.send(
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{self._key_map_by_run_id[run.id]}",
|
||||
"value": entry,
|
||||
}
|
||||
)
|
||||
|
||||
def _on_run_update(self, run: Run) -> None:
|
||||
@@ -382,7 +386,7 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
]
|
||||
)
|
||||
|
||||
self.send_stream.send_nowait(RunLogPatch(*ops))
|
||||
self.send(*ops)
|
||||
finally:
|
||||
if run.id == self.root_id:
|
||||
if self.auto_close:
|
||||
@@ -400,21 +404,19 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
if index is None:
|
||||
return
|
||||
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{index}/streamed_output_str/-",
|
||||
"value": token,
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{index}/streamed_output/-",
|
||||
"value": chunk.message
|
||||
if isinstance(chunk, ChatGenerationChunk)
|
||||
else token,
|
||||
},
|
||||
)
|
||||
self.send(
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{index}/streamed_output_str/-",
|
||||
"value": token,
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{index}/streamed_output/-",
|
||||
"value": chunk.message
|
||||
if isinstance(chunk, ChatGenerationChunk)
|
||||
else token,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
These functions do not depend on any other LangChain module.
|
||||
"""
|
||||
|
||||
from langchain_core.utils import image
|
||||
from langchain_core.utils.env import get_from_dict_or_env, get_from_env
|
||||
from langchain_core.utils.formatting import StrictFormatter, formatter
|
||||
from langchain_core.utils.input import (
|
||||
@@ -41,6 +42,7 @@ __all__ = [
|
||||
"xor_args",
|
||||
"try_load_from_hub",
|
||||
"build_extra_kwargs",
|
||||
"image",
|
||||
"get_from_env",
|
||||
"get_from_dict_or_env",
|
||||
"stringify_dict",
|
||||
|
||||
14
libs/core/langchain_core/utils/image.py
Normal file
14
libs/core/langchain_core/utils/image.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import base64
|
||||
import mimetypes
|
||||
|
||||
|
||||
def encode_image(image_path: str) -> str:
|
||||
"""Get base64 string from image URI."""
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
|
||||
def image_to_data_url(image_path: str) -> str:
|
||||
encoding = encode_image(image_path)
|
||||
mime_type = mimetypes.guess_type(image_path)[0]
|
||||
return f"data:{mime_type};base64,{encoding}"
|
||||
@@ -301,3 +301,24 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "generic-fake-chat-model"
|
||||
|
||||
|
||||
class ParrotFakeChatModel(BaseChatModel):
|
||||
"""A generic fake chat model that can be used to test the chat model interface.
|
||||
|
||||
* Chat model should be usable in both sync and async tests
|
||||
"""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
return ChatResult(generations=[ChatGeneration(message=messages[-1])])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "parrot-fake-chat-model"
|
||||
|
||||
@@ -5,8 +5,9 @@ from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.messages.human import HumanMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
|
||||
from tests.unit_tests.fake.chat_model import GenericFakeChatModel, ParrotFakeChatModel
|
||||
|
||||
|
||||
def test_generic_fake_chat_model_invoke() -> None:
|
||||
@@ -182,3 +183,11 @@ async def test_callback_handlers() -> None:
|
||||
AIMessageChunk(content="goodbye"),
|
||||
]
|
||||
assert tokens == ["hello", " ", "goodbye"]
|
||||
|
||||
|
||||
def test_chat_model_inputs() -> None:
|
||||
fake = ParrotFakeChatModel()
|
||||
|
||||
assert fake.invoke("hello") == HumanMessage(content="hello")
|
||||
assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah")
|
||||
assert fake.invoke([AIMessage(content="blah")]) == AIMessage(content="blah")
|
||||
|
||||
@@ -16,6 +16,7 @@ EXPECTED_ALL = [
|
||||
"SystemMessageChunk",
|
||||
"ToolMessage",
|
||||
"ToolMessageChunk",
|
||||
"convert_to_messages",
|
||||
"get_buffer_string",
|
||||
"message_chunk_to_message",
|
||||
"messages_from_dict",
|
||||
|
||||
@@ -3,6 +3,9 @@ from typing import Any, List, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core._api.deprecation import (
|
||||
LangChainPendingDeprecationWarning,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
@@ -243,14 +246,15 @@ def test_chat_valid_infer_variables() -> None:
|
||||
|
||||
def test_chat_from_role_strings() -> None:
|
||||
"""Test instantiation of chat template from role strings."""
|
||||
template = ChatPromptTemplate.from_role_strings(
|
||||
[
|
||||
("system", "You are a bot."),
|
||||
("assistant", "hello!"),
|
||||
("human", "{question}"),
|
||||
("other", "{quack}"),
|
||||
]
|
||||
)
|
||||
with pytest.warns(LangChainPendingDeprecationWarning):
|
||||
template = ChatPromptTemplate.from_role_strings(
|
||||
[
|
||||
("system", "You are a bot."),
|
||||
("assistant", "hello!"),
|
||||
("human", "{question}"),
|
||||
("other", "{quack}"),
|
||||
]
|
||||
)
|
||||
|
||||
messages = template.format_messages(question="How are you?", quack="duck")
|
||||
assert messages == [
|
||||
@@ -363,9 +367,145 @@ def test_chat_message_partial() -> None:
|
||||
assert template2.format(input="hello") == get_buffer_string(expected)
|
||||
|
||||
|
||||
def test_chat_tmpl_from_messages_multipart_text() -> None:
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are an AI assistant named {name}."),
|
||||
(
|
||||
"human",
|
||||
[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "text", "text": "Oh nvm"},
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
messages = template.format_messages(name="R2D2")
|
||||
expected = [
|
||||
SystemMessage(content="You are an AI assistant named R2D2."),
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "text", "text": "Oh nvm"},
|
||||
]
|
||||
),
|
||||
]
|
||||
assert messages == expected
|
||||
|
||||
|
||||
def test_chat_tmpl_from_messages_multipart_text_with_template() -> None:
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are an AI assistant named {name}."),
|
||||
(
|
||||
"human",
|
||||
[
|
||||
{"type": "text", "text": "What's in this {object_name}?"},
|
||||
{"type": "text", "text": "Oh nvm"},
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
messages = template.format_messages(name="R2D2", object_name="image")
|
||||
expected = [
|
||||
SystemMessage(content="You are an AI assistant named R2D2."),
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "text", "text": "Oh nvm"},
|
||||
]
|
||||
),
|
||||
]
|
||||
assert messages == expected
|
||||
|
||||
|
||||
def test_chat_tmpl_from_messages_multipart_image() -> None:
|
||||
base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
|
||||
other_base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are an AI assistant named {name}."),
|
||||
(
|
||||
"human",
|
||||
[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": "data:image/jpeg;base64,{my_image}",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,{my_image}"},
|
||||
},
|
||||
{"type": "image_url", "image_url": "{my_other_image}"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "{my_other_image}", "detail": "medium"},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://www.langchain.com/image.png"},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,foobar"},
|
||||
},
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
messages = template.format_messages(
|
||||
name="R2D2", my_image=base64_image, my_other_image=other_base64_image
|
||||
)
|
||||
expected = [
|
||||
SystemMessage(content="You are an AI assistant named R2D2."),
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{other_base64_image}"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"{other_base64_image}"},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"{other_base64_image}",
|
||||
"detail": "medium",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://www.langchain.com/image.png"},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,foobar"},
|
||||
},
|
||||
]
|
||||
),
|
||||
]
|
||||
assert messages == expected
|
||||
|
||||
|
||||
def test_messages_placeholder() -> None:
|
||||
prompt = MessagesPlaceholder("history")
|
||||
with pytest.raises(KeyError):
|
||||
prompt.format_messages()
|
||||
prompt = MessagesPlaceholder("history", optional=True)
|
||||
assert prompt.format_messages() == []
|
||||
prompt.format_messages(
|
||||
history=[("system", "You are an AI assistant."), "Hello!"]
|
||||
) == [
|
||||
SystemMessage(content="You are an AI assistant."),
|
||||
HumanMessage(content="Hello!"),
|
||||
]
|
||||
|
||||
@@ -14,6 +14,7 @@ from langchain_core.messages import (
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
convert_to_messages,
|
||||
get_buffer_string,
|
||||
message_chunk_to_message,
|
||||
messages_from_dict,
|
||||
@@ -428,3 +429,54 @@ def test_tool_calls_merge() -> None:
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_convert_to_messages() -> None:
|
||||
# dicts
|
||||
assert convert_to_messages(
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "ai", "content": "Hi!"},
|
||||
{"role": "human", "content": "Hello!", "name": "Jane"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hi!",
|
||||
"name": "JaneBot",
|
||||
"function_call": {"name": "greet", "arguments": '{"name": "Jane"}'},
|
||||
},
|
||||
{"role": "function", "name": "greet", "content": "Hi!"},
|
||||
{"role": "tool", "tool_call_id": "tool_id", "content": "Hi!"},
|
||||
]
|
||||
) == [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Hello!"),
|
||||
AIMessage(content="Hi!"),
|
||||
HumanMessage(content="Hello!", name="Jane"),
|
||||
AIMessage(
|
||||
content="Hi!",
|
||||
name="JaneBot",
|
||||
additional_kwargs={
|
||||
"function_call": {"name": "greet", "arguments": '{"name": "Jane"}'}
|
||||
},
|
||||
),
|
||||
FunctionMessage(name="greet", content="Hi!"),
|
||||
ToolMessage(tool_call_id="tool_id", content="Hi!"),
|
||||
]
|
||||
|
||||
# tuples
|
||||
assert convert_to_messages(
|
||||
[
|
||||
("system", "You are a helpful assistant."),
|
||||
"hello!",
|
||||
("ai", "Hi!"),
|
||||
("human", "Hello!"),
|
||||
("assistant", "Hi!"),
|
||||
]
|
||||
) == [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="hello!"),
|
||||
AIMessage(content="Hi!"),
|
||||
HumanMessage(content="Hello!"),
|
||||
AIMessage(content="Hi!"),
|
||||
]
|
||||
|
||||
@@ -16,6 +16,7 @@ EXPECTED_ALL = [
|
||||
"xor_args",
|
||||
"try_load_from_hub",
|
||||
"build_extra_kwargs",
|
||||
"image",
|
||||
"get_from_dict_or_env",
|
||||
"get_from_env",
|
||||
"stringify_dict",
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.language_models import LanguageModelLike
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.retrievers import RetrieverLike, RetrieverOutputLike
|
||||
@@ -48,13 +51,31 @@ def create_history_aware_retriever(
|
||||
chain.invoke({"input": "...", "chat_history": })
|
||||
|
||||
"""
|
||||
if "input" not in prompt.input_variables:
|
||||
input_vars = prompt.input_variables
|
||||
if "input" not in input_vars and "messages" not in input_vars:
|
||||
raise ValueError(
|
||||
"Expected `input` to be a prompt variable, "
|
||||
f"but got {prompt.input_variables}"
|
||||
"Expected either `input` or `messages` to be prompt variables, "
|
||||
f"but got {input_vars}"
|
||||
)
|
||||
|
||||
def messages_param_is_message_list(x: Dict) -> bool:
|
||||
return (
|
||||
isinstance(x.get("messages", []), list)
|
||||
and len(x.get("messages", [])) > 0
|
||||
and all(isinstance(i, BaseMessage) for i in x.get("messages", []))
|
||||
)
|
||||
|
||||
retrieve_documents: RetrieverOutputLike = RunnableBranch(
|
||||
(
|
||||
lambda x: messages_param_is_message_list(x)
|
||||
and len(x.get("messages", [])) > 1,
|
||||
prompt | llm | StrOutputParser() | retriever,
|
||||
),
|
||||
(
|
||||
lambda x: messages_param_is_message_list(x)
|
||||
and len(x.get("messages", [])) == 1,
|
||||
(lambda x: x["messages"][-1].content) | retriever,
|
||||
),
|
||||
(
|
||||
# Both empty string and empty list evaluate to False
|
||||
lambda x: not x.get("chat_history", False),
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.retrievers import (
|
||||
BaseRetriever,
|
||||
RetrieverOutput,
|
||||
@@ -55,10 +56,30 @@ def create_retrieval_chain(
|
||||
chain.invoke({"input": "..."})
|
||||
|
||||
"""
|
||||
|
||||
def messages_param_is_message_list(x: Dict) -> bool:
|
||||
return (
|
||||
isinstance(x.get("messages", []), list)
|
||||
and len(x.get("messages", [])) > 0
|
||||
and all(isinstance(i, BaseMessage) for i in x.get("messages", []))
|
||||
)
|
||||
|
||||
def extract_retriever_input_string(x: Dict) -> str:
|
||||
if not x.get("input"):
|
||||
if messages_param_is_message_list(x):
|
||||
return x["messages"][-1].content
|
||||
else:
|
||||
raise ValueError(
|
||||
"If `input` not provided, ",
|
||||
"`messages` parameter must be a list of messages.",
|
||||
)
|
||||
else:
|
||||
return x["input"]
|
||||
|
||||
if not isinstance(retriever, BaseRetriever):
|
||||
retrieval_docs: Runnable[dict, RetrieverOutput] = retriever
|
||||
else:
|
||||
retrieval_docs = (lambda x: x["input"]) | retriever
|
||||
retrieval_docs = extract_retriever_input_string | retriever
|
||||
|
||||
retrieval_chain = (
|
||||
RunnablePassthrough.assign(
|
||||
|
||||
@@ -391,7 +391,7 @@ async def _to_async_iterator(iterator: Iterable[T]) -> AsyncIterator[T]:
|
||||
|
||||
|
||||
async def aindex(
|
||||
docs_source: Union[BaseLoader, Iterable[Document], AsyncIterator[Document]],
|
||||
docs_source: Union[Iterable[Document], AsyncIterator[Document]],
|
||||
record_manager: RecordManager,
|
||||
vector_store: VectorStore,
|
||||
*,
|
||||
@@ -469,17 +469,16 @@ async def aindex(
|
||||
# implementation which just raises a NotImplementedError
|
||||
raise ValueError("Vectorstore has not implemented the delete method")
|
||||
|
||||
async_doc_iterator: AsyncIterator[Document]
|
||||
if isinstance(docs_source, BaseLoader):
|
||||
try:
|
||||
async_doc_iterator = docs_source.alazy_load()
|
||||
except NotImplementedError:
|
||||
async_doc_iterator = _to_async_iterator(await docs_source.aload())
|
||||
raise NotImplementedError(
|
||||
"Not supported yet. Please pass an async iterator of documents."
|
||||
)
|
||||
async_doc_iterator: AsyncIterator[Document]
|
||||
|
||||
if hasattr(docs_source, "__aiter__"):
|
||||
async_doc_iterator = docs_source # type: ignore[assignment]
|
||||
else:
|
||||
if hasattr(docs_source, "__aiter__"):
|
||||
async_doc_iterator = docs_source # type: ignore[assignment]
|
||||
else:
|
||||
async_doc_iterator = _to_async_iterator(docs_source)
|
||||
async_doc_iterator = _to_async_iterator(docs_source)
|
||||
|
||||
source_id_assigner = _get_source_id_assigner(source_id_key)
|
||||
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
"""Test conversation chain and memory."""
|
||||
from langchain_community.llms.fake import FakeListLLM
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.prompts import (
|
||||
ChatPromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
PromptTemplate,
|
||||
)
|
||||
|
||||
from langchain.chains import create_retrieval_chain
|
||||
from tests.unit_tests.retrievers.parrot_retriever import FakeParrotRetriever
|
||||
@@ -22,3 +27,31 @@ def test_create() -> None:
|
||||
}
|
||||
output = chain.invoke({"input": "What is the answer?", "chat_history": "foo"})
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_create_with_chat_history_messages_only() -> None:
|
||||
answer = "I know the answer!"
|
||||
llm = FakeListLLM(responses=[answer])
|
||||
retriever = FakeParrotRetriever()
|
||||
question_gen_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
chain = create_retrieval_chain(retriever, question_gen_prompt | llm)
|
||||
|
||||
expected_output = {
|
||||
"answer": "I know the answer!",
|
||||
"messages": [
|
||||
HumanMessage(content="What is the answer?"),
|
||||
],
|
||||
"context": [Document(page_content="What is the answer?")],
|
||||
}
|
||||
output = chain.invoke(
|
||||
{
|
||||
"messages": [
|
||||
HumanMessage(content="What is the answer?"),
|
||||
],
|
||||
}
|
||||
)
|
||||
assert output == expected_output
|
||||
|
||||
@@ -40,6 +40,19 @@ class ToyLoader(BaseLoader):
|
||||
"""Load the documents from the source."""
|
||||
return list(self.lazy_load())
|
||||
|
||||
async def alazy_load(
|
||||
self,
|
||||
) -> AsyncIterator[Document]:
|
||||
async def async_generator() -> AsyncIterator[Document]:
|
||||
for document in self.documents:
|
||||
yield document
|
||||
|
||||
return async_generator()
|
||||
|
||||
async def aload(self) -> List[Document]:
|
||||
"""Load the documents from the source."""
|
||||
return [doc async for doc in await self.alazy_load()]
|
||||
|
||||
|
||||
class InMemoryVectorStore(VectorStore):
|
||||
"""In-memory implementation of VectorStore using a dictionary."""
|
||||
@@ -219,7 +232,7 @@ async def test_aindexing_same_content(
|
||||
]
|
||||
)
|
||||
|
||||
assert await aindex(loader, arecord_manager, vector_store) == {
|
||||
assert await aindex(await loader.alazy_load(), arecord_manager, vector_store) == {
|
||||
"num_added": 2,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 0,
|
||||
@@ -230,7 +243,9 @@ async def test_aindexing_same_content(
|
||||
|
||||
for _ in range(2):
|
||||
# Run the indexing again
|
||||
assert await aindex(loader, arecord_manager, vector_store) == {
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store
|
||||
) == {
|
||||
"num_added": 0,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 2,
|
||||
@@ -332,7 +347,9 @@ async def test_aindex_simple_delete_full(
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp()
|
||||
):
|
||||
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
|
||||
) == {
|
||||
"num_added": 2,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 0,
|
||||
@@ -342,7 +359,9 @@ async def test_aindex_simple_delete_full(
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp()
|
||||
):
|
||||
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
|
||||
) == {
|
||||
"num_added": 0,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 2,
|
||||
@@ -363,7 +382,9 @@ async def test_aindex_simple_delete_full(
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
|
||||
) == {
|
||||
"num_added": 1,
|
||||
"num_deleted": 1,
|
||||
"num_skipped": 1,
|
||||
@@ -381,7 +402,9 @@ async def test_aindex_simple_delete_full(
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
|
||||
) == {
|
||||
"num_added": 0,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 2,
|
||||
@@ -450,7 +473,7 @@ async def test_aincremental_fails_with_bad_source_ids(
|
||||
with pytest.raises(ValueError):
|
||||
# Should raise an error because no source id function was specified
|
||||
await aindex(
|
||||
loader,
|
||||
await loader.alazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup="incremental",
|
||||
@@ -459,7 +482,7 @@ async def test_aincremental_fails_with_bad_source_ids(
|
||||
with pytest.raises(ValueError):
|
||||
# Should raise an error because no source id function was specified
|
||||
await aindex(
|
||||
loader,
|
||||
await loader.alazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup="incremental",
|
||||
@@ -570,7 +593,7 @@ async def test_ano_delete(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
loader,
|
||||
await loader.alazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup=None,
|
||||
@@ -587,7 +610,7 @@ async def test_ano_delete(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
loader,
|
||||
await loader.alazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup=None,
|
||||
@@ -617,7 +640,7 @@ async def test_ano_delete(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
loader,
|
||||
await loader.alazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup=None,
|
||||
@@ -756,7 +779,7 @@ async def test_aincremental_delete(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
loader.lazy_load(),
|
||||
await loader.alazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup="incremental",
|
||||
@@ -780,7 +803,7 @@ async def test_aincremental_delete(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
loader.lazy_load(),
|
||||
await loader.alazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup="incremental",
|
||||
@@ -815,7 +838,7 @@ async def test_aincremental_delete(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 3).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
loader.lazy_load(),
|
||||
await loader.alazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup="incremental",
|
||||
@@ -860,7 +883,9 @@ async def test_aindexing_with_no_docs(
|
||||
"""Check edge case when loader returns no new docs."""
|
||||
loader = ToyLoader(documents=[])
|
||||
|
||||
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
|
||||
) == {
|
||||
"num_added": 0,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 0,
|
||||
|
||||
@@ -97,22 +97,31 @@ def is_gemini_model(model_name: str) -> bool:
|
||||
|
||||
|
||||
def get_generation_info(
|
||||
candidate: Union[TextGenerationResponse, Candidate], is_gemini: bool
|
||||
candidate: Union[TextGenerationResponse, Candidate],
|
||||
is_gemini: bool,
|
||||
*,
|
||||
stream: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
if is_gemini:
|
||||
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
|
||||
return {
|
||||
info = {
|
||||
"is_blocked": any([rating.blocked for rating in candidate.safety_ratings]),
|
||||
"safety_ratings": [
|
||||
{
|
||||
"category": rating.category.name,
|
||||
"probability_label": rating.probability.name,
|
||||
"blocked": rating.blocked,
|
||||
}
|
||||
for rating in candidate.safety_ratings
|
||||
],
|
||||
"citation_metadata": candidate.citation_metadata,
|
||||
}
|
||||
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
|
||||
candidate_dc = dataclasses.asdict(candidate)
|
||||
candidate_dc.pop("text")
|
||||
return {k: v for k, v in candidate_dc.items() if not k.startswith("_")}
|
||||
else:
|
||||
info = dataclasses.asdict(candidate)
|
||||
info.pop("text")
|
||||
info = {k: v for k, v in info.items() if not k.startswith("_")}
|
||||
if stream:
|
||||
# Remove non-streamable types, like bools.
|
||||
info.pop("is_blocked")
|
||||
return info
|
||||
|
||||
@@ -315,10 +315,12 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
return result.total_tokens
|
||||
|
||||
def _response_to_generation(
|
||||
self, response: TextGenerationResponse
|
||||
self, response: TextGenerationResponse, *, stream: bool = False
|
||||
) -> GenerationChunk:
|
||||
"""Converts a stream response to a generation chunk."""
|
||||
generation_info = get_generation_info(response, self._is_gemini_model)
|
||||
generation_info = get_generation_info(
|
||||
response, self._is_gemini_model, stream=stream
|
||||
)
|
||||
try:
|
||||
text = response.text
|
||||
except AttributeError:
|
||||
@@ -401,7 +403,14 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
):
|
||||
chunk = self._response_to_generation(stream_resp)
|
||||
# Gemini models return GenerationResponse even when streaming, which has a
|
||||
# candidates field.
|
||||
stream_resp = (
|
||||
stream_resp
|
||||
if isinstance(stream_resp, TextGenerationResponse)
|
||||
else stream_resp.candidates[0]
|
||||
)
|
||||
chunk = self._response_to_generation(stream_resp, stream=True)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
|
||||
@@ -32,18 +32,33 @@ def test_vertex_initialization(model_name: str) -> None:
|
||||
"model_name",
|
||||
model_names_to_test_with_default,
|
||||
)
|
||||
def test_vertex_call(model_name: str) -> None:
|
||||
def test_vertex_invoke(model_name: str) -> None:
|
||||
llm = (
|
||||
VertexAI(model_name=model_name, temperature=0)
|
||||
if model_name
|
||||
else VertexAI(temperature=0.0)
|
||||
)
|
||||
output = llm("Say foo:")
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names_to_test_with_default,
|
||||
)
|
||||
def test_vertex_generate(model_name: str) -> None:
|
||||
llm = (
|
||||
VertexAI(model_name=model_name, temperature=0)
|
||||
if model_name
|
||||
else VertexAI(temperature=0.0)
|
||||
)
|
||||
output = llm.generate(["Say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="VertexAI doesn't always respect number of candidates")
|
||||
def test_vertex_generate() -> None:
|
||||
def test_vertex_generate_multiple_candidates() -> None:
|
||||
llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001")
|
||||
output = llm.generate(["Say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
Reference in New Issue
Block a user