Compare commits

..

29 Commits

Author SHA1 Message Date
jacoblee93
f90e665413 Lint 2024-01-28 20:46:46 -08:00
jacoblee93
fad076fa06 Lint 2024-01-28 20:43:54 -08:00
jacoblee93
59ffccf27d Fix lint 2024-01-28 20:40:39 -08:00
jacoblee93
ac85fca6f0 Switch to messages param 2024-01-28 20:28:00 -08:00
jacoblee93
f29ad020a0 Small tweak 2024-01-28 10:08:41 -08:00
jacoblee93
b67561890b Fix lint + tests 2024-01-28 10:07:21 -08:00
jacoblee93
b970bfe8da Make input param optional for retrieval chain and history aware retriever chain 2024-01-28 09:59:16 -08:00
Christophe Bornet
36e432672a community[minor]: Add async methods to AstraDBLoader (#16652) 2024-01-27 17:05:41 -08:00
William FH
38425c99d2 core[minor]: Image prompt template (#14263)
Builds on Bagatur's (#13227). See unit test for example usage (below)

```python
def test_chat_tmpl_from_messages_multipart_image() -> None:
    base64_image = "abcd123"
    other_base64_image = "abcd123"
    template = ChatPromptTemplate.from_messages(
        [
            ("system", "You are an AI assistant named {name}."),
            (
                "human",
                [
                    {"type": "text", "text": "What's in this image?"},
                    # OAI supports all these structures today
                    {
                        "type": "image_url",
                        "image_url": "data:image/jpeg;base64,{my_image}",
                    },
                    {
                        "type": "image_url",
                        "image_url": {"url": "data:image/jpeg;base64,{my_image}"},
                    },
                    {"type": "image_url", "image_url": "{my_other_image}"},
                    {
                        "type": "image_url",
                        "image_url": {"url": "{my_other_image}", "detail": "medium"},
                    },
                    {
                        "type": "image_url",
                        "image_url": {"url": "https://www.langchain.com/image.png"},
                    },
                    {
                        "type": "image_url",
                        "image_url": {"url": ""},
                    },
                ],
            ),
        ]
    )
    messages = template.format_messages(
        name="R2D2", my_image=base64_image, my_other_image=other_base64_image
    )
    expected = [
        SystemMessage(content="You are an AI assistant named R2D2."),
        HumanMessage(
            content=[
                {"type": "text", "text": "What's in this image?"},
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{other_base64_image}"
                    },
                },
                {
                    "type": "image_url",
                    "image_url": {"url": f"{other_base64_image}"},
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"{other_base64_image}",
                        "detail": "medium",
                    },
                },
                {
                    "type": "image_url",
                    "image_url": {"url": "https://www.langchain.com/image.png"},
                },
                {
                    "type": "image_url",
                    "image_url": {"url": ""},
                },
            ]
        ),
    ]
    assert messages == expected
```

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Brace Sproul <braceasproul@gmail.com>
2024-01-27 17:04:29 -08:00
ARKA1112
3c387bc12d docs: Error when importing packages from pydantic [docs] (#16564)
URL : https://python.langchain.com/docs/use_cases/extraction

Desc: 
<b> While the following statement executes successfully, it throws an
error which is described below when we use the imported packages</b>
 ```py 
from pydantic import BaseModel, Field, validator
```
Code: 
```python
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import (
    PromptTemplate,
)
from langchain_openai import OpenAI
from pydantic import BaseModel, Field, validator

# Define your desired data structure.
class Joke(BaseModel):
    setup: str = Field(description="question to set up a joke")
    punchline: str = Field(description="answer to resolve the joke")

    # You can add custom validation logic easily with Pydantic.
    @validator("setup")
    def question_ends_with_question_mark(cls, field):
        if field[-1] != "?":
            raise ValueError("Badly formed question!")
        return field
```

Error:
```md
PydanticUserError: The `field` and `config` parameters are not available
in Pydantic V2, please use the `info` parameter instead.

For further information visit
https://errors.pydantic.dev/2.5/u/validator-field-config-info
```

Solution:
Instead of doing:
```py
from pydantic import BaseModel, Field, validator
```
We should do:
```py
from langchain_core.pydantic_v1 import BaseModel, Field, validator
```
Thanks.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
2024-01-27 16:46:48 -08:00
Rashedul Hasan Rijul
481493dbce community[patch]: apply embedding functions during query if defined (#16646)
**Description:** This update ensures that the user-defined embedding
function specified during vector store creation is applied during
queries. Previously, even if a custom embedding function was defined at
the time of store creation, Bagel DB would default to using the standard
embedding function during query execution. This pull request addresses
this issue by consistently using the user-defined embedding function for
queries if one has been specified earlier.
2024-01-27 16:46:33 -08:00
Serena Ruan
f01fb47597 community[patch]: MLflowCallbackHandler -- Move textstat and spacy as optional dependency (#16657)
Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
2024-01-27 16:15:07 -08:00
Zhuoyun(John) Xu
508bde7f40 community[patch]: Ollama - Pass headers to post request in async method (#16660)
# Description
A previous PR (https://github.com/langchain-ai/langchain/pull/15881)
added option to pass headers to ollama endpoint, but headers are not
pass to the async method.
2024-01-27 16:11:32 -08:00
Leonid Ganeline
5e73603e8a docs: DeepInfra provider page update (#16665)
- added description, links
- consistent formatting
- added links to the example pages
2024-01-27 16:05:29 -08:00
João Carlos Ferra de Almeida
3e87b67a3c community[patch]: Add Cookie Support to Fetch Method (#16673)
- **Description:** This change allows the `_fetch` method in the
`WebBaseLoader` class to utilize cookies from an existing
`requests.Session`. It ensures that when the `fetch` method is used, any
cookies in the provided session are included in the request. This
enhancement maintains compatibility with existing functionality while
extending the utility of the `fetch` method for scenarios where cookie
persistence is necessary.
- **Issue:** Not applicable (new feature),
- **Dependencies:** Requires `aiohttp` and `requests` libraries (no new
dependencies introduced),
- **Twitter handle:** N/A

Co-authored-by: Joao Almeida <joao.almeida@mercedes-benz.io>
2024-01-27 16:03:53 -08:00
Daniel Erenrich
c314137f5b docs: Fix broken link in CONTRIBUTING.md (#16681)
- **Description:** link in CONTRIBUTING.md is broken
  - **Issue:** N/A
  - **Dependencies:** N/A
  - **Twitter handle:** @derenrich
2024-01-27 15:43:44 -08:00
Harrison Chase
27665e3546 [community] fix anthropic streaming (#16682) 2024-01-27 15:16:22 -08:00
Bagatur
5975bf39ec infra: delete old CI workflows (#16680) 2024-01-27 14:14:53 -08:00
Christophe Bornet
4915c3cd86 [Fix] Fix Cassandra Document loader default page content mapper (#16273)
We can't use `json.dumps` by default as many types returned by the
cassandra driver are not serializable. It's safer to use `str` and let
users define their own custom `page_content_mapper` if needed.
2024-01-27 11:23:02 -08:00
Nuno Campos
e86fd946c8 In stream_event and stream_log handle closed streams (#16661)
if eg. the stream iterator is interrupted then adding more events to the
send_stream will raise an exception that we should catch (and handle
where appropriate)

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
2024-01-27 08:09:29 -08:00
Jarod Stewart
0bc397957b docs: document Ionic Tool (#16649)
- **Description:** Documentation for the Ionic Tool. A shopping
assistant tool that effortlessly adds e-commerce capabilities to your
Agent.
2024-01-26 16:02:07 -08:00
Nuno Campos
52ccae3fb1 Accept message-like things in Chat models, LLMs and MessagesPlaceholder (#16418) 2024-01-26 15:44:28 -08:00
Seungwoo Ryu
570b4f8e66 docs: Update openai_tools.ipynb (#16618)
typo
2024-01-26 15:26:27 -08:00
Pasha
4e189cd89a community[patch]: youtube loader transcript format (#16625)
- **Description**: YoutubeLoader right now returns one document that
contains the entire transcript. I think it would be useful to add an
option to return multiple documents, where each document would contain
one line of transcript with the start time and duration in the metadata.
For example,
[AssemblyAIAudioTranscriptLoader](https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/document_loaders/assemblyai.py)
is implemented in a similar way, it allows you to choose between the
format to use for the document loader.
2024-01-26 15:26:09 -08:00
yin1991
a936472512 docs: Update documentation to use 'model_id' rather than 'model_name' to match actual API (#16615)
- **Description:** Replace 'model_name' with 'model_id' for accuracy 
- **Issue:**
[link-to-issue](https://github.com/langchain-ai/langchain/issues/16577)
  - **Dependencies:** 
  - **Twitter handle:**
2024-01-26 15:01:12 -08:00
Micah Parker
6543e585a5 community[patch]: Added support for Ollama's num_predict option in ChatOllama (#16633)
Just a simple default addition to the options payload for a ollama
generate call to support a max_new_tokens parameter.

Should fix issue: https://github.com/langchain-ai/langchain/issues/14715
2024-01-26 15:00:19 -08:00
Callum
6a75ef74ca docs: Fix typo in XML agent documentation (#16645)
This is a tiny PR that just replacer "moduels" with "modules" in the
documentation for XML agents.
2024-01-26 14:59:46 -08:00
baichuan-assistant
70ff54eace community[minor]: Add Baichuan Text Embedding Model and Baichuan Inc introduction (#16568)
- **Description:** Adding Baichuan Text Embedding Model and Baichuan Inc
introduction.

Baichuan Text Embedding ranks #1 in C-MTEB leaderboard:
https://huggingface.co/spaces/mteb/leaderboard

Co-authored-by: BaiChuanHelper <wintergyc@WinterGYCs-MacBook-Pro.local>
2024-01-26 12:57:26 -08:00
Bagatur
5b5115c408 google-vertexai[patch]: streaming bug (#16603)
Fixes errors seen here
https://github.com/langchain-ai/langchain/actions/runs/7661680517/job/20881556592#step:9:229
2024-01-26 09:45:34 -08:00
64 changed files with 1573 additions and 412 deletions

View File

@@ -13,7 +13,7 @@ There are many ways to contribute to LangChain. Here are some common ways people
- [**Documentation**](https://python.langchain.com/docs/contributing/documentation): Help improve our docs, including this one!
- [**Code**](https://python.langchain.com/docs/contributing/code): Help us write code, fix bugs, or improve our infrastructure.
- [**Integrations**](https://python.langchain.com/docs/contributing/integration): Help us integrate with your favorite vendors and tools.
- [**Integrations**](https://python.langchain.com/docs/contributing/integrations): Help us integrate with your favorite vendors and tools.
### 🚩GitHub Issues

View File

@@ -1,13 +0,0 @@
---
name: libs/cli Release
on:
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
jobs:
release:
uses:
./.github/workflows/_release.yml
with:
working-directory: libs/cli
secrets: inherit

View File

@@ -1,13 +0,0 @@
---
name: libs/community Release
on:
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
jobs:
release:
uses:
./.github/workflows/_release.yml
with:
working-directory: libs/community
secrets: inherit

View File

@@ -1,13 +0,0 @@
---
name: libs/core Release
on:
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
jobs:
release:
uses:
./.github/workflows/_release.yml
with:
working-directory: libs/core
secrets: inherit

View File

@@ -1,13 +0,0 @@
---
name: libs/experimental Release
on:
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
jobs:
release:
uses:
./.github/workflows/_release.yml
with:
working-directory: libs/experimental
secrets: inherit

View File

@@ -1,13 +0,0 @@
---
name: Experimental Test Release
on:
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
jobs:
release:
uses:
./.github/workflows/_test_release.yml
with:
working-directory: libs/experimental
secrets: inherit

View File

@@ -1,13 +0,0 @@
---
name: libs/core Release
on:
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
jobs:
release:
uses:
./.github/workflows/_release.yml
with:
working-directory: libs/core
secrets: inherit

View File

@@ -1,27 +0,0 @@
---
name: libs/langchain Release
on:
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
jobs:
release:
uses:
./.github/workflows/_release.yml
with:
working-directory: libs/langchain
secrets: inherit
# N.B.: It's possible that PyPI doesn't make the new release visible / available
# immediately after publishing. If that happens, the docker build might not
# create a new docker image for the new release, since it won't see it.
#
# If this ends up being a problem, add a check to the end of the `_release.yml`
# workflow that prevents the workflow from finishing until the new release
# is visible and installable on PyPI.
release-docker:
needs:
- release
uses:
./.github/workflows/langchain_release_docker.yml
secrets: inherit

View File

@@ -1,13 +0,0 @@
---
name: Test Release
on:
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
jobs:
release:
uses:
./.github/workflows/_test_release.yml
with:
working-directory: libs/langchain
secrets: inherit

View File

@@ -0,0 +1,13 @@
# Baichuan
>[Baichuan Inc.](https://www.baichuan-ai.com/) is a Chinese startup in the era of AGI, dedicated to addressing fundamental human needs: Efficiency, Health, and Happiness.
## Visit Us
Visit us at https://www.baichuan-ai.com/.
Register and get an API key if you are trying out our APIs.
## Baichuan Chat Model
An example is available at [example](/docs/integrations/chat/baichuan).
## Baichuan Text Embedding Model
An example is available at [example] (/docs/integrations/text_embedding/baichuan)

View File

@@ -1,45 +1,52 @@
# DeepInfra
This page covers how to use the DeepInfra ecosystem within LangChain.
>[DeepInfra](https://deepinfra.com/docs) allows us to run the
> [latest machine learning models](https://deepinfra.com/models) with ease.
> DeepInfra takes care of all the heavy lifting related to running, scaling and monitoring
> the models. Users can focus on your application and integrate the models with simple REST API calls.
>DeepInfra provides [examples](https://deepinfra.com/docs/advanced/langchain) of integration with LangChain.
This page covers how to use the `DeepInfra` ecosystem within `LangChain`.
It is broken into two parts: installation and setup, and then references to specific DeepInfra wrappers.
## Installation and Setup
- Get your DeepInfra api key from this link [here](https://deepinfra.com/).
- Get an DeepInfra api key and set it as an environment variable (`DEEPINFRA_API_TOKEN`)
## Available Models
DeepInfra provides a range of Open Source LLMs ready for deployment.
You can list supported models for
You can see supported models for
[text-generation](https://deepinfra.com/models?type=text-generation) and
[embeddings](https://deepinfra.com/models?type=embeddings).
google/flan\* models can be viewed [here](https://deepinfra.com/models?type=text2text-generation).
You can view a [list of request and response parameters](https://deepinfra.com/meta-llama/Llama-2-70b-chat-hf/api).
Chat models [follow openai api](https://deepinfra.com/meta-llama/Llama-2-70b-chat-hf/api?example=openai-http)
## Wrappers
### LLM
## LLM
There exists an DeepInfra LLM wrapper, which you can access with
See a [usage example](/docs/integrations/llms/deepinfra).
```python
from langchain_community.llms import DeepInfra
```
### Embeddings
## Embeddings
There is also an DeepInfra Embeddings wrapper, you can access with
See a [usage example](/docs/integrations/text_embedding/deepinfra).
```python
from langchain_community.embeddings import DeepInfraEmbeddings
```
### Chat Models
## Chat Models
There is a chat-oriented wrapper as well, accessible with
See a [usage example](/docs/integrations/chat/deepinfra).
```python
from langchain_community.chat_models import ChatDeepInfra

View File

@@ -0,0 +1,75 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Baichuan Text Embeddings\n",
"\n",
"As of today (Jan 25th, 2024) BaichuanTextEmbeddings ranks #1 in C-MTEB (Chinese Multi-Task Embedding Benchmark) leaderboard.\n",
"\n",
"Leaderboard (Under Overall -> Chinese section): https://huggingface.co/spaces/mteb/leaderboard\n",
"\n",
"Official Website: https://platform.baichuan-ai.com/docs/text-Embedding\n",
"An API-key is required to use this embedding model. You can get one by registering at https://platform.baichuan-ai.com/docs/text-Embedding.\n",
"BaichuanTextEmbeddings support 512 token window and preduces vectors with 1024 dimensions. \n",
"\n",
"Please NOTE that BaichuanTextEmbeddings only supports Chinese text embedding. Multi-language support is coming soon.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"from langchain_community.embeddings import BaichuanTextEmbeddings\n",
"\n",
"# Place your Baichuan API-key here.\n",
"embeddings = BaichuanTextEmbeddings(baichuan_api_key=\"sk-*\")\n",
"\n",
"text_1 = \"今天天气不错\"\n",
"text_2 = \"今天阳光很好\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(text_1)\n",
"query_result"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"doc_result = embeddings.embed_documents([text_1, text_2])\n",
"doc_result"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,160 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Ionic\n",
">[Ionic](https://www.ioniccommerce.com/) stands at the forefront of commerce innovation, offering a suite of APIs that serve as the backbone for AI assistants and their developers. With Ionic, you unlock a new realm of possibility where convenience and intelligence converge, enabling users to navigate purchases with unprecedented ease. Experience the synergy of human desire and AI capability, all through Ionic's seamless integration.\n",
"\n",
"By including an `IonicTool` in the list of tools provided to an Agent, you are effortlessly adding e-commerce capabilities to your Agent. For more documetation on setting up your Agent with Ionic, see the [Ionic documentation](https://docs.ioniccommerce.com/guides/langchain).\n",
"\n",
"This Jupyter Notebook demonstrates how to use the `Ionic` tool with an Agent.\n",
"\n",
"First, let's install the `ionic-langchain` package.\n",
"**The `ionic-langchain` package is maintained by the Ionic team, not the LangChain maintainers.**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"pip install ionic-langchain > /dev/null"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's create an `IonicTool` instance and initialize an Agent with the tool."
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"ExecuteTime": {
"end_time": "2024-01-24T17:33:11.755683Z",
"start_time": "2024-01-24T17:33:11.174044Z"
}
},
"outputs": [],
"source": [
"import os\n",
"\n",
"from ionic_langchain.tool import Ionic, IonicTool\n",
"from langchain import hub\n",
"from langchain.agents import AgentExecutor, Tool, create_react_agent\n",
"from langchain_openai import OpenAI\n",
"\n",
"open_ai_key = os.environ[\"OPENAI_API_KEY\"]\n",
"\n",
"llm = OpenAI(openai_api_key=open_ai_key, temperature=0.5)\n",
"\n",
"tools: list[Tool] = [IonicTool().tool()]\n",
"\n",
"prompt = hub.pull(\"hwchase17/react\") # the example prompt for create_react_agent\n",
"\n",
"agent = create_react_agent(\n",
" llm,\n",
" tools,\n",
" prompt=prompt,\n",
")\n",
"\n",
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can use the Agent to shop for products and get product information from Ionic."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"ExecuteTime": {
"end_time": "2024-01-24T17:34:31.257036Z",
"start_time": "2024-01-24T17:33:45.849440Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n",
"\u001B[32;1m\u001B[1;3m Since the user is looking for a specific product, we should use Ionic Commerce Shopping Tool to find and compare products.\n",
"Action: Ionic Commerce Shopping Tool\n",
"Action Input: 4K Monitor, 5, 100000, 1000000\u001B[0m\u001B[36;1m\u001B[1;3m[{'products': [{'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://goto.walmart.com/c/123456/568844/9383?veh=aff&sourceid=imp_000011112222333344&u=https%3A%2F%2Fwww.walmart.com%2Fip%2F118806626'}], 'merchant_name': 'Walmart', 'merchant_product_id': '118806626', 'name': 'ASUS ProArt Display PA32UCX-PK 32” 4K HDR Mini LED Monitor, 99% DCI-P3 99.5% Adobe RGB, DeltaE<1, 10-bit, IPS, Thunderbolt 3 USB-C HDMI DP, Calman Ready, Dolby Vision, 1200nits, w/ X-rite Calibrator', 'price': '$2299.00', 'status': 'available', 'thumbnail': 'https://i5.walmartimages.com/asr/5ddc6e4a-5197-4f08-b505-83551b541de3.fd51cbae2a4d88fb366f5880b41eef03.png?odnHeight=100&odnWidth=100&odnBg=ffffff', 'brand_name': 'ASUS', 'upc': '192876749388'}, {'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://www.amazon.com/dp/B0BHXNL922?tag=ioniccommer00-20&linkCode=osi&th=1&psc=1'}], 'merchant_name': 'Amazon', 'merchant_product_id': 'B0BHXNL922', 'name': 'LG Ultrafine™ OLED Monitor (27EQ850) 27 inch 4K UHD (3840 x 2160) OLED Pro Display with Adobe RGB 99%, DCI-P3 99%, 1M:1 Contrast Ratio, Hardware Calibration, Multi-Interface, USB Type-C™ (PD 90W)', 'price': '$1796.99', 'status': 'available', 'thumbnail': 'https://m.media-amazon.com/images/I/41VEl4V2U4L._SL160_.jpg', 'brand_name': 'LG', 'upc': None}, {'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://www.amazon.com/dp/B0BZR81SQG?tag=ioniccommer00-20&linkCode=osi&th=1&psc=1'}], 'merchant_name': 'Amazon', 'merchant_product_id': 'B0BZR81SQG', 'name': 'ASUS ROG Swift 38” 4K HDMI 2.1 HDR DSC Gaming Monitor (PG38UQ) - UHD (3840 x 2160), 144Hz, 1ms, Fast IPS, G-SYNC Compatible, Speakers, FreeSync Premium Pro, DisplayPort, DisplayHDR600, 98% DCI-P3', 'price': '$1001.42', 'status': 'available', 'thumbnail': 'https://m.media-amazon.com/images/I/41ULH0sb1zL._SL160_.jpg', 'brand_name': 'ASUS', 'upc': None}, {'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://www.amazon.com/dp/B0BBSV1LK5?tag=ioniccommer00-20&linkCode=osi&th=1&psc=1'}], 'merchant_name': 'Amazon', 'merchant_product_id': 'B0BBSV1LK5', 'name': 'ASUS ROG Swift 41.5\" 4K OLED 138Hz 0.1ms Gaming Monitor PG42UQ', 'price': '$1367.09', 'status': 'available', 'thumbnail': 'https://m.media-amazon.com/images/I/51ZM41brvHL._SL160_.jpg', 'brand_name': 'ASUS', 'upc': None}, {'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://www.amazon.com/dp/B07K8877Y5?tag=ioniccommer00-20&linkCode=osi&th=1&psc=1'}], 'merchant_name': 'Amazon', 'merchant_product_id': 'B07K8877Y5', 'name': 'LG 32UL950-W 32\" Class Ultrafine 4K UHD LED Monitor with Thunderbolt 3 Connectivity Silver (31.5\" Display)', 'price': '$1149.33', 'status': 'available', 'thumbnail': 'https://m.media-amazon.com/images/I/41Q2OE2NnDL._SL160_.jpg', 'brand_name': 'LG', 'upc': None}], 'query': {'query': '4K Monitor', 'max_price': 1000000, 'min_price': 100000, 'num_results': 5}}]\u001B[0m\u001B[32;1m\u001B[1;3m Since the results are in cents, we should convert them back to dollars before displaying the results to the user.\n",
"Action: Convert prices to dollars\n",
"Action Input: [{'products': [{'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://goto.walmart.com/c/123456/568844/9383?veh=aff&sourceid=imp_000011112222333344&u=https%3A%2F%2Fwww.walmart.com%2Fip%2F118806626'}], 'merchant_name': 'Walmart', 'merchant_product_id': '118806626', 'name': 'ASUS ProArt Display PA32UCX-PK 32” 4K HDR Mini LED Monitor, 99% DCI-P3 99.5% Adobe RGB, DeltaE<1, 10-bit, IPS, Thunderbolt 3 USB-C HDMI DP, Calman Ready, Dolby Vision, 1200nits, w/ X-rite Calibrator', 'price': '$2299.00', 'status': 'available', 'thumbnail': 'https://i5.walmartimages.com/asr/5ddc6e4a-5197\u001B[0mConvert prices to dollars is not a valid tool, try one of [Ionic Commerce Shopping Tool].\u001B[32;1m\u001B[1;3m The results are in a list format, we should display them to the user in a more readable format.\n",
"Action: Display results in readable format\n",
"Action Input: [{'products': [{'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://goto.walmart.com/c/123456/568844/9383?veh=aff&sourceid=imp_000011112222333344&u=https%3A%2F%2Fwww.walmart.com%2Fip%2F118806626'}], 'merchant_name': 'Walmart', 'merchant_product_id': '118806626', 'name': 'ASUS ProArt Display PA32UCX-PK 32” 4K HDR Mini LED Monitor, 99% DCI-P3 99.5% Adobe RGB, DeltaE<1, 10-bit, IPS, Thunderbolt 3 USB-C HDMI DP, Calman Ready, Dolby Vision, 1200nits, w/ X-rite Calibrator', 'price': '$2299.00', 'status': 'available', 'thumbnail': 'https://i5.walmartimages.com/asr/5ddc6e4a-5197\u001B[0mDisplay results in readable format is not a valid tool, try one of [Ionic Commerce Shopping Tool].\u001B[32;1m\u001B[1;3m We should check if the user is satisfied with the results or if they have additional requirements.\n",
"Action: Check user satisfaction\n",
"Action Input: None\u001B[0mCheck user satisfaction is not a valid tool, try one of [Ionic Commerce Shopping Tool].\u001B[32;1m\u001B[1;3m I now know the final answer\n",
"Final Answer: The final answer is [{'products': [{'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://goto.walmart.com/c/123456/568844/9383?veh=aff&sourceid=imp_000011112222333344&u=https%3A%2F%2Fwww.walmart.com%2Fip%2F118806626'}], 'merchant_name': 'Walmart', 'merchant_product_id': '118806626', 'name': 'ASUS ProArt Display PA32UCX-PK 32” 4K HDR Mini LED Monitor, 99% DCI-P3 99.5% Adobe RGB, DeltaE<1, 10-bit, IPS, Thunderbolt 3 USB-C HDMI DP, Calman Ready, Dolby Vision, 1200nits, w/ X-rite Calibrator', 'price': '$2299.00', 'status': 'available', 'thumbnail': 'https://i5.walmartimages.com/asr/5ddc6e4a-5197-4f08-b505-83551b541de3.fd51cbae2\u001B[0m\n",
"\n",
"\u001B[1m> Finished chain.\u001B[0m\n"
]
},
{
"data": {
"text/plain": "{'input': \"I'm looking for a new 4K Monitor with 1000R under $1000\",\n 'output': \"The final answer is [{'products': [{'links': [{'text': 'Details', 'type': 'pdp', 'url': 'https://goto.walmart.com/c/123456/568844/9383?veh=aff&sourceid=imp_000011112222333344&u=https%3A%2F%2Fwww.walmart.com%2Fip%2F118806626'}], 'merchant_name': 'Walmart', 'merchant_product_id': '118806626', 'name': 'ASUS ProArt Display PA32UCX-PK 32” 4K HDR Mini LED Monitor, 99% DCI-P3 99.5% Adobe RGB, DeltaE<1, 10-bit, IPS, Thunderbolt 3 USB-C HDMI DP, Calman Ready, Dolby Vision, 1200nits, w/ X-rite Calibrator', 'price': '$2299.00', 'status': 'available', 'thumbnail': 'https://i5.walmartimages.com/asr/5ddc6e4a-5197-4f08-b505-83551b541de3.fd51cbae2\"}"
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input = \"I'm looking for a new 4K Monitor under $1000\"\n",
"\n",
"agent_executor.invoke({\"input\": input})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "f85209c3c4c190dca7367d6a1e623da50a9a4392fd53313a7cf9d4bda9c4b85b"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -167,9 +167,9 @@
],
"source": [
"%%time\n",
"URL = 'https://www.conseil-constitutionnel.fr/node/3850/pdf'\n",
"PDF = 'Déclaration_des_droits_de_l_homme_et_du_citoyen.pdf'\n",
"open(PDF, 'wb').write(requests.get(URL).content)"
"URL = \"https://www.conseil-constitutionnel.fr/node/3850/pdf\"\n",
"PDF = \"Déclaration_des_droits_de_l_homme_et_du_citoyen.pdf\"\n",
"open(PDF, \"wb\").write(requests.get(URL).content)"
]
},
{
@@ -208,7 +208,7 @@
],
"source": [
"%%time\n",
"print('Read a PDF...')\n",
"print(\"Read a PDF...\")\n",
"loader = PyPDFLoader(PDF)\n",
"pages = loader.load_and_split()\n",
"len(pages)"
@@ -252,12 +252,14 @@
],
"source": [
"%%time\n",
"print('Create a Vector Database from PDF text...')\n",
"embeddings = OpenAIEmbeddings(model='text-embedding-ada-002')\n",
"print(\"Create a Vector Database from PDF text...\")\n",
"embeddings = OpenAIEmbeddings(model=\"text-embedding-ada-002\")\n",
"texts = [p.page_content for p in pages]\n",
"metadata = pd.DataFrame(index=list(range(len(texts))))\n",
"metadata['tag'] = 'law'\n",
"metadata['title'] = 'Déclaration des Droits de l\\'Homme et du Citoyen de 1789'.encode('utf-8')\n",
"metadata[\"tag\"] = \"law\"\n",
"metadata[\"title\"] = \"Déclaration des Droits de l'Homme et du Citoyen de 1789\".encode(\n",
" \"utf-8\"\n",
")\n",
"vectordb = KDBAI(table, embeddings)\n",
"vectordb.add_texts(texts=texts, metadatas=metadata)"
]
@@ -288,11 +290,13 @@
],
"source": [
"%%time\n",
"print('Create LangChain Pipeline...')\n",
"qabot = RetrievalQA.from_chain_type(chain_type='stuff',\n",
" llm=ChatOpenAI(model='gpt-3.5-turbo-16k', temperature=TEMP), \n",
" retriever=vectordb.as_retriever(search_kwargs=dict(k=K)),\n",
" return_source_documents=True)"
"print(\"Create LangChain Pipeline...\")\n",
"qabot = RetrievalQA.from_chain_type(\n",
" chain_type=\"stuff\",\n",
" llm=ChatOpenAI(model=\"gpt-3.5-turbo-16k\", temperature=TEMP),\n",
" retriever=vectordb.as_retriever(search_kwargs=dict(k=K)),\n",
" return_source_documents=True,\n",
")"
]
},
{
@@ -325,9 +329,9 @@
],
"source": [
"%%time\n",
"Q = 'Summarize the document in English:'\n",
"print(f'\\n\\n{Q}\\n')\n",
"print(qabot.invoke(dict(query=Q))['result'])"
"Q = \"Summarize the document in English:\"\n",
"print(f\"\\n\\n{Q}\\n\")\n",
"print(qabot.invoke(dict(query=Q))[\"result\"])"
]
},
{
@@ -362,9 +366,9 @@
],
"source": [
"%%time\n",
"Q = 'Is it a fair law and why ?'\n",
"print(f'\\n\\n{Q}\\n')\n",
"print(qabot.invoke(dict(query=Q))['result'])"
"Q = \"Is it a fair law and why ?\"\n",
"print(f\"\\n\\n{Q}\\n\")\n",
"print(qabot.invoke(dict(query=Q))[\"result\"])"
]
},
{
@@ -414,9 +418,9 @@
],
"source": [
"%%time\n",
"Q = 'What are the rights and duties of the man, the citizen and the society ?'\n",
"print(f'\\n\\n{Q}\\n')\n",
"print(qabot.invoke(dict(query=Q))['result'])"
"Q = \"What are the rights and duties of the man, the citizen and the society ?\"\n",
"print(f\"\\n\\n{Q}\\n\")\n",
"print(qabot.invoke(dict(query=Q))[\"result\"])"
]
},
{
@@ -441,9 +445,9 @@
],
"source": [
"%%time\n",
"Q = 'Is this law practical ?'\n",
"print(f'\\n\\n{Q}\\n')\n",
"print(qabot.invoke(dict(query=Q))['result'])"
"Q = \"Is this law practical ?\"\n",
"print(f\"\\n\\n{Q}\\n\")\n",
"print(qabot.invoke(dict(query=Q))[\"result\"])"
]
},
{

View File

@@ -19,7 +19,7 @@
"\n",
"Newer OpenAI models have been fine-tuned to detect when **one or more** function(s) should be called and respond with the inputs that should be passed to the function(s). In an API call, you can describe functions and have the model intelligently choose to output a JSON object containing arguments to call these functions. The goal of the OpenAI tools APIs is to more reliably return valid and useful function calls than what can be done using a generic text completion or chat API.\n",
"\n",
"OpenAI termed the capability to invoke a **single** function as **functions**, and the capability to invoke **one or more** funcitons as **tools**.\n",
"OpenAI termed the capability to invoke a **single** function as **functions**, and the capability to invoke **one or more** functions as **tools**.\n",
"\n",
":::tip\n",
"\n",

View File

@@ -23,7 +23,7 @@
"\n",
"* Use with regular LLMs, not with chat models.\n",
"* Use only with unstructured tools; i.e., tools that accept a single string input.\n",
"* See [AgentTypes](/docs/moduels/agents/agent_types/) documentation for more agent types.\n",
"* See [AgentTypes](/docs/modules/agents/agent_types/) documentation for more agent types.\n",
":::"
]
},

View File

@@ -430,7 +430,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"id": "64650362",
"metadata": {},
"outputs": [
@@ -452,8 +452,8 @@
"from langchain.prompts import (\n",
" PromptTemplate,\n",
")\n",
"from langchain_core.pydantic_v1 import BaseModel, Field, validator\n",
"from langchain_openai import OpenAI\n",
"from pydantic import BaseModel, Field, validator\n",
"\n",
"\n",
"class Person(BaseModel):\n",
@@ -531,8 +531,8 @@
"from langchain.prompts import (\n",
" PromptTemplate,\n",
")\n",
"from langchain_core.pydantic_v1 import BaseModel, Field, validator\n",
"from langchain_openai import OpenAI\n",
"from pydantic import BaseModel, Field, validator\n",
"\n",
"\n",
"# Define your desired data structure.\n",

View File

@@ -94,15 +94,19 @@ def analyze_text(
files serialized to HTML string.
"""
resp: Dict[str, Any] = {}
textstat = import_textstat()
spacy = import_spacy()
text_complexity_metrics = {
key: getattr(textstat, key)(text) for key in get_text_complexity_metrics()
}
resp.update({"text_complexity_metrics": text_complexity_metrics})
resp.update(text_complexity_metrics)
try:
textstat = import_textstat()
except ImportError:
pass
else:
text_complexity_metrics = {
key: getattr(textstat, key)(text) for key in get_text_complexity_metrics()
}
resp.update({"text_complexity_metrics": text_complexity_metrics})
resp.update(text_complexity_metrics)
if nlp is not None:
spacy = import_spacy()
doc = nlp(text)
dep_out = spacy.displacy.render( # type: ignore
@@ -279,9 +283,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) -> None:
"""Initialize callback handler."""
import_pandas()
import_textstat()
import_mlflow()
spacy = import_spacy()
super().__init__()
self.name = name
@@ -303,14 +305,19 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
)
self.action_records: list = []
self.nlp = None
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
logger.warning(
"Run `python -m spacy download en_core_web_sm` "
"to download en_core_web_sm model for text visualization."
)
self.nlp = None
spacy = import_spacy()
except ImportError:
pass
else:
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
logger.warning(
"Run `python -m spacy download en_core_web_sm` "
"to download en_core_web_sm model for text visualization."
)
self.metrics = {key: 0 for key in mlflow_callback_metrics()}

View File

@@ -142,9 +142,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
stream_resp = self.client.completions.create(**params, stream=True)
for data in stream_resp:
delta = data.completion
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager:
run_manager.on_llm_new_token(delta)
run_manager.on_llm_new_token(delta, chunk=chunk)
async def _astream(
self,
@@ -161,9 +162,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
stream_resp = await self.async_client.completions.create(**params, stream=True)
async for data in stream_resp:
delta = data.completion
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager:
await run_manager.on_llm_new_token(delta)
await run_manager.on_llm_new_token(delta, chunk=chunk)
def _generate(
self,

View File

@@ -59,6 +59,7 @@ from langchain_community.document_loaders.blob_loaders import (
from langchain_community.document_loaders.blockchain import BlockchainDocumentLoader
from langchain_community.document_loaders.brave_search import BraveSearchLoader
from langchain_community.document_loaders.browserless import BrowserlessLoader
from langchain_community.document_loaders.cassandra import CassandraLoader
from langchain_community.document_loaders.chatgpt import ChatGPTLoader
from langchain_community.document_loaders.chromium import AsyncChromiumLoader
from langchain_community.document_loaders.college_confidential import (
@@ -267,6 +268,7 @@ __all__ = [
"BlockchainDocumentLoader",
"BraveSearchLoader",
"BrowserlessLoader",
"CassandraLoader",
"CSVLoader",
"ChatGPTLoader",
"CoNLLULoader",

View File

@@ -2,12 +2,24 @@ import json
import logging
import threading
from queue import Queue
from typing import Any, Callable, Dict, Iterator, List, Optional
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
)
from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseLoader
if TYPE_CHECKING:
from astrapy.db import AstraDB, AsyncAstraDB
logger = logging.getLogger(__name__)
@@ -19,7 +31,8 @@ class AstraDBLoader(BaseLoader):
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
astra_db_client: Optional["AstraDB"] = None,
async_astra_db_client: Optional["AsyncAstraDB"] = None,
namespace: Optional[str] = None,
filter_criteria: Optional[Dict[str, Any]] = None,
projection: Optional[Dict[str, Any]] = None,
@@ -36,34 +49,60 @@ class AstraDBLoader(BaseLoader):
)
# Conflicting-arg checks:
if astra_db_client is not None:
if astra_db_client is not None or async_astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' to AstraDB if passing "
"'token' and 'api_endpoint'."
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
"AstraDB if passing 'token' and 'api_endpoint'."
)
self.collection_name = collection_name
self.filter = filter_criteria
self.projection = projection
self.find_options = find_options or {}
self.nb_prefetched = nb_prefetched
self.extraction_function = extraction_function
if astra_db_client is not None:
astra_db = astra_db_client
else:
astra_db = astra_db_client
async_astra_db = async_astra_db_client
if token and api_endpoint:
astra_db = AstraDB(
token=token,
api_endpoint=api_endpoint,
namespace=namespace,
)
self.collection = astra_db.collection(collection_name)
try:
from astrapy.db import AsyncAstraDB
async_astra_db = AsyncAstraDB(
token=token,
api_endpoint=api_endpoint,
namespace=namespace,
)
except (ImportError, ModuleNotFoundError):
pass
if not astra_db and not async_astra_db:
raise ValueError(
"Must provide 'astra_db_client' or 'async_astra_db_client' or 'token' "
"and 'api_endpoint'"
)
self.collection = astra_db.collection(collection_name) if astra_db else None
if async_astra_db:
from astrapy.db import AsyncAstraDBCollection
self.async_collection = AsyncAstraDBCollection(
astra_db=async_astra_db, collection_name=collection_name
)
else:
self.async_collection = None
def load(self) -> List[Document]:
"""Eagerly load the content."""
return list(self.lazy_load())
def lazy_load(self) -> Iterator[Document]:
if not self.collection:
raise ValueError("Missing AstraDB client")
queue = Queue(self.nb_prefetched)
t = threading.Thread(target=self.fetch_results, args=(queue,))
t.start()
@@ -74,6 +113,29 @@ class AstraDBLoader(BaseLoader):
yield doc
t.join()
async def aload(self) -> List[Document]:
"""Load data into Document objects."""
return [doc async for doc in self.alazy_load()]
async def alazy_load(self) -> AsyncIterator[Document]:
if not self.async_collection:
raise ValueError("Missing AsyncAstraDB client")
async for doc in self.async_collection.paginated_find(
filter=self.filter,
options=self.find_options,
projection=self.projection,
sort=None,
prefetched=True,
):
yield Document(
page_content=self.extraction_function(doc),
metadata={
"namespace": self.async_collection.astra_db.namespace,
"api_endpoint": self.async_collection.astra_db.base_url,
"collection": self.collection_name,
},
)
def fetch_results(self, queue: Queue):
self.fetch_page_result(queue)
while self.find_options.get("pageState"):

View File

@@ -2,10 +2,9 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Optional
from typing import TYPE_CHECKING, Iterator, List, Optional
from langchain_core.documents import Document
from langchain_core.runnables import run_in_executor
from langchain_community.document_loaders.blob_loaders import Blob
@@ -53,22 +52,14 @@ class BaseLoader(ABC):
# Attention: This method will be upgraded into an abstractmethod once it's
# implemented in all the existing subclasses.
def lazy_load(self) -> Iterator[Document]:
def lazy_load(
self,
) -> Iterator[Document]:
"""A lazy loader for Documents."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement lazy_load()"
)
async def alazy_load(self) -> AsyncIterator[Document]:
"""A lazy loader for Documents."""
iterator = await run_in_executor(None, self.lazy_load)
done = object()
while True:
doc = await run_in_executor(None, next, iterator, done)
if doc is done:
break
yield doc
class BaseBlobParser(ABC):
"""Abstract interface for blob parsers.

View File

@@ -1,4 +1,3 @@
import json
from typing import (
TYPE_CHECKING,
Any,
@@ -14,13 +13,6 @@ from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseLoader
def default_page_content_mapper(row: Any) -> str:
if hasattr(row, "_asdict"):
return json.dumps(row._asdict())
return json.dumps(row)
_NOT_SET = object()
if TYPE_CHECKING:
@@ -36,7 +28,7 @@ class CassandraLoader(BaseLoader):
session: Optional["Session"] = None,
keyspace: Optional[str] = None,
query: Optional[Union[str, "Statement"]] = None,
page_content_mapper: Callable[[Any], str] = default_page_content_mapper,
page_content_mapper: Callable[[Any], str] = str,
metadata_mapper: Callable[[Any], dict] = lambda _: {},
*,
query_parameters: Union[dict, Sequence] = None,
@@ -61,6 +53,7 @@ class CassandraLoader(BaseLoader):
query: The query used to load the data.
(do not use together with the table parameter)
page_content_mapper: a function to convert a row to string page content.
Defaults to the str representation of the row.
query_parameters: The query parameters used when calling session.execute .
query_timeout: The query timeout used when calling session.execute .
query_custom_payload: The query custom_payload used when calling

View File

@@ -1,4 +1,4 @@
from typing import AsyncIterator, Iterator, List
from typing import Iterator, List
from langchain_core.documents import Document
@@ -26,9 +26,3 @@ class MergedDataLoader(BaseLoader):
def load(self) -> List[Document]:
"""Load docs."""
return list(self.lazy_load())
async def alazy_load(self) -> AsyncIterator[Document]:
"""Lazy load docs from each individual loader."""
for loader in self.loaders:
async for document in loader.alazy_load():
yield document

View File

@@ -132,6 +132,7 @@ class WebBaseLoader(BaseLoader):
url,
headers=self.session.headers,
ssl=None if self.session.verify else False,
cookies=self.session.cookies.get_dict(),
) as response:
return await response.text()
except aiohttp.ClientConnectionError as e:

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import logging
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Union
from urllib.parse import parse_qs, urlparse
@@ -139,6 +140,11 @@ def _parse_video_id(url: str) -> Optional[str]:
return video_id
class TranscriptFormat(Enum):
TEXT = "text"
LINES = "lines"
class YoutubeLoader(BaseLoader):
"""Load `YouTube` transcripts."""
@@ -148,6 +154,7 @@ class YoutubeLoader(BaseLoader):
add_video_info: bool = False,
language: Union[str, Sequence[str]] = "en",
translation: Optional[str] = None,
transcript_format: TranscriptFormat = TranscriptFormat.TEXT,
continue_on_failure: bool = False,
):
"""Initialize with YouTube video ID."""
@@ -159,6 +166,7 @@ class YoutubeLoader(BaseLoader):
else:
self.language = language
self.translation = translation
self.transcript_format = transcript_format
self.continue_on_failure = continue_on_failure
@staticmethod
@@ -214,9 +222,19 @@ class YoutubeLoader(BaseLoader):
transcript_pieces = transcript.fetch()
transcript = " ".join([t["text"].strip(" ") for t in transcript_pieces])
return [Document(page_content=transcript, metadata=metadata)]
if self.transcript_format == TranscriptFormat.TEXT:
transcript = " ".join([t["text"].strip(" ") for t in transcript_pieces])
return [Document(page_content=transcript, metadata=metadata)]
elif self.transcript_format == TranscriptFormat.LINES:
return [
Document(
page_content=t["text"].strip(" "),
metadata=dict((key, t[key]) for key in t if key != "text"),
)
for t in transcript_pieces
]
else:
raise ValueError("Unknown transcript format.")
def _get_video_info(self) -> dict:
"""Get important video information.

View File

@@ -20,6 +20,7 @@ from langchain_community.embeddings.aleph_alpha import (
)
from langchain_community.embeddings.awa import AwaEmbeddings
from langchain_community.embeddings.azure_openai import AzureOpenAIEmbeddings
from langchain_community.embeddings.baichuan import BaichuanTextEmbeddings
from langchain_community.embeddings.baidu_qianfan_endpoint import (
QianfanEmbeddingsEndpoint,
)
@@ -92,6 +93,7 @@ logger = logging.getLogger(__name__)
__all__ = [
"OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
"BaichuanTextEmbeddings",
"ClarifaiEmbeddings",
"CohereEmbeddings",
"DatabricksEmbeddings",

View File

@@ -0,0 +1,113 @@
from typing import Any, Dict, List, Optional
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
BAICHUAN_API_URL: str = "http://api.baichuan-ai.com/v1/embeddings"
# BaichuanTextEmbeddings is an embedding model provided by Baichuan Inc. (https://www.baichuan-ai.com/home).
# As of today (Jan 25th, 2024) BaichuanTextEmbeddings ranks #1 in C-MTEB
# (Chinese Multi-Task Embedding Benchmark) leaderboard.
# Leaderboard (Under Overall -> Chinese section): https://huggingface.co/spaces/mteb/leaderboard
# Official Website: https://platform.baichuan-ai.com/docs/text-Embedding
# An API-key is required to use this embedding model. You can get one by registering
# at https://platform.baichuan-ai.com/docs/text-Embedding.
# BaichuanTextEmbeddings support 512 token window and preduces vectors with
# 1024 dimensions.
# NOTE!! BaichuanTextEmbeddings only supports Chinese text embedding.
# Multi-language support is coming soon.
class BaichuanTextEmbeddings(BaseModel, Embeddings):
"""Baichuan Text Embedding models."""
session: Any #: :meta private:
model_name: str = "Baichuan-Text-Embedding"
baichuan_api_key: Optional[SecretStr] = None
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that auth token exists in environment."""
try:
baichuan_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "baichuan_api_key", "BAICHUAN_API_KEY")
)
except ValueError as original_exc:
try:
baichuan_api_key = convert_to_secret_str(
get_from_dict_or_env(
values, "baichuan_auth_token", "BAICHUAN_AUTH_TOKEN"
)
)
except ValueError:
raise original_exc
session = requests.Session()
session.headers.update(
{
"Authorization": f"Bearer {baichuan_api_key.get_secret_value()}",
"Accept-Encoding": "identity",
"Content-type": "application/json",
}
)
values["session"] = session
return values
def _embed(self, texts: List[str]) -> Optional[List[List[float]]]:
"""Internal method to call Baichuan Embedding API and return embeddings.
Args:
texts: A list of texts to embed.
Returns:
A list of list of floats representing the embeddings, or None if an
error occurs.
"""
try:
response = self.session.post(
BAICHUAN_API_URL, json={"input": texts, "model": self.model_name}
)
# Check if the response status code indicates success
if response.status_code == 200:
resp = response.json()
embeddings = resp.get("data", [])
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e.get("index", 0))
# Return just the embeddings
return [result.get("embedding", []) for result in sorted_embeddings]
else:
# Log error or handle unsuccessful response appropriately
print(
f"""Error: Received status code {response.status_code} from
embedding API"""
)
return None
except Exception as e:
# Log the exception or handle it as needed
print(f"Exception occurred while trying to get embeddings: {str(e)}")
return None
def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]:
"""Public method to get embeddings for a list of documents.
Args:
texts: The list of texts to embed.
Returns:
A list of embeddings, one for each text, or None if an error occurs.
"""
return self._embed(texts)
def embed_query(self, text: str) -> Optional[List[float]]:
"""Public method to get embedding for a single query text.
Args:
text: The text to embed.
Returns:
Embeddings for the text, or None if an error occurs.
"""
result = self._embed([text])
return result[0] if result is not None else None

View File

@@ -71,9 +71,9 @@ class SelfHostedHuggingFaceEmbeddings(SelfHostedEmbeddings):
from langchain_community.embeddings import SelfHostedHuggingFaceEmbeddings
import runhouse as rh
model_name = "sentence-transformers/all-mpnet-base-v2"
model_id = "sentence-transformers/all-mpnet-base-v2"
gpu = rh.cluster(name="rh-a10x", instance_type="A100:1")
hf = SelfHostedHuggingFaceEmbeddings(model_name=model_name, hardware=gpu)
hf = SelfHostedHuggingFaceEmbeddings(model_id=model_id, hardware=gpu)
"""
client: Any #: :meta private:

View File

@@ -64,6 +64,10 @@ class _OllamaCommon(BaseLanguageModel):
It is recommended to set this value to the number of physical
CPU cores your system has (as opposed to the logical number of cores)."""
num_predict: Optional[int] = None
"""Maximum number of tokens to predict when generating text.
(Default: 128, -1 = infinite generation, -2 = fill context)"""
repeat_last_n: Optional[int] = None
"""Sets how far back for the model to look back to prevent
repetition. (Default: 64, 0 = disabled, -1 = num_ctx)"""
@@ -126,6 +130,7 @@ class _OllamaCommon(BaseLanguageModel):
"num_ctx": self.num_ctx,
"num_gpu": self.num_gpu,
"num_thread": self.num_thread,
"num_predict": self.num_predict,
"repeat_last_n": self.repeat_last_n,
"repeat_penalty": self.repeat_penalty,
"temperature": self.temperature,
@@ -279,7 +284,10 @@ class _OllamaCommon(BaseLanguageModel):
async with aiohttp.ClientSession() as session:
async with session.post(
url=api_url,
headers={"Content-Type": "application/json"},
headers={
"Content-Type": "application/json",
**(self.headers if isinstance(self.headers, dict) else {}),
},
json=request_payload,
timeout=self.timeout,
) as response:

View File

@@ -109,6 +109,12 @@ class Bagel(VectorStore):
import bagel # noqa: F401
except ImportError:
raise ImportError("Please install bagel `pip install betabageldb`.")
if self._embedding_function and query_embeddings is None and query_texts:
texts = list(query_texts)
query_embeddings = self._embedding_function.embed_documents(texts)
query_texts = None
return self._cluster.find(
query_texts=query_texts,
query_embeddings=query_embeddings,

View File

@@ -13,11 +13,18 @@ Required to run this test:
import json
import os
import uuid
from typing import TYPE_CHECKING
import pytest
from langchain_community.document_loaders.astradb import AstraDBLoader
if TYPE_CHECKING:
from astrapy.db import (
AstraDBCollection,
AsyncAstraDBCollection,
)
ASTRA_DB_APPLICATION_TOKEN = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
ASTRA_DB_API_ENDPOINT = os.getenv("ASTRA_DB_API_ENDPOINT")
ASTRA_DB_KEYSPACE = os.getenv("ASTRA_DB_KEYSPACE")
@@ -28,7 +35,7 @@ def _has_env_vars() -> bool:
@pytest.fixture
def astra_db_collection():
def astra_db_collection() -> "AstraDBCollection":
from astrapy.db import AstraDB
astra_db = AstraDB(
@@ -38,21 +45,41 @@ def astra_db_collection():
)
collection_name = f"lc_test_loader_{str(uuid.uuid4()).split('-')[0]}"
collection = astra_db.create_collection(collection_name)
collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20)
collection.insert_many(
[{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4
)
yield collection
astra_db.delete_collection(collection_name)
@pytest.fixture
async def async_astra_db_collection() -> "AsyncAstraDBCollection":
from astrapy.db import AsyncAstraDB
astra_db = AsyncAstraDB(
token=ASTRA_DB_APPLICATION_TOKEN,
api_endpoint=ASTRA_DB_API_ENDPOINT,
namespace=ASTRA_DB_KEYSPACE,
)
collection_name = f"lc_test_loader_{str(uuid.uuid4()).split('-')[0]}"
collection = await astra_db.create_collection(collection_name)
await collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20)
await collection.insert_many(
[{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4
)
yield collection
await astra_db.delete_collection(collection_name)
@pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
class TestAstraDB:
def test_astradb_loader(self, astra_db_collection) -> None:
astra_db_collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20)
astra_db_collection.insert_many(
[{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4
)
def test_astradb_loader(self, astra_db_collection: "AstraDBCollection") -> None:
loader = AstraDBLoader(
astra_db_collection.collection_name,
token=ASTRA_DB_APPLICATION_TOKEN,
@@ -79,9 +106,9 @@ class TestAstraDB:
"collection": astra_db_collection.collection_name,
}
def test_extraction_function(self, astra_db_collection) -> None:
astra_db_collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20)
def test_extraction_function(
self, astra_db_collection: "AstraDBCollection"
) -> None:
loader = AstraDBLoader(
astra_db_collection.collection_name,
token=ASTRA_DB_APPLICATION_TOKEN,
@@ -94,3 +121,51 @@ class TestAstraDB:
doc = next(docs)
assert doc.page_content == "bar"
async def test_astradb_loader_async(
self, async_astra_db_collection: "AsyncAstraDBCollection"
) -> None:
await async_astra_db_collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20)
await async_astra_db_collection.insert_many(
[{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4
)
loader = AstraDBLoader(
async_astra_db_collection.collection_name,
token=ASTRA_DB_APPLICATION_TOKEN,
api_endpoint=ASTRA_DB_API_ENDPOINT,
namespace=ASTRA_DB_KEYSPACE,
nb_prefetched=1,
projection={"foo": 1},
find_options={"limit": 22},
filter_criteria={"foo": "bar"},
)
docs = await loader.aload()
assert len(docs) == 22
ids = set()
for doc in docs:
content = json.loads(doc.page_content)
assert content["foo"] == "bar"
assert "baz" not in content
assert content["_id"] not in ids
ids.add(content["_id"])
assert doc.metadata == {
"namespace": async_astra_db_collection.astra_db.namespace,
"api_endpoint": async_astra_db_collection.astra_db.base_url,
"collection": async_astra_db_collection.collection_name,
}
async def test_extraction_function_async(
self, async_astra_db_collection: "AsyncAstraDBCollection"
) -> None:
loader = AstraDBLoader(
async_astra_db_collection.collection_name,
token=ASTRA_DB_APPLICATION_TOKEN,
api_endpoint=ASTRA_DB_API_ENDPOINT,
namespace=ASTRA_DB_KEYSPACE,
find_options={"limit": 30},
extraction_function=lambda x: x["foo"],
)
doc = await anext(loader.alazy_load())
assert doc.page_content == "bar"

View File

@@ -59,11 +59,11 @@ def test_loader_table(keyspace: str) -> None:
loader = CassandraLoader(table=CASSANDRA_TABLE)
assert loader.load() == [
Document(
page_content='{"row_id": "id1", "body_blob": "text1"}',
page_content="Row(row_id='id1', body_blob='text1')",
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
),
Document(
page_content='{"row_id": "id2", "body_blob": "text2"}',
page_content="Row(row_id='id2', body_blob='text2')",
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
),
]
@@ -74,8 +74,8 @@ def test_loader_query(keyspace: str) -> None:
query=f"SELECT body_blob FROM {keyspace}.{CASSANDRA_TABLE}"
)
assert loader.load() == [
Document(page_content='{"body_blob": "text1"}'),
Document(page_content='{"body_blob": "text2"}'),
Document(page_content="Row(body_blob='text1')"),
Document(page_content="Row(body_blob='text2')"),
]
@@ -103,7 +103,7 @@ def test_loader_metadata_mapper(keyspace: str) -> None:
loader = CassandraLoader(table=CASSANDRA_TABLE, metadata_mapper=mapper)
assert loader.load() == [
Document(
page_content='{"row_id": "id1", "body_blob": "text1"}',
page_content="Row(row_id='id1', body_blob='text1')",
metadata={
"table": CASSANDRA_TABLE,
"keyspace": keyspace,
@@ -111,7 +111,7 @@ def test_loader_metadata_mapper(keyspace: str) -> None:
},
),
Document(
page_content='{"row_id": "id2", "body_blob": "text2"}',
page_content="Row(row_id='id2', body_blob='text2')",
metadata={
"table": CASSANDRA_TABLE,
"keyspace": keyspace,

View File

@@ -0,0 +1,19 @@
"""Test Baichuan Text Embedding."""
from langchain_community.embeddings.baichuan import BaichuanTextEmbeddings
def test_baichuan_embedding_documents() -> None:
"""Test Baichuan Text Embedding for documents."""
documents = ["今天天气不错", "今天阳光灿烂"]
embedding = BaichuanTextEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 1024
def test_baichuan_embedding_query() -> None:
"""Test Baichuan Text Embedding for query."""
document = "所有的小学生都会学过只因兔同笼问题。"
embedding = BaichuanTextEmbeddings()
output = embedding.embed_query(document)
assert len(output) == 1024

View File

@@ -1,9 +1,9 @@
"""Test Base Schema of documents."""
from typing import Iterator, List
from typing import Iterator
from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseBlobParser, BaseLoader
from langchain_community.document_loaders.base import BaseBlobParser
from langchain_community.document_loaders.blob_loaders import Blob
@@ -27,25 +27,3 @@ def test_base_blob_parser() -> None:
docs = parser.parse(Blob(data="who?"))
assert len(docs) == 1
assert docs[0].page_content == "foo"
async def test_default_aload() -> None:
class FakeLoader(BaseLoader):
def load(self) -> List[Document]:
return list(self.lazy_load())
def lazy_load(self) -> Iterator[Document]:
yield from [
Document(page_content="foo"),
Document(page_content="bar")
]
loader = FakeLoader()
docs = loader.load()
assert docs == [Document(page_content="foo"), Document(page_content="bar")]
# Test that async lazy loading works
docs = [doc async for doc in loader.alazy_load()]
assert docs == [Document(page_content="foo"), Document(page_content="bar")]

View File

@@ -37,6 +37,7 @@ EXPECTED_ALL = [
"BlockchainDocumentLoader",
"BraveSearchLoader",
"BrowserlessLoader",
"CassandraLoader",
"CSVLoader",
"ChatGPTLoader",
"CoNLLULoader",

View File

@@ -3,6 +3,7 @@ from langchain_community.embeddings import __all__
EXPECTED_ALL = [
"OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
"BaichuanTextEmbeddings",
"ClarifaiEmbeddings",
"CohereEmbeddings",
"DatabricksEmbeddings",

View File

@@ -88,6 +88,7 @@ def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None:
"num_ctx": None,
"num_gpu": None,
"num_thread": None,
"num_predict": None,
"repeat_last_n": None,
"repeat_penalty": None,
"stop": [],
@@ -133,6 +134,7 @@ def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None:
"num_ctx": None,
"num_gpu": None,
"num_thread": None,
"num_predict": None,
"repeat_last_n": None,
"repeat_penalty": None,
"stop": [],

View File

@@ -16,7 +16,12 @@ from typing import (
from typing_extensions import TypeAlias
from langchain_core._api import deprecated
from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string
from langchain_core.messages import (
AnyMessage,
BaseMessage,
MessageLikeRepresentation,
get_buffer_string,
)
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables import Runnable, RunnableSerializable
from langchain_core.utils import get_pydantic_field_names
@@ -49,7 +54,7 @@ def _get_token_ids_default_method(text: str) -> List[int]:
return tokenizer.encode(text)
LanguageModelInput = Union[PromptValue, str, Sequence[BaseMessage]]
LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]]
LanguageModelOutput = Union[BaseMessage, str]
LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)

View File

@@ -34,6 +34,7 @@ from langchain_core.messages import (
BaseMessage,
BaseMessageChunk,
HumanMessage,
convert_to_messages,
message_chunk_to_message,
)
from langchain_core.outputs import (
@@ -144,7 +145,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
elif isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, Sequence):
return ChatPromptValue(messages=input)
return ChatPromptValue(messages=convert_to_messages(input))
else:
raise ValueError(
f"Invalid input type {type(input)}. "

View File

@@ -48,7 +48,12 @@ from langchain_core.callbacks import (
from langchain_core.globals import get_llm_cache
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
from langchain_core.load import dumpd
from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string
from langchain_core.messages import (
AIMessage,
BaseMessage,
convert_to_messages,
get_buffer_string,
)
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator, validator
@@ -210,7 +215,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
elif isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, Sequence):
return ChatPromptValue(messages=input)
return ChatPromptValue(messages=convert_to_messages(input))
else:
raise ValueError(
f"Invalid input type {type(input)}. "

View File

@@ -1,4 +1,4 @@
from typing import List, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.base import (
@@ -117,6 +117,110 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
)
MessageLikeRepresentation = Union[BaseMessage, Tuple[str, str], str, Dict[str, Any]]
def _create_message_from_message_type(
message_type: str,
content: str,
name: Optional[str] = None,
tool_call_id: Optional[str] = None,
**additional_kwargs: Any,
) -> BaseMessage:
"""Create a message from a message type and content string.
Args:
message_type: str the type of the message (e.g., "human", "ai", etc.)
content: str the content string.
Returns:
a message of the appropriate type.
"""
kwargs: Dict[str, Any] = {}
if name is not None:
kwargs["name"] = name
if tool_call_id is not None:
kwargs["tool_call_id"] = tool_call_id
if additional_kwargs:
kwargs["additional_kwargs"] = additional_kwargs # type: ignore[assignment]
if message_type in ("human", "user"):
message: BaseMessage = HumanMessage(content=content, **kwargs)
elif message_type in ("ai", "assistant"):
message = AIMessage(content=content, **kwargs)
elif message_type == "system":
message = SystemMessage(content=content, **kwargs)
elif message_type == "function":
message = FunctionMessage(content=content, **kwargs)
elif message_type == "tool":
message = ToolMessage(content=content, **kwargs)
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
f" 'user', 'ai', 'assistant', or 'system'."
)
return message
def _convert_to_message(
message: MessageLikeRepresentation,
) -> BaseMessage:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
- BaseMessagePromptTemplate
- BaseMessage
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
- dict: a message dict with role and content keys
- string: shorthand for ("human", template); e.g., "{user_input}"
Args:
message: a representation of a message in one of the supported formats
Returns:
an instance of a message or a message template
"""
if isinstance(message, BaseMessage):
_message = message
elif isinstance(message, str):
_message = _create_message_from_message_type("human", message)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
_message = _create_message_from_message_type(message_type_str, template)
elif isinstance(message, dict):
msg_kwargs = message.copy()
try:
msg_type = msg_kwargs.pop("role")
msg_content = msg_kwargs.pop("content")
except KeyError:
raise ValueError(
f"Message dict must contain 'role' and 'content' keys, got {message}"
)
_message = _create_message_from_message_type(
msg_type, msg_content, **msg_kwargs
)
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")
return _message
def convert_to_messages(
messages: Sequence[MessageLikeRepresentation],
) -> List[BaseMessage]:
"""Convert a sequence of messages to a list of messages.
Args:
messages: Sequence of messages to convert.
Returns:
List of messages (BaseMessages).
"""
return [_convert_to_message(m) for m in messages]
__all__ = [
"AIMessage",
"AIMessageChunk",
@@ -133,6 +237,7 @@ __all__ = [
"SystemMessageChunk",
"ToolMessage",
"ToolMessageChunk",
"convert_to_messages",
"get_buffer_string",
"message_chunk_to_message",
"messages_from_dict",

View File

@@ -3,6 +3,8 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Literal, Sequence
from typing_extensions import TypedDict
from langchain_core.load.serializable import Serializable
from langchain_core.messages import (
AnyMessage,
@@ -82,6 +84,30 @@ class ChatPromptValue(PromptValue):
return ["langchain", "prompts", "chat"]
class ImageURL(TypedDict, total=False):
detail: Literal["auto", "low", "high"]
"""Specifies the detail level of the image."""
url: str
"""Either a URL of the image or the base64 encoded image data."""
class ImagePromptValue(PromptValue):
"""Image prompt value."""
image_url: ImageURL
"""Prompt image."""
type: Literal["ImagePromptValue"] = "ImagePromptValue"
def to_string(self) -> str:
"""Return prompt as string."""
return self.image_url["url"]
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
return [HumanMessage(content=[self.image_url])]
class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts.
For use in external schemas."""

View File

@@ -8,10 +8,12 @@ from typing import (
Any,
Callable,
Dict,
Generic,
List,
Mapping,
Optional,
Type,
TypeVar,
Union,
)
@@ -30,7 +32,12 @@ if TYPE_CHECKING:
from langchain_core.documents import Document
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
FormatOutputType = TypeVar("FormatOutputType")
class BasePromptTemplate(
RunnableSerializable[Dict, PromptValue], Generic[FormatOutputType], ABC
):
"""Base class for all prompt templates, returning a prompt."""
input_variables: List[str]
@@ -142,7 +149,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
return {**partial_kwargs, **kwargs}
@abstractmethod
def format(self, **kwargs: Any) -> str:
def format(self, **kwargs: Any) -> FormatOutputType:
"""Format the prompt with the inputs.
Args:
@@ -210,7 +217,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
raise ValueError(f"{save_path} must be json or yaml")
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
"""Format a document into a string based on a prompt template.
First, this pulls information from the document from two sources:
@@ -236,7 +243,7 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
Example:
.. code-block:: python
from langchain_core import Document
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
doc = Document(page_content="This is a joke", metadata={"page": "1"})

View File

@@ -13,8 +13,10 @@ from typing import (
Set,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
cast,
overload,
)
@@ -27,12 +29,14 @@ from langchain_core.messages import (
ChatMessage,
HumanMessage,
SystemMessage,
convert_to_messages,
)
from langchain_core.messages.base import get_msg_title_repr
from langchain_core.prompt_values import ChatPromptValue, PromptValue
from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env
@@ -126,7 +130,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
f"variable {self.variable_name} should be a list of base messages, "
f"got {value}"
)
for v in value:
for v in convert_to_messages(value):
if not isinstance(v, BaseMessage):
raise ValueError(
f"variable {self.variable_name} should be a list of base messages,"
@@ -287,14 +291,153 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
)
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
_StringImageMessagePromptTemplateT = TypeVar(
"_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate"
)
class _TextTemplateParam(TypedDict, total=False):
text: Union[str, Dict]
class _ImageTemplateParam(TypedDict, total=False):
image_url: Union[str, Dict]
class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""
prompt: Union[
StringPromptTemplate, List[Union[StringPromptTemplate, ImagePromptTemplate]]
]
"""Prompt template."""
additional_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the prompt template."""
_msg_class: Type[BaseMessage]
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
@classmethod
def from_template(
cls: Type[_StringImageMessagePromptTemplateT],
template: Union[str, List[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template_format: str = "f-string",
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
"""Create a class from a string template.
Args:
template: a template.
template_format: format of the template.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
if isinstance(template, str):
prompt: Union[StringPromptTemplate, List] = PromptTemplate.from_template(
template, template_format=template_format
)
return cls(prompt=prompt, **kwargs)
elif isinstance(template, list):
prompt = []
for tmpl in template:
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
if isinstance(tmpl, str):
text: str = tmpl
else:
text = cast(_TextTemplateParam, tmpl)["text"] # type: ignore[assignment] # noqa: E501
prompt.append(
PromptTemplate.from_template(
text, template_format=template_format
)
)
elif isinstance(tmpl, dict) and "image_url" in tmpl:
img_template = cast(_ImageTemplateParam, tmpl)["image_url"]
if isinstance(img_template, str):
vars = get_template_variables(img_template, "f-string")
if vars:
if len(vars) > 1:
raise ValueError(
"Only one format variable allowed per image"
f" template.\nGot: {vars}"
f"\nFrom: {tmpl}"
)
input_variables = [vars[0]]
else:
input_variables = None
img_template = {"url": img_template}
img_template_obj = ImagePromptTemplate(
input_variables=input_variables, template=img_template
)
elif isinstance(img_template, dict):
img_template = dict(img_template)
if "url" in img_template:
input_variables = get_template_variables(
img_template["url"], "f-string"
)
else:
input_variables = None
img_template_obj = ImagePromptTemplate(
input_variables=input_variables, template=img_template
)
else:
raise ValueError()
prompt.append(img_template_obj)
else:
raise ValueError()
return cls(prompt=prompt, **kwargs)
else:
raise ValueError()
@classmethod
def from_template_file(
cls: Type[_StringImageMessagePromptTemplateT],
template_file: Union[str, Path],
input_variables: List[str],
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
"""Create a class from a template file.
Args:
template_file: path to a template file. String or Path.
input_variables: list of input variables.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
with open(str(template_file), "r") as f:
template = f.read()
return cls.from_template(template, input_variables=input_variables, **kwargs)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
return [self.format(**kwargs)]
@property
def input_variables(self) -> List[str]:
"""
Input variables for this prompt template.
Returns:
List of input variable names.
"""
prompts = self.prompt if isinstance(self.prompt, list) else [self.prompt]
input_variables = [iv for prompt in prompts for iv in prompt.input_variables]
return input_variables
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
@@ -304,53 +447,55 @@ class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
if isinstance(self.prompt, StringPromptTemplate):
text = self.prompt.format(**kwargs)
return self._msg_class(
content=text, additional_kwargs=self.additional_kwargs
)
else:
content = []
for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate):
formatted: Union[str, ImageURL] = prompt.format(**inputs)
content.append({"type": "text", "text": formatted})
elif isinstance(prompt, ImagePromptTemplate):
formatted = prompt.format(**inputs)
content.append({"type": "image_url", "image_url": formatted})
return self._msg_class(
content=content, additional_kwargs=self.additional_kwargs
)
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
class HumanMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""
_msg_class: Type[BaseMessage] = HumanMessage
class AIMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""AI message prompt template. This is a message sent from the AI."""
_msg_class: Type[BaseMessage] = AIMessage
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
class SystemMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""System message prompt template.
This is a message that is not sent to the user.
"""
_msg_class: Type[BaseMessage] = SystemMessage
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
"""Base class for chat prompt templates."""
@@ -404,8 +549,7 @@ MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTempla
MessageLikeRepresentation = Union[
MessageLike,
Tuple[str, str],
Tuple[Type, str],
Tuple[Union[str, Type], Union[str, List[dict], List[object]]],
str,
]
@@ -737,7 +881,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
def _create_template_from_message_type(
message_type: str, template: str
message_type: str, template: Union[str, list]
) -> BaseMessagePromptTemplate:
"""Create a message prompt template from a message type and template string.
@@ -753,9 +897,9 @@ def _create_template_from_message_type(
template
)
elif message_type in ("ai", "assistant"):
message = AIMessagePromptTemplate.from_template(template)
message = AIMessagePromptTemplate.from_template(cast(str, template))
elif message_type == "system":
message = SystemMessagePromptTemplate.from_template(template)
message = SystemMessagePromptTemplate.from_template(cast(str, template))
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
@@ -798,7 +942,9 @@ def _convert_to_message(
if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template)
else:
_message = message_type_str(prompt=PromptTemplate.from_template(template))
_message = message_type_str(
prompt=PromptTemplate.from_template(cast(str, template))
)
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")

View File

@@ -0,0 +1,76 @@
from typing import Any
from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.utils import image as image_utils
class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
"""An image prompt template for a multimodal model."""
template: dict = Field(default_factory=dict)
"""Template for the prompt."""
def __init__(self, **kwargs: Any) -> None:
if "input_variables" not in kwargs:
kwargs["input_variables"] = []
overlap = set(kwargs["input_variables"]) & set(("url", "path", "detail"))
if overlap:
raise ValueError(
"input_variables for the image template cannot contain"
" any of 'url', 'path', or 'detail'."
f" Found: {overlap}"
)
super().__init__(**kwargs)
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "image-prompt"
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
return ImagePromptValue(image_url=self.format(**kwargs))
def format(
self,
**kwargs: Any,
) -> ImageURL:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
formatted = {}
for k, v in self.template.items():
if isinstance(v, str):
formatted[k] = v.format(**kwargs)
else:
formatted[k] = v
url = kwargs.get("url") or formatted.get("url")
path = kwargs.get("path") or formatted.get("path")
detail = kwargs.get("detail") or formatted.get("detail")
if not url and not path:
raise ValueError("Must provide either url or path.")
if not url:
if not isinstance(path, str):
raise ValueError("path must be a string.")
url = image_utils.image_to_data_url(path)
if not isinstance(url, str):
raise ValueError("url must be a string.")
output: ImageURL = {"url": url}
if detail:
# Don't check literal values here: let the API check them
output["detail"] = detail # type: ignore[typeddict-item]
return output

View File

@@ -20,7 +20,7 @@ from typing import (
from uuid import UUID
import jsonpatch # type: ignore[import]
from anyio import create_memory_object_stream
from anyio import BrokenResourceError, ClosedResourceError, create_memory_object_stream
from typing_extensions import NotRequired, TypedDict
from langchain_core.load import dumps
@@ -223,6 +223,14 @@ class LogStreamCallbackHandler(BaseTracer):
def __aiter__(self) -> AsyncIterator[RunLogPatch]:
return self.receive_stream.__aiter__()
def send(self, *ops: Dict[str, Any]) -> bool:
"""Send a patch to the stream, return False if the stream is closed."""
try:
self.send_stream.send_nowait(RunLogPatch(*ops))
return True
except (ClosedResourceError, BrokenResourceError):
return False
async def tap_output_aiter(
self, run_id: UUID, output: AsyncIterator[T]
) -> AsyncIterator[T]:
@@ -233,15 +241,14 @@ class LogStreamCallbackHandler(BaseTracer):
# if we can't find the run silently ignore
# eg. because this run wasn't included in the log
if key := self._key_map_by_run_id.get(run_id):
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "add",
"path": f"/logs/{key}/streamed_output/-",
"value": chunk,
}
)
)
if not self.send(
{
"op": "add",
"path": f"/logs/{key}/streamed_output/-",
"value": chunk,
}
):
break
yield chunk
@@ -285,22 +292,21 @@ class LogStreamCallbackHandler(BaseTracer):
"""Start a run."""
if self.root_id is None:
self.root_id = run.id
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "replace",
"path": "",
"value": RunState(
id=str(run.id),
streamed_output=[],
final_output=None,
logs={},
name=run.name,
type=run.run_type,
),
}
)
)
if not self.send(
{
"op": "replace",
"path": "",
"value": RunState(
id=str(run.id),
streamed_output=[],
final_output=None,
logs={},
name=run.name,
type=run.run_type,
),
}
):
return
if not self.include_run(run):
return
@@ -331,14 +337,12 @@ class LogStreamCallbackHandler(BaseTracer):
entry["inputs"] = _get_standardized_inputs(run, self._schema_format)
# Add the run to the stream
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "add",
"path": f"/logs/{self._key_map_by_run_id[run.id]}",
"value": entry,
}
)
self.send(
{
"op": "add",
"path": f"/logs/{self._key_map_by_run_id[run.id]}",
"value": entry,
}
)
def _on_run_update(self, run: Run) -> None:
@@ -382,7 +386,7 @@ class LogStreamCallbackHandler(BaseTracer):
]
)
self.send_stream.send_nowait(RunLogPatch(*ops))
self.send(*ops)
finally:
if run.id == self.root_id:
if self.auto_close:
@@ -400,21 +404,19 @@ class LogStreamCallbackHandler(BaseTracer):
if index is None:
return
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "add",
"path": f"/logs/{index}/streamed_output_str/-",
"value": token,
},
{
"op": "add",
"path": f"/logs/{index}/streamed_output/-",
"value": chunk.message
if isinstance(chunk, ChatGenerationChunk)
else token,
},
)
self.send(
{
"op": "add",
"path": f"/logs/{index}/streamed_output_str/-",
"value": token,
},
{
"op": "add",
"path": f"/logs/{index}/streamed_output/-",
"value": chunk.message
if isinstance(chunk, ChatGenerationChunk)
else token,
},
)

View File

@@ -4,6 +4,7 @@
These functions do not depend on any other LangChain module.
"""
from langchain_core.utils import image
from langchain_core.utils.env import get_from_dict_or_env, get_from_env
from langchain_core.utils.formatting import StrictFormatter, formatter
from langchain_core.utils.input import (
@@ -41,6 +42,7 @@ __all__ = [
"xor_args",
"try_load_from_hub",
"build_extra_kwargs",
"image",
"get_from_env",
"get_from_dict_or_env",
"stringify_dict",

View File

@@ -0,0 +1,14 @@
import base64
import mimetypes
def encode_image(image_path: str) -> str:
"""Get base64 string from image URI."""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def image_to_data_url(image_path: str) -> str:
encoding = encode_image(image_path)
mime_type = mimetypes.guess_type(image_path)[0]
return f"data:{mime_type};base64,{encoding}"

View File

@@ -301,3 +301,24 @@ class GenericFakeChatModel(BaseChatModel):
@property
def _llm_type(self) -> str:
return "generic-fake-chat-model"
class ParrotFakeChatModel(BaseChatModel):
"""A generic fake chat model that can be used to test the chat model interface.
* Chat model should be usable in both sync and async tests
"""
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
return ChatResult(generations=[ChatGeneration(message=messages[-1])])
@property
def _llm_type(self) -> str:
return "parrot-fake-chat-model"

View File

@@ -5,8 +5,9 @@ from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
from tests.unit_tests.fake.chat_model import GenericFakeChatModel, ParrotFakeChatModel
def test_generic_fake_chat_model_invoke() -> None:
@@ -182,3 +183,11 @@ async def test_callback_handlers() -> None:
AIMessageChunk(content="goodbye"),
]
assert tokens == ["hello", " ", "goodbye"]
def test_chat_model_inputs() -> None:
fake = ParrotFakeChatModel()
assert fake.invoke("hello") == HumanMessage(content="hello")
assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah")
assert fake.invoke([AIMessage(content="blah")]) == AIMessage(content="blah")

View File

@@ -16,6 +16,7 @@ EXPECTED_ALL = [
"SystemMessageChunk",
"ToolMessage",
"ToolMessageChunk",
"convert_to_messages",
"get_buffer_string",
"message_chunk_to_message",
"messages_from_dict",

View File

@@ -3,6 +3,9 @@ from typing import Any, List, Union
import pytest
from langchain_core._api.deprecation import (
LangChainPendingDeprecationWarning,
)
from langchain_core.messages import (
AIMessage,
BaseMessage,
@@ -243,14 +246,15 @@ def test_chat_valid_infer_variables() -> None:
def test_chat_from_role_strings() -> None:
"""Test instantiation of chat template from role strings."""
template = ChatPromptTemplate.from_role_strings(
[
("system", "You are a bot."),
("assistant", "hello!"),
("human", "{question}"),
("other", "{quack}"),
]
)
with pytest.warns(LangChainPendingDeprecationWarning):
template = ChatPromptTemplate.from_role_strings(
[
("system", "You are a bot."),
("assistant", "hello!"),
("human", "{question}"),
("other", "{quack}"),
]
)
messages = template.format_messages(question="How are you?", quack="duck")
assert messages == [
@@ -363,9 +367,145 @@ def test_chat_message_partial() -> None:
assert template2.format(input="hello") == get_buffer_string(expected)
def test_chat_tmpl_from_messages_multipart_text() -> None:
template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
(
"human",
[
{"type": "text", "text": "What's in this image?"},
{"type": "text", "text": "Oh nvm"},
],
),
]
)
messages = template.format_messages(name="R2D2")
expected = [
SystemMessage(content="You are an AI assistant named R2D2."),
HumanMessage(
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "text", "text": "Oh nvm"},
]
),
]
assert messages == expected
def test_chat_tmpl_from_messages_multipart_text_with_template() -> None:
template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
(
"human",
[
{"type": "text", "text": "What's in this {object_name}?"},
{"type": "text", "text": "Oh nvm"},
],
),
]
)
messages = template.format_messages(name="R2D2", object_name="image")
expected = [
SystemMessage(content="You are an AI assistant named R2D2."),
HumanMessage(
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "text", "text": "Oh nvm"},
]
),
]
assert messages == expected
def test_chat_tmpl_from_messages_multipart_image() -> None:
base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
other_base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
(
"human",
[
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": "data:image/jpeg;base64,{my_image}",
},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,{my_image}"},
},
{"type": "image_url", "image_url": "{my_other_image}"},
{
"type": "image_url",
"image_url": {"url": "{my_other_image}", "detail": "medium"},
},
{
"type": "image_url",
"image_url": {"url": "https://www.langchain.com/image.png"},
},
{
"type": "image_url",
"image_url": {"url": ""},
},
],
),
]
)
messages = template.format_messages(
name="R2D2", my_image=base64_image, my_other_image=other_base64_image
)
expected = [
SystemMessage(content="You are an AI assistant named R2D2."),
HumanMessage(
content=[
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{other_base64_image}"
},
},
{
"type": "image_url",
"image_url": {"url": f"{other_base64_image}"},
},
{
"type": "image_url",
"image_url": {
"url": f"{other_base64_image}",
"detail": "medium",
},
},
{
"type": "image_url",
"image_url": {"url": "https://www.langchain.com/image.png"},
},
{
"type": "image_url",
"image_url": {"url": ""},
},
]
),
]
assert messages == expected
def test_messages_placeholder() -> None:
prompt = MessagesPlaceholder("history")
with pytest.raises(KeyError):
prompt.format_messages()
prompt = MessagesPlaceholder("history", optional=True)
assert prompt.format_messages() == []
prompt.format_messages(
history=[("system", "You are an AI assistant."), "Hello!"]
) == [
SystemMessage(content="You are an AI assistant."),
HumanMessage(content="Hello!"),
]

View File

@@ -14,6 +14,7 @@ from langchain_core.messages import (
HumanMessageChunk,
SystemMessage,
ToolMessage,
convert_to_messages,
get_buffer_string,
message_chunk_to_message,
messages_from_dict,
@@ -428,3 +429,54 @@ def test_tool_calls_merge() -> None:
]
},
)
def test_convert_to_messages() -> None:
# dicts
assert convert_to_messages(
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
{"role": "ai", "content": "Hi!"},
{"role": "human", "content": "Hello!", "name": "Jane"},
{
"role": "assistant",
"content": "Hi!",
"name": "JaneBot",
"function_call": {"name": "greet", "arguments": '{"name": "Jane"}'},
},
{"role": "function", "name": "greet", "content": "Hi!"},
{"role": "tool", "tool_call_id": "tool_id", "content": "Hi!"},
]
) == [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello!"),
AIMessage(content="Hi!"),
HumanMessage(content="Hello!", name="Jane"),
AIMessage(
content="Hi!",
name="JaneBot",
additional_kwargs={
"function_call": {"name": "greet", "arguments": '{"name": "Jane"}'}
},
),
FunctionMessage(name="greet", content="Hi!"),
ToolMessage(tool_call_id="tool_id", content="Hi!"),
]
# tuples
assert convert_to_messages(
[
("system", "You are a helpful assistant."),
"hello!",
("ai", "Hi!"),
("human", "Hello!"),
("assistant", "Hi!"),
]
) == [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="hello!"),
AIMessage(content="Hi!"),
HumanMessage(content="Hello!"),
AIMessage(content="Hi!"),
]

View File

@@ -16,6 +16,7 @@ EXPECTED_ALL = [
"xor_args",
"try_load_from_hub",
"build_extra_kwargs",
"image",
"get_from_dict_or_env",
"get_from_env",
"stringify_dict",

View File

@@ -1,6 +1,9 @@
from __future__ import annotations
from typing import Dict
from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.retrievers import RetrieverLike, RetrieverOutputLike
@@ -48,13 +51,31 @@ def create_history_aware_retriever(
chain.invoke({"input": "...", "chat_history": })
"""
if "input" not in prompt.input_variables:
input_vars = prompt.input_variables
if "input" not in input_vars and "messages" not in input_vars:
raise ValueError(
"Expected `input` to be a prompt variable, "
f"but got {prompt.input_variables}"
"Expected either `input` or `messages` to be prompt variables, "
f"but got {input_vars}"
)
def messages_param_is_message_list(x: Dict) -> bool:
return (
isinstance(x.get("messages", []), list)
and len(x.get("messages", [])) > 0
and all(isinstance(i, BaseMessage) for i in x.get("messages", []))
)
retrieve_documents: RetrieverOutputLike = RunnableBranch(
(
lambda x: messages_param_is_message_list(x)
and len(x.get("messages", [])) > 1,
prompt | llm | StrOutputParser() | retriever,
),
(
lambda x: messages_param_is_message_list(x)
and len(x.get("messages", [])) == 1,
(lambda x: x["messages"][-1].content) | retriever,
),
(
# Both empty string and empty list evaluate to False
lambda x: not x.get("chat_history", False),

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Any, Dict, Union
from langchain_core.messages import BaseMessage
from langchain_core.retrievers import (
BaseRetriever,
RetrieverOutput,
@@ -55,10 +56,30 @@ def create_retrieval_chain(
chain.invoke({"input": "..."})
"""
def messages_param_is_message_list(x: Dict) -> bool:
return (
isinstance(x.get("messages", []), list)
and len(x.get("messages", [])) > 0
and all(isinstance(i, BaseMessage) for i in x.get("messages", []))
)
def extract_retriever_input_string(x: Dict) -> str:
if not x.get("input"):
if messages_param_is_message_list(x):
return x["messages"][-1].content
else:
raise ValueError(
"If `input` not provided, ",
"`messages` parameter must be a list of messages.",
)
else:
return x["input"]
if not isinstance(retriever, BaseRetriever):
retrieval_docs: Runnable[dict, RetrieverOutput] = retriever
else:
retrieval_docs = (lambda x: x["input"]) | retriever
retrieval_docs = extract_retriever_input_string | retriever
retrieval_chain = (
RunnablePassthrough.assign(

View File

@@ -391,7 +391,7 @@ async def _to_async_iterator(iterator: Iterable[T]) -> AsyncIterator[T]:
async def aindex(
docs_source: Union[BaseLoader, Iterable[Document], AsyncIterator[Document]],
docs_source: Union[Iterable[Document], AsyncIterator[Document]],
record_manager: RecordManager,
vector_store: VectorStore,
*,
@@ -469,17 +469,16 @@ async def aindex(
# implementation which just raises a NotImplementedError
raise ValueError("Vectorstore has not implemented the delete method")
async_doc_iterator: AsyncIterator[Document]
if isinstance(docs_source, BaseLoader):
try:
async_doc_iterator = docs_source.alazy_load()
except NotImplementedError:
async_doc_iterator = _to_async_iterator(await docs_source.aload())
raise NotImplementedError(
"Not supported yet. Please pass an async iterator of documents."
)
async_doc_iterator: AsyncIterator[Document]
if hasattr(docs_source, "__aiter__"):
async_doc_iterator = docs_source # type: ignore[assignment]
else:
if hasattr(docs_source, "__aiter__"):
async_doc_iterator = docs_source # type: ignore[assignment]
else:
async_doc_iterator = _to_async_iterator(docs_source)
async_doc_iterator = _to_async_iterator(docs_source)
source_id_assigner = _get_source_id_assigner(source_id_key)

View File

@@ -1,7 +1,12 @@
"""Test conversation chain and memory."""
from langchain_community.llms.fake import FakeListLLM
from langchain_core.documents import Document
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.messages import HumanMessage
from langchain_core.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
PromptTemplate,
)
from langchain.chains import create_retrieval_chain
from tests.unit_tests.retrievers.parrot_retriever import FakeParrotRetriever
@@ -22,3 +27,31 @@ def test_create() -> None:
}
output = chain.invoke({"input": "What is the answer?", "chat_history": "foo"})
assert output == expected_output
def test_create_with_chat_history_messages_only() -> None:
answer = "I know the answer!"
llm = FakeListLLM(responses=[answer])
retriever = FakeParrotRetriever()
question_gen_prompt = ChatPromptTemplate.from_messages(
[
MessagesPlaceholder(variable_name="messages"),
]
)
chain = create_retrieval_chain(retriever, question_gen_prompt | llm)
expected_output = {
"answer": "I know the answer!",
"messages": [
HumanMessage(content="What is the answer?"),
],
"context": [Document(page_content="What is the answer?")],
}
output = chain.invoke(
{
"messages": [
HumanMessage(content="What is the answer?"),
],
}
)
assert output == expected_output

View File

@@ -40,6 +40,19 @@ class ToyLoader(BaseLoader):
"""Load the documents from the source."""
return list(self.lazy_load())
async def alazy_load(
self,
) -> AsyncIterator[Document]:
async def async_generator() -> AsyncIterator[Document]:
for document in self.documents:
yield document
return async_generator()
async def aload(self) -> List[Document]:
"""Load the documents from the source."""
return [doc async for doc in await self.alazy_load()]
class InMemoryVectorStore(VectorStore):
"""In-memory implementation of VectorStore using a dictionary."""
@@ -219,7 +232,7 @@ async def test_aindexing_same_content(
]
)
assert await aindex(loader, arecord_manager, vector_store) == {
assert await aindex(await loader.alazy_load(), arecord_manager, vector_store) == {
"num_added": 2,
"num_deleted": 0,
"num_skipped": 0,
@@ -230,7 +243,9 @@ async def test_aindexing_same_content(
for _ in range(2):
# Run the indexing again
assert await aindex(loader, arecord_manager, vector_store) == {
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store
) == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
@@ -332,7 +347,9 @@ async def test_aindex_simple_delete_full(
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp()
):
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
"num_added": 2,
"num_deleted": 0,
"num_skipped": 0,
@@ -342,7 +359,9 @@ async def test_aindex_simple_delete_full(
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp()
):
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
@@ -363,7 +382,9 @@ async def test_aindex_simple_delete_full(
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
"num_added": 1,
"num_deleted": 1,
"num_skipped": 1,
@@ -381,7 +402,9 @@ async def test_aindex_simple_delete_full(
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
@@ -450,7 +473,7 @@ async def test_aincremental_fails_with_bad_source_ids(
with pytest.raises(ValueError):
# Should raise an error because no source id function was specified
await aindex(
loader,
await loader.alazy_load(),
arecord_manager,
vector_store,
cleanup="incremental",
@@ -459,7 +482,7 @@ async def test_aincremental_fails_with_bad_source_ids(
with pytest.raises(ValueError):
# Should raise an error because no source id function was specified
await aindex(
loader,
await loader.alazy_load(),
arecord_manager,
vector_store,
cleanup="incremental",
@@ -570,7 +593,7 @@ async def test_ano_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
loader,
await loader.alazy_load(),
arecord_manager,
vector_store,
cleanup=None,
@@ -587,7 +610,7 @@ async def test_ano_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
loader,
await loader.alazy_load(),
arecord_manager,
vector_store,
cleanup=None,
@@ -617,7 +640,7 @@ async def test_ano_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
loader,
await loader.alazy_load(),
arecord_manager,
vector_store,
cleanup=None,
@@ -756,7 +779,7 @@ async def test_aincremental_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
loader.lazy_load(),
await loader.alazy_load(),
arecord_manager,
vector_store,
cleanup="incremental",
@@ -780,7 +803,7 @@ async def test_aincremental_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
loader.lazy_load(),
await loader.alazy_load(),
arecord_manager,
vector_store,
cleanup="incremental",
@@ -815,7 +838,7 @@ async def test_aincremental_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 3).timestamp()
):
assert await aindex(
loader.lazy_load(),
await loader.alazy_load(),
arecord_manager,
vector_store,
cleanup="incremental",
@@ -860,7 +883,9 @@ async def test_aindexing_with_no_docs(
"""Check edge case when loader returns no new docs."""
loader = ToyLoader(documents=[])
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 0,

View File

@@ -97,22 +97,31 @@ def is_gemini_model(model_name: str) -> bool:
def get_generation_info(
candidate: Union[TextGenerationResponse, Candidate], is_gemini: bool
candidate: Union[TextGenerationResponse, Candidate],
is_gemini: bool,
*,
stream: bool = False,
) -> Dict[str, Any]:
if is_gemini:
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
return {
info = {
"is_blocked": any([rating.blocked for rating in candidate.safety_ratings]),
"safety_ratings": [
{
"category": rating.category.name,
"probability_label": rating.probability.name,
"blocked": rating.blocked,
}
for rating in candidate.safety_ratings
],
"citation_metadata": candidate.citation_metadata,
}
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
candidate_dc = dataclasses.asdict(candidate)
candidate_dc.pop("text")
return {k: v for k, v in candidate_dc.items() if not k.startswith("_")}
else:
info = dataclasses.asdict(candidate)
info.pop("text")
info = {k: v for k, v in info.items() if not k.startswith("_")}
if stream:
# Remove non-streamable types, like bools.
info.pop("is_blocked")
return info

View File

@@ -315,10 +315,12 @@ class VertexAI(_VertexAICommon, BaseLLM):
return result.total_tokens
def _response_to_generation(
self, response: TextGenerationResponse
self, response: TextGenerationResponse, *, stream: bool = False
) -> GenerationChunk:
"""Converts a stream response to a generation chunk."""
generation_info = get_generation_info(response, self._is_gemini_model)
generation_info = get_generation_info(
response, self._is_gemini_model, stream=stream
)
try:
text = response.text
except AttributeError:
@@ -401,7 +403,14 @@ class VertexAI(_VertexAICommon, BaseLLM):
run_manager=run_manager,
**params,
):
chunk = self._response_to_generation(stream_resp)
# Gemini models return GenerationResponse even when streaming, which has a
# candidates field.
stream_resp = (
stream_resp
if isinstance(stream_resp, TextGenerationResponse)
else stream_resp.candidates[0]
)
chunk = self._response_to_generation(stream_resp, stream=True)
yield chunk
if run_manager:
run_manager.on_llm_new_token(

View File

@@ -32,18 +32,33 @@ def test_vertex_initialization(model_name: str) -> None:
"model_name",
model_names_to_test_with_default,
)
def test_vertex_call(model_name: str) -> None:
def test_vertex_invoke(model_name: str) -> None:
llm = (
VertexAI(model_name=model_name, temperature=0)
if model_name
else VertexAI(temperature=0.0)
)
output = llm("Say foo:")
output = llm.invoke("Say foo:")
assert isinstance(output, str)
@pytest.mark.parametrize(
"model_name",
model_names_to_test_with_default,
)
def test_vertex_generate(model_name: str) -> None:
llm = (
VertexAI(model_name=model_name, temperature=0)
if model_name
else VertexAI(temperature=0.0)
)
output = llm.generate(["Say foo:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
@pytest.mark.xfail(reason="VertexAI doesn't always respect number of candidates")
def test_vertex_generate() -> None:
def test_vertex_generate_multiple_candidates() -> None:
llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001")
output = llm.generate(["Say foo:"])
assert isinstance(output, LLMResult)