Compare commits

...

80 Commits

Author SHA1 Message Date
Erick Friis
63f380d349 ignore ones 2023-11-22 14:47:04 -05:00
Erick Friis
9c55d65606 document transformers gone 2023-11-22 14:32:33 -05:00
Bagatur
fabebb2042 more 2023-11-22 11:31:51 -08:00
Bagatur
b2b8122249 Merge remote-tracking branch 'origin/erick/core-namespace-same' into bagatur/fix_core_namespace 2023-11-22 11:20:07 -08:00
Bagatur
5c6973a1b5 more 2023-11-22 11:19:41 -08:00
Bagatur
4449481cd3 BUGFIX: backwards compatible core namespacing 2023-11-22 11:10:27 -08:00
Erick Friis
70f6be32e0 number 2023-11-22 13:44:10 -05:00
Erick Friis
1548f32f3d lint 2023-11-22 13:42:59 -05:00
Erick Friis
6c59db482b unit test for core 2023-11-22 13:38:00 -05:00
Bagatur
32d087fcb8 REFACTOR: combine core documents files (#13733) 2023-11-22 10:10:26 -08:00
h3l
14d4fb98fc DOCS: Fix typo/line break in python code (#13708) 2023-11-22 09:10:07 -08:00
William FH
5b90fe5b1c Fix locking (#13725) 2023-11-22 07:37:25 -08:00
Bagatur
16af282429 BUGFIX: add prompt imports for backwards compat (#13702) 2023-11-21 23:04:20 -08:00
Erick Friis
78da34153e TEMPLATES Metadata (#13691)
Co-authored-by: Lance Martin <lance@langchain.dev>
2023-11-22 01:41:12 -05:00
Bagatur
e327bb4ba4 IMPROVEMENT: Conditionally import core type hints (#13700) 2023-11-21 21:38:49 -08:00
dandanwei
d47ee1ae79 BUGFIX: redis vector store overwrites falsey metadata (#13652)
- **Description:** This commit fixed the problem that Redis vector store
will change the value of a metadata from 0 to empty when saving the
document, which should be an un-intended behavior.
  - **Issue:** N/A
  - **Dependencies:** N/A
2023-11-21 20:16:23 -08:00
Bagatur
a21e84faf7 BUGFIX: llm backwards compat imports (#13698) 2023-11-21 20:12:35 -08:00
Yujie Qian
ace9e64d62 IMPROVEMENT: VoyageEmbeddings embed_general_texts (#13620)
- **Description:** add method embed_general_texts in VoyageEmebddings to
support input_type
  - **Issue:** 
  - **Dependencies:** 
  - **Tag maintainer:** 
  - **Twitter handle:** @Voyage_AI_
2023-11-21 18:33:07 -08:00
tanujtiwari-at
5064890fcf BUGFIX: handle tool message type when converting to string (#13626)
**Description:** Currently, if we pass in a ToolMessage back to the
chain, it crashes with error

`Got unsupported message type: `

This fixes it. 

Tested locally

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-11-21 18:20:58 -08:00
Josep Pon Farreny
143049c90f Added partial_variables to BaseStringMessagePromptTemplate.from_template(...) (#13645)
**Description:** BaseStringMessagePromptTemplate.from_template was
passing the value of partial_variables into cls(...) via **kwargs,
rather than passing it to PromptTemplate.from_template. Which resulted
in those *partial_variables being* lost and becoming required
*input_variables*.

Co-authored-by: Josep Pon Farreny <josep.pon-farreny@siemens.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-11-21 17:48:38 -08:00
Erick Friis
c5ae9f832d INFRA: Lint for imports (#13632)
- Adds pydantic/import linting to core
- Adds a check for `langchain_experimental` imports to langchain
2023-11-21 17:42:56 -08:00
Erick Friis
131db4ba68 BUGFIX: anthropic models on bedrock (#13629)
Introduced in #13403
2023-11-21 17:40:29 -08:00
David Ruan
04bddbaba4 BUGFIX: Update bedrock.py to fix provider bug (#13646)
Provider check was incorrectly failing for anything other than "meta"
2023-11-21 17:28:38 -08:00
Guangya Liu
aec8715073 DOCS: remove openai api key from cookbook (#13633) 2023-11-21 17:25:06 -08:00
Guangya Liu
bb18b0266e DOCS: fixed import error for BashOutputParser (#13680) 2023-11-21 16:33:40 -08:00
Bagatur
dc53523837 IMPROVEMENT: bump core dep 0.0.3 (#13690) 2023-11-21 15:50:19 -08:00
Bagatur
a208abe6b7 add callback import test (#13689) 2023-11-21 15:28:49 -08:00
Bagatur
083afba697 BUG: Add core utils imports (#13688) 2023-11-21 15:25:47 -08:00
Bagatur
c61e30632e BUG: more core fixes (#13665)
Fix some circular deps:
- move PromptValue into top level module bc both PromptTemplates and
OutputParsers import
- move tracer context vars to `tracers.context` and import them in
functions in `callbacks.manager`
- add core import tests
2023-11-21 15:15:48 -08:00
William FH
59df16ab92 Update name (#13676) 2023-11-21 13:39:30 -08:00
Erick Friis
bfb980b968 CLI 0.0.19 (#13677) 2023-11-21 12:34:38 -08:00
Taqi Jaffri
d65c36d60a docugami cookbook (#13183)
Adds a cookbook for semi-structured RAG via Docugami. This follows the
same outline as the semi-structured RAG with Unstructured cookbook:
https://github.com/langchain-ai/langchain/blob/master/cookbook/Semi_Structured_RAG.ipynb

The main change is this cookbook uses Docugami instead of Unstructured
to find text and tables, and shows how XML markup in the output helps
with retrieval and generation.

We are \@docugami on twitter, I am \@tjaffri

---------

Co-authored-by: Taqi Jaffri <tjaffri@docugami.com>
2023-11-21 12:02:20 -08:00
jakerachleff
249c796785 update langserve to v0.0.30 (#13673)
Upgrade langserve template version to 0.0.30 to include new improvements
2023-11-21 11:17:47 -08:00
jakerachleff
c6937a2eb4 fix templates dockerfile (#13672)
- **Description:** We need to update the Dockerfile for templates to
also copy your README.md. This is because poetry requires that a readme
exists if it is specified in the pyproject.toml
2023-11-21 11:09:55 -08:00
Bagatur
11614700a4 bump 0.0.339rc0 (#13664) 2023-11-21 08:41:59 -08:00
Bagatur
d32e511826 REFACTOR: Refactor langchain_core (#13627)
Changes:
- remove langchain_core/schema since no clear distinction b/n schema and
non-schema modules
- make every module that doesn't end in -y plural
- where easy have 1-2 classes per file
- no more than one level of nesting in directories
- only import from top level core modules in langchain
2023-11-21 08:35:29 -08:00
William FH
17c6551c18 Add error rate (#13568)
To the in-memory outputs. Separate it out from the outputs so it's
present in the dataframe.describe() results
2023-11-21 07:51:30 -08:00
Nuno Campos
8329f81072 Use pytest asyncio auto mode (#13643)
<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
2023-11-21 15:00:13 +00:00
Lance Martin
611e1e0ca4 Add template for gpt-crawler (#13625)
Template for RAG using
[gpt-crawler](https://github.com/BuilderIO/gpt-crawler).

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
2023-11-20 21:32:57 -08:00
Bagatur
99b4f46cbe REFACTOR: Add core as dep (#13623) 2023-11-20 14:38:10 -08:00
Harrison Chase
d82cbf5e76 Separate out langchain_core package (#13577)
Co-authored-by: Nuno Campos <nuno@boringbits.io>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
2023-11-20 13:09:30 -08:00
Bagatur
4eec47b191 DOCS: update rag use case images (#13615) 2023-11-20 10:14:52 -08:00
Bagatur
e620347a83 RELEASE: bump 339 (#13613) 2023-11-20 09:56:43 -08:00
Ofer Mendelevitch
52e23e50b1 BUG: Fix search_kwargs in Vectara retriever (#13299)
- **Description:** fix a bug that prevented as_retriever() in Vectara to
use the desired input arguments
  - **Issue:** as_retriever did not pass the arguments properly
  - **Tag maintainer:** @baskaryan
  - **Twitter handle:** @ofermend
2023-11-20 09:44:43 -08:00
Holt Skinner
1c08dbfb33 IMPROVEMENT: Reduce post-processing time for DocAIParser (#13210)
- Remove `WrappedDocument` introduced in
https://github.com/langchain-ai/langchain/pull/11413
- https://github.com/googleapis/python-documentai-toolbox/issues/198 in
Document AI Toolbox to improve initialization time for `WrappedDocument`
object.

@lkuligin

@baskaryan

@hwchase17
2023-11-20 09:41:44 -08:00
Leonid Kuligin
f3fcdea574 fixed an UnboundLocalError when no documents are found (#12995)
Replace this entire comment with:
  - **Description:** fixed a bug
  - **Issue:** the issue # #12780
2023-11-20 09:41:14 -08:00
Stijn Tratsaert
b6f70d776b VertexAI LLM count_tokens method requires list of prompts (#13451)
I encountered this during summarization with VertexAI. I was receiving
an INVALID_ARGUMENT error, as it was trying to send a list of about
17000 single characters.

The [count_tokens
method](https://github.com/googleapis/python-aiplatform/blob/main/vertexai/language_models/_language_models.py#L658)
made available by Google takes in a list of prompts. It does not fail
for small texts, but it does for longer documents because the argument
list will be exceeding Googles allowed limit. Enforcing the list type
makes it work successfully.

This change will cast the input text to count to a list of that single
text so that the input format is always correct.

[Twitter](https://www.x.com/stijn_tratsaert)
2023-11-20 09:40:48 -08:00
Wang Wei
fe7b40cb2a feat: add ERNIE-Bot-4 Function Calling (#13320)
- **Description:** ERNIE-Bot-Chat-4 Large Language Model adds the
ability of `Function Calling` by passing parameters through the
`functions` parameter in the request. To simplify function calling for
ERNIE-Bot-Chat-4, the `create_ernie_fn_chain()` function has been added.
The definition and usage of the `create_ernie_fn_chain()` function is
similar to that of the `create_openai_fn_chain()` function.

Examples as the follows:

```
import json

from langchain.chains.ernie_functions import (
    create_ernie_fn_chain,
)
from langchain.chat_models import ErnieBotChat
from langchain.prompts import ChatPromptTemplate

def get_current_news(location: str) -> str:
    """Get the current news based on the location.'

    Args:
        location (str): The location to query.
    
    Returs:
        str: Current news based on the location.
    """

    news_info = {
        "location": location,
        "news": [
            "I have a Book.",
            "It's a nice day, today."
        ]
    }

    return json.dumps(news_info)

def get_current_weather(location: str, unit: str="celsius") -> str:
    """Get the current weather in a given location

    Args:
        location (str): location of the weather.
        unit (str): unit of the tempuature.
    
    Returns:
        str: weather in the given location.
    """

    weather_info = {
        "location": location,
        "temperature": "27",
        "unit": unit,
        "forecast": ["sunny", "windy"],
    }
    return json.dumps(weather_info)

llm = ErnieBotChat(model_name="ERNIE-Bot-4")
prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{query}"),
    ]
)

chain = create_ernie_fn_chain([get_current_weather, get_current_news], llm, prompt, verbose=True)
res = chain.run("北京今天的新闻是什么?")
print(res)
```

The running results of the above program are shown below:
```
> Entering new LLMChain chain...
Prompt after formatting:
Human: 北京今天的新闻是什么?



> Finished chain.
{'name': 'get_current_news', 'thoughts': '用户想要知道北京今天的新闻。我可以使用get_current_news工具来获取这些信息。', 'arguments': {'location': '北京'}}
```
2023-11-19 22:36:12 -08:00
Adilkhan Sarsen
10418ab0c1 DeepLake Backwards compatibility fix (#13388)
- **Description:** during search with DeepLake some people are facing
backwards compatibility issues, this PR fixes it by making search
accessible for the older datasets

---------

Co-authored-by: adolkhan <adilkhan.sarsen@alumni.nu.edu.kz>
2023-11-19 21:46:01 -08:00
Tyler Hutcherson
190952fe76 IMPROVEMENT: Minor redis improvements (#13381)
- **Description:**
- Fixes a `key_prefix` bug where passing it in on
`Redis.from_existing(...)` did not work properly. Updates doc strings
accordingly.
- Updates Redis filter classes logic with best practices on typing,
string formatting, and handling "empty" filters.
- Fixes a bug that would prevent multiple tag filters from being applied
together in some scenarios.
- Added a whole new filter unit testing module. Also updated code
formatting for a number of modules that were failing the `make`
commands.
  - **Issue:** N/A
  - **Dependencies:** N/A
  - **Tag maintainer:** @baskaryan 
  - **Twitter handle:** @tchutch94
2023-11-19 19:15:45 -08:00
Sijun He
674bd90a47 DOCS: Fix typo in MongoDB memory docs (#13588)
- **Description:** Fix typo in MongoDB memory docs
  - **Tag maintainer:** @eyurtsev

<!-- Thank you for contributing to LangChain!

  - **Description:** Fix typo in MongoDB memory docs
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
  - **Tag maintainer:** @baskaryan
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
2023-11-19 19:13:35 -08:00
Sergey Kozlov
df03267edf Fix tool arguments formatting in StructuredChatAgent (#10480)
In the `FORMAT_INSTRUCTIONS` template, 4 curly braces (escaping) are
used to get single curly brace after formatting:

```
"{{{ ... }}}}" -> format_instructions.format() ->  "{{ ... }}" -> template.format() -> "{ ... }".
```

Tool's `args_schema` string contains single braces `{ ... }`, and is
also transformed to `{{{{ ... }}}}` form. But this is not really correct
since there is only one `format()` call:

```
"{{{{ ... }}}}" -> template.format() -> "{{ ... }}".
```

As a result we get double curly braces in the prompt:
````
Respond to the human as helpfully and accurately as possible. You have access to the following tools:

foo: Test tool FOO, args: {{'tool_input': {{'type': 'string'}}}}    # <--- !!!
...
Provide only ONE action per $JSON_BLOB, as shown:

```
{
  "action": $TOOL_NAME,
  "action_input": $INPUT
}
```
````

This PR fixes curly braces escaping in the `args_schema` to have single
braces in the final prompt:
````
Respond to the human as helpfully and accurately as possible. You have access to the following tools:

foo: Test tool FOO, args: {'tool_input': {'type': 'string'}}    # <--- !!!
...
Provide only ONE action per $JSON_BLOB, as shown:

```
{
  "action": $TOOL_NAME,
  "action_input": $INPUT
}
```
````

---------

Co-authored-by: Sergey Kozlov <sergey.kozlov@ludditelabs.io>
2023-11-19 18:45:43 -08:00
Wouter Durnez
ef7802b325 Add llama2-13b-chat-v1 support to chat_models.BedrockChat (#13403)
Hi 👋 We are working with Llama2 on Bedrock, and would like to add it to
Langchain. We saw a [pull
request](https://github.com/langchain-ai/langchain/pull/13322) to add it
to the `llm.Bedrock` class, but since it concerns a chat model, we would
like to add it to `BedrockChat` as well.

- **Description:** Add support for Llama2 to `BedrockChat` in
`chat_models`
- **Issue:** the issue # it fixes (if applicable)
[#13316](https://github.com/langchain-ai/langchain/issues/13316)
  - **Dependencies:** any dependencies required for this change `None`
  - **Tag maintainer:** /
  - **Twitter handle:** `@SimonBockaert @WouterDurnez`

---------

Co-authored-by: wouter.durnez <wouter.durnez@showpad.com>
Co-authored-by: Simon Bockaert <simon.bockaert@showpad.com>
2023-11-19 18:44:58 -08:00
jwbeck97
a93616e972 FEAT: Add azure cognitive health tool (#13448)
- **Description:** This change adds an agent to the Azure Cognitive
Services toolkit for identifying healthcare entities
  - **Dependencies:** azure-ai-textanalytics (Optional)

---------

Co-authored-by: James Beck <James.Beck@sa.gov.au>
Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-11-19 18:44:01 -08:00
Massimiliano Pronesti
6bf9b2cb51 BUG: Limit Azure OpenAI embeddings chunk size (#13425)
Hi! 
This short PR aims at:
* Fixing `OpenAIEmbeddings`' check on `chunk_size` when used with Azure
OpenAI (thus with openai < 1.0). Azure OpenAI embeddings support at most
16 chunks per batch, I believe we are supposed to take the min between
the passed value/default value and 16, not the max - which, I suppose,
was introduced by accident while refactoring the previous version of
this check from this other PR of mine: #10707
* Porting this fix to the newest class (`AzureOpenAIEmbeddings`) for
openai >= 1.0

This fixes #13539 (closed but the issue persists).  

@baskaryan @hwchase17
2023-11-19 18:34:51 -08:00
Zeyang Lin
e53f59f01a DOCS: doc-string - langchain.vectorstores.dashvector.DashVector (#13502)
- **Description:** There are several mistakes in the sample code in the
doc-string of `DashVector` class, and this pull request aims to correct
them.
The correction code has been tested against latest version (at the time
of creation of this pull request) of: `langchain==0.0.336`
`dashvector==1.0.6` .
- **Issue:** No issue is created for this.
- **Dependencies:** No dependency is required for this change,
<!-- - **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below), -->
- **Twitter handle:** `zeyanglin`

<!-- Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
-->
2023-11-19 18:24:05 -08:00
John Mai
16f7912e1b BUG: fix hunyuan appid type (#13496)
- **Description: fix hunyuan appid type
- **Issue:
https://github.com/langchain-ai/langchain/pull/12022#issuecomment-1815627855
2023-11-19 18:23:45 -08:00
Leonid Ganeline
43972be632 docs updating AzureML notebooks (#13492)
- Added/updated descriptions and links

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
2023-11-19 18:07:12 -08:00
Nicolò Boschi
8362bd729b AstraDB: use includeSimilarity option instead of $similarity (#13512)
- **Description:** AstraDB is going to deprecate the `$similarity`
projection property in favor of the ´includeSimilarity´ option flag. I
moved all the queries to the new format.
- **Tag maintainer:** @hemidactylus 
- **Twitter handle:** nicoloboschi
2023-11-19 17:54:35 -08:00
shumpei
7100d586ef Introduce search_kwargs for Custom Parameters in BingSearchAPIWrapper (#13525)
Added a `search_kwargs` field to BingSearchAPIWrapper in
`bing_search.py,` enabling users to include extra keyword arguments in
Bing search queries. This update, like specifying language preferences,
adds more customization to searches. The `search_kwargs` seamlessly
merge with standard parameters in `_bing_search_results` method.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
2023-11-19 17:51:02 -08:00
Nicolò Boschi
ad0c3b9479 Fix Astra integration tests (#13520)
- **Description:** Fix Astra integration tests that are failing. The
`delete` always return True as the deletion is successful if no errors
are thrown. I aligned the test to verify this behaviour
  - **Tag maintainer:** @hemidactylus 
  - **Twitter handle:** nicoloboschi

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-11-19 17:50:49 -08:00
umair mehmood
69d39e2173 fix: VLLMOpenAI -- create() got an unexpected keyword argument 'api_key' (#13517)
The issue was accuring because of `openai` update in Completions. its
not accepting `api_key` and 'api_base' args.

The fix is we check for the openai version and if ats v1 then remove
these keys from args before passing them to `Compilation.create(...)`
when sending from `VLLMOpenAI`

Fixed: #13507 

@eyu
@efriis 
@hwchase17

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
2023-11-19 17:49:55 -08:00
Manuel Alemán Cueto
6bc08266e0 Fix for oracle schema parsing stated on the issue #7928 (#13545)
- **Description:** In this pull request, we address an issue related to
assigning a schema to the SQLDatabase class when utilizing an Oracle
database. The current implementation encounters a bug where, upon
attempting to execute a query, the alter session parse is not
appropriately defined for Oracle, leading to an error,
  - **Issue:** #7928,
  - **Dependencies:** No dependencies,
  - **Tag maintainer:** @baskaryan,

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-11-19 17:35:27 -08:00
Andrew Teeter
325bdac673 feat: load all namespaces (#13549)
- **Description:** This change allows for the `MWDumpLoader` to load all
namespaces including custom by default instead of only loading the
[default
namespaces](https://www.mediawiki.org/wiki/Help:Namespaces#Localisation).
  - **Tag maintainer:** @hwchase17
2023-11-19 17:35:17 -08:00
Taranjeet Singh
47451764a7 Add embedchain retriever (#13553)
**Description:**

This commit adds embedchain retriever along with tests and docs.
Embedchain is a RAG framework to create data pipelines.

**Twitter handle:**
- [Taranjeet's twitter](https://twitter.com/taranjeetio) and
[Embedchain's twitter](https://twitter.com/embedchain)

**Reviewer**
@hwchase17

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-11-19 17:35:03 -08:00
rafly lesmana
420a17542d fix: Make YoutubeLoader support on demand language translation (#13583)
**Description:**
Enhance the functionality of YoutubeLoader to enable the translation of
available transcripts by refining the existing logic.

**Issue:**
Encountering a problem with YoutubeLoader (#13523) where the translation
feature is not functioning as expected.

Tag maintainers/contributors who might be interested:
@eyurtsev

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-11-19 17:34:48 -08:00
Leonid Ganeline
cc50e023d1 DOCS langchain decorators update (#13535)
added disclaimer

---------

Co-authored-by: Erick Friis <erickfriis@gmail.com>
2023-11-19 17:30:05 -08:00
Brace Sproul
02a13030c0 DOCS: updated langchain stack img to be svg (#13540) 2023-11-19 16:26:53 -08:00
Bagatur
78a1f4b264 bump 338, exp 42 (#13564) 2023-11-18 15:12:07 -08:00
Bagatur
790ed8be69 update multi index templates (#13569) 2023-11-18 14:42:22 -08:00
Harrison Chase
f4c0e3cc15 move streaming stdout (#13559) 2023-11-18 12:24:49 -05:00
Leonid Ganeline
43dad6cb91 BUG fixed openai_assistant namespace (#13543)
BUG: langchain.agents.openai_assistant has a reference as
`from langchain_experimental.openai_assistant.base import
OpenAIAssistantRunnable`
should be 
`from langchain.agents.openai_assistant.base import
OpenAIAssistantRunnable`

This prevents building of the API Reference docs
2023-11-17 17:15:33 -08:00
Bassem Yacoube
ff382b7b1b IMPROVEMENT Adds support for new OctoAI endpoints (#13521)
small fix to add support for new OctoAI LLM endpoints
2023-11-17 17:15:21 -08:00
Mark Silverberg
cda1b33270 Fix typo/line break in the middle of a word (#13314)
- **Description:** a simple typo/extra line break fix
  - **Dependencies:** none
2023-11-17 16:43:42 -08:00
William FH
cac849ae86 Use random seed (#13544)
For default eval llm
2023-11-17 16:33:31 -08:00
Martin Krasser
79ed66f870 EXPERIMENTAL Generic LLM wrapper to support chat model interface with configurable chat prompt format (#8295)
## Update 2023-09-08

This PR now supports further models in addition to Lllama-2 chat models.
See [this comment](#issuecomment-1668988543) for further details. The
title of this PR has been updated accordingly.

## Original PR description

This PR adds a generic `Llama2Chat` model, a wrapper for LLMs able to
serve Llama-2 chat models (like `LlamaCPP`,
`HuggingFaceTextGenInference`, ...). It implements `BaseChatModel`,
converts a list of chat messages into the [required Llama-2 chat prompt
format](https://huggingface.co/blog/llama2#how-to-prompt-llama-2) and
forwards the formatted prompt as `str` to the wrapped `LLM`. Usage
example:

```python
# uses a locally hosted Llama2 chat model
llm = HuggingFaceTextGenInference(
    inference_server_url="http://127.0.0.1:8080/",
    max_new_tokens=512,
    top_k=50,
    temperature=0.1,
    repetition_penalty=1.03,
)

# Wrap llm to support Llama2 chat prompt format.
# Resulting model is a chat model
model = Llama2Chat(llm=llm)

messages = [
    SystemMessage(content="You are a helpful assistant."),
    MessagesPlaceholder(variable_name="chat_history"),
    HumanMessagePromptTemplate.from_template("{text}"),
]

prompt = ChatPromptTemplate.from_messages(messages)
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
chain = LLMChain(llm=model, prompt=prompt, memory=memory)

# use chat model in a conversation
# ...
```

Also part of this PR are tests and a demo notebook.

- Tag maintainer: @hwchase17
- Twitter handle: `@mrt1nz`

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
2023-11-17 16:32:13 -08:00
William FH
c56faa6ef1 Add execution time (#13542)
And warn instead of raising an error, since the chain API is too
inconsistent.
2023-11-17 16:04:16 -08:00
pedro-inf-custodio
0fb5f857f9 IMPROVEMENT WebResearchRetriever error handling in urls with connection error (#13401)
- **Description:** Added a method `fetch_valid_documents` to
`WebResearchRetriever` class that will test the connection for every url
in `new_urls` and remove those that raise a `ConnectionError`.
- **Issue:** [Previous
PR](https://github.com/langchain-ai/langchain/pull/13353),
  - **Dependencies:** None,
  - **Tag maintainer:** @efriis 

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
2023-11-17 14:02:26 -08:00
Piyush Jain
d2335d0114 IMPROVEMENT Neptune graph updates (#13491)
## Description
This PR adds an option to allow unsigned requests to the Neptune
database when using the `NeptuneGraph` class.

```python
graph = NeptuneGraph(
    host='<my-cluster>',
    port=8182,
    sign=False
)
```

Also, added is an option in the `NeptuneOpenCypherQAChain` to provide
additional domain instructions to the graph query generation prompt.
This will be injected in the prompt as-is, so you should include any
provider specific tags, for example `<instructions>` or `<INSTR>`.

```python
chain = NeptuneOpenCypherQAChain.from_llm(
    llm=llm,
    graph=graph,
    extra_instructions="""
    Follow these instructions to build the query:
    1. Countries contain airports, not the other way around
    2. Use the airport code for identifying airports
    """
)
```
2023-11-17 13:49:31 -08:00
William FH
5a28dc3210 Override Keys Option (#13537)
Should be able to override the global key if you want to evaluate
different outputs in a single run
2023-11-17 13:32:43 -08:00
1421 changed files with 55996 additions and 26044 deletions

View File

@@ -7,6 +7,10 @@ on:
required: true
type: string
description: "From which folder this pipeline executes"
langchain-core-location:
required: false
type: string
description: "Relative path to the langchain core library folder"
env:
POETRY_VERSION: "1.6.1"
@@ -40,6 +44,14 @@ jobs:
shell: bash
run: poetry install --with=test_integration
- name: Install langchain core editable
working-directory: ${{ inputs.working-directory }}
if: ${{ inputs.langchain-core-location }}
env:
LANGCHAIN_CORE_LOCATION: ${{ inputs.langchain-core-location }}
run: |
poetry run pip install -e "$LANGCHAIN_CORE_LOCATION"
- name: Check integration tests compile
shell: bash
run: poetry run pytest -m compile tests/integration_tests

View File

@@ -11,6 +11,10 @@ on:
required: false
type: string
description: "Relative path to the langchain library folder"
langchain-core-location:
required: false
type: string
description: "Relative path to the langchain core library folder"
env:
POETRY_VERSION: "1.6.1"
@@ -76,7 +80,15 @@ jobs:
env:
LANGCHAIN_LOCATION: ${{ inputs.langchain-location }}
run: |
pip install -e "$LANGCHAIN_LOCATION"
poetry run pip install -e "$LANGCHAIN_LOCATION"
- name: Install langchain core editable
working-directory: ${{ inputs.working-directory }}
if: ${{ inputs.langchain-core-location }}
env:
LANGCHAIN_CORE_LOCATION: ${{ inputs.langchain-core-location }}
run: |
poetry run pip install -e "$LANGCHAIN_CORE_LOCATION"
- name: Get .mypy_cache to speed up mypy
uses: actions/cache@v3

View File

@@ -7,6 +7,14 @@ on:
required: true
type: string
description: "From which folder this pipeline executes"
langchain-location:
required: false
type: string
description: "Relative path to the langchain library folder"
langchain-core-location:
required: false
type: string
description: "Relative path to the langchain core library folder"
env:
POETRY_VERSION: "1.6.1"
@@ -40,6 +48,22 @@ jobs:
shell: bash
run: poetry install
- name: Install langchain editable
working-directory: ${{ inputs.working-directory }}
if: ${{ inputs.langchain-location }}
env:
LANGCHAIN_LOCATION: ${{ inputs.langchain-location }}
run: |
poetry run pip install -e "$LANGCHAIN_LOCATION"
- name: Install langchain core editable
working-directory: ${{ inputs.working-directory }}
if: ${{ inputs.langchain-core-location }}
env:
LANGCHAIN_CORE_LOCATION: ${{ inputs.langchain-core-location }}
run: |
poetry run pip install -e "$LANGCHAIN_CORE_LOCATION"
- name: Install the opposite major version of pydantic
# If normal tests use pydantic v1, here we'll use v2, and vice versa.
shell: bash

View File

@@ -7,6 +7,14 @@ on:
required: true
type: string
description: "From which folder this pipeline executes"
langchain-location:
required: false
type: string
description: "Relative path to the langchain library folder"
langchain-core-location:
required: false
type: string
description: "Relative path to the langchain core library folder"
env:
POETRY_VERSION: "1.6.1"
@@ -40,9 +48,26 @@ jobs:
shell: bash
run: poetry install
- name: Install langchain editable
working-directory: ${{ inputs.working-directory }}
if: ${{ inputs.langchain-location }}
env:
LANGCHAIN_LOCATION: ${{ inputs.langchain-location }}
run: |
poetry run pip install -e "$LANGCHAIN_LOCATION"
- name: Install langchain core editable
working-directory: ${{ inputs.working-directory }}
if: ${{ inputs.langchain-core-location }}
env:
LANGCHAIN_CORE_LOCATION: ${{ inputs.langchain-core-location }}
run: |
poetry run pip install -e "$LANGCHAIN_CORE_LOCATION"
- name: Run core tests
shell: bash
run: make test
run: |
make test
- name: Ensure the tests did not create any additional files
shell: bash

View File

@@ -36,6 +36,7 @@ jobs:
./.github/workflows/_lint.yml
with:
working-directory: libs/langchain
langchain-core-location: ../core
secrets: inherit
test:
@@ -43,6 +44,7 @@ jobs:
./.github/workflows/_test.yml
with:
working-directory: libs/langchain
langchain-core-location: ../core
secrets: inherit
compile-integration-tests:
@@ -50,6 +52,7 @@ jobs:
./.github/workflows/_compile_integration_test.yml
with:
working-directory: libs/langchain
langchain-core-location: ../core
secrets: inherit
pydantic-compatibility:
@@ -57,8 +60,49 @@ jobs:
./.github/workflows/_pydantic_compatibility.yml
with:
working-directory: libs/langchain
langchain-core-location: ../core
secrets: inherit
# It's possible that langchain works fine with the latest *published* langchain-core,
# but is broken with the langchain-core on `master`.
#
# We want to catch situations like that *before* releasing a new langchain-core, hence this test.
test-with-latest-langchain-core:
runs-on: ubuntu-latest
defaults:
run:
working-directory: ${{ env.WORKDIR }}
strategy:
matrix:
python-version:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
name: test with unpublished langchain-core - Python ${{ matrix.python-version }}
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
uses: "./.github/actions/poetry_setup"
with:
python-version: ${{ matrix.python-version }}
poetry-version: ${{ env.POETRY_VERSION }}
working-directory: ${{ env.WORKDIR }}
cache-key: unpublished-langchain-core
- name: Install dependencies
shell: bash
run: |
echo "Running tests with unpublished langchain, installing dependencies with poetry..."
poetry install
echo "Editably installing langchain-core outside of poetry, to avoid messing up lockfile..."
poetry run pip install -e ../core
- name: Run tests
run: make test
extended-tests:
runs-on: ubuntu-latest
defaults:
@@ -89,6 +133,11 @@ jobs:
echo "Running extended tests, installing dependencies with poetry..."
poetry install -E extended_testing
- name: Install langchain core editable
shell: bash
run: |
poetry run pip install -e ../core
- name: Run extended tests
run: make extended_tests

52
.github/workflows/langchain_core_ci.yml vendored Normal file
View File

@@ -0,0 +1,52 @@
---
name: libs/langchain core CI
on:
push:
branches: [ master ]
pull_request:
paths:
- '.github/actions/poetry_setup/action.yml'
- '.github/tools/**'
- '.github/workflows/_lint.yml'
- '.github/workflows/_test.yml'
- '.github/workflows/_pydantic_compatibility.yml'
- '.github/workflows/langchain_core_ci.yml'
- 'libs/core/**'
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
# If another push to the same PR or branch happens while this workflow is still running,
# cancel the earlier run in favor of the next run.
#
# There's no point in testing an outdated version of the code. GitHub only allows
# a limited number of job runners to be active at the same time, so it's better to cancel
# pointless jobs early so that more useful jobs can run sooner.
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
env:
POETRY_VERSION: "1.6.1"
WORKDIR: "libs/core"
jobs:
lint:
uses:
./.github/workflows/_lint.yml
with:
working-directory: libs/core
secrets: inherit
test:
uses:
./.github/workflows/_test.yml
with:
working-directory: libs/core
secrets: inherit
pydantic-compatibility:
uses:
./.github/workflows/_pydantic_compatibility.yml
with:
working-directory: libs/core
secrets: inherit

View File

@@ -0,0 +1,13 @@
---
name: libs/core Release
on:
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
jobs:
release:
uses:
./.github/workflows/_release.yml
with:
working-directory: libs/core
secrets: inherit

View File

@@ -36,6 +36,7 @@ jobs:
with:
working-directory: libs/experimental
langchain-location: ../langchain
langchain-core-location: ../core
secrets: inherit
test:
@@ -43,6 +44,8 @@ jobs:
./.github/workflows/_test.yml
with:
working-directory: libs/experimental
langchain-location: ../langchain
langchain-core-location: ../core
secrets: inherit
compile-integration-tests:
@@ -88,6 +91,7 @@ jobs:
echo "Editably installing langchain outside of poetry, to avoid messing up lockfile..."
poetry run pip install -e ../langchain
poetry run pip install -e ../core
- name: Run tests
run: make test

View File

@@ -648,7 +648,7 @@
{
"data": {
"text/plain": [
"OpenAIEmbeddings(client=<class 'openai.api_resources.embedding.Embedding'>, model='text-embedding-ada-002', deployment='text-embedding-ada-002', openai_api_version='', openai_api_base='', openai_api_type='', openai_proxy='', embedding_ctx_length=8191, openai_api_key='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', openai_organization='', allowed_special=set(), disallowed_special='all', chunk_size=1000, max_retries=6, request_timeout=None, headers=None, tiktoken_model_name=None, show_progress_bar=False, model_kwargs={})"
"OpenAIEmbeddings(client=<class 'openai.api_resources.embedding.Embedding'>, model='text-embedding-ada-002', deployment='text-embedding-ada-002', openai_api_version='', openai_api_base='', openai_api_type='', openai_proxy='', embedding_ctx_length=8191, openai_api_key='', openai_organization='', allowed_special=set(), disallowed_special='all', chunk_size=1000, max_retries=6, request_timeout=None, headers=None, tiktoken_model_name=None, show_progress_bar=False, model_kwargs={})"
]
},
"execution_count": 13,

File diff suppressed because one or more lines are too long

View File

@@ -69,8 +69,8 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.llm_bash.prompt import BashOutputParser\n",
"from langchain.prompts.prompt import PromptTemplate\n",
"from langchain_experimental.llm_bash.prompt import BashOutputParser\n",
"\n",
"_PROMPT_TEMPLATE = \"\"\"If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put \"#!/bin/bash\" in your answer. Make sure to reason step by step, using this format:\n",
"Question: \"copy the files in the directory named 'target' into a new directory at the same level as target called 'myNewDirectory'\"\n",

View File

@@ -13,8 +13,10 @@ HERE = Path(__file__).parent
PKG_DIR = ROOT_DIR / "libs" / "langchain" / "langchain"
EXP_DIR = ROOT_DIR / "libs" / "experimental" / "langchain_experimental"
CORE_DIR = ROOT_DIR / "libs" / "core" / "langchain_core"
WRITE_FILE = HERE / "api_reference.rst"
EXP_WRITE_FILE = HERE / "experimental_api_reference.rst"
CORE_WRITE_FILE = HERE / "core_api_reference.rst"
ClassKind = Literal["TypedDict", "Regular", "Pydantic", "enum"]
@@ -292,6 +294,17 @@ def _document_langchain_experimental() -> None:
def _document_langchain_core() -> None:
"""Document the langchain_core package."""
# Generate core_api_reference.rst
core_members = _load_package_modules(EXP_DIR)
core_doc = ".. _core_api_reference:\n\n" + _construct_doc(
"langchain_core", core_members
)
with open(CORE_WRITE_FILE, "w") as f:
f.write(core_doc)
def _document_langchain() -> None:
"""Document the main langchain package."""
# load top level module members
lc_members = _load_package_modules(PKG_DIR)
@@ -306,7 +319,6 @@ def _document_langchain_core() -> None:
"agents.output_parsers": agents["output_parsers"],
"agents.format_scratchpad": agents["format_scratchpad"],
"tools.render": tools["render"],
"schema.runnable": schema["runnable"],
}
)
@@ -318,8 +330,9 @@ def _document_langchain_core() -> None:
def main() -> None:
"""Generate the reference.rst file for each package."""
_document_langchain_core()
_document_langchain()
_document_langchain_experimental()
_document_langchain_core()
if __name__ == "__main__":

View File

@@ -34,6 +34,9 @@
<li class="nav-item">
<a class="sk-nav-link nav-link" href="{{ pathto('api_reference') }}">API</a>
</li>
<li class="nav-item">
<a class="sk-nav-link nav-link" href="{{ pathto('core_api_reference') }}">Core</a>
</li>
<li class="nav-item">
<a class="sk-nav-link nav-link" href="{{ pathto('experimental_api_reference') }}">Experimental</a>
</li>

View File

@@ -14,7 +14,7 @@ This framework consists of several parts.
- **[LangServe](/docs/langserve)**: A library for deploying LangChain chains as a REST API.
- **[LangSmith](/docs/langsmith)**: A developer platform that lets you debug, test, evaluate, and monitor chains built on any LLM framework and seamlessly integrates with LangChain.
![LangChain Diagram](/img/langchain_stack.png)
![LangChain Diagram](/svg/langchain_stack.svg)
Together, these products simplify the entire application lifecycle:
- **Develop**: Write your applications in LangChain/LangChain.js. Hit the ground running using Templates for reference.

View File

@@ -7,7 +7,9 @@
"source": [
"# Azure OpenAI\n",
"\n",
"This notebook goes over how to connect to an Azure hosted OpenAI endpoint. We recommend having version `openai>=1` installed."
">[Azure OpenAI Service](https://learn.microsoft.com/en-us/azure/ai-services/openai/overview) provides REST API access to OpenAI's powerful language models including the GPT-4, GPT-3.5-Turbo, and Embeddings model series. These models can be easily adapted to your specific task including but not limited to content generation, summarization, semantic search, and natural language to code translation. Users can access the service through REST APIs, Python SDK, or a web-based interface in the Azure OpenAI Studio.\n",
"\n",
"This notebook goes over how to connect to an Azure-hosted OpenAI endpoint. We recommend having version `openai>=1` installed."
]
},
{
@@ -162,7 +164,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -4,11 +4,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# AzureML Chat Online Endpoint\n",
"# Azure ML Endpoint\n",
"\n",
"[AzureML](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides Azure Foundation Models and OpenAI Models. Azure Foundation Models include various open-source models and popular Hugging Face models. Users can also import models of their liking into AzureML.\n",
">[Azure Machine Learning](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides Azure Foundation Models and OpenAI Models. `Azure Foundation Models` include various open-source models and popular Hugging Face models. Users can also import models of their liking into AzureML.\n",
">\n",
">[Azure Machine Learning Online Endpoints](https://learn.microsoft.com/en-us/azure/machine-learning/concept-endpoints). After you train machine learning models or pipelines, you need to deploy them to production so that others can use them for inference. Inference is the process of applying new input data to the machine learning model or pipeline to generate outputs. While these outputs are typically referred to as \"predictions,\" inferencing can be used to generate outputs for other machine learning tasks, such as classification and clustering. In `Azure Machine Learning`, you perform inferencing by using endpoints and deployments. `Endpoints` and `Deployments` allow you to decouple the interface of your production workload from the implementation that serves it.\n",
"\n",
"This notebook goes over how to use a chat model hosted on an `AzureML online endpoint`"
"This notebook goes over how to use a chat model hosted on an `Azure Machine Learning Endpoint`."
]
},
{
@@ -91,7 +93,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -36,7 +36,7 @@
"outputs": [],
"source": [
"chat = ChatHunyuan(\n",
" hunyuan_app_id=\"YOUR_APP_ID\",\n",
" hunyuan_app_id=111111111,\n",
" hunyuan_secret_id=\"YOUR_SECRET_ID\",\n",
" hunyuan_secret_key=\"YOUR_SECRET_KEY\",\n",
")"

View File

@@ -0,0 +1,729 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "90a1faf2",
"metadata": {},
"source": [
"# Llama-2 Chat\n",
"\n",
"This notebook shows how to augment Llama-2 `LLM`s with the `Llama2Chat` wrapper to support the [Llama-2 chat prompt format](https://huggingface.co/blog/llama2#how-to-prompt-llama-2). Several `LLM` implementations in LangChain can be used as interface to Llama-2 chat models. These include [HuggingFaceTextGenInference](https://python.langchain.com/docs/integrations/llms/huggingface_textgen_inference), [LlamaCpp](https://python.langchain.com/docs/use_cases/question_answering/how_to/local_retrieval_qa), [GPT4All](https://python.langchain.com/docs/integrations/llms/gpt4all), ..., to mention a few examples. \n",
"\n",
"`Llama2Chat` is a generic wrapper that implements `BaseChatModel` and can therefore be used in applications as [chat model](https://python.langchain.com/docs/modules/model_io/models/chat/). `Llama2Chat` converts a list of [chat messages](https://python.langchain.com/docs/modules/model_io/models/chat/#messages) into the [required chat prompt format](https://huggingface.co/blog/llama2#how-to-prompt-llama-2) and forwards the formatted prompt as `str` to the wrapped `LLM`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "36c03540",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import LLMChain\n",
"from langchain.memory import ConversationBufferMemory\n",
"from langchain_experimental.chat_models import Llama2Chat"
]
},
{
"cell_type": "markdown",
"id": "5c76910f",
"metadata": {},
"source": [
"For the chat application examples below, we'll use the following chat `prompt_template`:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9bbfaf3a",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts.chat import (\n",
" ChatPromptTemplate,\n",
" HumanMessagePromptTemplate,\n",
" MessagesPlaceholder,\n",
")\n",
"from langchain.schema import SystemMessage\n",
"\n",
"template_messages = [\n",
" SystemMessage(content=\"You are a helpful assistant.\"),\n",
" MessagesPlaceholder(variable_name=\"chat_history\"),\n",
" HumanMessagePromptTemplate.from_template(\"{text}\"),\n",
"]\n",
"prompt_template = ChatPromptTemplate.from_messages(template_messages)"
]
},
{
"cell_type": "markdown",
"id": "2f3343b7",
"metadata": {},
"source": [
"## Chat with Llama-2 via `HuggingFaceTextGenInference` LLM"
]
},
{
"cell_type": "markdown",
"id": "2ff99380",
"metadata": {},
"source": [
"A [HuggingFaceTextGenInference](https://python.langchain.com/docs/integrations/llms/huggingface_textgen_inference) LLM encapsulates access to a [text-generation-inference](https://github.com/huggingface/text-generation-inference) server. In the following example, the inference server serves a [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) model. It can be started locally with:\n",
"\n",
"```bash\n",
"docker run \\\n",
" --rm \\\n",
" --gpus all \\\n",
" --ipc=host \\\n",
" -p 8080:80 \\\n",
" -v ~/.cache/huggingface/hub:/data \\\n",
" -e HF_API_TOKEN=${HF_API_TOKEN} \\\n",
" ghcr.io/huggingface/text-generation-inference:0.9 \\\n",
" --hostname 0.0.0.0 \\\n",
" --model-id meta-llama/Llama-2-13b-chat-hf \\\n",
" --quantize bitsandbytes \\\n",
" --num-shard 4\n",
"```\n",
"\n",
"This works on a machine with 4 x RTX 3080ti cards, for example. Adjust the `--num_shard` value to the number of GPUs available. The `HF_API_TOKEN` environment variable holds the Hugging Face API token."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "238095fd",
"metadata": {},
"outputs": [],
"source": [
"# !pip3 install text-generation"
]
},
{
"cell_type": "markdown",
"id": "79c4ace9",
"metadata": {},
"source": [
"Create a `HuggingFaceTextGenInference` instance that connects to the local inference server and wrap it into `Llama2Chat`."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "7a9f6de2",
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import HuggingFaceTextGenInference\n",
"\n",
"llm = HuggingFaceTextGenInference(\n",
" inference_server_url=\"http://127.0.0.1:8080/\",\n",
" max_new_tokens=512,\n",
" top_k=50,\n",
" temperature=0.1,\n",
" repetition_penalty=1.03,\n",
")\n",
"\n",
"model = Llama2Chat(llm=llm)"
]
},
{
"cell_type": "markdown",
"id": "4f646a2b",
"metadata": {},
"source": [
"Then you are ready to use the chat `model` together with `prompt_template` and conversation `memory` in an `LLMChain`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "54b5d1d1",
"metadata": {},
"outputs": [],
"source": [
"memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)\n",
"chain = LLMChain(llm=model, prompt=prompt_template, memory=memory)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e6717947",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Sure, I'd be happy to help! Here are a few popular locations to consider visiting in Vienna:\n",
"\n",
"1. Schönbrunn Palace\n",
"2. St. Stephen's Cathedral\n",
"3. Hofburg Palace\n",
"4. Belvedere Palace\n",
"5. Prater Park\n",
"6. Vienna State Opera\n",
"7. Albertina Museum\n",
"8. Museum of Natural History\n",
"9. Kunsthistorisches Museum\n",
"10. Ringstrasse\n"
]
}
],
"source": [
"print(\n",
" chain.run(\n",
" text=\"What can I see in Vienna? Propose a few locations. Names only, no details.\"\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "17bf10d5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Certainly! St. Stephen's Cathedral (Stephansdom) is one of the most recognizable landmarks in Vienna and a must-see attraction for visitors. This stunning Gothic cathedral is located in the heart of the city and is known for its intricate stone carvings, colorful stained glass windows, and impressive dome.\n",
"\n",
"The cathedral was built in the 12th century and has been the site of many important events throughout history, including the coronation of Holy Roman emperors and the funeral of Mozart. Today, it is still an active place of worship and offers guided tours, concerts, and special events. Visitors can climb up the south tower for panoramic views of the city or attend a service to experience the beautiful music and chanting.\n"
]
}
],
"source": [
"print(chain.run(text=\"Tell me more about #2.\"))"
]
},
{
"cell_type": "markdown",
"id": "2a297e09",
"metadata": {},
"source": [
"## Chat with Llama-2 via `LlamaCPP` LLM"
]
},
{
"cell_type": "markdown",
"id": "52c1a0b9",
"metadata": {},
"source": [
"For using a Llama-2 chat model with a [LlamaCPP](https://python.langchain.com/docs/integrations/llms/llamacpp) `LMM`, install the `llama-cpp-python` library using [these installation instructions](https://python.langchain.com/docs/integrations/llms/llamacpp#installation). The following example uses a quantized [llama-2-7b-chat.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_0.gguf) model stored locally at `~/Models/llama-2-7b-chat.Q4_0.gguf`. \n",
"\n",
"After creating a `LlamaCpp` instance, the `llm` is again wrapped into `Llama2Chat`"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "07c0d04e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"llama_model_loader: loaded meta data with 19 key-value pairs and 291 tensors from /home/martin/Models/llama-2-7b-chat.Q4_0.gguf (version GGUF V2)\n",
"llama_model_loader: - tensor 0: token_embd.weight q4_0 [ 4096, 32000, 1, 1 ]\n",
"llama_model_loader: - tensor 1: blk.0.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 2: blk.0.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 3: blk.0.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 4: blk.0.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 5: blk.0.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 6: blk.0.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 7: blk.0.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 8: blk.0.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 9: blk.0.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 10: blk.1.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 11: blk.1.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 12: blk.1.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 13: blk.1.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 14: blk.1.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 15: blk.1.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 16: blk.1.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 17: blk.1.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 18: blk.1.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 19: blk.10.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 20: blk.10.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 21: blk.10.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 22: blk.10.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 23: blk.10.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 24: blk.10.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 25: blk.10.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 26: blk.10.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 27: blk.10.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 28: blk.11.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 29: blk.11.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 30: blk.11.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 31: blk.11.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 32: blk.11.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 33: blk.11.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 34: blk.11.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 35: blk.11.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 36: blk.11.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 37: blk.12.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 38: blk.12.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 39: blk.12.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 40: blk.12.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 41: blk.12.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 42: blk.12.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 43: blk.12.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 44: blk.12.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 45: blk.12.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 46: blk.13.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 47: blk.13.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 48: blk.13.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 49: blk.13.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 50: blk.13.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 51: blk.13.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 52: blk.13.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 53: blk.13.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 54: blk.13.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 55: blk.14.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 56: blk.14.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 57: blk.14.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 58: blk.14.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 59: blk.14.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 60: blk.14.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 61: blk.14.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 62: blk.14.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 63: blk.14.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 64: blk.15.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 65: blk.15.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 66: blk.15.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 67: blk.15.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 68: blk.15.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 69: blk.15.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 70: blk.15.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 71: blk.15.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 72: blk.15.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 73: blk.16.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 74: blk.16.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 75: blk.16.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 76: blk.16.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 77: blk.16.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 78: blk.16.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 79: blk.16.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 80: blk.16.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 81: blk.16.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 82: blk.17.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 83: blk.17.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 84: blk.17.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 85: blk.17.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 86: blk.17.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 87: blk.17.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 88: blk.17.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 89: blk.17.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 90: blk.17.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 91: blk.18.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 92: blk.18.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 93: blk.18.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 94: blk.18.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 95: blk.18.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 96: blk.18.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 97: blk.18.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 98: blk.18.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 99: blk.18.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 100: blk.19.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 101: blk.19.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 102: blk.19.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 103: blk.19.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 104: blk.19.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 105: blk.19.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 106: blk.19.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 107: blk.19.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 108: blk.19.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 109: blk.2.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 110: blk.2.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 111: blk.2.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 112: blk.2.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 113: blk.2.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 114: blk.2.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 115: blk.2.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 116: blk.2.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 117: blk.2.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 118: blk.20.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 119: blk.20.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 120: blk.20.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 121: blk.20.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 122: blk.20.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 123: blk.20.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 124: blk.20.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 125: blk.20.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 126: blk.20.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 127: blk.21.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 128: blk.21.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 129: blk.21.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 130: blk.21.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 131: blk.21.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 132: blk.21.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 133: blk.21.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 134: blk.21.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 135: blk.21.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 136: blk.22.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 137: blk.22.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 138: blk.22.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 139: blk.22.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 140: blk.22.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 141: blk.22.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 142: blk.22.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 143: blk.22.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 144: blk.22.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 145: blk.23.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 146: blk.23.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 147: blk.23.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 148: blk.23.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 149: blk.23.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 150: blk.23.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 151: blk.23.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 152: blk.23.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 153: blk.23.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 154: blk.3.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 155: blk.3.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 156: blk.3.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 157: blk.3.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 158: blk.3.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 159: blk.3.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 160: blk.3.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 161: blk.3.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 162: blk.3.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 163: blk.4.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 164: blk.4.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 165: blk.4.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 166: blk.4.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 167: blk.4.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 168: blk.4.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 169: blk.4.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 170: blk.4.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 171: blk.4.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 172: blk.5.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 173: blk.5.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 174: blk.5.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 175: blk.5.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 176: blk.5.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 177: blk.5.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 178: blk.5.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 179: blk.5.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 180: blk.5.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 181: blk.6.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 182: blk.6.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 183: blk.6.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 184: blk.6.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 185: blk.6.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 186: blk.6.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 187: blk.6.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 188: blk.6.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 189: blk.6.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 190: blk.7.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 191: blk.7.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 192: blk.7.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 193: blk.7.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 194: blk.7.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 195: blk.7.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 196: blk.7.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 197: blk.7.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 198: blk.7.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 199: blk.8.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 200: blk.8.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 201: blk.8.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 202: blk.8.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 203: blk.8.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 204: blk.8.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 205: blk.8.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 206: blk.8.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 207: blk.8.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 208: blk.9.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 209: blk.9.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 210: blk.9.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 211: blk.9.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 212: blk.9.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 213: blk.9.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 214: blk.9.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 215: blk.9.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 216: blk.9.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 217: output.weight q6_K [ 4096, 32000, 1, 1 ]\n",
"llama_model_loader: - tensor 218: blk.24.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 219: blk.24.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 220: blk.24.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 221: blk.24.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 222: blk.24.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 223: blk.24.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 224: blk.24.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 225: blk.24.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 226: blk.24.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 227: blk.25.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 228: blk.25.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 229: blk.25.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 230: blk.25.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 231: blk.25.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 232: blk.25.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 233: blk.25.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 234: blk.25.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 235: blk.25.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 236: blk.26.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 237: blk.26.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 238: blk.26.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 239: blk.26.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 240: blk.26.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 241: blk.26.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 242: blk.26.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 243: blk.26.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 244: blk.26.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 245: blk.27.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 246: blk.27.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 247: blk.27.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 248: blk.27.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 249: blk.27.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 250: blk.27.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 251: blk.27.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 252: blk.27.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 253: blk.27.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 254: blk.28.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 255: blk.28.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 256: blk.28.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 257: blk.28.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 258: blk.28.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 259: blk.28.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 260: blk.28.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 261: blk.28.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 262: blk.28.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 263: blk.29.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 264: blk.29.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 265: blk.29.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 266: blk.29.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 267: blk.29.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 268: blk.29.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 269: blk.29.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 270: blk.29.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 271: blk.29.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 272: blk.30.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 273: blk.30.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 274: blk.30.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 275: blk.30.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 276: blk.30.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 277: blk.30.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 278: blk.30.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 279: blk.30.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 280: blk.30.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 281: blk.31.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 282: blk.31.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 283: blk.31.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 284: blk.31.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
"llama_model_loader: - tensor 285: blk.31.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 286: blk.31.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 287: blk.31.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 288: blk.31.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 289: blk.31.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
"llama_model_loader: - tensor 290: output_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - kv 0: general.architecture str \n",
"llama_model_loader: - kv 1: general.name str \n",
"llama_model_loader: - kv 2: llama.context_length u32 \n",
"llama_model_loader: - kv 3: llama.embedding_length u32 \n",
"llama_model_loader: - kv 4: llama.block_count u32 \n",
"llama_model_loader: - kv 5: llama.feed_forward_length u32 \n",
"llama_model_loader: - kv 6: llama.rope.dimension_count u32 \n",
"llama_model_loader: - kv 7: llama.attention.head_count u32 \n",
"llama_model_loader: - kv 8: llama.attention.head_count_kv u32 \n",
"llama_model_loader: - kv 9: llama.attention.layer_norm_rms_epsilon f32 \n",
"llama_model_loader: - kv 10: general.file_type u32 \n",
"llama_model_loader: - kv 11: tokenizer.ggml.model str \n",
"llama_model_loader: - kv 12: tokenizer.ggml.tokens arr \n",
"llama_model_loader: - kv 13: tokenizer.ggml.scores arr \n",
"llama_model_loader: - kv 14: tokenizer.ggml.token_type arr \n",
"llama_model_loader: - kv 15: tokenizer.ggml.bos_token_id u32 \n",
"llama_model_loader: - kv 16: tokenizer.ggml.eos_token_id u32 \n",
"llama_model_loader: - kv 17: tokenizer.ggml.unknown_token_id u32 \n",
"llama_model_loader: - kv 18: general.quantization_version u32 \n",
"llama_model_loader: - type f32: 65 tensors\n",
"llama_model_loader: - type q4_0: 225 tensors\n",
"llama_model_loader: - type q6_K: 1 tensors\n",
"llm_load_vocab: special tokens definition check successful ( 259/32000 ).\n",
"llm_load_print_meta: format = GGUF V2\n",
"llm_load_print_meta: arch = llama\n",
"llm_load_print_meta: vocab type = SPM\n",
"llm_load_print_meta: n_vocab = 32000\n",
"llm_load_print_meta: n_merges = 0\n",
"llm_load_print_meta: n_ctx_train = 4096\n",
"llm_load_print_meta: n_embd = 4096\n",
"llm_load_print_meta: n_head = 32\n",
"llm_load_print_meta: n_head_kv = 32\n",
"llm_load_print_meta: n_layer = 32\n",
"llm_load_print_meta: n_rot = 128\n",
"llm_load_print_meta: n_gqa = 1\n",
"llm_load_print_meta: f_norm_eps = 0.0e+00\n",
"llm_load_print_meta: f_norm_rms_eps = 1.0e-06\n",
"llm_load_print_meta: f_clamp_kqv = 0.0e+00\n",
"llm_load_print_meta: f_max_alibi_bias = 0.0e+00\n",
"llm_load_print_meta: n_ff = 11008\n",
"llm_load_print_meta: rope scaling = linear\n",
"llm_load_print_meta: freq_base_train = 10000.0\n",
"llm_load_print_meta: freq_scale_train = 1\n",
"llm_load_print_meta: n_yarn_orig_ctx = 4096\n",
"llm_load_print_meta: rope_finetuned = unknown\n",
"llm_load_print_meta: model type = 7B\n",
"llm_load_print_meta: model ftype = mostly Q4_0\n",
"llm_load_print_meta: model params = 6.74 B\n",
"llm_load_print_meta: model size = 3.56 GiB (4.54 BPW) \n",
"llm_load_print_meta: general.name = LLaMA v2\n",
"llm_load_print_meta: BOS token = 1 '<s>'\n",
"llm_load_print_meta: EOS token = 2 '</s>'\n",
"llm_load_print_meta: UNK token = 0 '<unk>'\n",
"llm_load_print_meta: LF token = 13 '<0x0A>'\n",
"llm_load_tensors: ggml ctx size = 0.11 MB\n",
"llm_load_tensors: mem required = 3647.97 MB\n",
"..................................................................................................\n",
"llama_new_context_with_model: n_ctx = 512\n",
"llama_new_context_with_model: freq_base = 10000.0\n",
"llama_new_context_with_model: freq_scale = 1\n",
"llama_new_context_with_model: kv self size = 256.00 MB\n",
"llama_build_graph: non-view tensors processed: 740/740\n",
"llama_new_context_with_model: compute buffer total size = 2.66 MB\n",
"AVX = 1 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 0 | AVX512_VNNI = 1 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | \n"
]
}
],
"source": [
"from os.path import expanduser\n",
"\n",
"from langchain.llms import LlamaCpp\n",
"\n",
"model_path = expanduser(\"~/Models/llama-2-7b-chat.Q4_0.gguf\")\n",
"\n",
"llm = LlamaCpp(\n",
" model_path=model_path,\n",
" streaming=False,\n",
")\n",
"model = Llama2Chat(llm=llm)"
]
},
{
"cell_type": "markdown",
"id": "50498d96",
"metadata": {},
"source": [
"and used in the same way as in the previous example."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "90782b96",
"metadata": {},
"outputs": [],
"source": [
"memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)\n",
"chain = LLMChain(llm=model, prompt=prompt_template, memory=memory)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "2160b26d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Of course! Vienna is a beautiful city with a rich history and culture. Here are some of the top tourist attractions you might want to consider visiting:\n",
"1. Schönbrunn Palace\n",
"2. St. Stephen's Cathedral\n",
"3. Hofburg Palace\n",
"4. Belvedere Palace\n",
"5. Prater Park\n",
"6. MuseumsQuartier\n",
"7. Ringstrasse\n",
"8. Vienna State Opera\n",
"9. Kunsthistorisches Museum\n",
"10. Imperial Palace\n",
"\n",
"These are just a few of the many amazing places to see in Vienna. Each one has its own unique history and charm, so I hope you enjoy exploring this beautiful city!\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"llama_print_timings: load time = 250.46 ms\n",
"llama_print_timings: sample time = 56.40 ms / 144 runs ( 0.39 ms per token, 2553.37 tokens per second)\n",
"llama_print_timings: prompt eval time = 1444.25 ms / 47 tokens ( 30.73 ms per token, 32.54 tokens per second)\n",
"llama_print_timings: eval time = 8832.02 ms / 143 runs ( 61.76 ms per token, 16.19 tokens per second)\n",
"llama_print_timings: total time = 10645.94 ms\n"
]
}
],
"source": [
"print(\n",
" chain.run(\n",
" text=\"What can I see in Vienna? Propose a few locations. Names only, no details.\"\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d9ce06e3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Llama.generate: prefix-match hit\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Of course! St. Stephen's Cathedral (also known as Stephansdom) is a stunning Gothic-style cathedral located in the heart of Vienna, Austria. It is one of the most recognizable landmarks in the city and is considered a symbol of Vienna.\n",
"Here are some interesting facts about St. Stephen's Cathedral:\n",
"1. History: The construction of St. Stephen's Cathedral began in the 12th century on the site of a former Romanesque church, and it took over 600 years to complete. The cathedral has been renovated and expanded several times throughout its history, with the most significant renovation taking place in the 19th century.\n",
"2. Architecture: St. Stephen's Cathedral is built in the Gothic style, characterized by its tall spires, pointed arches, and intricate stone carvings. The cathedral features a mix of Romanesque, Gothic, and Baroque elements, making it a unique blend of styles.\n",
"3. Design: The cathedral's design is based on the plan of a cross with a long nave and two shorter arms extending from it. The main altar is\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"llama_print_timings: load time = 250.46 ms\n",
"llama_print_timings: sample time = 100.60 ms / 256 runs ( 0.39 ms per token, 2544.73 tokens per second)\n",
"llama_print_timings: prompt eval time = 5128.71 ms / 160 tokens ( 32.05 ms per token, 31.20 tokens per second)\n",
"llama_print_timings: eval time = 16193.02 ms / 255 runs ( 63.50 ms per token, 15.75 tokens per second)\n",
"llama_print_timings: total time = 21988.57 ms\n"
]
}
],
"source": [
"print(chain.run(text=\"Tell me more about #2.\"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -99,7 +99,7 @@
"\n",
"Language param : It's a list of language codes in a descending priority, `en` by default.\n",
"\n",
"translation param : It's a translate preference when the youtube does'nt have your select language, `en` by default."
"translation param : It's a translate preference, you can translate available transcript to your preferred language."
]
},
{

View File

@@ -5,7 +5,7 @@
"id": "91c6a7ef",
"metadata": {},
"source": [
"# MongodDB\n",
"# MongoDB\n",
"\n",
">`MongoDB` is a source-available cross-platform document-oriented database program. Classified as a NoSQL database program, `MongoDB` uses `JSON`-like documents with optional schemas.\n",
">\n",

View File

@@ -1,10 +1,13 @@
# LangChain Decorators ✨
lanchchain decorators is a layer on the top of LangChain that provides syntactic sugar 🍭 for writing custom langchain prompts and chains
For Feedback, Issues, Contributions - please raise an issue here:
[ju-bezdek/langchain-decorators](https://github.com/ju-bezdek/langchain-decorators)
~~~
Disclaimer: `LangChain decorators` is not created by the LangChain team and is not supported by it.
~~~
>`LangChain decorators` is a layer on the top of LangChain that provides syntactic sugar 🍭 for writing custom langchain prompts and chains
>
>For Feedback, Issues, Contributions - please raise an issue here:
>[ju-bezdek/langchain-decorators](https://github.com/ju-bezdek/langchain-decorators)
Main principles and benefits:
@@ -17,7 +20,6 @@ Main principles and benefits:
- easily share parameters between the prompts by binding them to one class
Here is a simple example of a code written with **LangChain Decorators ✨**
``` python

View File

@@ -0,0 +1,255 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "2f0f85ac-9c49-4111-a320-e53bccc99b13",
"metadata": {},
"source": [
"# Embedchain\n",
"\n",
"Embedchain is a RAG framework to create data pipelines. It loads, indexes, retrieves and syncs all the data.\n",
"\n",
"It is available as an [open source package](https://github.com/embedchain/embedchain) and as a [hosted platform solution](https://app.embedchain.ai/).\n",
"\n",
"This notebook shows how to use a retriever that uses Embedchain."
]
},
{
"cell_type": "markdown",
"id": "e48de822-307b-4284-96e7-c91f11ce005b",
"metadata": {},
"source": [
"# Installation\n",
"\n",
"First you will need to install the [`embedchain` package](https://pypi.org/project/embedchain/). \n",
"\n",
"You can install the package by running "
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c690a78c-5999-4072-b4e1-2712ff73f950",
"metadata": {},
"outputs": [],
"source": [
"#!pip install --upgrade embedchain"
]
},
{
"cell_type": "markdown",
"id": "bc89ba12-6ebd-4cd6-8c85-7410531579ff",
"metadata": {},
"source": [
"# Create New Retriever\n",
"\n",
"`EmbedchainRetriever` has a static `.create()` factory method that takes the following arguments:\n",
"\n",
"* `yaml_path: string` optional -- Path to the YAML configuration file. If not provided, a default configuration is used. You can browse the [docs](https://docs.embedchain.ai/) to explore various customization options."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8e639bd4-2e60-487b-b7aa-f7e6b921b069",
"metadata": {},
"outputs": [
{
"name": "stdin",
"output_type": "stream",
"text": [
" ········\n"
]
}
],
"source": [
"# Setup API Key\n",
"\n",
"import os\n",
"from getpass import getpass\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "223fbc76-91ad-4504-87e9-980fb0e027fc",
"metadata": {},
"outputs": [],
"source": [
"from langchain.retrievers import EmbedchainRetriever\n",
"\n",
"# create retriever with default options\n",
"retriever = EmbedchainRetriever.create()\n",
"\n",
"# or if you want to customize, pass the yaml config path\n",
"# retriever = EmbedchainRetiever.create(yaml_path=\"config.yaml\")"
]
},
{
"cell_type": "markdown",
"id": "536f3a1d-3491-45b5-9f25-869bd6fb6d6a",
"metadata": {},
"source": [
"# Add Data\n",
"\n",
"In embedchain, you can as many supported data types as possible. You can browse our [docs](https://docs.embedchain.ai/) to see the data types supported.\n",
"\n",
"Embedchain automatically deduces the types of the data. So you can add a string, URL or local file path."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "31262be3-7d0d-42e8-9253-052160576dc7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Inserting batches in chromadb: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00, 2.22s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Successfully saved https://en.wikipedia.org/wiki/Elon_Musk (DataType.WEB_PAGE). New chunks count: 378\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Inserting batches in chromadb: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.17s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Successfully saved https://www.forbes.com/profile/elon-musk (DataType.WEB_PAGE). New chunks count: 13\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Inserting batches in chromadb: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00, 2.25s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Successfully saved https://www.youtube.com/watch?v=RcYjXbSJBN8 (DataType.YOUTUBE_VIDEO). New chunks count: 53\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/plain": [
"['1eab8dd1ffa92906f7fc839862871ca5',\n",
" '8cf46026cabf9b05394a2658bd1fe890',\n",
" 'da3227cdbcedb018e05c47b774d625f6']"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retriever.add_texts(\n",
" [\n",
" \"https://en.wikipedia.org/wiki/Elon_Musk\",\n",
" \"https://www.forbes.com/profile/elon-musk\",\n",
" \"https://www.youtube.com/watch?v=RcYjXbSJBN8\",\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"id": "e1f34a62-7f8e-4c03-8e10-c317ed3296aa",
"metadata": {},
"source": [
"# Use Retriever\n",
"\n",
"You can now use the retrieve to find relevant documents given a query"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6106baf9-652a-4a94-b2d7-d6a5d2917975",
"metadata": {},
"outputs": [],
"source": [
"result = retriever.get_relevant_documents(\n",
" \"How many companies does Elon Musk run and name those?\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1deae5d0-e0fa-431d-b164-e9680ef3e69b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='Views Filmography Companies Zip2 X.com PayPal SpaceX Starlink Tesla, Inc. Energycriticismlitigation OpenAI Neuralink The Boring Company Thud X Corp. Twitteracquisitiontenure as CEO xAI In popular culture Elon Musk (Isaacson) Elon Musk (Vance) Ludicrous Power Play \"Members Only\" \"The Platonic Permutation\" \"The Musk Who Fell to Earth\" \"One Crew over the Crewcoo\\'s Morty\" Elon Musk\\'s Crash Course Related Boring Test Tunnel Hyperloop Musk family Musk vs. Zuckerberg SolarCity Tesla Roadster in space', metadata={'source': 'https://en.wikipedia.org/wiki/Elon_Musk', 'document_id': 'c33c05d0-5028-498b-b5e3-c43a4f9e8bf8--3342161a0fbc19e91f6bf387204aa30fbb2cea05abc81882502476bde37b9392'}),\n",
" Document(page_content='Elon Musk PROFILEElon MuskCEO, Tesla$241.2B$508M (0.21%)Real Time Net Worthas of 11/18/23Reflects change since 5 pm ET of prior trading day. 1 in the world todayPhoto by Martin Schoeller for ForbesAbout Elon MuskElon Musk cofounded six companies, including electric car maker Tesla, rocket producer SpaceX and tunneling startup Boring Company.He owns about 21% of Tesla between stock and options, but has pledged more than half his shares as collateral for personal loans of up to $3.5', metadata={'source': 'https://www.forbes.com/profile/elon-musk', 'document_id': 'c33c05d0-5028-498b-b5e3-c43a4f9e8bf8--3c8573134c575fafc025e9211413723e1f7a725b5936e8ee297fb7fb63bdd01a'}),\n",
" Document(page_content='to form PayPal. In October 2002, eBay acquired PayPal for $1.5 billion, and that same year, with $100 million of the money he made, Musk founded SpaceX, a spaceflight services company. In 2004, he became an early investor in electric vehicle manufacturer Tesla Motors, Inc. (now Tesla, Inc.). He became its chairman and product architect, assuming the position of CEO in 2008. In 2006, Musk helped create SolarCity, a solar-energy company that was acquired by Tesla in 2016 and became Tesla Energy.', metadata={'source': 'https://en.wikipedia.org/wiki/Elon_Musk', 'document_id': 'c33c05d0-5028-498b-b5e3-c43a4f9e8bf8--3342161a0fbc19e91f6bf387204aa30fbb2cea05abc81882502476bde37b9392'})]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b3f26c2b-048d-4588-90a0-50f5c9c35837",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -12,7 +12,8 @@
"- AzureCogsImageAnalysisTool: used to extract caption, objects, tags, and text from images. (Note: this tool is not available on Mac OS yet, due to the dependency on `azure-ai-vision` package, which is only supported on Windows and Linux currently.)\n",
"- AzureCogsFormRecognizerTool: used to extract text, tables, and key-value pairs from documents.\n",
"- AzureCogsSpeech2TextTool: used to transcribe speech to text.\n",
"- AzureCogsText2SpeechTool: used to synthesize text to speech."
"- AzureCogsText2SpeechTool: used to synthesize text to speech.\n",
"- AzureCogsTextAnalyticsHealthTool: used to extract healthcare entities."
]
},
{
@@ -32,6 +33,7 @@
"source": [
"# !pip install --upgrade azure-ai-formrecognizer > /dev/null\n",
"# !pip install --upgrade azure-cognitiveservices-speech > /dev/null\n",
"# !pip install --upgrade azure-ai-textanalytics > /dev/null\n",
"\n",
"# For Windows/Linux\n",
"# !pip install --upgrade azure-ai-vision > /dev/null"
@@ -60,7 +62,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
@@ -101,7 +103,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@@ -111,7 +113,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
@@ -240,6 +242,65 @@
"display.display(audio)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mAction:\n",
"```\n",
"{\n",
" \"action\": \"azure_cognitive_services_text_analyics_health\",\n",
" \"action_input\": \"The patient is a 54-year-old gentleman with a history of progressive angina over the past several months. The patient had a cardiac catheterization in July of this year revealing total occlusion of the RCA and 50% left main disease, with a strong family history of coronary artery disease with a brother dying at the age of 52 from a myocardial infarction and another brother who is status post coronary artery bypass grafting. The patient had a stress echocardiogram done on July, 2001, which showed no wall motion abnormalities, but this was a difficult study due to body habitus. The patient went for six minutes with minimal ST depressions in the anterior lateral leads, thought due to fatigue and wrist pain, his anginal equivalent. Due to the patient's increased symptoms and family history and history left main disease with total occasional of his RCA was referred for revascularization with open heart surgery.\"\n",
"}\n",
"```\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mThe text conatins the following healthcare entities: 54-year-old is a healthcare entity of type Age, gentleman is a healthcare entity of type Gender, progressive angina is a healthcare entity of type Diagnosis, past several months is a healthcare entity of type Time, cardiac catheterization is a healthcare entity of type ExaminationName, July of this year is a healthcare entity of type Time, total is a healthcare entity of type ConditionQualifier, occlusion is a healthcare entity of type SymptomOrSign, RCA is a healthcare entity of type BodyStructure, 50 is a healthcare entity of type MeasurementValue, % is a healthcare entity of type MeasurementUnit, left main is a healthcare entity of type BodyStructure, disease is a healthcare entity of type Diagnosis, family is a healthcare entity of type FamilyRelation, coronary artery disease is a healthcare entity of type Diagnosis, brother is a healthcare entity of type FamilyRelation, dying is a healthcare entity of type Diagnosis, 52 is a healthcare entity of type Age, myocardial infarction is a healthcare entity of type Diagnosis, brother is a healthcare entity of type FamilyRelation, coronary artery bypass grafting is a healthcare entity of type TreatmentName, stress echocardiogram is a healthcare entity of type ExaminationName, July, 2001 is a healthcare entity of type Time, wall motion abnormalities is a healthcare entity of type SymptomOrSign, body habitus is a healthcare entity of type SymptomOrSign, six minutes is a healthcare entity of type Time, minimal is a healthcare entity of type ConditionQualifier, ST depressions in the anterior lateral leads is a healthcare entity of type SymptomOrSign, fatigue is a healthcare entity of type SymptomOrSign, wrist pain is a healthcare entity of type SymptomOrSign, anginal equivalent is a healthcare entity of type SymptomOrSign, increased is a healthcare entity of type Course, symptoms is a healthcare entity of type SymptomOrSign, family is a healthcare entity of type FamilyRelation, left is a healthcare entity of type Direction, main is a healthcare entity of type BodyStructure, disease is a healthcare entity of type Diagnosis, occasional is a healthcare entity of type Course, RCA is a healthcare entity of type BodyStructure, revascularization is a healthcare entity of type TreatmentName, open heart surgery is a healthcare entity of type TreatmentName\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I know what to respond\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"Final Answer\",\n",
" \"action_input\": \"The text contains the following diagnoses: progressive angina, coronary artery disease, myocardial infarction, and coronary artery bypass grafting.\"\n",
"}\n",
"```\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'The text contains the following diagnoses: progressive angina, coronary artery disease, myocardial infarction, and coronary artery bypass grafting.'"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\n",
" \"\"\"The patient is a 54-year-old gentleman with a history of progressive angina over the past several months.\n",
"The patient had a cardiac catheterization in July of this year revealing total occlusion of the RCA and 50% left main disease ,\n",
"with a strong family history of coronary artery disease with a brother dying at the age of 52 from a myocardial infarction and\n",
"another brother who is status post coronary artery bypass grafting. The patient had a stress echocardiogram done on July , 2001 ,\n",
"which showed no wall motion abnormalities , but this was a difficult study due to body habitus. The patient went for six minutes with\n",
"minimal ST depressions in the anterior lateral leads , thought due to fatigue and wrist pain , his anginal equivalent. Due to the patient's\n",
"increased symptoms and family history and history left main disease with total occasional of his RCA was referred for revascularization with open heart surgery.\n",
"\n",
"List all the diagnoses.\n",
"\"\"\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -264,7 +325,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.8.10"
}
},
"nbformat": 4,

View File

@@ -57,11 +57,14 @@
"1. **Load**: First we need to load our data. We'll use [DocumentLoaders](/docs/modules/data_connection/document_loaders/) for this.\n",
"2. **Split**: [Text splitters](/docs/modules/data_connection/document_transformers/) break large `Documents` into smaller chunks. This is useful both for indexing data and for passing it in to a model, since large chunks are harder to search over and won't in a model's finite context window.\n",
"3. **Store**: We need somewhere to store and index our splits, so that they can later be searched over. This is often done using a [VectorStore](/docs/modules/data_connection/vectorstores/) and [Embeddings](/docs/modules/data_connection/text_embedding/) model.\n",
"\n",
"![index_diagram](/img/rag_indexing.png)\n",
"\n",
"#### Retrieval and generation\n",
"4. **Retrieve**: Given a user input, relevant splits are retrieved from storage using a [Retriever](/docs/modules/data_connection/retrievers/).\n",
"5. **Generate**: A [ChatModel](/docs/modules/model_io/chat_models) / [LLM](/docs/modules/model_io/llms/) produces an answer using a prompt that includes the question and the retrieved data\n",
"\n",
"![flow.jpeg](/img/qa_flow.jpeg)"
"![retrieval_diagram](/img/rag_retrieval_generation.png)"
]
},
{

Binary file not shown.

Before

Width:  |  Height:  |  Size: 173 KiB

BIN
docs/static/img/rag_indexing.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

16095
docs/static/svg/langchain_stack.svg vendored Normal file

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 956 KiB

View File

@@ -7,7 +7,7 @@ from langchain_cli.namespaces import app as app_namespace
from langchain_cli.namespaces import template as template_namespace
from langchain_cli.utils.packages import get_langserve_export, get_package_root
__version__ = "0.0.18"
__version__ = "0.0.19"
app = typer.Typer(no_args_is_help=True, add_completion=False)
app.add_typer(

View File

@@ -6,7 +6,7 @@ RUN poetry config virtualenvs.create false
WORKDIR /code
COPY ./pyproject.toml ./poetry.lock* ./
COPY ./pyproject.toml ./README.md ./poetry.lock* ./
COPY ./packages ./packages

View File

@@ -11,7 +11,7 @@ packages = [
[tool.poetry.dependencies]
python = "^3.11"
uvicorn = "^0.23.2"
langserve = {extras = ["server"], version = ">=0.0.22"}
langserve = {extras = ["server"], version = ">=0.0.30"}
pydantic = "<2"

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-cli"
version = "0.0.18"
version = "0.0.19"
description = "CLI for interacting with LangChain"
authors = ["Erick Friis <erick@langchain.dev>"]
readme = "README.md"

56
libs/core/Makefile Normal file
View File

@@ -0,0 +1,56 @@
.PHONY: all format lint test tests test_watch integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
test:
poetry run pytest $(TEST_FILE)
tests:
poetry run pytest $(TEST_FILE)
test_watch:
poetry run ptw --snapshot-update --now . -- -x tests/unit_tests
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/experimental --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint lint_diff:
./scripts/check_pydantic.sh .
./scripts/check_imports.sh
poetry run ruff .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
######################
# HELP
######################
help:
@echo '----'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'
@echo 'test_watch - run unit tests in watch mode'

1
libs/core/README.md Normal file
View File

@@ -0,0 +1 @@
# langchain-core

View File

@@ -0,0 +1,7 @@
from importlib import metadata
try:
__version__ = metadata.version(__package__)
except metadata.PackageNotFoundError:
# Case where package metadata is not available.
__version__ = ""

View File

@@ -0,0 +1,29 @@
"""Helper functions for managing the LangChain API.
This module is only relevant for LangChain developers, not for users.
.. warning::
This module and its submodules are for internal use only. Do not use them
in your own code. We may change the API at any time with no warning.
"""
from .deprecation import (
LangChainDeprecationWarning,
deprecated,
suppress_langchain_deprecation_warning,
surface_langchain_deprecation_warnings,
warn_deprecated,
)
from .path import as_import_path, get_relative_path
__all__ = [
"as_import_path",
"deprecated",
"get_relative_path",
"LangChainDeprecationWarning",
"suppress_langchain_deprecation_warning",
"surface_langchain_deprecation_warnings",
"warn_deprecated",
]

View File

@@ -0,0 +1,341 @@
"""Helper functions for deprecating parts of the LangChain API.
This module was adapted from matplotlibs _api/deprecation.py module:
https://github.com/matplotlib/matplotlib/blob/main/lib/matplotlib/_api/deprecation.py
.. warning::
This module is for internal use only. Do not use it in your own code.
We may change the API at any time with no warning.
"""
import contextlib
import functools
import inspect
import warnings
from typing import Any, Callable, Generator, Type, TypeVar
class LangChainDeprecationWarning(DeprecationWarning):
"""A class for issuing deprecation warnings for LangChain users."""
class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
"""A class for issuing deprecation warnings for LangChain users."""
# PUBLIC API
T = TypeVar("T", Type, Callable)
def deprecated(
since: str,
*,
message: str = "",
name: str = "",
alternative: str = "",
pending: bool = False,
obj_type: str = "",
addendum: str = "",
removal: str = "",
) -> Callable[[T], T]:
"""Decorator to mark a function, a class, or a property as deprecated.
When deprecating a classmethod, a staticmethod, or a property, the
``@deprecated`` decorator should go *under* ``@classmethod`` and
``@staticmethod`` (i.e., `deprecated` should directly decorate the
underlying callable), but *over* ``@property``.
When deprecating a class ``C`` intended to be used as a base class in a
multiple inheritance hierarchy, ``C`` *must* define an ``__init__`` method
(if ``C`` instead inherited its ``__init__`` from its own base class, then
``@deprecated`` would mess up ``__init__`` inheritance when installing its
own (deprecation-emitting) ``C.__init__``).
Parameters are the same as for `warn_deprecated`, except that *obj_type*
defaults to 'class' if decorating a class, 'attribute' if decorating a
property, and 'function' otherwise.
Arguments:
since : str
The release at which this API became deprecated.
message : str, optional
Override the default deprecation message. The %(since)s,
%(name)s, %(alternative)s, %(obj_type)s, %(addendum)s,
and %(removal)s format specifiers will be replaced by the
values of the respective arguments passed to this function.
name : str, optional
The name of the deprecated object.
alternative : str, optional
An alternative API that the user may use in place of the
deprecated API. The deprecation warning will tell the user
about this alternative if provided.
pending : bool, optional
If True, uses a PendingDeprecationWarning instead of a
DeprecationWarning. Cannot be used together with removal.
obj_type : str, optional
The object type being deprecated.
addendum : str, optional
Additional text appended directly to the final message.
removal : str, optional
The expected removal version. With the default (an empty
string), a removal version is automatically computed from
since. Set to other Falsy values to not schedule a removal
date. Cannot be used together with pending.
Examples
--------
.. code-block:: python
@deprecated('1.4.0')
def the_function_to_deprecate():
pass
"""
def deprecate(
obj: T,
*,
_obj_type: str = obj_type,
_name: str = name,
_message: str = message,
_alternative: str = alternative,
_pending: bool = pending,
_addendum: str = addendum,
) -> T:
"""Implementation of the decorator returned by `deprecated`."""
if isinstance(obj, type):
if not _obj_type:
_obj_type = "class"
wrapped = obj.__init__ # type: ignore
_name = _name or obj.__name__
old_doc = obj.__doc__
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
"""Finalize the deprecation of a class."""
try:
obj.__doc__ = new_doc
except AttributeError: # Can't set on some extension objects.
pass
obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc]
wrapper
)
return obj
elif isinstance(obj, property):
if not _obj_type:
_obj_type = "attribute"
wrapped = None
_name = _name or obj.fget.__name__
old_doc = obj.__doc__
class _deprecated_property(type(obj)): # type: ignore
"""A deprecated property."""
def __get__(self, instance, owner=None): # type: ignore
if instance is not None or owner is not None:
emit_warning()
return super().__get__(instance, owner)
def __set__(self, instance, value): # type: ignore
if instance is not None:
emit_warning()
return super().__set__(instance, value)
def __delete__(self, instance): # type: ignore
if instance is not None:
emit_warning()
return super().__delete__(instance)
def __set_name__(self, owner, set_name): # type: ignore
nonlocal _name
if _name == "<lambda>":
_name = set_name
def finalize(_: Any, new_doc: str) -> Any: # type: ignore
"""Finalize the property."""
return _deprecated_property(
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
)
else:
if not _obj_type:
_obj_type = "function"
wrapped = obj
_name = _name or obj.__name__ # type: ignore
old_doc = wrapped.__doc__
def finalize( # type: ignore
wrapper: Callable[..., Any], new_doc: str
) -> T:
"""Wrap the wrapped function using the wrapper and update the docstring.
Args:
wrapper: The wrapper function.
new_doc: The new docstring.
Returns:
The wrapped function.
"""
wrapper = functools.wraps(wrapped)(wrapper)
wrapper.__doc__ = new_doc
return wrapper
def emit_warning() -> None:
"""Emit the warning."""
warn_deprecated(
since,
message=_message,
name=_name,
alternative=_alternative,
pending=_pending,
obj_type=_obj_type,
addendum=_addendum,
removal=removal,
)
def warning_emitting_wrapper(*args: Any, **kwargs: Any) -> Any:
"""Wrapper for the original wrapped callable that emits a warning.
Args:
*args: The positional arguments to the function.
**kwargs: The keyword arguments to the function.
Returns:
The return value of the function being wrapped.
"""
emit_warning()
return wrapped(*args, **kwargs)
old_doc = inspect.cleandoc(old_doc or "").strip("\n")
if not old_doc:
new_doc = "[*Deprecated*]"
else:
new_doc = f"[*Deprecated*] {old_doc}"
# Modify the docstring to include a deprecation notice.
notes_header = "\nNotes\n-----"
components = [
message,
f"Use {alternative} instead." if alternative else "",
addendum,
]
details = " ".join([component.strip() for component in components if component])
new_doc += (
f"[*Deprecated*] {old_doc}\n"
f"{notes_header if notes_header not in old_doc else ''}\n"
f".. deprecated:: {since}\n"
f" {details}"
)
return finalize(warning_emitting_wrapper, new_doc)
return deprecate
@contextlib.contextmanager
def suppress_langchain_deprecation_warning() -> Generator[None, None, None]:
"""Context manager to suppress LangChainDeprecationWarning."""
with warnings.catch_warnings():
warnings.simplefilter("ignore", LangChainDeprecationWarning)
warnings.simplefilter("ignore", LangChainPendingDeprecationWarning)
yield
def warn_deprecated(
since: str,
*,
message: str = "",
name: str = "",
alternative: str = "",
pending: bool = False,
obj_type: str = "",
addendum: str = "",
removal: str = "",
) -> None:
"""Display a standardized deprecation.
Arguments:
since : str
The release at which this API became deprecated.
message : str, optional
Override the default deprecation message. The %(since)s,
%(name)s, %(alternative)s, %(obj_type)s, %(addendum)s,
and %(removal)s format specifiers will be replaced by the
values of the respective arguments passed to this function.
name : str, optional
The name of the deprecated object.
alternative : str, optional
An alternative API that the user may use in place of the
deprecated API. The deprecation warning will tell the user
about this alternative if provided.
pending : bool, optional
If True, uses a PendingDeprecationWarning instead of a
DeprecationWarning. Cannot be used together with removal.
obj_type : str, optional
The object type being deprecated.
addendum : str, optional
Additional text appended directly to the final message.
removal : str, optional
The expected removal version. With the default (an empty
string), a removal version is automatically computed from
since. Set to other Falsy values to not schedule a removal
date. Cannot be used together with pending.
"""
if pending and removal:
raise ValueError("A pending deprecation cannot have a scheduled removal")
if not pending:
if not removal:
removal = f"in {removal}" if removal else "within ?? minor releases"
raise NotImplementedError(
f"Need to determine which default deprecation schedule to use. "
f"{removal}"
)
else:
removal = f"in {removal}"
if not message:
message = ""
if obj_type:
message += f"The {obj_type} `{name}`"
else:
message += f"`{name}`"
if pending:
message += " will be deprecated in a future version"
else:
message += f" was deprecated in LangChain {since}"
if removal:
message += f" and will be removed {removal}"
if alternative:
message += f". Use {alternative} instead."
if addendum:
message += f" {addendum}"
warning_cls = (
LangChainPendingDeprecationWarning if pending else LangChainDeprecationWarning
)
warning = warning_cls(message)
warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=2)
def surface_langchain_deprecation_warnings() -> None:
"""Unmute LangChain deprecation warnings."""
warnings.filterwarnings(
"default",
category=LangChainPendingDeprecationWarning,
)
warnings.filterwarnings(
"default",
category=LangChainDeprecationWarning,
)

View File

@@ -0,0 +1,36 @@
import os
from pathlib import Path
from typing import Optional, Union
HERE = Path(__file__).parent
# Get directory of langchain package
PACKAGE_DIR = HERE.parent
SEPARATOR = os.sep
def get_relative_path(
file: Union[Path, str], *, relative_to: Path = PACKAGE_DIR
) -> str:
"""Get the path of the file as a relative path to the package directory."""
if isinstance(file, str):
file = Path(file)
return str(file.relative_to(relative_to))
def as_import_path(
file: Union[Path, str],
*,
suffix: Optional[str] = None,
relative_to: Path = PACKAGE_DIR,
) -> str:
"""Path of the file as a LangChain import exclude langchain top namespace."""
if isinstance(file, str):
file = Path(file)
path = get_relative_path(file, relative_to=relative_to)
if file.is_file():
path = path[: -len(file.suffix)]
import_path = path.replace(SEPARATOR, ".")
if suffix:
import_path += "." + suffix
return import_path

View File

@@ -0,0 +1,89 @@
from __future__ import annotations
from typing import Any, List, Literal, Sequence, Union
from langchain_core.load.serializable import Serializable
from langchain_core.messages import BaseMessage
class AgentAction(Serializable):
"""A full description of an action for an ActionAgent to execute."""
tool: str
"""The name of the Tool to execute."""
tool_input: Union[str, dict]
"""The input to pass in to the Tool."""
log: str
"""Additional information to log about the action.
This log can be used in a few ways. First, it can be used to audit
what exactly the LLM predicted to lead to this (tool, tool_input).
Second, it can be used in future iterations to show the LLMs prior
thoughts. This is useful when (tool, tool_input) does not contain
full information about the LLM prediction (for example, any `thought`
before the tool/tool_input)."""
type: Literal["AgentAction"] = "AgentAction"
def __init__(
self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
):
"""Override init to support instantiation by position for backward compat."""
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "schema", "agent"]
class AgentActionMessageLog(AgentAction):
message_log: Sequence[BaseMessage]
"""Similar to log, this can be used to pass along extra
information about what exact messages were predicted by the LLM
before parsing out the (tool, tool_input). This is again useful
if (tool, tool_input) cannot be used to fully recreate the LLM
prediction, and you need that LLM prediction (for future agent iteration).
Compared to `log`, this is useful when the underlying LLM is a
ChatModel (and therefore returns messages rather than a string)."""
# Ignoring type because we're overriding the type from AgentAction.
# And this is the correct thing to do in this case.
# The type literal is used for serialization purposes.
type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "schema", "agent"]
class AgentFinish(Serializable):
"""The final return value of an ActionAgent."""
return_values: dict
"""Dictionary of return values."""
log: str
"""Additional information to log about the return value.
This is used to pass along the full LLM prediction, not just the parsed out
return value. For example, if the full LLM prediction was
`Final Answer: 2` you may want to just return `2` as a return value, but pass
along the full string as a `log` (for debugging or observability purposes).
"""
type: Literal["AgentFinish"] = "AgentFinish"
def __init__(self, return_values: dict, log: str, **kwargs: Any):
"""Override init to support instantiation by position for backward compat."""
super().__init__(return_values=return_values, log=log, **kwargs)
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "schema", "agent"]

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Optional, Sequence
from langchain_core.outputs import Generation
RETURN_VAL_TYPE = Sequence[Generation]
class BaseCache(ABC):
"""Base interface for cache."""
@abstractmethod
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
@abstractmethod
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
@abstractmethod
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""

View File

@@ -0,0 +1,65 @@
from langchain_core.callbacks.base import (
AsyncCallbackHandler,
BaseCallbackHandler,
BaseCallbackManager,
CallbackManagerMixin,
Callbacks,
ChainManagerMixin,
LLMManagerMixin,
RetrieverManagerMixin,
RunManagerMixin,
ToolManagerMixin,
)
from langchain_core.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainGroup,
AsyncCallbackManagerForChainRun,
AsyncCallbackManagerForLLMRun,
AsyncCallbackManagerForRetrieverRun,
AsyncCallbackManagerForToolRun,
AsyncParentRunManager,
AsyncRunManager,
BaseRunManager,
CallbackManager,
CallbackManagerForChainGroup,
CallbackManagerForChainRun,
CallbackManagerForLLMRun,
CallbackManagerForRetrieverRun,
CallbackManagerForToolRun,
ParentRunManager,
RunManager,
)
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
__all__ = [
"RetrieverManagerMixin",
"LLMManagerMixin",
"ChainManagerMixin",
"ToolManagerMixin",
"Callbacks",
"CallbackManagerMixin",
"RunManagerMixin",
"BaseCallbackHandler",
"AsyncCallbackHandler",
"BaseCallbackManager",
"BaseRunManager",
"RunManager",
"ParentRunManager",
"AsyncRunManager",
"AsyncParentRunManager",
"CallbackManagerForLLMRun",
"AsyncCallbackManagerForLLMRun",
"CallbackManagerForChainRun",
"AsyncCallbackManagerForChainRun",
"CallbackManagerForToolRun",
"AsyncCallbackManagerForToolRun",
"CallbackManagerForRetrieverRun",
"AsyncCallbackManagerForRetrieverRun",
"CallbackManager",
"CallbackManagerForChainGroup",
"AsyncCallbackManager",
"AsyncCallbackManagerForChainGroup",
"StdOutCallbackHandler",
"StreamingStdOutCallbackHandler",
]

View File

@@ -0,0 +1,599 @@
"""Base callback handler that can be used to handle callbacks in langchain."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
from uuid import UUID
from tenacity import RetryCallState
if TYPE_CHECKING:
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
class RetrieverManagerMixin:
"""Mixin for Retriever callbacks."""
def on_retriever_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever errors."""
def on_retriever_end(
self,
documents: Sequence[Document],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever ends running."""
class LLMManagerMixin:
"""Mixin for LLM callbacks."""
def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on new LLM token. Only available when streaming is enabled.
Args:
token (str): The new token.
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
containing content and other information.
"""
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM ends running."""
def on_llm_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM errors."""
class ChainManagerMixin:
"""Mixin for chain callbacks."""
def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when chain ends running."""
def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when chain errors."""
def on_agent_action(
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on agent action."""
def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on agent end."""
class ToolManagerMixin:
"""Mixin for tool callbacks."""
def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when tool ends running."""
def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when tool errors."""
class CallbackManagerMixin:
"""Mixin for callback manager."""
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM starts running."""
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
)
def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever starts running."""
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when chain starts running."""
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when tool starts running."""
class RunManagerMixin:
"""Mixin for run manager."""
def on_text(
self,
text: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on arbitrary text."""
def on_retry(
self,
retry_state: RetryCallState,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on a retry event."""
class BaseCallbackHandler(
LLMManagerMixin,
ChainManagerMixin,
ToolManagerMixin,
RetrieverManagerMixin,
CallbackManagerMixin,
RunManagerMixin,
):
"""Base callback handler that handles callbacks from LangChain."""
raise_error: bool = False
run_inline: bool = False
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return False
@property
def ignore_retry(self) -> bool:
"""Whether to ignore retry callbacks."""
return False
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return False
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return False
@property
def ignore_retriever(self) -> bool:
"""Whether to ignore retriever callbacks."""
return False
@property
def ignore_chat_model(self) -> bool:
"""Whether to ignore chat model callbacks."""
return False
class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that handles callbacks from LangChain."""
async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM starts running."""
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
)
async def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
async def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM ends running."""
async def on_llm_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM errors."""
async def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when chain starts running."""
async def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when chain ends running."""
async def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when chain errors."""
async def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when tool starts running."""
async def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool ends running."""
async def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool errors."""
async def on_text(
self,
text: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on arbitrary text."""
async def on_retry(
self,
retry_state: RetryCallState,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on a retry event."""
async def on_agent_action(
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on agent action."""
async def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on agent end."""
async def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run on retriever start."""
async def on_retriever_end(
self,
documents: Sequence[Document],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on retriever end."""
async def on_retriever_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on retriever error."""
T = TypeVar("T", bound="BaseCallbackManager")
class BaseCallbackManager(CallbackManagerMixin):
"""Base callback manager that handles callbacks from LangChain."""
def __init__(
self,
handlers: List[BaseCallbackHandler],
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
parent_run_id: Optional[UUID] = None,
*,
tags: Optional[List[str]] = None,
inheritable_tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers
self.inheritable_handlers: List[BaseCallbackHandler] = (
inheritable_handlers or []
)
self.parent_run_id: Optional[UUID] = parent_run_id
self.tags = tags or []
self.inheritable_tags = inheritable_tags or []
self.metadata = metadata or {}
self.inheritable_metadata = inheritable_metadata or {}
def copy(self: T) -> T:
"""Copy the callback manager."""
return self.__class__(
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
@property
def is_async(self) -> bool:
"""Whether the callback manager is async."""
return False
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
"""Add a handler to the callback manager."""
if handler not in self.handlers:
self.handlers.append(handler)
if inherit and handler not in self.inheritable_handlers:
self.inheritable_handlers.append(handler)
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)
self.inheritable_handlers.remove(handler)
def set_handlers(
self, handlers: List[BaseCallbackHandler], inherit: bool = True
) -> None:
"""Set handlers as the only handlers on the callback manager."""
self.handlers = []
self.inheritable_handlers = []
for handler in handlers:
self.add_handler(handler, inherit=inherit)
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
"""Set handler as the only handler on the callback manager."""
self.set_handlers([handler], inherit=inherit)
def add_tags(self, tags: List[str], inherit: bool = True) -> None:
for tag in tags:
if tag in self.tags:
self.remove_tags([tag])
self.tags.extend(tags)
if inherit:
self.inheritable_tags.extend(tags)
def remove_tags(self, tags: List[str]) -> None:
for tag in tags:
self.tags.remove(tag)
self.inheritable_tags.remove(tag)
def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None:
self.metadata.update(metadata)
if inherit:
self.inheritable_metadata.update(metadata)
def remove_metadata(self, keys: List[str]) -> None:
for key in keys:
self.metadata.pop(key)
self.inheritable_metadata.pop(key)
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,102 @@
"""Callback Handler that prints to std out."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.utils import print_text
if TYPE_CHECKING:
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.outputs import LLMResult
class StdOutCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
def __init__(self, color: Optional[str] = None) -> None:
"""Initialize callback handler."""
self.color = color
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
print("\n\033[1m> Finished chain.\033[0m")
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing."""
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
"""Run on agent action."""
print_text(action.log, color=color or self.color)
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
if observation_prefix is not None:
print_text(f"\n{observation_prefix}")
print_text(output, color=color or self.color)
if llm_prefix is not None:
print_text(f"\n{llm_prefix}")
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Any,
) -> None:
"""Run when agent ends."""
print_text(text, color=color or self.color, end=end)
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
print_text(finish.log, color=color or self.color, end="\n")

View File

@@ -0,0 +1,72 @@
"""Callback Handler streams to stdout on new llm token."""
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Any, Dict, List
from langchain_core.callbacks.base import BaseCallbackHandler
if TYPE_CHECKING:
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult
class StreamingStdOutCallbackHandler(BaseCallbackHandler):
"""Callback handler for streaming. Only works with LLMs that support streaming."""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any,
) -> None:
"""Run when LLM starts running."""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
sys.stdout.write(token)
sys.stdout.flush()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
pass
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text."""
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""

View File

@@ -0,0 +1,67 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
class BaseChatMessageHistory(ABC):
"""Abstract base class for storing chat message history.
See `ChatMessageHistory` for default implementation.
Example:
.. code-block:: python
class FileChatMessageHistory(BaseChatMessageHistory):
storage_path: str
session_id: str
@property
def messages(self):
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
messages = json.loads(f.read())
return messages_from_dict(messages)
def add_message(self, message: BaseMessage) -> None:
messages = self.messages.append(_message_to_dict(message))
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)
def clear(self):
with open(os.path.join(storage_path, session_id), 'w') as f:
f.write("[]")
"""
messages: List[BaseMessage]
"""A list of Messages stored in-memory."""
def add_user_message(self, message: str) -> None:
"""Convenience method for adding a human message string to the store.
Args:
message: The string contents of a human message.
"""
self.add_message(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
"""Convenience method for adding an AI message string to the store.
Args:
message: The string contents of an AI message.
"""
self.add_message(AIMessage(content=message))
@abstractmethod
def add_message(self, message: BaseMessage) -> None:
"""Add a Message object to the store.
Args:
message: A BaseMessage object to store.
"""
raise NotImplementedError()
@abstractmethod
def clear(self) -> None:
"""Remove all messages from the store"""

View File

@@ -0,0 +1,13 @@
from typing import Sequence, TypedDict
from langchain_core.messages import BaseMessage
class ChatSession(TypedDict, total=False):
"""Chat Session represents a single
conversation, channel, or other group of messages."""
messages: Sequence[BaseMessage]
"""The LangChain chat messages loaded from the source."""
functions: Sequence[dict]
"""The function calling specs for the messages."""

View File

@@ -0,0 +1,4 @@
from langchain_core.documents.base import Document
from langchain_core.documents.transformers import BaseDocumentTransformer
__all__ = ["Document", "BaseDocumentTransformer"]

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
from typing import List, Literal
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Field
class Document(Serializable):
"""Class for storing a piece of text and associated metadata."""
page_content: str
"""String text."""
metadata: dict = Field(default_factory=dict)
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
type: Literal["Document"] = "Document"
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "schema", "document"]

View File

@@ -0,0 +1,74 @@
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from functools import partial
from typing import TYPE_CHECKING, Any, Sequence
if TYPE_CHECKING:
from langchain_core.documents import Document
class BaseDocumentTransformer(ABC):
"""Abstract base class for document transformation systems.
A document transformation system takes a sequence of Documents and returns a
sequence of transformed Documents.
Example:
.. code-block:: python
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
embeddings: Embeddings
similarity_fn: Callable = cosine_similarity
similarity_threshold: float = 0.95
class Config:
arbitrary_types_allowed = True
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
stateful_documents = get_stateful_documents(documents)
embedded_documents = _get_embeddings_from_stateful_docs(
self.embeddings, stateful_documents
)
included_idxs = _filter_similar_embeddings(
embedded_documents, self.similarity_fn, self.similarity_threshold
)
return [stateful_documents[i] for i in sorted(included_idxs)]
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
raise NotImplementedError
""" # noqa: E501
@abstractmethod
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Transform a list of documents.
Args:
documents: A sequence of Documents to be transformed.
Returns:
A list of transformed Documents.
"""
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Asynchronously transform a list of documents.
Args:
documents: A sequence of Documents to be transformed.
Returns:
A list of transformed Documents.
"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.transform_documents, **kwargs), documents
)

View File

@@ -0,0 +1,27 @@
import asyncio
from abc import ABC, abstractmethod
from typing import List
class Embeddings(ABC):
"""Interface for embedding models."""
@abstractmethod
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
@abstractmethod
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs."""
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_documents, texts
)
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_query, text
)

View File

@@ -0,0 +1,17 @@
import platform
from functools import lru_cache
@lru_cache(maxsize=1)
def get_runtime_environment() -> dict:
"""Get information about the LangChain runtime environment."""
# Lazy import to avoid circular imports
from langchain_core import __version__
return {
"library_version": __version__,
"library": "langchain",
"platform": platform.platform(),
"runtime": "python",
"runtime_version": platform.python_version(),
}

View File

@@ -0,0 +1,18 @@
"""Logic for selecting examples to include in prompts."""
from langchain_core.example_selectors.base import BaseExampleSelector
from langchain_core.example_selectors.length_based import (
LengthBasedExampleSelector,
)
from langchain_core.example_selectors.semantic_similarity import (
MaxMarginalRelevanceExampleSelector,
SemanticSimilarityExampleSelector,
sorted_values,
)
__all__ = [
"BaseExampleSelector",
"LengthBasedExampleSelector",
"MaxMarginalRelevanceExampleSelector",
"SemanticSimilarityExampleSelector",
"sorted_values",
]

View File

@@ -0,0 +1,15 @@
"""Interface for selecting examples to include in prompts."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List
class BaseExampleSelector(ABC):
"""Interface for selecting examples to include in prompts."""
@abstractmethod
def add_example(self, example: Dict[str, str]) -> Any:
"""Add new example to store for a key."""
@abstractmethod
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on the inputs."""

View File

@@ -0,0 +1,63 @@
"""Select examples based on length."""
import re
from typing import Callable, Dict, List
from langchain_core.example_selectors.base import BaseExampleSelector
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, validator
def _get_length_based(text: str) -> int:
return len(re.split("\n| ", text))
class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
"""Select examples based on length."""
examples: List[dict]
"""A list of the examples that the prompt template expects."""
example_prompt: PromptTemplate
"""Prompt template used to format the examples."""
get_text_length: Callable[[str], int] = _get_length_based
"""Function to measure prompt length. Defaults to word count."""
max_length: int = 2048
"""Max length for the prompt, beyond which examples are cut."""
example_text_lengths: List[int] = [] #: :meta private:
def add_example(self, example: Dict[str, str]) -> None:
"""Add new example to list."""
self.examples.append(example)
string_example = self.example_prompt.format(**example)
self.example_text_lengths.append(self.get_text_length(string_example))
@validator("example_text_lengths", always=True)
def calculate_example_text_lengths(cls, v: List[int], values: Dict) -> List[int]:
"""Calculate text lengths if they don't exist."""
# Check if text lengths were passed in
if v:
return v
# If they were not, calculate them
example_prompt = values["example_prompt"]
get_text_length = values["get_text_length"]
string_examples = [example_prompt.format(**eg) for eg in values["examples"]]
return [get_text_length(eg) for eg in string_examples]
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on the input lengths."""
inputs = " ".join(input_variables.values())
remaining_length = self.max_length - self.get_text_length(inputs)
i = 0
examples = []
while remaining_length > 0 and i < len(self.examples):
new_length = remaining_length - self.example_text_lengths[i]
if new_length < 0:
break
else:
examples.append(self.examples[i])
remaining_length = new_length
i += 1
return examples

View File

@@ -0,0 +1,167 @@
"""Example selector that selects examples based on SemanticSimilarity."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
from langchain_core.example_selectors.base import BaseExampleSelector
from langchain_core.pydantic_v1 import BaseModel, Extra
from langchain_core.vectorstores import VectorStore
if TYPE_CHECKING:
from langchain_core.embeddings import Embeddings
def sorted_values(values: Dict[str, str]) -> List[Any]:
"""Return a list of values in dict sorted by key."""
return [values[val] for val in sorted(values)]
class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
"""Example selector that selects examples based on SemanticSimilarity."""
vectorstore: VectorStore
"""VectorStore than contains information about examples."""
k: int = 4
"""Number of examples to select."""
example_keys: Optional[List[str]] = None
"""Optional keys to filter examples to."""
input_keys: Optional[List[str]] = None
"""Optional keys to filter input to. If provided, the search is based on
the input variables instead of all variables."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def add_example(self, example: Dict[str, str]) -> str:
"""Add new example to vectorstore."""
if self.input_keys:
string_example = " ".join(
sorted_values({key: example[key] for key in self.input_keys})
)
else:
string_example = " ".join(sorted_values(example))
ids = self.vectorstore.add_texts([string_example], metadatas=[example])
return ids[0]
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on semantic similarity."""
# Get the docs with the highest similarity.
if self.input_keys:
input_variables = {key: input_variables[key] for key in self.input_keys}
query = " ".join(sorted_values(input_variables))
example_docs = self.vectorstore.similarity_search(query, k=self.k)
# Get the examples from the metadata.
# This assumes that examples are stored in metadata.
examples = [dict(e.metadata) for e in example_docs]
# If example keys are provided, filter examples to those keys.
if self.example_keys:
examples = [{k: eg[k] for k in self.example_keys} for eg in examples]
return examples
@classmethod
def from_examples(
cls,
examples: List[dict],
embeddings: Embeddings,
vectorstore_cls: Type[VectorStore],
k: int = 4,
input_keys: Optional[List[str]] = None,
**vectorstore_cls_kwargs: Any,
) -> SemanticSimilarityExampleSelector:
"""Create k-shot example selector using example list and embeddings.
Reshuffles examples dynamically based on query similarity.
Args:
examples: List of examples to use in the prompt.
embeddings: An initialized embedding API interface, e.g. OpenAIEmbeddings().
vectorstore_cls: A vector store DB interface class, e.g. FAISS.
k: Number of examples to select
input_keys: If provided, the search is based on the input variables
instead of all variables.
vectorstore_cls_kwargs: optional kwargs containing url for vector store
Returns:
The ExampleSelector instantiated, backed by a vector store.
"""
if input_keys:
string_examples = [
" ".join(sorted_values({k: eg[k] for k in input_keys}))
for eg in examples
]
else:
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
vectorstore = vectorstore_cls.from_texts(
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
)
return cls(vectorstore=vectorstore, k=k, input_keys=input_keys)
class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector):
"""ExampleSelector that selects examples based on Max Marginal Relevance.
This was shown to improve performance in this paper:
https://arxiv.org/pdf/2211.13892.pdf
"""
fetch_k: int = 20
"""Number of examples to fetch to rerank."""
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on semantic similarity."""
# Get the docs with the highest similarity.
if self.input_keys:
input_variables = {key: input_variables[key] for key in self.input_keys}
query = " ".join(sorted_values(input_variables))
example_docs = self.vectorstore.max_marginal_relevance_search(
query, k=self.k, fetch_k=self.fetch_k
)
# Get the examples from the metadata.
# This assumes that examples are stored in metadata.
examples = [dict(e.metadata) for e in example_docs]
# If example keys are provided, filter examples to those keys.
if self.example_keys:
examples = [{k: eg[k] for k in self.example_keys} for eg in examples]
return examples
@classmethod
def from_examples(
cls,
examples: List[dict],
embeddings: Embeddings,
vectorstore_cls: Type[VectorStore],
k: int = 4,
input_keys: Optional[List[str]] = None,
fetch_k: int = 20,
**vectorstore_cls_kwargs: Any,
) -> MaxMarginalRelevanceExampleSelector:
"""Create k-shot example selector using example list and embeddings.
Reshuffles examples dynamically based on query similarity.
Args:
examples: List of examples to use in the prompt.
embeddings: An iniialized embedding API interface, e.g. OpenAIEmbeddings().
vectorstore_cls: A vector store DB interface class, e.g. FAISS.
k: Number of examples to select
input_keys: If provided, the search is based on the input variables
instead of all variables.
vectorstore_cls_kwargs: optional kwargs containing url for vector store
Returns:
The ExampleSelector instantiated, backed by a vector store.
"""
if input_keys:
string_examples = [
" ".join(sorted_values({k: eg[k] for k in input_keys}))
for eg in examples
]
else:
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
vectorstore = vectorstore_cls.from_texts(
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
)
return cls(vectorstore=vectorstore, k=k, fetch_k=fetch_k, input_keys=input_keys)

View File

@@ -0,0 +1,48 @@
from typing import Any, Optional
class LangChainException(Exception):
"""General LangChain exception."""
class TracerException(LangChainException):
"""Base class for exceptions in tracers module."""
class OutputParserException(ValueError, LangChainException):
"""Exception that output parsers should raise to signify a parsing error.
This exists to differentiate parsing errors from other code or execution errors
that also may arise inside the output parser. OutputParserExceptions will be
available to catch and handle in ways to fix the parsing error, while other
errors will be raised.
Args:
error: The error that's being re-raised or an error message.
observation: String explanation of error which can be passed to a
model to try and remediate the issue.
llm_output: String model output which is error-ing.
send_to_llm: Whether to send the observation and llm_output back to an Agent
after an OutputParserException has been raised. This gives the underlying
model driving the agent the context that the previous output was improperly
structured, in the hopes that it will update the output to the correct
format.
"""
def __init__(
self,
error: Any,
observation: Optional[str] = None,
llm_output: Optional[str] = None,
send_to_llm: bool = False,
):
super(OutputParserException, self).__init__(error)
if send_to_llm:
if observation is None or llm_output is None:
raise ValueError(
"Arguments 'observation' & 'llm_output'"
" are required if 'send_to_llm' is True"
)
self.observation = observation
self.llm_output = llm_output
self.send_to_llm = send_to_llm

View File

@@ -0,0 +1,197 @@
# flake8: noqa
"""Global values and configuration that apply to all of LangChain."""
import warnings
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from langchain_core.caches import BaseCache
# DO NOT USE THESE VALUES DIRECTLY!
# Use them only via `get_<X>()` and `set_<X>()` below,
# or else your code may behave unexpectedly with other uses of these global settings:
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
_verbose: bool = False
_debug: bool = False
_llm_cache: Optional["BaseCache"] = None
def set_verbose(value: bool) -> None:
"""Set a new value for the `verbose` global setting."""
try:
import langchain # type: ignore[import]
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=(
"Importing verbose from langchain_core root module is no longer supported"
),
)
# N.B.: This is a workaround for an unfortunate quirk of Python's
# module-level `__getattr__()` implementation:
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
#
# Remove it once `langchain.verbose` is no longer supported, and once all users
# have migrated to using `set_verbose()` here.
langchain.verbose = value
except ImportError:
pass
global _verbose
_verbose = value
def get_verbose() -> bool:
"""Get the value of the `verbose` global setting."""
try:
import langchain # type: ignore[import]
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=(
"Importing verbose from langchain_core root module is no longer supported"
),
)
# N.B.: This is a workaround for an unfortunate quirk of Python's
# module-level `__getattr__()` implementation:
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
#
# Remove it once `langchain.verbose` is no longer supported, and once all users
# have migrated to using `set_verbose()` here.
#
# In the meantime, the `verbose` setting is considered True if either the old
# or the new value are True. This accommodates users who haven't migrated
# to using `set_verbose()` yet. Those users are getting deprecation warnings
# directing them to use `set_verbose()` when they import `langhchain.verbose`.
old_verbose = langchain.verbose
except ImportError:
old_verbose = False
global _verbose
return _verbose or old_verbose
def set_debug(value: bool) -> None:
"""Set a new value for the `debug` global setting."""
try:
import langchain # type: ignore[import]
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Importing debug from langchain_core root module is no longer supported",
)
# N.B.: This is a workaround for an unfortunate quirk of Python's
# module-level `__getattr__()` implementation:
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
#
# Remove it once `langchain.debug` is no longer supported, and once all users
# have migrated to using `set_debug()` here.
langchain.debug = value
except ImportError:
pass
global _debug
_debug = value
def get_debug() -> bool:
"""Get the value of the `debug` global setting."""
try:
import langchain # type: ignore[import]
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Importing debug from langchain_core root module is no longer supported",
)
# N.B.: This is a workaround for an unfortunate quirk of Python's
# module-level `__getattr__()` implementation:
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
#
# Remove it once `langchain.debug` is no longer supported, and once all users
# have migrated to using `set_debug()` here.
#
# In the meantime, the `debug` setting is considered True if either the old
# or the new value are True. This accommodates users who haven't migrated
# to using `set_debug()` yet. Those users are getting deprecation warnings
# directing them to use `set_debug()` when they import `langhchain.debug`.
old_debug = langchain.debug
except ImportError:
old_debug = False
global _debug
return _debug or old_debug
def set_llm_cache(value: Optional["BaseCache"]) -> None:
"""Set a new LLM cache, overwriting the previous value, if any."""
try:
import langchain # type: ignore[import]
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=(
"Importing llm_cache from langchain_core root module is no longer supported"
),
)
# N.B.: This is a workaround for an unfortunate quirk of Python's
# module-level `__getattr__()` implementation:
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
#
# Remove it once `langchain.llm_cache` is no longer supported, and
# once all users have migrated to using `set_llm_cache()` here.
langchain.llm_cache = value
except ImportError:
pass
global _llm_cache
_llm_cache = value
def get_llm_cache() -> "BaseCache":
"""Get the value of the `llm_cache` global setting."""
try:
import langchain # type: ignore[import]
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=(
"Importing llm_cache from langchain_core root module is no longer supported"
),
)
# N.B.: This is a workaround for an unfortunate quirk of Python's
# module-level `__getattr__()` implementation:
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
#
# Remove it once `langchain.llm_cache` is no longer supported, and
# once all users have migrated to using `set_llm_cache()` here.
#
# In the meantime, the `llm_cache` setting returns whichever of
# its two backing sources is truthy (not `None` and non-empty),
# or the old value if both are falsy. This accommodates users
# who haven't migrated to using `set_llm_cache()` yet.
# Those users are getting deprecation warnings directing them
# to use `set_llm_cache()` when they import `langhchain.llm_cache`.
old_llm_cache = langchain.llm_cache
except ImportError:
old_llm_cache = None
global _llm_cache
return _llm_cache or old_llm_cache

View File

@@ -0,0 +1,17 @@
from langchain_core.language_models.base import (
BaseLanguageModel,
LanguageModelInput,
get_tokenizer,
)
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.language_models.llms import LLM, BaseLLM
__all__ = [
"BaseLanguageModel",
"BaseChatModel",
"SimpleChatModel",
"BaseLLM",
"LLM",
"LanguageModelInput",
"get_tokenizer",
]

View File

@@ -0,0 +1,293 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Any,
List,
Optional,
Sequence,
Set,
TypeVar,
Union,
)
from typing_extensions import TypeAlias
from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables import RunnableSerializable
from langchain_core.utils import get_pydantic_field_names
if TYPE_CHECKING:
from langchain_core.callbacks import Callbacks
from langchain_core.outputs import LLMResult
@lru_cache(maxsize=None) # Cache the tokenizer
def get_tokenizer() -> Any:
try:
from transformers import GPT2TokenizerFast # type: ignore[import]
except ImportError:
raise ImportError(
"Could not import transformers python package. "
"This is needed in order to calculate get_token_ids. "
"Please install it with `pip install transformers`."
)
# create a GPT-2 tokenizer instance
return GPT2TokenizerFast.from_pretrained("gpt2")
def _get_token_ids_default_method(text: str) -> List[int]:
"""Encode the text into token IDs."""
# get the cached tokenizer
tokenizer = get_tokenizer()
# tokenize the text using the GPT-2 tokenizer
return tokenizer.encode(text)
LanguageModelInput = Union[PromptValue, str, List[BaseMessage]]
LanguageModelOutput = TypeVar("LanguageModelOutput")
class BaseLanguageModel(
RunnableSerializable[LanguageModelInput, LanguageModelOutput], ABC
):
"""Abstract base class for interfacing with language models.
All language model wrappers inherit from BaseLanguageModel.
Exposes three main methods:
- generate_prompt: generate language model outputs for a sequence of prompt
values. A prompt value is a model input that can be converted to any language
model input format (string or messages).
- predict: pass in a single string to a language model and return a string
prediction.
- predict_messages: pass in a sequence of BaseMessages (corresponding to a single
model call) to a language model and return a BaseMessage prediction.
Each of these has an equivalent asynchronous method.
"""
@property
def InputType(self) -> TypeAlias:
"""Get the input type for this runnable."""
from langchain_core.prompt_values import (
ChatPromptValueConcrete,
StringPromptValue,
)
# This is a version of LanguageModelInput which replaces the abstract
# base class BaseMessage with a union of its subclasses, which makes
# for a much better schema.
return Union[
str,
Union[StringPromptValue, ChatPromptValueConcrete],
List[AnyMessage],
]
@abstractmethod
def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
"""Pass a sequence of prompts to the model and return model generations.
This method should make use of batched calls for models that expose a batched
API.
Use this method when you want to:
1. take advantage of batched calls,
2. need more output from the model than just the top generated value,
3. are building chains that are agnostic to the underlying language model
type (e.g., pure text completion models vs chat models).
Args:
prompts: List of PromptValues. A PromptValue is an object that can be
converted to match the format of any language model (string for pure
text generation models and BaseMessages for chat models).
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
callbacks: Callbacks to pass through. Used for executing additional
functionality, such as logging or streaming, throughout generation.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
"""
@abstractmethod
async def agenerate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
"""Asynchronously pass a sequence of prompts and return model generations.
This method should make use of batched calls for models that expose a batched
API.
Use this method when you want to:
1. take advantage of batched calls,
2. need more output from the model than just the top generated value,
3. are building chains that are agnostic to the underlying language model
type (e.g., pure text completion models vs chat models).
Args:
prompts: List of PromptValues. A PromptValue is an object that can be
converted to match the format of any language model (string for pure
text generation models and BaseMessages for chat models).
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
callbacks: Callbacks to pass through. Used for executing additional
functionality, such as logging or streaming, throughout generation.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
"""
@abstractmethod
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
"""Pass a single string input to the model and return a string prediction.
Use this method when passing in raw text. If you want to pass in specific
types of chat messages, use predict_messages.
Args:
text: String input to pass to the model.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
Top model prediction as a string.
"""
@abstractmethod
def predict_messages(
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
"""Pass a message sequence to the model and return a message prediction.
Use this method when passing in chat messages. If you want to pass in raw text,
use predict.
Args:
messages: A sequence of chat messages corresponding to a single model input.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
Top model prediction as a message.
"""
@abstractmethod
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
"""Asynchronously pass a string to the model and return a string prediction.
Use this method when calling pure text generation models and only the top
candidate generation is needed.
Args:
text: String input to pass to the model.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
Top model prediction as a string.
"""
@abstractmethod
async def apredict_messages(
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
"""Asynchronously pass messages to the model and return a message prediction.
Use this method when calling chat models and only the top
candidate generation is needed.
Args:
messages: A sequence of chat messages corresponding to a single model input.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
Top model prediction as a message.
"""
def get_token_ids(self, text: str) -> List[int]:
"""Return the ordered ids of the tokens in a text.
Args:
text: The string input to tokenize.
Returns:
A list of ids corresponding to the tokens in the text, in order they occur
in the text.
"""
return _get_token_ids_default_method(text)
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
return len(self.get_token_ids(text))
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input will fit in a model's context window.
Args:
messages: The message inputs to tokenize.
Returns:
The sum of the number of tokens across the messages.
"""
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
@classmethod
def _all_required_field_names(cls) -> Set:
"""DEPRECATED: Kept for backwards compatibility.
Use get_pydantic_field_names.
"""
return get_pydantic_field_names(cls)

View File

@@ -0,0 +1,738 @@
from __future__ import annotations
import asyncio
import inspect
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
cast,
)
from langchain_core.callbacks import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
BaseCallbackManager,
CallbackManager,
CallbackManagerForLLMRun,
Callbacks,
)
from langchain_core.globals import get_llm_cache
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
from langchain_core.load import dumpd, dumps
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
BaseMessageChunk,
HumanMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
LLMResult,
RunInfo,
)
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator
if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig
def _get_verbosity() -> bool:
from langchain_core.globals import get_verbose
return get_verbose()
def _generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
generation: Optional[ChatGenerationChunk] = None
for chunk in stream:
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
async def _agenerate_from_stream(
stream: AsyncIterator[ChatGenerationChunk],
) -> ChatResult:
generation: Optional[ChatGenerationChunk] = None
async for chunk in stream:
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Base class for Chat models."""
cache: Optional[bool] = None
"""Whether to cache the response."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to add to the run trace."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""Callback manager to add to the run trace."""
tags: Optional[List[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace."""
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
"""Metadata to add to the run trace."""
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
# --- Runnable methods ---
@property
def OutputType(self) -> Any:
"""Get the output type for this runnable."""
return AnyMessage
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
elif isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, list):
return ChatPromptValue(messages=input)
else:
raise ValueError(
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessage:
config = config or {}
return cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
).generations[0][0],
).message
async def ainvoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessage:
config = config or {}
llm_result = await self.agenerate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
return cast(ChatGeneration, llm_result.generations[0][0]).message
def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[BaseMessageChunk]:
if type(self)._stream == BaseChatModel._stream:
# model doesn't implement streaming, so use default implementation
yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
)
else:
config = config or {}
messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
callback_manager = CallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = callback_manager.on_chat_model_start(
dumpd(self),
[messages],
invocation_params=params,
options=options,
name=config.get("run_name"),
)
try:
generation: Optional[ChatGenerationChunk] = None
for chunk in self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.message
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
except BaseException as e:
run_manager.on_llm_error(e)
raise e
else:
run_manager.on_llm_end(
LLMResult(generations=[[generation]]),
)
async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]:
if type(self)._astream == BaseChatModel._astream:
# model doesn't implement streaming, so use default implementation
yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
)
else:
config = config or {}
messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = await callback_manager.on_chat_model_start(
dumpd(self),
[messages],
invocation_params=params,
options=options,
name=config.get("run_name"),
)
try:
generation: Optional[ChatGenerationChunk] = None
async for chunk in self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.message
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
except BaseException as e:
await run_manager.on_llm_error(e)
raise e
else:
await run_manager.on_llm_end(
LLMResult(generations=[[generation]]),
)
# --- Custom methods ---
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return {}
def _get_invocation_params(
self,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> dict:
params = self.dict()
params["stop"] = stop
return {**params, **kwargs}
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
if self.is_lc_serializable():
params = {**kwargs, **{"stop": stop}}
param_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = dumps(self)
return llm_string + "---" + param_string
else:
params = self._get_invocation_params(stop=stop, **kwargs)
params = {**params, **kwargs}
return str(sorted([(k, v) for k, v in params.items()]))
def generate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop}
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
run_managers = callback_manager.on_chat_model_start(
dumpd(self),
messages,
invocation_params=params,
options=options,
name=run_name,
)
results = []
for i, m in enumerate(messages):
try:
results.append(
self._generate_with_cache(
m,
stop=stop,
run_manager=run_managers[i] if run_managers else None,
**kwargs,
)
)
except BaseException as e:
if run_managers:
run_managers[i].on_llm_error(e)
raise e
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
if run_managers:
run_infos = []
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
run_infos.append(RunInfo(run_id=manager.run_id))
output.run = run_infos
return output
async def agenerate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
run_managers = await callback_manager.on_chat_model_start(
dumpd(self),
messages,
invocation_params=params,
options=options,
name=run_name,
)
results = await asyncio.gather(
*[
self._agenerate_with_cache(
m,
stop=stop,
run_manager=run_managers[i] if run_managers else None,
**kwargs,
)
for i, m in enumerate(messages)
],
return_exceptions=True,
)
exceptions = []
for i, res in enumerate(results):
if isinstance(res, BaseException):
if run_managers:
await run_managers[i].on_llm_error(res)
exceptions.append(res)
if exceptions:
if run_managers:
await asyncio.gather(
*[
run_manager.on_llm_end(
LLMResult(
generations=[res.generations], llm_output=res.llm_output
)
)
for run_manager, res in zip(run_managers, results)
if not isinstance(res, Exception)
]
)
raise exceptions[0]
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
)
]
)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
async def agenerate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
return await self.agenerate(
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
)
def _generate_with_cache(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
llm_cache = get_llm_cache()
if llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
if new_arg_supported:
return self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
return self._generate(messages, stop=stop, **kwargs)
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
if new_arg_supported:
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
result = self._generate(messages, stop=stop, **kwargs)
llm_cache.update(prompt, llm_string, result.generations)
return result
async def _agenerate_with_cache(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
llm_cache = get_llm_cache()
if llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
if new_arg_supported:
return await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
return await self._agenerate(messages, stop=stop, **kwargs)
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
if new_arg_supported:
result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
result = await self._agenerate(messages, stop=stop, **kwargs)
llm_cache.update(prompt, llm_string, result.generations)
return result
@abstractmethod
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, **kwargs), messages, stop, run_manager
)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
raise NotImplementedError()
def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
raise NotImplementedError()
def __call__(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseMessage:
generation = self.generate(
[messages], stop=stop, callbacks=callbacks, **kwargs
).generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
else:
raise ValueError("Unexpected generation type")
async def _call_async(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseMessage:
result = await self.agenerate(
[messages], stop=stop, callbacks=callbacks, **kwargs
)
generation = result.generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
else:
raise ValueError("Unexpected generation type")
def call_as_llm(
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
) -> str:
return self.predict(message, stop=stop, **kwargs)
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
if isinstance(result.content, str):
return result.content
else:
raise ValueError("Cannot use predict when output is not a string.")
def predict_messages(
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
if stop is None:
_stop = None
else:
_stop = list(stop)
return self(messages, stop=_stop, **kwargs)
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
result = await self._call_async(
[HumanMessage(content=text)], stop=_stop, **kwargs
)
if isinstance(result.content, str):
return result.content
else:
raise ValueError("Cannot use predict when output is not a string.")
async def apredict_messages(
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
if stop is None:
_stop = None
else:
_stop = list(stop)
return await self._call_async(messages, stop=_stop, **kwargs)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {}
@property
@abstractmethod
def _llm_type(self) -> str:
"""Return type of chat model."""
def dict(self, **kwargs: Any) -> Dict:
"""Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type
return starter_dict
class SimpleChatModel(BaseChatModel):
"""Simple Chat Model."""
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@abstractmethod
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Simpler interface."""
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
"""Serialization and deserialization."""
from langchain_core.load.dump import dumpd, dumps
from langchain_core.load.load import load, loads
from langchain_core.load.serializable import Serializable
__all__ = ["dumpd", "dumps", "load", "loads", "Serializable"]

View File

@@ -0,0 +1,26 @@
import json
from typing import Any, Dict
from langchain_core.load.serializable import Serializable, to_json_not_implemented
def default(obj: Any) -> Any:
"""Return a default value for a Serializable object or
a SerializedNotImplemented object."""
if isinstance(obj, Serializable):
return obj.to_json()
else:
return to_json_not_implemented(obj)
def dumps(obj: Any, *, pretty: bool = False) -> str:
"""Return a json string representation of an object."""
if pretty:
return json.dumps(obj, default=default, indent=2)
else:
return json.dumps(obj, default=default)
def dumpd(obj: Any) -> Dict[str, Any]:
"""Return a json dict representation of an object."""
return json.loads(dumps(obj))

View File

@@ -0,0 +1,130 @@
import importlib
import json
import os
from typing import Any, Dict, List, Optional
from langchain_core.load.serializable import Serializable
DEFAULT_NAMESPACES = ["langchain", "langchain_core"]
class Reviver:
"""Reviver for JSON objects."""
def __init__(
self,
secrets_map: Optional[Dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None,
) -> None:
self.secrets_map = secrets_map or dict()
# By default only support langchain, but user can pass in additional namespaces
self.valid_namespaces = (
[*DEFAULT_NAMESPACES, *valid_namespaces]
if valid_namespaces
else DEFAULT_NAMESPACES
)
def __call__(self, value: Dict[str, Any]) -> Any:
if (
value.get("lc", None) == 1
and value.get("type", None) == "secret"
and value.get("id", None) is not None
):
[key] = value["id"]
if key in self.secrets_map:
return self.secrets_map[key]
else:
if key in os.environ and os.environ[key]:
return os.environ[key]
raise KeyError(f'Missing key "{key}" in load(secrets_map)')
if (
value.get("lc", None) == 1
and value.get("type", None) == "not_implemented"
and value.get("id", None) is not None
):
raise NotImplementedError(
"Trying to load an object that doesn't implement "
f"serialization: {value}"
)
if (
value.get("lc", None) == 1
and value.get("type", None) == "constructor"
and value.get("id", None) is not None
):
[*namespace, name] = value["id"]
if namespace[0] not in self.valid_namespaces:
raise ValueError(f"Invalid namespace: {value}")
# The root namespace "langchain" is not a valid identifier.
if len(namespace) == 1 and namespace[0] == "langchain":
raise ValueError(f"Invalid namespace: {value}")
mod = importlib.import_module(".".join(namespace))
cls = getattr(mod, name)
# The class must be a subclass of Serializable.
if not issubclass(cls, Serializable):
raise ValueError(f"Invalid namespace: {value}")
# We don't need to recurse on kwargs
# as json.loads will do that for us.
kwargs = value.get("kwargs", dict())
return cls(**kwargs)
return value
def loads(
text: str,
*,
secrets_map: Optional[Dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None,
) -> Any:
"""Revive a LangChain class from a JSON string.
Equivalent to `load(json.loads(text))`.
Args:
text: The string to load.
secrets_map: A map of secrets to load.
valid_namespaces: A list of additional namespaces (modules)
to allow to be deserialized.
Returns:
Revived LangChain objects.
"""
return json.loads(text, object_hook=Reviver(secrets_map, valid_namespaces))
def load(
obj: Any,
*,
secrets_map: Optional[Dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None,
) -> Any:
"""Revive a LangChain class from a JSON object. Use this if you already
have a parsed JSON object, eg. from `json.load` or `orjson.loads`.
Args:
obj: The object to load.
secrets_map: A map of secrets to load.
valid_namespaces: A list of additional namespaces (modules)
to allow to be deserialized.
Returns:
Revived LangChain objects.
"""
reviver = Reviver(secrets_map, valid_namespaces)
def _load(obj: Any) -> Any:
if isinstance(obj, dict):
# Need to revive leaf nodes before reviving this node
loaded_obj = {k: _load(v) for k, v in obj.items()}
return reviver(loaded_obj)
if isinstance(obj, list):
return [_load(o) for o in obj]
return obj
return _load(obj)

View File

@@ -0,0 +1,207 @@
from abc import ABC
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
class BaseSerialized(TypedDict):
"""Base class for serialized objects."""
lc: int
id: List[str]
class SerializedConstructor(BaseSerialized):
"""Serialized constructor."""
type: Literal["constructor"]
kwargs: Dict[str, Any]
class SerializedSecret(BaseSerialized):
"""Serialized secret."""
type: Literal["secret"]
class SerializedNotImplemented(BaseSerialized):
"""Serialized not implemented."""
type: Literal["not_implemented"]
repr: Optional[str]
def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
try:
return model.__fields__[key].get_default() != value
except Exception:
return True
class Serializable(BaseModel, ABC):
"""Serializable base class."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Is this class serializable?"""
return False
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.
For example, if the class is `langchain.llms.openai.OpenAI`, then the
namespace is ["langchain", "llms", "openai"]
"""
return cls.__module__.split(".")
@property
def lc_secrets(self) -> Dict[str, str]:
"""A map of constructor argument names to secret ids.
For example,
{"openai_api_key": "OPENAI_API_KEY"}
"""
return dict()
@property
def lc_attributes(self) -> Dict:
"""List of attribute names that should be included in the serialized kwargs.
These attributes must be accepted by the constructor.
"""
return {}
@classmethod
def lc_id(cls) -> List[str]:
"""A unique identifier for this class for serialization purposes.
The unique identifier is a list of strings that describes the path
to the object.
"""
return [*cls.get_lc_namespace(), cls.__name__]
class Config:
extra = "ignore"
def __repr_args__(self) -> Any:
return [
(k, v)
for k, v in super().__repr_args__()
if (k not in self.__fields__ or try_neq_default(v, k, self))
]
_lc_kwargs = PrivateAttr(default_factory=dict)
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._lc_kwargs = kwargs
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
if not self.is_lc_serializable():
return self.to_json_not_implemented()
secrets = dict()
# Get latest values for kwargs if there is an attribute with same name
lc_kwargs = {
k: getattr(self, k, v)
for k, v in self._lc_kwargs.items()
if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
}
# Merge the lc_secrets and lc_attributes from every class in the MRO
for cls in [None, *self.__class__.mro()]:
# Once we get to Serializable, we're done
if cls is Serializable:
break
if cls:
deprecated_attributes = [
"lc_namespace",
"lc_serializable",
]
for attr in deprecated_attributes:
if hasattr(cls, attr):
raise ValueError(
f"Class {self.__class__} has a deprecated "
f"attribute {attr}. Please use the corresponding "
f"classmethod instead."
)
# Get a reference to self bound to each class in the MRO
this = cast(Serializable, self if cls is None else super(cls, self))
secrets.update(this.lc_secrets)
lc_kwargs.update(this.lc_attributes)
# include all secrets, even if not specified in kwargs
# as these secrets may be passed as an environment variable instead
for key in secrets.keys():
secret_value = getattr(self, key, None) or lc_kwargs.get(key)
if secret_value is not None:
lc_kwargs.update({key: secret_value})
return {
"lc": 1,
"type": "constructor",
"id": self.lc_id(),
"kwargs": lc_kwargs
if not secrets
else _replace_secrets(lc_kwargs, secrets),
}
def to_json_not_implemented(self) -> SerializedNotImplemented:
return to_json_not_implemented(self)
def _replace_secrets(
root: Dict[Any, Any], secrets_map: Dict[str, str]
) -> Dict[Any, Any]:
result = root.copy()
for path, secret_id in secrets_map.items():
[*parts, last] = path.split(".")
current = result
for part in parts:
if part not in current:
break
current[part] = current[part].copy()
current = current[part]
if last in current:
current[last] = {
"lc": 1,
"type": "secret",
"id": [secret_id],
}
return result
def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
"""Serialize a "not implemented" object.
Args:
obj: object to serialize
Returns:
SerializedNotImplemented
"""
_id: List[str] = []
try:
if hasattr(obj, "__name__"):
_id = [*obj.__module__.split("."), obj.__name__]
elif hasattr(obj, "__class__"):
_id = [*obj.__class__.__module__.split("."), obj.__class__.__name__]
except Exception:
pass
result: SerializedNotImplemented = {
"lc": 1,
"type": "not_implemented",
"id": _id,
"repr": None,
}
try:
result["repr"] = repr(obj)
except Exception:
pass
return result

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from langchain_core.load.serializable import Serializable
class BaseMemory(Serializable, ABC):
"""Abstract base class for memory in Chains.
Memory refers to state in Chains. Memory can be used to store information about
past executions of a Chain and inject that information into the inputs of
future executions of the Chain. For example, for conversational Chains Memory
can be used to store conversations and automatically add them to future model
prompts so that the model has the necessary context to respond coherently to
the latest input.
Example:
.. code-block:: python
class SimpleMemory(BaseMemory):
memories: Dict[str, Any] = dict()
@property
def memory_variables(self) -> List[str]:
return list(self.memories.keys())
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.memories
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
pass
def clear(self) -> None:
pass
""" # noqa: E501
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
@abstractmethod
def memory_variables(self) -> List[str]:
"""The string keys this memory class will add to chain inputs."""
@abstractmethod
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""
@abstractmethod
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context of this chain run to memory."""
@abstractmethod
def clear(self) -> None:
"""Clear memory contents."""

View File

@@ -0,0 +1,122 @@
from typing import List, Sequence, Union
from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
merge_content,
message_to_dict,
messages_to_dict,
)
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk
from langchain_core.messages.human import HumanMessage, HumanMessageChunk
from langchain_core.messages.system import SystemMessage, SystemMessageChunk
from langchain_core.messages.tool import ToolMessage, ToolMessageChunk
AnyMessage = Union[
AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage, ToolMessage
]
def get_buffer_string(
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Convert sequence of Messages to strings and concatenate them into one string.
Args:
messages: Messages to be converted to strings.
human_prefix: The prefix to prepend to contents of HumanMessages.
ai_prefix: THe prefix to prepend to contents of AIMessages.
Returns:
A single string concatenation of all input messages.
Example:
.. code-block:: python
from langchain_core import AIMessage, HumanMessage
messages = [
HumanMessage(content="Hi, how are you?"),
AIMessage(content="Good, how are you?"),
]
get_buffer_string(messages)
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
"""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
elif isinstance(m, SystemMessage):
role = "System"
elif isinstance(m, FunctionMessage):
role = "Function"
elif isinstance(m, ToolMessage):
role = "Tool"
elif isinstance(m, ChatMessage):
role = m.role
else:
raise ValueError(f"Got unsupported message type: {m}")
message = f"{role}: {m.content}"
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
message += f"{m.additional_kwargs['function_call']}"
string_messages.append(message)
return "\n".join(string_messages)
def _message_from_dict(message: dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
elif _type == "ai":
return AIMessage(**message["data"])
elif _type == "system":
return SystemMessage(**message["data"])
elif _type == "chat":
return ChatMessage(**message["data"])
elif _type == "function":
return FunctionMessage(**message["data"])
elif _type == "tool":
return ToolMessage(**message["data"])
else:
raise ValueError(f"Got unexpected message type: {_type}")
def messages_from_dict(messages: Sequence[dict]) -> List[BaseMessage]:
"""Convert a sequence of messages from dicts to Message objects.
Args:
messages: Sequence of messages (as dicts) to convert.
Returns:
List of messages (BaseMessages).
"""
return [_message_from_dict(m) for m in messages]
__all__ = [
"AIMessage",
"AIMessageChunk",
"AnyMessage",
"BaseMessage",
"BaseMessageChunk",
"ChatMessage",
"ChatMessageChunk",
"FunctionMessage",
"FunctionMessageChunk",
"HumanMessage",
"HumanMessageChunk",
"SystemMessage",
"SystemMessageChunk",
"ToolMessage",
"ToolMessageChunk",
"get_buffer_string",
"messages_from_dict",
"messages_to_dict",
"message_to_dict",
"merge_content",
]

View File

@@ -0,0 +1,47 @@
from typing import Any, Literal
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
merge_content,
)
class AIMessage(BaseMessage):
"""A Message from an AI."""
example: bool = False
"""Whether this Message is being passed in to the model as part of an example
conversation.
"""
type: Literal["ai"] = "ai"
AIMessage.update_forward_refs()
class AIMessageChunk(AIMessage, BaseMessageChunk):
"""A Message chunk from an AI."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):
if self.example != other.example:
raise ValueError(
"Cannot concatenate AIMessageChunks with different example values."
)
return self.__class__(
example=self.example,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)

View File

@@ -0,0 +1,131 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Extra, Field
if TYPE_CHECKING:
from langchain_core.prompts.chat import ChatPromptTemplate
class BaseMessage(Serializable):
"""The base abstract Message class.
Messages are the inputs and outputs of ChatModels.
"""
content: Union[str, List[Union[str, Dict]]]
"""The string contents of the message."""
additional_kwargs: dict = Field(default_factory=dict)
"""Any additional information."""
type: str
class Config:
extra = Extra.allow
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
def __add__(self, other: Any) -> ChatPromptTemplate:
from langchain_core.prompts.chat import ChatPromptTemplate
prompt = ChatPromptTemplate(messages=[self])
return prompt + other
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "schema", "messages"]
def merge_content(
first_content: Union[str, List[Union[str, Dict]]],
second_content: Union[str, List[Union[str, Dict]]],
) -> Union[str, List[Union[str, Dict]]]:
# If first chunk is a string
if isinstance(first_content, str):
# If the second chunk is also a string, then merge them naively
if isinstance(second_content, str):
return first_content + second_content
# If the second chunk is a list, add the first chunk to the start of the list
else:
return_list: List[Union[str, Dict]] = [first_content]
return return_list + second_content
# If both are lists, merge them naively
elif isinstance(second_content, List):
return first_content + second_content
# If the first content is a list, and the second content is a string
else:
# If the last element of the first content is a string
# Add the second content to the last element
if isinstance(first_content[-1], str):
return first_content[:-1] + [first_content[-1] + second_content]
else:
# Otherwise, add the second content as a new element of the list
return first_content + [second_content]
class BaseMessageChunk(BaseMessage):
"""A Message chunk, which can be concatenated with other Message chunks."""
def _merge_kwargs_dict(
self, left: Dict[str, Any], right: Dict[str, Any]
) -> Dict[str, Any]:
"""Merge additional_kwargs from another BaseMessageChunk into this one."""
merged = left.copy()
for k, v in right.items():
if k not in merged:
merged[k] = v
elif type(merged[k]) != type(v):
raise ValueError(
f'additional_kwargs["{k}"] already exists in this message,'
" but with a different type."
)
elif isinstance(merged[k], str):
merged[k] += v
elif isinstance(merged[k], dict):
merged[k] = self._merge_kwargs_dict(merged[k], v)
else:
raise ValueError(
f"Additional kwargs key {k} already exists in this message."
)
return merged
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, BaseMessageChunk):
# If both are (subclasses of) BaseMessageChunk,
# concat into a single BaseMessageChunk
return self.__class__(
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
else:
raise TypeError(
'unsupported operand type(s) for +: "'
f"{self.__class__.__name__}"
f'" and "{other.__class__.__name__}"'
)
def message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
"""Convert a sequence of Messages to a list of dictionaries.
Args:
messages: Sequence of messages (as BaseMessages) to convert.
Returns:
List of messages as dicts.
"""
return [message_to_dict(m) for m in messages]

View File

@@ -0,0 +1,53 @@
from typing import Any, Literal
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
merge_content,
)
class ChatMessage(BaseMessage):
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
role: str
"""The speaker / role of the Message."""
type: Literal["chat"] = "chat"
ChatMessage.update_forward_refs()
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
"""A Chat Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["ChatMessageChunk"] = "ChatMessageChunk" # type: ignore
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, ChatMessageChunk):
if self.role != other.role:
raise ValueError(
"Cannot concatenate ChatMessageChunks with different roles."
)
return self.__class__(
role=self.role,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
elif isinstance(other, BaseMessageChunk):
return self.__class__(
role=self.role,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
else:
return super().__add__(other)

View File

@@ -0,0 +1,45 @@
from typing import Any, Literal
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
merge_content,
)
class FunctionMessage(BaseMessage):
"""A Message for passing the result of executing a function back to a model."""
name: str
"""The name of the function that was executed."""
type: Literal["function"] = "function"
FunctionMessage.update_forward_refs()
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
"""A Function Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["FunctionMessageChunk"] = "FunctionMessageChunk" # type: ignore[assignment]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, FunctionMessageChunk):
if self.name != other.name:
raise ValueError(
"Cannot concatenate FunctionMessageChunks with different names."
)
return self.__class__(
name=self.name,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)

View File

@@ -0,0 +1,26 @@
from typing import Literal
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
class HumanMessage(BaseMessage):
"""A Message from a human."""
example: bool = False
"""Whether this Message is being passed in to the model as part of an example
conversation.
"""
type: Literal["human"] = "human"
HumanMessage.update_forward_refs()
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
"""A Human Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["HumanMessageChunk"] = "HumanMessageChunk" # type: ignore[assignment] # noqa: E501

View File

@@ -0,0 +1,23 @@
from typing import Literal
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
class SystemMessage(BaseMessage):
"""A Message for priming AI behavior, usually passed in as the first of a sequence
of input messages.
"""
type: Literal["system"] = "system"
SystemMessage.update_forward_refs()
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
"""A System Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["SystemMessageChunk"] = "SystemMessageChunk" # type: ignore[assignment] # noqa: E501

View File

@@ -0,0 +1,45 @@
from typing import Any, Literal
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
merge_content,
)
class ToolMessage(BaseMessage):
"""A Message for passing the result of executing a tool back to a model."""
tool_call_id: str
"""Tool call that this message is responding to."""
type: Literal["tool"] = "tool"
ToolMessage.update_forward_refs()
class ToolMessageChunk(ToolMessage, BaseMessageChunk):
"""A Tool Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, ToolMessageChunk):
if self.tool_call_id != other.tool_call_id:
raise ValueError(
"Cannot concatenate ToolMessageChunks with different names."
)
return self.__class__(
tool_call_id=self.tool_call_id,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)

View File

@@ -0,0 +1,29 @@
from langchain_core.output_parsers.base import (
BaseGenerationOutputParser,
BaseLLMOutputParser,
BaseOutputParser,
)
from langchain_core.output_parsers.list import (
CommaSeparatedListOutputParser,
ListOutputParser,
MarkdownListOutputParser,
NumberedListOutputParser,
)
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.output_parsers.transform import (
BaseCumulativeTransformOutputParser,
BaseTransformOutputParser,
)
__all__ = [
"BaseLLMOutputParser",
"BaseGenerationOutputParser",
"BaseOutputParser",
"ListOutputParser",
"CommaSeparatedListOutputParser",
"NumberedListOutputParser",
"MarkdownListOutputParser",
"StrOutputParser",
"BaseTransformOutputParser",
"BaseCumulativeTransformOutputParser",
]

View File

@@ -0,0 +1,301 @@
from __future__ import annotations
import asyncio
import functools
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Optional,
Type,
TypeVar,
Union,
)
from typing_extensions import get_args
from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.runnables import RunnableConfig, RunnableSerializable
if TYPE_CHECKING:
from langchain_core.prompt_values import PromptValue
T = TypeVar("T")
class BaseLLMOutputParser(Generic[T], ABC):
"""Abstract base class for parsing the outputs of a model."""
@abstractmethod
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
"""Parse a list of candidate model Generations into a specific format.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.parse_result, result
)
class BaseGenerationOutputParser(
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
):
"""Base class to parse the output of an LLM call."""
@property
def InputType(self) -> Any:
return Union[str, AnyMessage]
@property
def OutputType(self) -> Type[T]:
# even though mypy complains this isn't valid,
# it is good enough for pydantic to build the schema from
return T # type: ignore[misc]
def invoke(
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
) -> T:
if isinstance(input, BaseMessage):
return self._call_with_config(
lambda inner_input: self.parse_result(
[ChatGeneration(message=inner_input)]
),
input,
config,
run_type="parser",
)
else:
return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
async def ainvoke(
self,
input: str | BaseMessage,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> T:
if isinstance(input, BaseMessage):
return await self._acall_with_config(
lambda inner_input: self.aparse_result(
[ChatGeneration(message=inner_input)]
),
input,
config,
run_type="parser",
)
else:
return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
class BaseOutputParser(
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
):
"""Base class to parse the output of an LLM call.
Output parsers help structure language model responses.
Example:
.. code-block:: python
class BooleanOutputParser(BaseOutputParser[bool]):
true_val: str = "YES"
false_val: str = "NO"
def parse(self, text: str) -> bool:
cleaned_text = text.strip().upper()
if cleaned_text not in (self.true_val.upper(), self.false_val.upper()):
raise OutputParserException(
f"BooleanOutputParser expected output value to either be "
f"{self.true_val} or {self.false_val} (case-insensitive). "
f"Received {cleaned_text}."
)
return cleaned_text == self.true_val.upper()
@property
def _type(self) -> str:
return "boolean_output_parser"
""" # noqa: E501
@property
def InputType(self) -> Any:
return Union[str, AnyMessage]
@property
def OutputType(self) -> Type[T]:
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 1:
return type_args[0]
raise TypeError(
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
"Override the OutputType property to specify the output type."
)
def invoke(
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
) -> T:
if isinstance(input, BaseMessage):
return self._call_with_config(
lambda inner_input: self.parse_result(
[ChatGeneration(message=inner_input)]
),
input,
config,
run_type="parser",
)
else:
return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
async def ainvoke(
self,
input: str | BaseMessage,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> T:
if isinstance(input, BaseMessage):
return await self._acall_with_config(
lambda inner_input: self.aparse_result(
[ChatGeneration(message=inner_input)]
),
input,
config,
run_type="parser",
)
else:
return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
is assumed to be the highest-likelihood Generation.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
return self.parse(result[0].text)
@abstractmethod
def parse(self, text: str) -> T:
"""Parse a single string model output into some structure.
Args:
text: String output of a language model.
Returns:
Structured output.
"""
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
is assumed to be the highest-likelihood Generation.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
return await asyncio.get_running_loop().run_in_executor(
None, functools.partial(self.parse_result, partial=partial), result
)
async def aparse(self, text: str) -> T:
"""Parse a single string model output into some structure.
Args:
text: String output of a language model.
Returns:
Structured output.
"""
return await asyncio.get_running_loop().run_in_executor(None, self.parse, text)
# TODO: rename 'completion' -> 'text'.
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
"""Parse the output of an LLM call with the input prompt for context.
The prompt is largely provided in the event the OutputParser wants
to retry or fix the output in some way, and needs information from
the prompt to do so.
Args:
completion: String output of a language model.
prompt: Input PromptValue.
Returns:
Structured output
"""
return self.parse(completion)
def get_format_instructions(self) -> str:
"""Instructions on how the LLM output should be formatted."""
raise NotImplementedError
@property
def _type(self) -> str:
"""Return the output parser type for serialization."""
raise NotImplementedError(
f"_type property is not implemented in class {self.__class__.__name__}."
" This is required for serialization."
)
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict(**kwargs)
try:
output_parser_dict["_type"] = self._type
except NotImplementedError:
pass
return output_parser_dict

View File

@@ -0,0 +1,84 @@
from __future__ import annotations
import re
from abc import abstractmethod
from typing import List
from langchain_core.output_parsers.base import BaseOutputParser
class ListOutputParser(BaseOutputParser[List[str]]):
"""Parse the output of an LLM call to a list."""
@property
def _type(self) -> str:
return "list"
@abstractmethod
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "output_parsers", "list"]
class CommaSeparatedListOutputParser(ListOutputParser):
"""Parse the output of an LLM call to a comma-separated list."""
@classmethod
def is_lc_serializable(cls) -> bool:
return True
def get_format_instructions(self) -> str:
return (
"Your response should be a list of comma separated values, "
"eg: `foo, bar, baz`"
)
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
return text.strip().split(", ")
@property
def _type(self) -> str:
return "comma-separated-list"
class NumberedListOutputParser(ListOutputParser):
"""Parse a numbered list."""
def get_format_instructions(self) -> str:
return (
"Your response should be a numbered list with each item on a new line. "
"For example: \n\n1. foo\n\n2. bar\n\n3. baz"
)
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
pattern = r"\d+\.\s([^\n]+)"
# Extract the text of each item
matches = re.findall(pattern, text)
return matches
@property
def _type(self) -> str:
return "numbered-list"
class MarkdownListOutputParser(ListOutputParser):
"""Parse a markdown list."""
def get_format_instructions(self) -> str:
return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`"
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
pattern = r"-\s([^\n]+)"
return re.findall(pattern, text)
@property
def _type(self) -> str:
return "markdown-list"

View File

@@ -0,0 +1,26 @@
from typing import List
from langchain_core.output_parsers.transform import BaseTransformOutputParser
class StrOutputParser(BaseTransformOutputParser[str]):
"""OutputParser that parses LLMResult into the top likely string."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@property
def _type(self) -> str:
"""Return the output parser type for serialization."""
return "default"
def parse(self, text: str) -> str:
"""Returns the input text with no changes."""
return text
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "schema", "output_parser"]

View File

@@ -0,0 +1,131 @@
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Iterator,
Optional,
Union,
)
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.output_parsers.base import BaseOutputParser, T
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
Generation,
GenerationChunk,
)
if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig
class BaseTransformOutputParser(BaseOutputParser[T]):
"""Base class for an output parser that can handle streaming input."""
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]:
for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
else:
yield self.parse_result([Generation(text=chunk)])
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]:
async for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
else:
yield self.parse_result([Generation(text=chunk)])
def transform(
self,
input: Iterator[Union[str, BaseMessage]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[T]:
yield from self._transform_stream_with_config(
input, self._transform, config, run_type="parser"
)
async def atransform(
self,
input: AsyncIterator[Union[str, BaseMessage]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[T]:
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, run_type="parser"
):
yield chunk
class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
"""Base class for an output parser that can handle streaming input."""
diff: bool = False
"""In streaming mode, whether to yield diffs between the previous and current
parsed output, or just the current parsed output.
"""
def _diff(self, prev: Optional[T], next: T) -> T:
"""Convert parsed outputs into a diff format. The semantics of this are
up to the output parser."""
raise NotImplementedError()
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
prev_parsed = None
acc_gen = None
for chunk in input:
if isinstance(chunk, BaseMessageChunk):
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
)
else:
chunk_gen = GenerationChunk(text=chunk)
if acc_gen is None:
acc_gen = chunk_gen
else:
acc_gen += chunk_gen
parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
else:
yield parsed
prev_parsed = parsed
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]:
prev_parsed = None
acc_gen = None
async for chunk in input:
if isinstance(chunk, BaseMessageChunk):
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
)
else:
chunk_gen = GenerationChunk(text=chunk)
if acc_gen is None:
acc_gen = chunk_gen
else:
acc_gen += chunk_gen
parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
else:
yield parsed
prev_parsed = parsed

View File

@@ -0,0 +1,15 @@
from langchain_core.outputs.chat_generation import ChatGeneration, ChatGenerationChunk
from langchain_core.outputs.chat_result import ChatResult
from langchain_core.outputs.generation import Generation, GenerationChunk
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.outputs.run_info import RunInfo
__all__ = [
"ChatGeneration",
"ChatGenerationChunk",
"ChatResult",
"Generation",
"GenerationChunk",
"LLMResult",
"RunInfo",
]

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
from typing import Any, Dict, Literal
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation
from langchain_core.pydantic_v1 import root_validator
class ChatGeneration(Generation):
"""A single chat generation output."""
text: str = ""
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
message: BaseMessage
"""The message output by the chat model."""
# Override type to be ChatGeneration, ignore mypy error as this is intentional
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
"""Type is used exclusively for serialization purposes."""
@root_validator
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Set the text attribute to be the contents of the message."""
try:
values["text"] = values["message"].content
except (KeyError, AttributeError) as e:
raise ValueError("Error while initializing ChatGeneration") from e
return values
class ChatGenerationChunk(ChatGeneration):
"""A ChatGeneration chunk, which can be concatenated with other
ChatGeneration chunks.
Attributes:
message: The message chunk output by the chat model.
"""
message: BaseMessageChunk
# Override type to be ChatGeneration, ignore mypy error as this is intentional
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501
"""Type is used exclusively for serialization purposes."""
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk):
generation_info = (
{**(self.generation_info or {}), **(other.generation_info or {})}
if self.generation_info is not None or other.generation_info is not None
else None
)
return ChatGenerationChunk(
message=self.message + other.message,
generation_info=generation_info,
)
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)

View File

@@ -0,0 +1,15 @@
from typing import List, Optional
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.pydantic_v1 import BaseModel
class ChatResult(BaseModel):
"""Class that contains all results for a single chat model call."""
generations: List[ChatGeneration]
"""List of the chat generations. This is a List because an input can have multiple
candidate generations.
"""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""

View File

@@ -0,0 +1,50 @@
from __future__ import annotations
from typing import Any, Dict, List, Literal, Optional
from langchain_core.load import Serializable
class Generation(Serializable):
"""A single text generation output."""
text: str
"""Generated text output."""
generation_info: Optional[Dict[str, Any]] = None
"""Raw response from the provider. May include things like the
reason for finishing or token log probabilities.
"""
type: Literal["Generation"] = "Generation"
"""Type is used exclusively for serialization purposes."""
# TODO: add log probs as separate attribute
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "schema", "output"]
class GenerationChunk(Generation):
"""A Generation chunk, which can be concatenated with other Generation chunks."""
def __add__(self, other: GenerationChunk) -> GenerationChunk:
if isinstance(other, GenerationChunk):
generation_info = (
{**(self.generation_info or {}), **(other.generation_info or {})}
if self.generation_info is not None or other.generation_info is not None
else None
)
return GenerationChunk(
text=self.text + other.text,
generation_info=generation_info,
)
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)

View File

@@ -0,0 +1,65 @@
from __future__ import annotations
from copy import deepcopy
from typing import List, Optional
from langchain_core.outputs.generation import Generation
from langchain_core.outputs.run_info import RunInfo
from langchain_core.pydantic_v1 import BaseModel
class LLMResult(BaseModel):
"""Class that contains all results for a batched LLM call."""
generations: List[List[Generation]]
"""List of generated outputs. This is a List[List[]] because
each input could have multiple candidate generations."""
llm_output: Optional[dict] = None
"""Arbitrary LLM provider-specific output."""
run: Optional[List[RunInfo]] = None
"""List of metadata info for model call for each input."""
def flatten(self) -> List[LLMResult]:
"""Flatten generations into a single list.
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult
contains only a single Generation. If token usage information is available,
it is kept only for the LLMResult corresponding to the top-choice
Generation, to avoid over-counting of token usage downstream.
Returns:
List of LLMResults where each returned LLMResult contains a single
Generation.
"""
llm_results = []
for i, gen_list in enumerate(self.generations):
# Avoid double counting tokens in OpenAICallback
if i == 0:
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=self.llm_output,
)
)
else:
if self.llm_output is not None:
llm_output = deepcopy(self.llm_output)
llm_output["token_usage"] = dict()
else:
llm_output = None
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=llm_output,
)
)
return llm_results
def __eq__(self, other: object) -> bool:
"""Check for LLMResult equality by ignoring any metadata related to runs."""
if not isinstance(other, LLMResult):
return NotImplemented
return (
self.generations == other.generations
and self.llm_output == other.llm_output
)

View File

@@ -0,0 +1,12 @@
from __future__ import annotations
from uuid import UUID
from langchain_core.pydantic_v1 import BaseModel
class RunInfo(BaseModel):
"""Class that contains metadata for a single execution of a Chain or model."""
run_id: UUID
"""A unique identifier for the model or chain run."""

View File

@@ -0,0 +1,86 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Literal, Sequence
from langchain_core.load.serializable import Serializable
from langchain_core.messages import (
AnyMessage,
BaseMessage,
HumanMessage,
get_buffer_string,
)
class PromptValue(Serializable, ABC):
"""Base abstract class for inputs to any language model.
PromptValues can be converted to both LLM (pure text-generation) inputs and
ChatModel inputs.
"""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@abstractmethod
def to_string(self) -> str:
"""Return prompt value as string."""
@abstractmethod
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of Messages."""
class StringPromptValue(PromptValue):
"""String prompt value."""
text: str
"""Prompt text."""
type: Literal["StringPromptValue"] = "StringPromptValue"
def to_string(self) -> str:
"""Return prompt as string."""
return self.text
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
return [HumanMessage(content=self.text)]
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "prompts", "base"]
class ChatPromptValue(PromptValue):
"""Chat prompt value.
A type of a prompt value that is built from messages.
"""
messages: Sequence[BaseMessage]
"""List of messages."""
def to_string(self) -> str:
"""Return prompt as string."""
return get_buffer_string(self.messages)
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of messages."""
return list(self.messages)
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "prompts", "chat"]
class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts.
For use in external schemas."""
messages: Sequence[AnyMessage]
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"

View File

@@ -0,0 +1,74 @@
"""**Prompt** is the input to the model.
Prompt is often constructed
from multiple components. Prompt classes and functions make constructing
and working with prompts easy.
**Class hierarchy:**
.. code-block::
BasePromptTemplate --> PipelinePromptTemplate
StringPromptTemplate --> PromptTemplate
FewShotPromptTemplate
FewShotPromptWithTemplates
BaseChatPromptTemplate --> AutoGPTPrompt
ChatPromptTemplate --> AgentScratchPadChatPromptTemplate
BaseMessagePromptTemplate --> MessagesPlaceholder
BaseStringMessagePromptTemplate --> ChatMessagePromptTemplate
HumanMessagePromptTemplate
AIMessagePromptTemplate
SystemMessagePromptTemplate
""" # noqa: E501
from langchain_core.prompts.base import BasePromptTemplate, format_document
from langchain_core.prompts.chat import (
AIMessagePromptTemplate,
BaseChatPromptTemplate,
ChatMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from langchain_core.prompts.few_shot import (
FewShotChatMessagePromptTemplate,
FewShotPromptTemplate,
)
from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates
from langchain_core.prompts.loading import load_prompt
from langchain_core.prompts.pipeline import PipelinePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import (
StringPromptTemplate,
check_valid_template,
get_template_variables,
jinja2_formatter,
validate_jinja2,
)
__all__ = [
"AIMessagePromptTemplate",
"BaseChatPromptTemplate",
"BasePromptTemplate",
"ChatMessagePromptTemplate",
"ChatPromptTemplate",
"FewShotPromptTemplate",
"FewShotPromptWithTemplates",
"FewShotChatMessagePromptTemplate",
"HumanMessagePromptTemplate",
"MessagesPlaceholder",
"PipelinePromptTemplate",
"PromptTemplate",
"StringPromptTemplate",
"SystemMessagePromptTemplate",
"load_prompt",
"format_document",
"check_valid_template",
"get_template_variables",
"jinja2_formatter",
"validate_jinja2",
]

View File

@@ -0,0 +1,246 @@
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Type,
Union,
)
import yaml
from langchain_core.output_parsers.base import BaseOutputParser
from langchain_core.prompt_values import (
ChatPromptValueConcrete,
PromptValue,
StringPromptValue,
)
from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator
from langchain_core.runnables import RunnableConfig, RunnableSerializable
if TYPE_CHECKING:
from langchain_core.documents import Document
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
"""Base class for all prompt templates, returning a prompt."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
input_types: Dict[str, Any] = Field(default_factory=dict)
"""A dictionary of the types of the variables the prompt template expects.
If not provided, all variables are assumed to be strings."""
output_parser: Optional[BaseOutputParser] = None
"""How to parse the output of calling an LLM on this formatted prompt."""
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
default_factory=dict
)
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def OutputType(self) -> Any:
return Union[StringPromptValue, ChatPromptValueConcrete]
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"PromptInput",
**{k: (self.input_types.get(k, str), None) for k in self.input_variables},
)
def invoke(
self, input: Dict, config: Optional[RunnableConfig] = None
) -> PromptValue:
return self._call_with_config(
lambda inner_input: self.format_prompt(
**{key: inner_input[key] for key in self.input_variables}
),
input,
config,
run_type="prompt",
)
@abstractmethod
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
@root_validator()
def validate_variable_names(cls, values: Dict) -> Dict:
"""Validate variable names do not include restricted names."""
if "stop" in values["input_variables"]:
raise ValueError(
"Cannot have an input variable named 'stop', as it is used internally,"
" please rename."
)
if "stop" in values["partial_variables"]:
raise ValueError(
"Cannot have an partial variable named 'stop', as it is used "
"internally, please rename."
)
overall = set(values["input_variables"]).intersection(
values["partial_variables"]
)
if overall:
raise ValueError(
f"Found overlapping input and partial variables: {overall}"
)
return values
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
"""Return a partial of the prompt template."""
prompt_dict = self.__dict__.copy()
prompt_dict["input_variables"] = list(
set(self.input_variables).difference(kwargs)
)
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
# Get partial params:
partial_kwargs = {
k: v if isinstance(v, str) else v()
for k, v in self.partial_variables.items()
}
return {**partial_kwargs, **kwargs}
@abstractmethod
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of prompt."""
prompt_dict = super().dict(**kwargs)
try:
prompt_dict["_type"] = self._prompt_type
except NotImplementedError:
pass
return prompt_dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the prompt.
Args:
file_path: Path to directory to save prompt to.
Example:
.. code-block:: python
prompt.save(file_path="path/prompt.yaml")
"""
if self.partial_variables:
raise ValueError("Cannot save prompt with partial variables.")
# Fetch dictionary to save
prompt_dict = self.dict()
if "_type" not in prompt_dict:
raise NotImplementedError(f"Prompt {self} does not support saving.")
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain"] + cls.__module__.split(".")[1:]
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
"""Format a document into a string based on a prompt template.
First, this pulls information from the document from two sources:
1. `page_content`:
This takes the information from the `document.page_content`
and assigns it to a variable named `page_content`.
2. metadata:
This takes information from `document.metadata` and assigns
it to variables of the same name.
Those variables are then passed into the `prompt` to produce a formatted string.
Args:
doc: Document, the page_content and metadata will be used to create
the final string.
prompt: BasePromptTemplate, will be used to format the page_content
and metadata into the final string.
Returns:
string of the document formatted.
Example:
.. code-block:: python
from langchain_core import Document
from langchain_core.prompts import PromptTemplate
doc = Document(page_content="This is a joke", metadata={"page": "1"})
prompt = PromptTemplate.from_template("Page {page}: {page_content}")
format_document(doc, prompt)
>>> "Page 1: This is a joke"
"""
base_info = {"page_content": doc.page_content, **doc.metadata}
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
]
raise ValueError(
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
document_info = {k: base_info[k] for k in prompt.input_variables}
return prompt.format(**document_info)

View File

@@ -0,0 +1,733 @@
"""Chat prompt template."""
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
overload,
)
from langchain_core._api import deprecated
from langchain_core.load import Serializable
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.prompt_values import ChatPromptValue, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate
from langchain_core.pydantic_v1 import Field, root_validator
class BaseMessagePromptTemplate(Serializable, ABC):
"""Base class for message prompt templates."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return True
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs. Should return a list of BaseMessages.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
@property
@abstractmethod
def input_variables(self) -> List[str]:
"""Input variables for this prompt template.
Returns:
List of input variables.
"""
def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.
Args:
other: Another prompt template.
Returns:
Combined prompt template.
"""
prompt = ChatPromptTemplate(messages=[self])
return prompt + other
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain"] + cls.__module__.split(".")[1:]
class MessagesPlaceholder(BaseMessagePromptTemplate):
"""Prompt template that assumes variable is already list of messages."""
variable_name: str
"""Name of variable to use as messages."""
def __init__(self, variable_name: str, **kwargs: Any):
return super().__init__(variable_name=variable_name, **kwargs)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessage.
"""
value = kwargs[self.variable_name]
if not isinstance(value, list):
raise ValueError(
f"variable {self.variable_name} should be a list of base messages, "
f"got {value}"
)
for v in value:
if not isinstance(v, BaseMessage):
raise ValueError(
f"variable {self.variable_name} should be a list of base messages,"
f" got {value}"
)
return value
@property
def input_variables(self) -> List[str]:
"""Input variables for this prompt template.
Returns:
List of input variable names.
"""
return [self.variable_name]
MessagePromptTemplateT = TypeVar(
"MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
)
"""Type variable for message prompt templates."""
class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
"""Base class for message prompt templates that use a string prompt template."""
prompt: StringPromptTemplate
"""String prompt template."""
additional_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the prompt template."""
@classmethod
def from_template(
cls: Type[MessagePromptTemplateT],
template: str,
template_format: str = "f-string",
partial_variables: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> MessagePromptTemplateT:
"""Create a class from a string template.
Args:
template: a template.
template_format: format of the template.
partial_variables: A dictionary of variables that can be used to partially
fill in the template. For example, if the template is
`"{variable1} {variable2}"`, and `partial_variables` is
`{"variable1": "foo"}`, then the final prompt will be
`"foo {variable2}"`.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
prompt = PromptTemplate.from_template(
template,
template_format=template_format,
partial_variables=partial_variables,
)
return cls(prompt=prompt, **kwargs)
@classmethod
def from_template_file(
cls: Type[MessagePromptTemplateT],
template_file: Union[str, Path],
input_variables: List[str],
**kwargs: Any,
) -> MessagePromptTemplateT:
"""Create a class from a template file.
Args:
template_file: path to a template file. String or Path.
input_variables: list of input variables.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
prompt = PromptTemplate.from_file(template_file, input_variables)
return cls(prompt=prompt, **kwargs)
@abstractmethod
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
return [self.format(**kwargs)]
@property
def input_variables(self) -> List[str]:
"""
Input variables for this prompt template.
Returns:
List of input variable names.
"""
return self.prompt.input_variables
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Chat message prompt template."""
role: str
"""Role of the message."""
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return ChatMessage(
content=text, role=self.role, additional_kwargs=self.additional_kwargs
)
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""AI message prompt template. This is a message sent from the AI."""
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""System message prompt template.
This is a message that is not sent to the user.
"""
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
"""Base class for chat prompt templates."""
@property
def lc_attributes(self) -> Dict:
"""
Return a list of attribute names that should be included in the
serialized kwargs. These attributes must be accepted by the
constructor.
"""
return {"input_variables": self.input_variables}
def format(self, **kwargs: Any) -> str:
"""Format the chat template into a string.
Args:
**kwargs: keyword arguments to use for filling in template variables
in all the template messages in this chat template.
Returns:
formatted string
"""
return self.format_prompt(**kwargs).to_string()
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""
Format prompt. Should return a PromptValue.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
PromptValue.
"""
messages = self.format_messages(**kwargs)
return ChatPromptValue(messages=messages)
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages."""
MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]
MessageLikeRepresentation = Union[
MessageLike,
Tuple[str, str],
Tuple[Type, str],
str,
]
class ChatPromptTemplate(BaseChatPromptTemplate):
"""A prompt template for chat models.
Use to create flexible templated prompts for chat models.
Examples:
.. code-block:: python
from langchain_core.prompts import ChatPromptTemplate
template = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI bot. Your name is {name}."),
("human", "Hello, how are you doing?"),
("ai", "I'm doing well, thanks!"),
("human", "{user_input}"),
])
messages = template.format_messages(
name="Bob",
user_input="What is your name?"
)
"""
input_variables: List[str]
"""List of input variables in template messages. Used for validation."""
messages: List[MessageLike]
"""List of messages consisting of either message prompt templates or messages."""
validate_template: bool = False
"""Whether or not to try validating the template."""
def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.
Args:
other: Another prompt template.
Returns:
Combined prompt template.
"""
# Allow for easy combining
if isinstance(other, ChatPromptTemplate):
return ChatPromptTemplate(messages=self.messages + other.messages)
elif isinstance(
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
):
return ChatPromptTemplate(messages=self.messages + [other])
elif isinstance(other, (list, tuple)):
_other = ChatPromptTemplate.from_messages(other)
return ChatPromptTemplate(messages=self.messages + _other.messages)
elif isinstance(other, str):
prompt = HumanMessagePromptTemplate.from_template(other)
return ChatPromptTemplate(messages=self.messages + [prompt])
else:
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
@root_validator(pre=True)
def validate_input_variables(cls, values: dict) -> dict:
"""Validate input variables.
If input_variables is not set, it will be set to the union of
all input variables in the messages.
Args:
values: values to validate.
Returns:
Validated values.
"""
messages = values["messages"]
input_vars = set()
input_types: Dict[str, Any] = values.get("input_types", {})
for message in messages:
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
input_vars.update(message.input_variables)
if isinstance(message, MessagesPlaceholder):
if message.variable_name not in input_types:
input_types[message.variable_name] = List[AnyMessage]
if "partial_variables" in values:
input_vars = input_vars - set(values["partial_variables"])
if "input_variables" in values and values.get("validate_template"):
if input_vars != set(values["input_variables"]):
raise ValueError(
"Got mismatched input_variables. "
f"Expected: {input_vars}. "
f"Got: {values['input_variables']}"
)
else:
values["input_variables"] = sorted(input_vars)
values["input_types"] = input_types
return values
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
"""Create a chat prompt template from a template string.
Creates a chat template consisting of a single message assumed to be from
the human.
Args:
template: template string
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
prompt_template = PromptTemplate.from_template(template, **kwargs)
message = HumanMessagePromptTemplate(prompt=prompt_template)
return cls.from_messages([message])
@classmethod
@deprecated("0.0.260", alternative="from_messages classmethod", pending=True)
def from_role_strings(
cls, string_messages: List[Tuple[str, str]]
) -> ChatPromptTemplate:
"""Create a chat prompt template from a list of (role, template) tuples.
Args:
string_messages: list of (role, template) tuples.
Returns:
a chat prompt template
"""
return cls(
messages=[
ChatMessagePromptTemplate.from_template(template, role=role)
for role, template in string_messages
]
)
@classmethod
@deprecated("0.0.260", alternative="from_messages classmethod", pending=True)
def from_strings(
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]]
) -> ChatPromptTemplate:
"""Create a chat prompt template from a list of (role class, template) tuples.
Args:
string_messages: list of (role class, template) tuples.
Returns:
a chat prompt template
"""
return cls.from_messages(string_messages)
@classmethod
def from_messages(
cls,
messages: Sequence[MessageLikeRepresentation],
) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats.
Examples:
Instantiation from a list of message templates:
.. code-block:: python
template = ChatPromptTemplate.from_messages([
("human", "Hello, how are you?"),
("ai", "I'm doing well, thanks!"),
("human", "That's good to hear."),
])
Instantiation from mixed message formats:
.. code-block:: python
template = ChatPromptTemplate.from_messages([
SystemMessage(content="hello"),
("human", "Hello, how are you?"),
])
Args:
messages: sequence of message representations.
A message can be represented using the following formats:
(1) BaseMessagePromptTemplate, (2) BaseMessage, (3) 2-tuple of
(message type, template); e.g., ("human", "{user_input}"),
(4) 2-tuple of (message class, template), (4) a string which is
shorthand for ("human", template); e.g., "{user_input}"
Returns:
a chat prompt template
"""
_messages = [_convert_to_message(message) for message in messages]
# Automatically infer input variables from messages
input_vars: Set[str] = set()
for _message in _messages:
if isinstance(
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
):
input_vars.update(_message.input_variables)
return cls(input_variables=sorted(input_vars), messages=_messages)
def format(self, **kwargs: Any) -> str:
"""Format the chat template into a string.
Args:
**kwargs: keyword arguments to use for filling in template variables
in all the template messages in this chat template.
Returns:
formatted string
"""
return self.format_prompt(**kwargs).to_string()
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the chat template into a list of finalized messages.
Args:
**kwargs: keyword arguments to use for filling in template variables
in all the template messages in this chat template.
Returns:
list of formatted messages
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
result = []
for message_template in self.messages:
if isinstance(message_template, BaseMessage):
result.extend([message_template])
elif isinstance(
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
):
rel_params = {
k: v
for k, v in kwargs.items()
if k in message_template.input_variables
}
message = message_template.format_messages(**rel_params)
result.extend(message)
else:
raise ValueError(f"Unexpected input: {message_template}")
return result
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate:
"""Get a new ChatPromptTemplate with some input variables already filled in.
Args:
**kwargs: keyword arguments to use for filling in template variables. Ought
to be a subset of the input variables.
Returns:
A new ChatPromptTemplate.
Example:
.. code-block:: python
from langchain_core.prompts import ChatPromptTemplate
template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
("human", "Hi I'm {user}"),
("ai", "Hi there, {user}, I'm {name}."),
("human", "{input}"),
]
)
template2 = template.partial(user="Lucy", name="R2D2")
template2.format_messages(input="hello")
"""
prompt_dict = self.__dict__.copy()
prompt_dict["input_variables"] = list(
set(self.input_variables).difference(kwargs)
)
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
def append(self, message: MessageLikeRepresentation) -> None:
"""Append message to the end of the chat template.
Args:
message: representation of a message to append.
"""
self.messages.append(_convert_to_message(message))
def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None:
"""Extend the chat template with a sequence of messages."""
self.messages.extend([_convert_to_message(message) for message in messages])
@overload
def __getitem__(self, index: int) -> MessageLike:
...
@overload
def __getitem__(self, index: slice) -> ChatPromptTemplate:
...
def __getitem__(
self, index: Union[int, slice]
) -> Union[MessageLike, ChatPromptTemplate]:
"""Use to index into the chat template."""
if isinstance(index, slice):
start, stop, step = index.indices(len(self.messages))
messages = self.messages[start:stop:step]
return ChatPromptTemplate.from_messages(messages)
else:
return self.messages[index]
def __len__(self) -> int:
"""Get the length of the chat template."""
return len(self.messages)
@property
def _prompt_type(self) -> str:
"""Name of prompt type."""
return "chat"
def save(self, file_path: Union[Path, str]) -> None:
"""Save prompt to file.
Args:
file_path: path to file.
"""
raise NotImplementedError()
def _create_template_from_message_type(
message_type: str, template: str
) -> BaseMessagePromptTemplate:
"""Create a message prompt template from a message type and template string.
Args:
message_type: str the type of the message template (e.g., "human", "ai", etc.)
template: str the template string.
Returns:
a message prompt template of the appropriate type.
"""
if message_type in ("human", "user"):
message: BaseMessagePromptTemplate = HumanMessagePromptTemplate.from_template(
template
)
elif message_type in ("ai", "assistant"):
message = AIMessagePromptTemplate.from_template(template)
elif message_type == "system":
message = SystemMessagePromptTemplate.from_template(template)
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
f" 'user', 'ai', 'assistant', or 'system'."
)
return message
def _convert_to_message(
message: MessageLikeRepresentation,
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
- BaseMessagePromptTemplate
- BaseMessage
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
- 2-tuple of (message class, template)
- string: shorthand for ("human", template); e.g., "{user_input}"
Args:
message: a representation of a message in one of the supported formats
Returns:
an instance of a message or a message template
"""
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
_message: Union[
BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate
] = message
elif isinstance(message, BaseMessage):
_message = message
elif isinstance(message, str):
_message = _create_template_from_message_type("human", message)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template)
else:
_message = message_type_str(prompt=PromptTemplate.from_template(template))
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")
return _message

View File

@@ -0,0 +1,342 @@
"""Prompt template that contains few shot examples."""
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.prompts.chat import (
BaseChatPromptTemplate,
BaseMessagePromptTemplate,
)
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
check_valid_template,
get_template_variables,
)
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
class _FewShotPromptTemplateMixin(BaseModel):
"""Prompt template that contains few shot examples."""
examples: Optional[List[dict]] = None
"""Examples to format into the prompt.
Either this or example_selector should be provided."""
example_selector: Any = None
"""ExampleSelector to choose the examples to format into the prompt.
Either this or examples should be provided."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict:
"""Check that one and only one of examples/example_selector are provided."""
examples = values.get("examples", None)
example_selector = values.get("example_selector", None)
if examples and example_selector:
raise ValueError(
"Only one of 'examples' and 'example_selector' should be provided"
)
if examples is None and example_selector is None:
raise ValueError(
"One of 'examples' and 'example_selector' should be provided"
)
return values
def _get_examples(self, **kwargs: Any) -> List[dict]:
"""Get the examples to use for formatting the prompt.
Args:
**kwargs: Keyword arguments to be passed to the example selector.
Returns:
List of examples.
"""
if self.examples is not None:
return self.examples
elif self.example_selector is not None:
return self.example_selector.select_examples(kwargs)
else:
raise ValueError(
"One of 'examples' and 'example_selector' should be provided"
)
class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
"""Prompt template that contains few shot examples."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return False
validate_template: bool = False
"""Whether or not to try validating the template."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
example_prompt: PromptTemplate
"""PromptTemplate used to format an individual example."""
suffix: str
"""A prompt template string to put after the examples."""
example_separator: str = "\n\n"
"""String separator used to join the prefix, the examples, and suffix."""
prefix: str = ""
"""A prompt template string to put before the examples."""
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix, and input variables are consistent."""
if values["validate_template"]:
check_valid_template(
values["prefix"] + values["suffix"],
values["template_format"],
values["input_variables"] + list(values["partial_variables"]),
)
elif values.get("template_format"):
values["input_variables"] = [
var
for var in get_template_variables(
values["prefix"] + values["suffix"], values["template_format"]
)
if var not in values["partial_variables"]
]
return values
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
**kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
# Get the examples to use.
examples = self._get_examples(**kwargs)
examples = [
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
]
# Format the examples.
example_strings = [
self.example_prompt.format(**example) for example in examples
]
# Create the overall template.
pieces = [self.prefix, *example_strings, self.suffix]
template = self.example_separator.join([piece for piece in pieces if piece])
# Format the template with the input variables.
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "few_shot"
def save(self, file_path: Union[Path, str]) -> None:
if self.example_selector:
raise ValueError("Saving an example selector is not currently supported")
return super().save(file_path)
class FewShotChatMessagePromptTemplate(
BaseChatPromptTemplate, _FewShotPromptTemplateMixin
):
"""Chat prompt template that supports few-shot examples.
The high level structure of produced by this prompt template is a list of messages
consisting of prefix message(s), example message(s), and suffix message(s).
This structure enables creating a conversation with intermediate examples like:
System: You are a helpful AI Assistant
Human: What is 2+2?
AI: 4
Human: What is 2+3?
AI: 5
Human: What is 4+4?
This prompt template can be used to generate a fixed list of examples or else
to dynamically select examples based on the input.
Examples:
Prompt template with a fixed list of examples (matching the sample
conversation above):
.. code-block:: python
from langchain_core.prompts import (
FewShotChatMessagePromptTemplate,
ChatPromptTemplate
)
examples = [
{"input": "2+2", "output": "4"},
{"input": "2+3", "output": "5"},
]
example_prompt = ChatPromptTemplate.from_messages(
[('human', '{input}'), ('ai', '{output}')]
)
few_shot_prompt = FewShotChatMessagePromptTemplate(
examples=examples,
# This is a prompt template used to format each individual example.
example_prompt=example_prompt,
)
final_prompt = ChatPromptTemplate.from_messages(
[
('system', 'You are a helpful AI Assistant'),
few_shot_prompt,
('human', '{input}'),
]
)
final_prompt.format(input="What is 4+4?")
Prompt template with dynamically selected examples:
.. code-block:: python
from langchain_core.prompts import SemanticSimilarityExampleSelector
from langchain_core.embeddings import OpenAIEmbeddings
from langchain_core.vectorstores import Chroma
examples = [
{"input": "2+2", "output": "4"},
{"input": "2+3", "output": "5"},
{"input": "2+4", "output": "6"},
# ...
]
to_vectorize = [
" ".join(example.values())
for example in examples
]
embeddings = OpenAIEmbeddings()
vectorstore = Chroma.from_texts(
to_vectorize, embeddings, metadatas=examples
)
example_selector = SemanticSimilarityExampleSelector(
vectorstore=vectorstore
)
from langchain_core import SystemMessage
from langchain_core.prompts import HumanMessagePromptTemplate
from langchain_core.prompts.few_shot import FewShotChatMessagePromptTemplate
few_shot_prompt = FewShotChatMessagePromptTemplate(
# Which variable(s) will be passed to the example selector.
input_variables=["input"],
example_selector=example_selector,
# Define how each example will be formatted.
# In this case, each example will become 2 messages:
# 1 human, and 1 AI
example_prompt=(
HumanMessagePromptTemplate.from_template("{input}")
+ AIMessagePromptTemplate.from_template("{output}")
),
)
# Define the overall prompt.
final_prompt = (
SystemMessagePromptTemplate.from_template(
"You are a helpful AI Assistant"
)
+ few_shot_prompt
+ HumanMessagePromptTemplate.from_template("{input}")
)
# Show the prompt
print(final_prompt.format_messages(input="What's 3+3?"))
# Use within an LLM
from langchain_core.chat_models import ChatAnthropic
chain = final_prompt | ChatAnthropic()
chain.invoke({"input": "What's 3+3?"})
"""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return False
input_variables: List[str] = Field(default_factory=list)
"""A list of the names of the variables the prompt template will use
to pass to the example_selector, if provided."""
example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate]
"""The class to format each example."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages.
Args:
**kwargs: keyword arguments to use for filling in templates in messages.
Returns:
A list of formatted messages with all template variables filled in.
"""
# Get the examples to use.
examples = self._get_examples(**kwargs)
examples = [
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
]
# Format the examples.
messages = [
message
for example in examples
for message in self.example_prompt.format_messages(**example)
]
return messages
def format(self, **kwargs: Any) -> str:
"""Format the prompt with inputs generating a string.
Use this method to generate a string representation of a prompt consisting
of chat messages.
Useful for feeding into a string based completion language model or debugging.
Args:
**kwargs: keyword arguments to use for formatting.
Returns:
A string representation of the prompt
"""
messages = self.format_messages(**kwargs)
return get_buffer_string(messages)

View File

@@ -0,0 +1,155 @@
"""Prompt template that contains few shot examples."""
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
)
from langchain_core.pydantic_v1 import Extra, root_validator
class FewShotPromptWithTemplates(StringPromptTemplate):
"""Prompt template that contains few shot examples."""
examples: Optional[List[dict]] = None
"""Examples to format into the prompt.
Either this or example_selector should be provided."""
example_selector: Any = None
"""ExampleSelector to choose the examples to format into the prompt.
Either this or examples should be provided."""
example_prompt: PromptTemplate
"""PromptTemplate used to format an individual example."""
suffix: StringPromptTemplate
"""A PromptTemplate to put after the examples."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
example_separator: str = "\n\n"
"""String separator used to join the prefix, the examples, and suffix."""
prefix: Optional[StringPromptTemplate] = None
"""A PromptTemplate to put before the examples."""
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
validate_template: bool = False
"""Whether or not to try validating the template."""
@root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict:
"""Check that one and only one of examples/example_selector are provided."""
examples = values.get("examples", None)
example_selector = values.get("example_selector", None)
if examples and example_selector:
raise ValueError(
"Only one of 'examples' and 'example_selector' should be provided"
)
if examples is None and example_selector is None:
raise ValueError(
"One of 'examples' and 'example_selector' should be provided"
)
return values
@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix, and input variables are consistent."""
if values["validate_template"]:
input_variables = values["input_variables"]
expected_input_variables = set(values["suffix"].input_variables)
expected_input_variables |= set(values["partial_variables"])
if values["prefix"] is not None:
expected_input_variables |= set(values["prefix"].input_variables)
missing_vars = expected_input_variables.difference(input_variables)
if missing_vars:
raise ValueError(
f"Got input_variables={input_variables}, but based on "
f"prefix/suffix expected {expected_input_variables}"
)
else:
values["input_variables"] = sorted(
set(values["suffix"].input_variables)
| set(values["prefix"].input_variables if values["prefix"] else [])
- set(values["partial_variables"])
)
return values
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def _get_examples(self, **kwargs: Any) -> List[dict]:
if self.examples is not None:
return self.examples
elif self.example_selector is not None:
return self.example_selector.select_examples(kwargs)
else:
raise ValueError
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
# Get the examples to use.
examples = self._get_examples(**kwargs)
# Format the examples.
example_strings = [
self.example_prompt.format(**example) for example in examples
]
# Create the overall prefix.
if self.prefix is None:
prefix = ""
else:
prefix_kwargs = {
k: v for k, v in kwargs.items() if k in self.prefix.input_variables
}
for k in prefix_kwargs.keys():
kwargs.pop(k)
prefix = self.prefix.format(**prefix_kwargs)
# Create the overall suffix
suffix_kwargs = {
k: v for k, v in kwargs.items() if k in self.suffix.input_variables
}
for k in suffix_kwargs.keys():
kwargs.pop(k)
suffix = self.suffix.format(
**suffix_kwargs,
)
pieces = [prefix, *example_strings, suffix]
template = self.example_separator.join([piece for piece in pieces if piece])
# Format the template with the input variables.
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "few_shot_with_templates"
def save(self, file_path: Union[Path, str]) -> None:
if self.example_selector:
raise ValueError("Saving an example selector is not currently supported")
return super().save(file_path)

View File

@@ -0,0 +1,160 @@
"""Load prompts."""
import json
import logging
from pathlib import Path
from typing import Callable, Dict, Union
import yaml
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.utils import try_load_from_hub
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
logger = logging.getLogger(__name__)
def load_prompt_from_config(config: dict) -> BasePromptTemplate:
"""Load prompt from Config Dict."""
if "_type" not in config:
logger.warning("No `_type` key found, defaulting to `prompt`.")
config_type = config.pop("_type", "prompt")
if config_type not in type_to_loader_dict:
raise ValueError(f"Loading {config_type} prompt not supported")
prompt_loader = type_to_loader_dict[config_type]
return prompt_loader(config)
def _load_template(var_name: str, config: dict) -> dict:
"""Load template from the path if applicable."""
# Check if template_path exists in config.
if f"{var_name}_path" in config:
# If it does, make sure template variable doesn't also exist.
if var_name in config:
raise ValueError(
f"Both `{var_name}_path` and `{var_name}` cannot be provided."
)
# Pop the template path from the config.
template_path = Path(config.pop(f"{var_name}_path"))
# Load the template.
if template_path.suffix == ".txt":
with open(template_path) as f:
template = f.read()
else:
raise ValueError
# Set the template variable to the extracted variable.
config[var_name] = template
return config
def _load_examples(config: dict) -> dict:
"""Load examples if necessary."""
if isinstance(config["examples"], list):
pass
elif isinstance(config["examples"], str):
with open(config["examples"]) as f:
if config["examples"].endswith(".json"):
examples = json.load(f)
elif config["examples"].endswith((".yaml", ".yml")):
examples = yaml.safe_load(f)
else:
raise ValueError(
"Invalid file format. Only json or yaml formats are supported."
)
config["examples"] = examples
else:
raise ValueError("Invalid examples format. Only list or string are supported.")
return config
def _load_output_parser(config: dict) -> dict:
"""Load output parser."""
if "output_parser" in config and config["output_parser"]:
_config = config.pop("output_parser")
output_parser_type = _config.pop("_type")
if output_parser_type == "default":
output_parser = StrOutputParser(**_config)
else:
raise ValueError(f"Unsupported output parser {output_parser_type}")
config["output_parser"] = output_parser
return config
def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate:
"""Load the "few shot" prompt from the config."""
# Load the suffix and prefix templates.
config = _load_template("suffix", config)
config = _load_template("prefix", config)
# Load the example prompt.
if "example_prompt_path" in config:
if "example_prompt" in config:
raise ValueError(
"Only one of example_prompt and example_prompt_path should "
"be specified."
)
config["example_prompt"] = load_prompt(config.pop("example_prompt_path"))
else:
config["example_prompt"] = load_prompt_from_config(config["example_prompt"])
# Load the examples.
config = _load_examples(config)
config = _load_output_parser(config)
return FewShotPromptTemplate(**config)
def _load_prompt(config: dict) -> PromptTemplate:
"""Load the prompt template from config."""
# Load the template from disk if necessary.
config = _load_template("template", config)
config = _load_output_parser(config)
template_format = config.get("template_format", "f-string")
if template_format == "jinja2":
# Disabled due to:
# https://github.com/langchain-ai/langchain/issues/4394
raise ValueError(
f"Loading templates with '{template_format}' format is no longer supported "
f"since it can lead to arbitrary code execution. Please migrate to using "
f"the 'f-string' template format, which does not suffer from this issue."
)
return PromptTemplate(**config)
def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
"""Unified method for loading a prompt from LangChainHub or local fs."""
if hub_result := try_load_from_hub(
path, _load_prompt_from_file, "prompts", {"py", "json", "yaml"}
):
return hub_result
else:
return _load_prompt_from_file(path)
def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
"""Load prompt from file."""
# Convert file to a Path object.
if isinstance(file, str):
file_path = Path(file)
else:
file_path = file
# Load from either json or yaml.
if file_path.suffix == ".json":
with open(file_path) as f:
config = json.load(f)
elif file_path.suffix == ".yaml":
with open(file_path, "r") as f:
config = yaml.safe_load(f)
else:
raise ValueError(f"Got unsupported file type {file_path.suffix}")
# Load the prompt from the config now.
return load_prompt_from_config(config)
type_to_loader_dict: Dict[str, Callable[[dict], BasePromptTemplate]] = {
"prompt": _load_prompt,
"few_shot": _load_few_shot_prompt,
}

View File

@@ -0,0 +1,57 @@
from typing import Any, Dict, List, Tuple
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.chat import BaseChatPromptTemplate
from langchain_core.pydantic_v1 import root_validator
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
return {k: inputs[k] for k in input_variables}
class PipelinePromptTemplate(BasePromptTemplate):
"""A prompt template for composing multiple prompt templates together.
This can be useful when you want to reuse parts of prompts.
A PipelinePrompt consists of two main parts:
- final_prompt: This is the final prompt that is returned
- pipeline_prompts: This is a list of tuples, consisting
of a string (`name`) and a Prompt Template.
Each PromptTemplate will be formatted and then passed
to future prompt templates as a variable with
the same name as `name`
"""
final_prompt: BasePromptTemplate
"""The final prompt that is returned."""
pipeline_prompts: List[Tuple[str, BasePromptTemplate]]
"""A list of tuples, consisting of a string (`name`) and a Prompt Template."""
@root_validator(pre=True)
def get_input_variables(cls, values: Dict) -> Dict:
"""Get input variables."""
created_variables = set()
all_variables = set()
for k, prompt in values["pipeline_prompts"]:
created_variables.add(k)
all_variables.update(prompt.input_variables)
values["input_variables"] = list(all_variables.difference(created_variables))
return values
def format_prompt(self, **kwargs: Any) -> PromptValue:
for k, prompt in self.pipeline_prompts:
_inputs = _get_inputs(kwargs, prompt.input_variables)
if isinstance(prompt, BaseChatPromptTemplate):
kwargs[k] = prompt.format_messages(**_inputs)
else:
kwargs[k] = prompt.format(**_inputs)
_inputs = _get_inputs(kwargs, self.final_prompt.input_variables)
return self.final_prompt.format_prompt(**_inputs)
def format(self, **kwargs: Any) -> str:
return self.format_prompt(**kwargs).to_string()
@property
def _prompt_type(self) -> str:
raise ValueError

View File

@@ -0,0 +1,246 @@
"""Prompt schema definition."""
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
check_valid_template,
get_template_variables,
)
from langchain_core.pydantic_v1 import root_validator
class PromptTemplate(StringPromptTemplate):
"""A prompt template for a language model.
A prompt template consists of a string template. It accepts a set of parameters
from the user that can be used to generate a prompt for a language model.
The template can be formatted using either f-strings (default) or jinja2 syntax.
*Security warning*: Prefer using `template_format="f-string"` instead of
`template_format="jinja2"`, or make sure to NEVER accept jinja2 templates
from untrusted sources as they may lead to arbitrary Python code execution.
As of LangChain 0.0.329, Jinja2 templates will be rendered using
Jinja2's SandboxedEnvironment by default. This sand-boxing should
be treated as a best-effort approach rather than a guarantee of security,
as it is an opt-out rather than opt-in approach.
Despite the sand-boxing, we recommend to never use jinja2 templates
from untrusted sources.
Example:
.. code-block:: python
from langchain_core.prompts import PromptTemplate
# Instantiation using from_template (recommended)
prompt = PromptTemplate.from_template("Say {foo}")
prompt.format(foo="bar")
# Instantiation using initializer
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
"""
@property
def lc_attributes(self) -> Dict[str, Any]:
return {
"template_format": self.template_format,
}
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
template: str
"""The prompt template."""
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
validate_template: bool = False
"""Whether or not to try validating the template."""
def __add__(self, other: Any) -> PromptTemplate:
"""Override the + operator to allow for combining prompt templates."""
# Allow for easy combining
if isinstance(other, PromptTemplate):
if self.template_format != "f-string":
raise ValueError(
"Adding prompt templates only supported for f-strings."
)
if other.template_format != "f-string":
raise ValueError(
"Adding prompt templates only supported for f-strings."
)
input_variables = list(
set(self.input_variables) | set(other.input_variables)
)
template = self.template + other.template
# If any do not want to validate, then don't
validate_template = self.validate_template and other.validate_template
partial_variables = {k: v for k, v in self.partial_variables.items()}
for k, v in other.partial_variables.items():
if k in partial_variables:
raise ValueError("Cannot have same variable partialed twice.")
else:
partial_variables[k] = v
return PromptTemplate(
template=template,
input_variables=input_variables,
partial_variables=partial_variables,
template_format="f-string",
validate_template=validate_template,
)
elif isinstance(other, str):
prompt = PromptTemplate.from_template(other)
return self + prompt
else:
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "prompt"
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that template and input variables are consistent."""
if values["validate_template"]:
all_inputs = values["input_variables"] + list(values["partial_variables"])
check_valid_template(
values["template"], values["template_format"], all_inputs
)
elif values.get("template_format"):
values["input_variables"] = [
var
for var in get_template_variables(
values["template"], values["template_format"]
)
if var not in values["partial_variables"]
]
return values
@classmethod
def from_examples(
cls,
examples: List[str],
suffix: str,
input_variables: List[str],
example_separator: str = "\n\n",
prefix: str = "",
**kwargs: Any,
) -> PromptTemplate:
"""Take examples in list format with prefix and suffix to create a prompt.
Intended to be used as a way to dynamically create a prompt from examples.
Args:
examples: List of examples to use in the prompt.
suffix: String to go after the list of examples. Should generally
set up the user's input.
input_variables: A list of variable names the final prompt template
will expect.
example_separator: The separator to use in between examples. Defaults
to two new line characters.
prefix: String that should go before any examples. Generally includes
examples. Default to an empty string.
Returns:
The final prompt generated.
"""
template = example_separator.join([prefix, *examples, suffix])
return cls(input_variables=input_variables, template=template, **kwargs)
@classmethod
def from_file(
cls, template_file: Union[str, Path], input_variables: List[str], **kwargs: Any
) -> PromptTemplate:
"""Load a prompt from a file.
Args:
template_file: The path to the file containing the prompt template.
input_variables: A list of variable names the final prompt template
will expect.
Returns:
The prompt loaded from the file.
"""
with open(str(template_file), "r") as f:
template = f.read()
return cls(input_variables=input_variables, template=template, **kwargs)
@classmethod
def from_template(
cls,
template: str,
*,
template_format: str = "f-string",
partial_variables: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> PromptTemplate:
"""Load a prompt template from a template.
*Security warning*: Prefer using `template_format="f-string"` instead of
`template_format="jinja2"`, or make sure to NEVER accept jinja2 templates
from untrusted sources as they may lead to arbitrary Python code execution.
As of LangChain 0.0.329, Jinja2 templates will be rendered using
Jinja2's SandboxedEnvironment by default. This sand-boxing should
be treated as a best-effort approach rather than a guarantee of security,
as it is an opt-out rather than opt-in approach.
Despite the sand-boxing, we recommend to never use jinja2 templates
from untrusted sources.
Args:
template: The template to load.
template_format: The format of the template. Use `jinja2` for jinja2,
and `f-string` or None for f-strings.
partial_variables: A dictionary of variables that can be used to partially
fill in the template. For example, if the template is
`"{variable1} {variable2}"`, and `partial_variables` is
`{"variable1": "foo"}`, then the final prompt will be
`"foo {variable2}"`.
Returns:
The prompt template loaded from the template.
"""
input_variables = get_template_variables(template, template_format)
_partial_variables = partial_variables or {}
if _partial_variables:
input_variables = [
var for var in input_variables if var not in _partial_variables
]
return cls(
input_variables=input_variables,
template=template,
template_format=template_format,
partial_variables=_partial_variables,
**kwargs,
)

View File

@@ -0,0 +1,156 @@
"""BasePrompt schema definition."""
from __future__ import annotations
import warnings
from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict, List, Set
from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.utils.formatting import formatter
def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2.
*Security warning*: As of LangChain 0.0.329, this method uses Jinja2's
SandboxedEnvironment by default. However, this sand-boxing should
be treated as a best-effort approach rather than a guarantee of security.
Do not accept jinja2 templates from untrusted sources as they may lead
to arbitrary Python code execution.
https://jinja.palletsprojects.com/en/3.1.x/sandbox/
"""
try:
from jinja2.sandbox import SandboxedEnvironment
except ImportError:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
"Please be cautious when using jinja2 templates. "
"Do not expand jinja2 templates using unverified or user-controlled "
"inputs as that can result in arbitrary Python code execution."
)
# This uses a sandboxed environment to prevent arbitrary code execution.
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
# Please treat this sand-boxing as a best-effort approach rather than
# a guarantee of security.
# We recommend to never use jinja2 templates with untrusted inputs.
# https://jinja.palletsprojects.com/en/3.1.x/sandbox/
# approach not a guarantee of security.
return SandboxedEnvironment().from_string(template).render(**kwargs)
def validate_jinja2(template: str, input_variables: List[str]) -> None:
"""
Validate that the input variables are valid for the template.
Issues a warning if missing or extra variables are found.
Args:
template: The template string.
input_variables: The input variables.
"""
input_variables_set = set(input_variables)
valid_variables = _get_jinja2_variables_from_template(template)
missing_variables = valid_variables - input_variables_set
extra_variables = input_variables_set - valid_variables
warning_message = ""
if missing_variables:
warning_message += f"Missing variables: {missing_variables} "
if extra_variables:
warning_message += f"Extra variables: {extra_variables}"
if warning_message:
warnings.warn(warning_message.strip())
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
try:
from jinja2 import Environment, meta
except ImportError:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
)
env = Environment()
ast = env.parse(template)
variables = meta.find_undeclared_variables(ast)
return variables
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": formatter.format,
"jinja2": jinja2_formatter,
}
DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
"f-string": formatter.validate_input_variables,
"jinja2": validate_jinja2,
}
def check_valid_template(
template: str, template_format: str, input_variables: List[str]
) -> None:
"""Check that template string is valid.
Args:
template: The template string.
template_format: The template format. Should be one of "f-string" or "jinja2".
input_variables: The input variables.
Raises:
ValueError: If the template format is not supported.
"""
if template_format not in DEFAULT_FORMATTER_MAPPING:
valid_formats = list(DEFAULT_FORMATTER_MAPPING)
raise ValueError(
f"Invalid template format. Got `{template_format}`;"
f" should be one of {valid_formats}"
)
try:
validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
validator_func(template, input_variables)
except KeyError as e:
raise ValueError(
"Invalid prompt schema; check for mismatched or missing input parameters. "
+ str(e)
)
def get_template_variables(template: str, template_format: str) -> List[str]:
"""Get the variables from the template.
Args:
template: The template string.
template_format: The template format. Should be one of "f-string" or "jinja2".
Returns:
The variables from the template.
Raises:
ValueError: If the template format is not supported.
"""
if template_format == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
elif template_format == "f-string":
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
else:
raise ValueError(f"Unsupported template format: {template_format}")
return sorted(input_variables)
class StringPromptTemplate(BasePromptTemplate, ABC):
"""String prompt that exposes the format method, returning a prompt."""
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
return StringPromptValue(text=self.format(**kwargs))

View File

@@ -0,0 +1,23 @@
from importlib import metadata
## Create namespaces for pydantic v1 and v2.
# This code must stay at the top of the file before other modules may
# attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules.
#
# This hack is done for the following reasons:
# * Langchain will attempt to remain compatible with both pydantic v1 and v2 since
# both dependencies and dependents may be stuck on either version of v1 or v2.
# * Creating namespaces for pydantic v1 and v2 should allow us to write code that
# unambiguously uses either v1 or v2 API.
# * This change is easier to roll out and roll back.
try:
from pydantic.v1 import * # noqa: F403 # type: ignore
except ImportError:
from pydantic import * # noqa: F403 # type: ignore
try:
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
except metadata.PackageNotFoundError:
_PYDANTIC_MAJOR_VERSION = 0

View File

@@ -0,0 +1,4 @@
try:
from pydantic.v1.dataclasses import * # noqa: F403
except ImportError:
from pydantic.dataclasses import * # noqa: F403

View File

@@ -0,0 +1,4 @@
try:
from pydantic.v1.main import * # noqa: F403
except ImportError:
from pydantic.main import * # noqa: F403

View File

@@ -0,0 +1,275 @@
from __future__ import annotations
import asyncio
import warnings
from abc import ABC, abstractmethod
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.documents import Document
from langchain_core.load.dump import dumpd
from langchain_core.runnables import RunnableConfig, RunnableSerializable
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
Callbacks,
)
class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
"""Abstract base class for a Document retrieval system.
A retrieval system is defined as something that can take string queries and return
the most 'relevant' Documents from some source.
Example:
.. code-block:: python
class TFIDFRetriever(BaseRetriever, BaseModel):
vectorizer: Any
docs: List[Document]
tfidf_array: Any
k: int = 4
class Config:
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
from sklearn.metrics.pairwise import cosine_similarity
# Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
query_vec = self.vectorizer.transform([query])
# Op -- (n_docs,1) -- Cosine Sim with each doc
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
""" # noqa: E501
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
_new_arg_supported: bool = False
_expects_other_args: bool = False
tags: Optional[List[str]] = None
"""Optional list of tags associated with the retriever. Defaults to None
These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a retriever with its
use case.
"""
metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the retriever. Defaults to None
This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a retriever with its
use case.
"""
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
# Version upgrade for old retrievers that implemented the public
# methods directly.
if cls.get_relevant_documents != BaseRetriever.get_relevant_documents:
warnings.warn(
"Retrievers must implement abstract `_get_relevant_documents` method"
" instead of `get_relevant_documents`",
DeprecationWarning,
)
swap = cls.get_relevant_documents
cls.get_relevant_documents = ( # type: ignore[assignment]
BaseRetriever.get_relevant_documents
)
cls._get_relevant_documents = swap # type: ignore[assignment]
if (
hasattr(cls, "aget_relevant_documents")
and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents
):
warnings.warn(
"Retrievers must implement abstract `_aget_relevant_documents` method"
" instead of `aget_relevant_documents`",
DeprecationWarning,
)
aswap = cls.aget_relevant_documents
cls.aget_relevant_documents = ( # type: ignore[assignment]
BaseRetriever.aget_relevant_documents
)
cls._aget_relevant_documents = aswap # type: ignore[assignment]
parameters = signature(cls._get_relevant_documents).parameters
cls._new_arg_supported = parameters.get("run_manager") is not None
# If a V1 retriever broke the interface and expects additional arguments
cls._expects_other_args = (
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
)
def invoke(
self, input: str, config: Optional[RunnableConfig] = None
) -> List[Document]:
config = config or {}
return self.get_relevant_documents(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
)
async def ainvoke(
self,
input: str,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> List[Document]:
config = config or {}
return await self.aget_relevant_documents(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
)
@abstractmethod
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._get_relevant_documents, run_manager=run_manager), query
)
def get_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Retrieve documents relevant to a query.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
tags: Optional list of tags associated with the retriever. Defaults to None
These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
metadata: Optional metadata associated with the retriever. Defaults to None
This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
Returns:
List of relevant documents
"""
from langchain_core.callbacks.manager import CallbackManager
callback_manager = CallbackManager.configure(
callbacks,
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=tags,
local_tags=self.tags,
inheritable_metadata=metadata,
local_metadata=self.metadata,
)
run_manager = callback_manager.on_retriever_start(
dumpd(self),
query,
name=run_name,
**kwargs,
)
try:
_kwargs = kwargs if self._expects_other_args else {}
if self._new_arg_supported:
result = self._get_relevant_documents(
query, run_manager=run_manager, **_kwargs
)
else:
result = self._get_relevant_documents(query, **_kwargs)
except Exception as e:
run_manager.on_retriever_error(e)
raise e
else:
run_manager.on_retriever_end(
result,
**kwargs,
)
return result
async def aget_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
tags: Optional list of tags associated with the retriever. Defaults to None
These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
metadata: Optional metadata associated with the retriever. Defaults to None
This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
Returns:
List of relevant documents
"""
from langchain_core.callbacks.manager import AsyncCallbackManager
callback_manager = AsyncCallbackManager.configure(
callbacks,
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=tags,
local_tags=self.tags,
inheritable_metadata=metadata,
local_metadata=self.metadata,
)
run_manager = await callback_manager.on_retriever_start(
dumpd(self),
query,
name=run_name,
**kwargs,
)
try:
_kwargs = kwargs if self._expects_other_args else {}
if self._new_arg_supported:
result = await self._aget_relevant_documents(
query, run_manager=run_manager, **_kwargs
)
else:
result = await self._aget_relevant_documents(query, **_kwargs)
except Exception as e:
await run_manager.on_retriever_error(e)
raise e
else:
await run_manager.on_retriever_end(
result,
**kwargs,
)
return result

Some files were not shown because too many files have changed in this diff Show More