mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-04 16:20:16 +00:00
Compare commits
80 Commits
v0.0.337
...
bagatur/fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
63f380d349 | ||
|
|
9c55d65606 | ||
|
|
fabebb2042 | ||
|
|
b2b8122249 | ||
|
|
5c6973a1b5 | ||
|
|
4449481cd3 | ||
|
|
70f6be32e0 | ||
|
|
1548f32f3d | ||
|
|
6c59db482b | ||
|
|
32d087fcb8 | ||
|
|
14d4fb98fc | ||
|
|
5b90fe5b1c | ||
|
|
16af282429 | ||
|
|
78da34153e | ||
|
|
e327bb4ba4 | ||
|
|
d47ee1ae79 | ||
|
|
a21e84faf7 | ||
|
|
ace9e64d62 | ||
|
|
5064890fcf | ||
|
|
143049c90f | ||
|
|
c5ae9f832d | ||
|
|
131db4ba68 | ||
|
|
04bddbaba4 | ||
|
|
aec8715073 | ||
|
|
bb18b0266e | ||
|
|
dc53523837 | ||
|
|
a208abe6b7 | ||
|
|
083afba697 | ||
|
|
c61e30632e | ||
|
|
59df16ab92 | ||
|
|
bfb980b968 | ||
|
|
d65c36d60a | ||
|
|
249c796785 | ||
|
|
c6937a2eb4 | ||
|
|
11614700a4 | ||
|
|
d32e511826 | ||
|
|
17c6551c18 | ||
|
|
8329f81072 | ||
|
|
611e1e0ca4 | ||
|
|
99b4f46cbe | ||
|
|
d82cbf5e76 | ||
|
|
4eec47b191 | ||
|
|
e620347a83 | ||
|
|
52e23e50b1 | ||
|
|
1c08dbfb33 | ||
|
|
f3fcdea574 | ||
|
|
b6f70d776b | ||
|
|
fe7b40cb2a | ||
|
|
10418ab0c1 | ||
|
|
190952fe76 | ||
|
|
674bd90a47 | ||
|
|
df03267edf | ||
|
|
ef7802b325 | ||
|
|
a93616e972 | ||
|
|
6bf9b2cb51 | ||
|
|
e53f59f01a | ||
|
|
16f7912e1b | ||
|
|
43972be632 | ||
|
|
8362bd729b | ||
|
|
7100d586ef | ||
|
|
ad0c3b9479 | ||
|
|
69d39e2173 | ||
|
|
6bc08266e0 | ||
|
|
325bdac673 | ||
|
|
47451764a7 | ||
|
|
420a17542d | ||
|
|
cc50e023d1 | ||
|
|
02a13030c0 | ||
|
|
78a1f4b264 | ||
|
|
790ed8be69 | ||
|
|
f4c0e3cc15 | ||
|
|
43dad6cb91 | ||
|
|
ff382b7b1b | ||
|
|
cda1b33270 | ||
|
|
cac849ae86 | ||
|
|
79ed66f870 | ||
|
|
c56faa6ef1 | ||
|
|
0fb5f857f9 | ||
|
|
d2335d0114 | ||
|
|
5a28dc3210 |
12
.github/workflows/_compile_integration_test.yml
vendored
12
.github/workflows/_compile_integration_test.yml
vendored
@@ -7,6 +7,10 @@ on:
|
||||
required: true
|
||||
type: string
|
||||
description: "From which folder this pipeline executes"
|
||||
langchain-core-location:
|
||||
required: false
|
||||
type: string
|
||||
description: "Relative path to the langchain core library folder"
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.6.1"
|
||||
@@ -40,6 +44,14 @@ jobs:
|
||||
shell: bash
|
||||
run: poetry install --with=test_integration
|
||||
|
||||
- name: Install langchain core editable
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
if: ${{ inputs.langchain-core-location }}
|
||||
env:
|
||||
LANGCHAIN_CORE_LOCATION: ${{ inputs.langchain-core-location }}
|
||||
run: |
|
||||
poetry run pip install -e "$LANGCHAIN_CORE_LOCATION"
|
||||
|
||||
- name: Check integration tests compile
|
||||
shell: bash
|
||||
run: poetry run pytest -m compile tests/integration_tests
|
||||
|
||||
14
.github/workflows/_lint.yml
vendored
14
.github/workflows/_lint.yml
vendored
@@ -11,6 +11,10 @@ on:
|
||||
required: false
|
||||
type: string
|
||||
description: "Relative path to the langchain library folder"
|
||||
langchain-core-location:
|
||||
required: false
|
||||
type: string
|
||||
description: "Relative path to the langchain core library folder"
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.6.1"
|
||||
@@ -76,7 +80,15 @@ jobs:
|
||||
env:
|
||||
LANGCHAIN_LOCATION: ${{ inputs.langchain-location }}
|
||||
run: |
|
||||
pip install -e "$LANGCHAIN_LOCATION"
|
||||
poetry run pip install -e "$LANGCHAIN_LOCATION"
|
||||
|
||||
- name: Install langchain core editable
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
if: ${{ inputs.langchain-core-location }}
|
||||
env:
|
||||
LANGCHAIN_CORE_LOCATION: ${{ inputs.langchain-core-location }}
|
||||
run: |
|
||||
poetry run pip install -e "$LANGCHAIN_CORE_LOCATION"
|
||||
|
||||
- name: Get .mypy_cache to speed up mypy
|
||||
uses: actions/cache@v3
|
||||
|
||||
24
.github/workflows/_pydantic_compatibility.yml
vendored
24
.github/workflows/_pydantic_compatibility.yml
vendored
@@ -7,6 +7,14 @@ on:
|
||||
required: true
|
||||
type: string
|
||||
description: "From which folder this pipeline executes"
|
||||
langchain-location:
|
||||
required: false
|
||||
type: string
|
||||
description: "Relative path to the langchain library folder"
|
||||
langchain-core-location:
|
||||
required: false
|
||||
type: string
|
||||
description: "Relative path to the langchain core library folder"
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.6.1"
|
||||
@@ -40,6 +48,22 @@ jobs:
|
||||
shell: bash
|
||||
run: poetry install
|
||||
|
||||
- name: Install langchain editable
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
if: ${{ inputs.langchain-location }}
|
||||
env:
|
||||
LANGCHAIN_LOCATION: ${{ inputs.langchain-location }}
|
||||
run: |
|
||||
poetry run pip install -e "$LANGCHAIN_LOCATION"
|
||||
|
||||
- name: Install langchain core editable
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
if: ${{ inputs.langchain-core-location }}
|
||||
env:
|
||||
LANGCHAIN_CORE_LOCATION: ${{ inputs.langchain-core-location }}
|
||||
run: |
|
||||
poetry run pip install -e "$LANGCHAIN_CORE_LOCATION"
|
||||
|
||||
- name: Install the opposite major version of pydantic
|
||||
# If normal tests use pydantic v1, here we'll use v2, and vice versa.
|
||||
shell: bash
|
||||
|
||||
27
.github/workflows/_test.yml
vendored
27
.github/workflows/_test.yml
vendored
@@ -7,6 +7,14 @@ on:
|
||||
required: true
|
||||
type: string
|
||||
description: "From which folder this pipeline executes"
|
||||
langchain-location:
|
||||
required: false
|
||||
type: string
|
||||
description: "Relative path to the langchain library folder"
|
||||
langchain-core-location:
|
||||
required: false
|
||||
type: string
|
||||
description: "Relative path to the langchain core library folder"
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.6.1"
|
||||
@@ -40,9 +48,26 @@ jobs:
|
||||
shell: bash
|
||||
run: poetry install
|
||||
|
||||
- name: Install langchain editable
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
if: ${{ inputs.langchain-location }}
|
||||
env:
|
||||
LANGCHAIN_LOCATION: ${{ inputs.langchain-location }}
|
||||
run: |
|
||||
poetry run pip install -e "$LANGCHAIN_LOCATION"
|
||||
|
||||
- name: Install langchain core editable
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
if: ${{ inputs.langchain-core-location }}
|
||||
env:
|
||||
LANGCHAIN_CORE_LOCATION: ${{ inputs.langchain-core-location }}
|
||||
run: |
|
||||
poetry run pip install -e "$LANGCHAIN_CORE_LOCATION"
|
||||
|
||||
- name: Run core tests
|
||||
shell: bash
|
||||
run: make test
|
||||
run: |
|
||||
make test
|
||||
|
||||
- name: Ensure the tests did not create any additional files
|
||||
shell: bash
|
||||
|
||||
49
.github/workflows/langchain_ci.yml
vendored
49
.github/workflows/langchain_ci.yml
vendored
@@ -36,6 +36,7 @@ jobs:
|
||||
./.github/workflows/_lint.yml
|
||||
with:
|
||||
working-directory: libs/langchain
|
||||
langchain-core-location: ../core
|
||||
secrets: inherit
|
||||
|
||||
test:
|
||||
@@ -43,6 +44,7 @@ jobs:
|
||||
./.github/workflows/_test.yml
|
||||
with:
|
||||
working-directory: libs/langchain
|
||||
langchain-core-location: ../core
|
||||
secrets: inherit
|
||||
|
||||
compile-integration-tests:
|
||||
@@ -50,6 +52,7 @@ jobs:
|
||||
./.github/workflows/_compile_integration_test.yml
|
||||
with:
|
||||
working-directory: libs/langchain
|
||||
langchain-core-location: ../core
|
||||
secrets: inherit
|
||||
|
||||
pydantic-compatibility:
|
||||
@@ -57,8 +60,49 @@ jobs:
|
||||
./.github/workflows/_pydantic_compatibility.yml
|
||||
with:
|
||||
working-directory: libs/langchain
|
||||
langchain-core-location: ../core
|
||||
secrets: inherit
|
||||
|
||||
# It's possible that langchain works fine with the latest *published* langchain-core,
|
||||
# but is broken with the langchain-core on `master`.
|
||||
#
|
||||
# We want to catch situations like that *before* releasing a new langchain-core, hence this test.
|
||||
test-with-latest-langchain-core:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ${{ env.WORKDIR }}
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.8"
|
||||
- "3.9"
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
name: test with unpublished langchain-core - Python ${{ matrix.python-version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
|
||||
uses: "./.github/actions/poetry_setup"
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
poetry-version: ${{ env.POETRY_VERSION }}
|
||||
working-directory: ${{ env.WORKDIR }}
|
||||
cache-key: unpublished-langchain-core
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
echo "Running tests with unpublished langchain, installing dependencies with poetry..."
|
||||
poetry install
|
||||
|
||||
echo "Editably installing langchain-core outside of poetry, to avoid messing up lockfile..."
|
||||
poetry run pip install -e ../core
|
||||
|
||||
- name: Run tests
|
||||
run: make test
|
||||
|
||||
extended-tests:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
@@ -89,6 +133,11 @@ jobs:
|
||||
echo "Running extended tests, installing dependencies with poetry..."
|
||||
poetry install -E extended_testing
|
||||
|
||||
- name: Install langchain core editable
|
||||
shell: bash
|
||||
run: |
|
||||
poetry run pip install -e ../core
|
||||
|
||||
- name: Run extended tests
|
||||
run: make extended_tests
|
||||
|
||||
|
||||
52
.github/workflows/langchain_core_ci.yml
vendored
Normal file
52
.github/workflows/langchain_core_ci.yml
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
---
|
||||
name: libs/langchain core CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
pull_request:
|
||||
paths:
|
||||
- '.github/actions/poetry_setup/action.yml'
|
||||
- '.github/tools/**'
|
||||
- '.github/workflows/_lint.yml'
|
||||
- '.github/workflows/_test.yml'
|
||||
- '.github/workflows/_pydantic_compatibility.yml'
|
||||
- '.github/workflows/langchain_core_ci.yml'
|
||||
- 'libs/core/**'
|
||||
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
|
||||
|
||||
# If another push to the same PR or branch happens while this workflow is still running,
|
||||
# cancel the earlier run in favor of the next run.
|
||||
#
|
||||
# There's no point in testing an outdated version of the code. GitHub only allows
|
||||
# a limited number of job runners to be active at the same time, so it's better to cancel
|
||||
# pointless jobs early so that more useful jobs can run sooner.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.6.1"
|
||||
WORKDIR: "libs/core"
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
uses:
|
||||
./.github/workflows/_lint.yml
|
||||
with:
|
||||
working-directory: libs/core
|
||||
secrets: inherit
|
||||
|
||||
test:
|
||||
uses:
|
||||
./.github/workflows/_test.yml
|
||||
with:
|
||||
working-directory: libs/core
|
||||
secrets: inherit
|
||||
|
||||
pydantic-compatibility:
|
||||
uses:
|
||||
./.github/workflows/_pydantic_compatibility.yml
|
||||
with:
|
||||
working-directory: libs/core
|
||||
secrets: inherit
|
||||
13
.github/workflows/langchain_core_release.yml
vendored
Normal file
13
.github/workflows/langchain_core_release.yml
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
---
|
||||
name: libs/core Release
|
||||
|
||||
on:
|
||||
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
|
||||
|
||||
jobs:
|
||||
release:
|
||||
uses:
|
||||
./.github/workflows/_release.yml
|
||||
with:
|
||||
working-directory: libs/core
|
||||
secrets: inherit
|
||||
@@ -36,6 +36,7 @@ jobs:
|
||||
with:
|
||||
working-directory: libs/experimental
|
||||
langchain-location: ../langchain
|
||||
langchain-core-location: ../core
|
||||
secrets: inherit
|
||||
|
||||
test:
|
||||
@@ -43,6 +44,8 @@ jobs:
|
||||
./.github/workflows/_test.yml
|
||||
with:
|
||||
working-directory: libs/experimental
|
||||
langchain-location: ../langchain
|
||||
langchain-core-location: ../core
|
||||
secrets: inherit
|
||||
|
||||
compile-integration-tests:
|
||||
@@ -88,6 +91,7 @@ jobs:
|
||||
|
||||
echo "Editably installing langchain outside of poetry, to avoid messing up lockfile..."
|
||||
poetry run pip install -e ../langchain
|
||||
poetry run pip install -e ../core
|
||||
|
||||
- name: Run tests
|
||||
run: make test
|
||||
|
||||
@@ -648,7 +648,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"OpenAIEmbeddings(client=<class 'openai.api_resources.embedding.Embedding'>, model='text-embedding-ada-002', deployment='text-embedding-ada-002', openai_api_version='', openai_api_base='', openai_api_type='', openai_proxy='', embedding_ctx_length=8191, openai_api_key='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', openai_organization='', allowed_special=set(), disallowed_special='all', chunk_size=1000, max_retries=6, request_timeout=None, headers=None, tiktoken_model_name=None, show_progress_bar=False, model_kwargs={})"
|
||||
"OpenAIEmbeddings(client=<class 'openai.api_resources.embedding.Embedding'>, model='text-embedding-ada-002', deployment='text-embedding-ada-002', openai_api_version='', openai_api_base='', openai_api_type='', openai_proxy='', embedding_ctx_length=8191, openai_api_key='', openai_organization='', allowed_special=set(), disallowed_special='all', chunk_size=1000, max_retries=6, request_timeout=None, headers=None, tiktoken_model_name=None, show_progress_bar=False, model_kwargs={})"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
|
||||
993
cookbook/docugami_xml_kg_rag.ipynb
Normal file
993
cookbook/docugami_xml_kg_rag.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -69,8 +69,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains.llm_bash.prompt import BashOutputParser\n",
|
||||
"from langchain.prompts.prompt import PromptTemplate\n",
|
||||
"from langchain_experimental.llm_bash.prompt import BashOutputParser\n",
|
||||
"\n",
|
||||
"_PROMPT_TEMPLATE = \"\"\"If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put \"#!/bin/bash\" in your answer. Make sure to reason step by step, using this format:\n",
|
||||
"Question: \"copy the files in the directory named 'target' into a new directory at the same level as target called 'myNewDirectory'\"\n",
|
||||
|
||||
@@ -13,8 +13,10 @@ HERE = Path(__file__).parent
|
||||
|
||||
PKG_DIR = ROOT_DIR / "libs" / "langchain" / "langchain"
|
||||
EXP_DIR = ROOT_DIR / "libs" / "experimental" / "langchain_experimental"
|
||||
CORE_DIR = ROOT_DIR / "libs" / "core" / "langchain_core"
|
||||
WRITE_FILE = HERE / "api_reference.rst"
|
||||
EXP_WRITE_FILE = HERE / "experimental_api_reference.rst"
|
||||
CORE_WRITE_FILE = HERE / "core_api_reference.rst"
|
||||
|
||||
|
||||
ClassKind = Literal["TypedDict", "Regular", "Pydantic", "enum"]
|
||||
@@ -292,6 +294,17 @@ def _document_langchain_experimental() -> None:
|
||||
|
||||
|
||||
def _document_langchain_core() -> None:
|
||||
"""Document the langchain_core package."""
|
||||
# Generate core_api_reference.rst
|
||||
core_members = _load_package_modules(EXP_DIR)
|
||||
core_doc = ".. _core_api_reference:\n\n" + _construct_doc(
|
||||
"langchain_core", core_members
|
||||
)
|
||||
with open(CORE_WRITE_FILE, "w") as f:
|
||||
f.write(core_doc)
|
||||
|
||||
|
||||
def _document_langchain() -> None:
|
||||
"""Document the main langchain package."""
|
||||
# load top level module members
|
||||
lc_members = _load_package_modules(PKG_DIR)
|
||||
@@ -306,7 +319,6 @@ def _document_langchain_core() -> None:
|
||||
"agents.output_parsers": agents["output_parsers"],
|
||||
"agents.format_scratchpad": agents["format_scratchpad"],
|
||||
"tools.render": tools["render"],
|
||||
"schema.runnable": schema["runnable"],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -318,8 +330,9 @@ def _document_langchain_core() -> None:
|
||||
|
||||
def main() -> None:
|
||||
"""Generate the reference.rst file for each package."""
|
||||
_document_langchain_core()
|
||||
_document_langchain()
|
||||
_document_langchain_experimental()
|
||||
_document_langchain_core()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -34,6 +34,9 @@
|
||||
<li class="nav-item">
|
||||
<a class="sk-nav-link nav-link" href="{{ pathto('api_reference') }}">API</a>
|
||||
</li>
|
||||
<li class="nav-item">
|
||||
<a class="sk-nav-link nav-link" href="{{ pathto('core_api_reference') }}">Core</a>
|
||||
</li>
|
||||
<li class="nav-item">
|
||||
<a class="sk-nav-link nav-link" href="{{ pathto('experimental_api_reference') }}">Experimental</a>
|
||||
</li>
|
||||
|
||||
@@ -14,7 +14,7 @@ This framework consists of several parts.
|
||||
- **[LangServe](/docs/langserve)**: A library for deploying LangChain chains as a REST API.
|
||||
- **[LangSmith](/docs/langsmith)**: A developer platform that lets you debug, test, evaluate, and monitor chains built on any LLM framework and seamlessly integrates with LangChain.
|
||||
|
||||

|
||||

|
||||
|
||||
Together, these products simplify the entire application lifecycle:
|
||||
- **Develop**: Write your applications in LangChain/LangChain.js. Hit the ground running using Templates for reference.
|
||||
|
||||
@@ -7,7 +7,9 @@
|
||||
"source": [
|
||||
"# Azure OpenAI\n",
|
||||
"\n",
|
||||
"This notebook goes over how to connect to an Azure hosted OpenAI endpoint. We recommend having version `openai>=1` installed."
|
||||
">[Azure OpenAI Service](https://learn.microsoft.com/en-us/azure/ai-services/openai/overview) provides REST API access to OpenAI's powerful language models including the GPT-4, GPT-3.5-Turbo, and Embeddings model series. These models can be easily adapted to your specific task including but not limited to content generation, summarization, semantic search, and natural language to code translation. Users can access the service through REST APIs, Python SDK, or a web-based interface in the Azure OpenAI Studio.\n",
|
||||
"\n",
|
||||
"This notebook goes over how to connect to an Azure-hosted OpenAI endpoint. We recommend having version `openai>=1` installed."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -162,7 +164,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -4,11 +4,13 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# AzureML Chat Online Endpoint\n",
|
||||
"# Azure ML Endpoint\n",
|
||||
"\n",
|
||||
"[AzureML](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides Azure Foundation Models and OpenAI Models. Azure Foundation Models include various open-source models and popular Hugging Face models. Users can also import models of their liking into AzureML.\n",
|
||||
">[Azure Machine Learning](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides Azure Foundation Models and OpenAI Models. `Azure Foundation Models` include various open-source models and popular Hugging Face models. Users can also import models of their liking into AzureML.\n",
|
||||
">\n",
|
||||
">[Azure Machine Learning Online Endpoints](https://learn.microsoft.com/en-us/azure/machine-learning/concept-endpoints). After you train machine learning models or pipelines, you need to deploy them to production so that others can use them for inference. Inference is the process of applying new input data to the machine learning model or pipeline to generate outputs. While these outputs are typically referred to as \"predictions,\" inferencing can be used to generate outputs for other machine learning tasks, such as classification and clustering. In `Azure Machine Learning`, you perform inferencing by using endpoints and deployments. `Endpoints` and `Deployments` allow you to decouple the interface of your production workload from the implementation that serves it.\n",
|
||||
"\n",
|
||||
"This notebook goes over how to use a chat model hosted on an `AzureML online endpoint`"
|
||||
"This notebook goes over how to use a chat model hosted on an `Azure Machine Learning Endpoint`."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -91,7 +93,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat = ChatHunyuan(\n",
|
||||
" hunyuan_app_id=\"YOUR_APP_ID\",\n",
|
||||
" hunyuan_app_id=111111111,\n",
|
||||
" hunyuan_secret_id=\"YOUR_SECRET_ID\",\n",
|
||||
" hunyuan_secret_key=\"YOUR_SECRET_KEY\",\n",
|
||||
")"
|
||||
|
||||
729
docs/docs/integrations/chat/llama2_chat.ipynb
Normal file
729
docs/docs/integrations/chat/llama2_chat.ipynb
Normal file
@@ -0,0 +1,729 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "90a1faf2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Llama-2 Chat\n",
|
||||
"\n",
|
||||
"This notebook shows how to augment Llama-2 `LLM`s with the `Llama2Chat` wrapper to support the [Llama-2 chat prompt format](https://huggingface.co/blog/llama2#how-to-prompt-llama-2). Several `LLM` implementations in LangChain can be used as interface to Llama-2 chat models. These include [HuggingFaceTextGenInference](https://python.langchain.com/docs/integrations/llms/huggingface_textgen_inference), [LlamaCpp](https://python.langchain.com/docs/use_cases/question_answering/how_to/local_retrieval_qa), [GPT4All](https://python.langchain.com/docs/integrations/llms/gpt4all), ..., to mention a few examples. \n",
|
||||
"\n",
|
||||
"`Llama2Chat` is a generic wrapper that implements `BaseChatModel` and can therefore be used in applications as [chat model](https://python.langchain.com/docs/modules/model_io/models/chat/). `Llama2Chat` converts a list of [chat messages](https://python.langchain.com/docs/modules/model_io/models/chat/#messages) into the [required chat prompt format](https://huggingface.co/blog/llama2#how-to-prompt-llama-2) and forwards the formatted prompt as `str` to the wrapped `LLM`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "36c03540",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain.memory import ConversationBufferMemory\n",
|
||||
"from langchain_experimental.chat_models import Llama2Chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5c76910f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"For the chat application examples below, we'll use the following chat `prompt_template`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "9bbfaf3a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts.chat import (\n",
|
||||
" ChatPromptTemplate,\n",
|
||||
" HumanMessagePromptTemplate,\n",
|
||||
" MessagesPlaceholder,\n",
|
||||
")\n",
|
||||
"from langchain.schema import SystemMessage\n",
|
||||
"\n",
|
||||
"template_messages = [\n",
|
||||
" SystemMessage(content=\"You are a helpful assistant.\"),\n",
|
||||
" MessagesPlaceholder(variable_name=\"chat_history\"),\n",
|
||||
" HumanMessagePromptTemplate.from_template(\"{text}\"),\n",
|
||||
"]\n",
|
||||
"prompt_template = ChatPromptTemplate.from_messages(template_messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2f3343b7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Chat with Llama-2 via `HuggingFaceTextGenInference` LLM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2ff99380",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"A [HuggingFaceTextGenInference](https://python.langchain.com/docs/integrations/llms/huggingface_textgen_inference) LLM encapsulates access to a [text-generation-inference](https://github.com/huggingface/text-generation-inference) server. In the following example, the inference server serves a [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) model. It can be started locally with:\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"docker run \\\n",
|
||||
" --rm \\\n",
|
||||
" --gpus all \\\n",
|
||||
" --ipc=host \\\n",
|
||||
" -p 8080:80 \\\n",
|
||||
" -v ~/.cache/huggingface/hub:/data \\\n",
|
||||
" -e HF_API_TOKEN=${HF_API_TOKEN} \\\n",
|
||||
" ghcr.io/huggingface/text-generation-inference:0.9 \\\n",
|
||||
" --hostname 0.0.0.0 \\\n",
|
||||
" --model-id meta-llama/Llama-2-13b-chat-hf \\\n",
|
||||
" --quantize bitsandbytes \\\n",
|
||||
" --num-shard 4\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"This works on a machine with 4 x RTX 3080ti cards, for example. Adjust the `--num_shard` value to the number of GPUs available. The `HF_API_TOKEN` environment variable holds the Hugging Face API token."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "238095fd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# !pip3 install text-generation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "79c4ace9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Create a `HuggingFaceTextGenInference` instance that connects to the local inference server and wrap it into `Llama2Chat`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "7a9f6de2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import HuggingFaceTextGenInference\n",
|
||||
"\n",
|
||||
"llm = HuggingFaceTextGenInference(\n",
|
||||
" inference_server_url=\"http://127.0.0.1:8080/\",\n",
|
||||
" max_new_tokens=512,\n",
|
||||
" top_k=50,\n",
|
||||
" temperature=0.1,\n",
|
||||
" repetition_penalty=1.03,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"model = Llama2Chat(llm=llm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4f646a2b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Then you are ready to use the chat `model` together with `prompt_template` and conversation `memory` in an `LLMChain`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "54b5d1d1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)\n",
|
||||
"chain = LLMChain(llm=model, prompt=prompt_template, memory=memory)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "e6717947",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" Sure, I'd be happy to help! Here are a few popular locations to consider visiting in Vienna:\n",
|
||||
"\n",
|
||||
"1. Schönbrunn Palace\n",
|
||||
"2. St. Stephen's Cathedral\n",
|
||||
"3. Hofburg Palace\n",
|
||||
"4. Belvedere Palace\n",
|
||||
"5. Prater Park\n",
|
||||
"6. Vienna State Opera\n",
|
||||
"7. Albertina Museum\n",
|
||||
"8. Museum of Natural History\n",
|
||||
"9. Kunsthistorisches Museum\n",
|
||||
"10. Ringstrasse\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\n",
|
||||
" chain.run(\n",
|
||||
" text=\"What can I see in Vienna? Propose a few locations. Names only, no details.\"\n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "17bf10d5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" Certainly! St. Stephen's Cathedral (Stephansdom) is one of the most recognizable landmarks in Vienna and a must-see attraction for visitors. This stunning Gothic cathedral is located in the heart of the city and is known for its intricate stone carvings, colorful stained glass windows, and impressive dome.\n",
|
||||
"\n",
|
||||
"The cathedral was built in the 12th century and has been the site of many important events throughout history, including the coronation of Holy Roman emperors and the funeral of Mozart. Today, it is still an active place of worship and offers guided tours, concerts, and special events. Visitors can climb up the south tower for panoramic views of the city or attend a service to experience the beautiful music and chanting.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(chain.run(text=\"Tell me more about #2.\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2a297e09",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Chat with Llama-2 via `LlamaCPP` LLM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "52c1a0b9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"For using a Llama-2 chat model with a [LlamaCPP](https://python.langchain.com/docs/integrations/llms/llamacpp) `LMM`, install the `llama-cpp-python` library using [these installation instructions](https://python.langchain.com/docs/integrations/llms/llamacpp#installation). The following example uses a quantized [llama-2-7b-chat.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_0.gguf) model stored locally at `~/Models/llama-2-7b-chat.Q4_0.gguf`. \n",
|
||||
"\n",
|
||||
"After creating a `LlamaCpp` instance, the `llm` is again wrapped into `Llama2Chat`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "07c0d04e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"llama_model_loader: loaded meta data with 19 key-value pairs and 291 tensors from /home/martin/Models/llama-2-7b-chat.Q4_0.gguf (version GGUF V2)\n",
|
||||
"llama_model_loader: - tensor 0: token_embd.weight q4_0 [ 4096, 32000, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 1: blk.0.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 2: blk.0.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 3: blk.0.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 4: blk.0.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 5: blk.0.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 6: blk.0.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 7: blk.0.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 8: blk.0.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 9: blk.0.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 10: blk.1.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 11: blk.1.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 12: blk.1.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 13: blk.1.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 14: blk.1.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 15: blk.1.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 16: blk.1.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 17: blk.1.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 18: blk.1.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 19: blk.10.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 20: blk.10.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 21: blk.10.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 22: blk.10.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 23: blk.10.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 24: blk.10.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 25: blk.10.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 26: blk.10.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 27: blk.10.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 28: blk.11.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 29: blk.11.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 30: blk.11.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 31: blk.11.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 32: blk.11.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 33: blk.11.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 34: blk.11.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 35: blk.11.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 36: blk.11.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 37: blk.12.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 38: blk.12.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 39: blk.12.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 40: blk.12.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 41: blk.12.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 42: blk.12.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 43: blk.12.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 44: blk.12.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 45: blk.12.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 46: blk.13.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 47: blk.13.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 48: blk.13.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 49: blk.13.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 50: blk.13.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 51: blk.13.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 52: blk.13.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 53: blk.13.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 54: blk.13.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 55: blk.14.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 56: blk.14.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 57: blk.14.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 58: blk.14.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 59: blk.14.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 60: blk.14.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 61: blk.14.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 62: blk.14.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 63: blk.14.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 64: blk.15.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 65: blk.15.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 66: blk.15.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 67: blk.15.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 68: blk.15.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 69: blk.15.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 70: blk.15.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 71: blk.15.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 72: blk.15.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 73: blk.16.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 74: blk.16.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 75: blk.16.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 76: blk.16.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 77: blk.16.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 78: blk.16.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 79: blk.16.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 80: blk.16.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 81: blk.16.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 82: blk.17.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 83: blk.17.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 84: blk.17.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 85: blk.17.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 86: blk.17.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 87: blk.17.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 88: blk.17.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 89: blk.17.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 90: blk.17.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 91: blk.18.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 92: blk.18.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 93: blk.18.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 94: blk.18.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 95: blk.18.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 96: blk.18.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 97: blk.18.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 98: blk.18.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 99: blk.18.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 100: blk.19.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 101: blk.19.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 102: blk.19.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 103: blk.19.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 104: blk.19.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 105: blk.19.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 106: blk.19.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 107: blk.19.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 108: blk.19.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 109: blk.2.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 110: blk.2.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 111: blk.2.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 112: blk.2.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 113: blk.2.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 114: blk.2.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 115: blk.2.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 116: blk.2.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 117: blk.2.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 118: blk.20.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 119: blk.20.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 120: blk.20.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 121: blk.20.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 122: blk.20.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 123: blk.20.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 124: blk.20.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 125: blk.20.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 126: blk.20.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 127: blk.21.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 128: blk.21.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 129: blk.21.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 130: blk.21.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 131: blk.21.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 132: blk.21.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 133: blk.21.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 134: blk.21.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 135: blk.21.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 136: blk.22.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 137: blk.22.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 138: blk.22.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 139: blk.22.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 140: blk.22.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 141: blk.22.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 142: blk.22.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 143: blk.22.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 144: blk.22.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 145: blk.23.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 146: blk.23.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 147: blk.23.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 148: blk.23.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 149: blk.23.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 150: blk.23.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 151: blk.23.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 152: blk.23.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 153: blk.23.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 154: blk.3.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 155: blk.3.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 156: blk.3.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 157: blk.3.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 158: blk.3.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 159: blk.3.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 160: blk.3.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 161: blk.3.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 162: blk.3.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 163: blk.4.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 164: blk.4.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 165: blk.4.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 166: blk.4.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 167: blk.4.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 168: blk.4.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 169: blk.4.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 170: blk.4.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 171: blk.4.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 172: blk.5.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 173: blk.5.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 174: blk.5.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 175: blk.5.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 176: blk.5.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 177: blk.5.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 178: blk.5.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 179: blk.5.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 180: blk.5.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 181: blk.6.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 182: blk.6.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 183: blk.6.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 184: blk.6.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 185: blk.6.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 186: blk.6.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 187: blk.6.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 188: blk.6.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 189: blk.6.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 190: blk.7.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 191: blk.7.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 192: blk.7.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 193: blk.7.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 194: blk.7.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 195: blk.7.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 196: blk.7.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 197: blk.7.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 198: blk.7.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 199: blk.8.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 200: blk.8.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 201: blk.8.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 202: blk.8.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 203: blk.8.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 204: blk.8.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 205: blk.8.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 206: blk.8.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 207: blk.8.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 208: blk.9.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 209: blk.9.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 210: blk.9.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 211: blk.9.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 212: blk.9.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 213: blk.9.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 214: blk.9.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 215: blk.9.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 216: blk.9.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 217: output.weight q6_K [ 4096, 32000, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 218: blk.24.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 219: blk.24.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 220: blk.24.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 221: blk.24.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 222: blk.24.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 223: blk.24.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 224: blk.24.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 225: blk.24.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 226: blk.24.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 227: blk.25.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 228: blk.25.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 229: blk.25.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 230: blk.25.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 231: blk.25.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 232: blk.25.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 233: blk.25.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 234: blk.25.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 235: blk.25.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 236: blk.26.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 237: blk.26.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 238: blk.26.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 239: blk.26.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 240: blk.26.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 241: blk.26.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 242: blk.26.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 243: blk.26.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 244: blk.26.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 245: blk.27.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 246: blk.27.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 247: blk.27.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 248: blk.27.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 249: blk.27.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 250: blk.27.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 251: blk.27.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 252: blk.27.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 253: blk.27.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 254: blk.28.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 255: blk.28.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 256: blk.28.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 257: blk.28.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 258: blk.28.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 259: blk.28.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 260: blk.28.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 261: blk.28.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 262: blk.28.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 263: blk.29.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 264: blk.29.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 265: blk.29.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 266: blk.29.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 267: blk.29.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 268: blk.29.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 269: blk.29.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 270: blk.29.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 271: blk.29.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 272: blk.30.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 273: blk.30.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 274: blk.30.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 275: blk.30.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 276: blk.30.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 277: blk.30.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 278: blk.30.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 279: blk.30.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 280: blk.30.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 281: blk.31.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 282: blk.31.ffn_down.weight q4_0 [ 11008, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 283: blk.31.ffn_gate.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 284: blk.31.ffn_up.weight q4_0 [ 4096, 11008, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 285: blk.31.ffn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 286: blk.31.attn_k.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 287: blk.31.attn_output.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 288: blk.31.attn_q.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 289: blk.31.attn_v.weight q4_0 [ 4096, 4096, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 290: output_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - kv 0: general.architecture str \n",
|
||||
"llama_model_loader: - kv 1: general.name str \n",
|
||||
"llama_model_loader: - kv 2: llama.context_length u32 \n",
|
||||
"llama_model_loader: - kv 3: llama.embedding_length u32 \n",
|
||||
"llama_model_loader: - kv 4: llama.block_count u32 \n",
|
||||
"llama_model_loader: - kv 5: llama.feed_forward_length u32 \n",
|
||||
"llama_model_loader: - kv 6: llama.rope.dimension_count u32 \n",
|
||||
"llama_model_loader: - kv 7: llama.attention.head_count u32 \n",
|
||||
"llama_model_loader: - kv 8: llama.attention.head_count_kv u32 \n",
|
||||
"llama_model_loader: - kv 9: llama.attention.layer_norm_rms_epsilon f32 \n",
|
||||
"llama_model_loader: - kv 10: general.file_type u32 \n",
|
||||
"llama_model_loader: - kv 11: tokenizer.ggml.model str \n",
|
||||
"llama_model_loader: - kv 12: tokenizer.ggml.tokens arr \n",
|
||||
"llama_model_loader: - kv 13: tokenizer.ggml.scores arr \n",
|
||||
"llama_model_loader: - kv 14: tokenizer.ggml.token_type arr \n",
|
||||
"llama_model_loader: - kv 15: tokenizer.ggml.bos_token_id u32 \n",
|
||||
"llama_model_loader: - kv 16: tokenizer.ggml.eos_token_id u32 \n",
|
||||
"llama_model_loader: - kv 17: tokenizer.ggml.unknown_token_id u32 \n",
|
||||
"llama_model_loader: - kv 18: general.quantization_version u32 \n",
|
||||
"llama_model_loader: - type f32: 65 tensors\n",
|
||||
"llama_model_loader: - type q4_0: 225 tensors\n",
|
||||
"llama_model_loader: - type q6_K: 1 tensors\n",
|
||||
"llm_load_vocab: special tokens definition check successful ( 259/32000 ).\n",
|
||||
"llm_load_print_meta: format = GGUF V2\n",
|
||||
"llm_load_print_meta: arch = llama\n",
|
||||
"llm_load_print_meta: vocab type = SPM\n",
|
||||
"llm_load_print_meta: n_vocab = 32000\n",
|
||||
"llm_load_print_meta: n_merges = 0\n",
|
||||
"llm_load_print_meta: n_ctx_train = 4096\n",
|
||||
"llm_load_print_meta: n_embd = 4096\n",
|
||||
"llm_load_print_meta: n_head = 32\n",
|
||||
"llm_load_print_meta: n_head_kv = 32\n",
|
||||
"llm_load_print_meta: n_layer = 32\n",
|
||||
"llm_load_print_meta: n_rot = 128\n",
|
||||
"llm_load_print_meta: n_gqa = 1\n",
|
||||
"llm_load_print_meta: f_norm_eps = 0.0e+00\n",
|
||||
"llm_load_print_meta: f_norm_rms_eps = 1.0e-06\n",
|
||||
"llm_load_print_meta: f_clamp_kqv = 0.0e+00\n",
|
||||
"llm_load_print_meta: f_max_alibi_bias = 0.0e+00\n",
|
||||
"llm_load_print_meta: n_ff = 11008\n",
|
||||
"llm_load_print_meta: rope scaling = linear\n",
|
||||
"llm_load_print_meta: freq_base_train = 10000.0\n",
|
||||
"llm_load_print_meta: freq_scale_train = 1\n",
|
||||
"llm_load_print_meta: n_yarn_orig_ctx = 4096\n",
|
||||
"llm_load_print_meta: rope_finetuned = unknown\n",
|
||||
"llm_load_print_meta: model type = 7B\n",
|
||||
"llm_load_print_meta: model ftype = mostly Q4_0\n",
|
||||
"llm_load_print_meta: model params = 6.74 B\n",
|
||||
"llm_load_print_meta: model size = 3.56 GiB (4.54 BPW) \n",
|
||||
"llm_load_print_meta: general.name = LLaMA v2\n",
|
||||
"llm_load_print_meta: BOS token = 1 '<s>'\n",
|
||||
"llm_load_print_meta: EOS token = 2 '</s>'\n",
|
||||
"llm_load_print_meta: UNK token = 0 '<unk>'\n",
|
||||
"llm_load_print_meta: LF token = 13 '<0x0A>'\n",
|
||||
"llm_load_tensors: ggml ctx size = 0.11 MB\n",
|
||||
"llm_load_tensors: mem required = 3647.97 MB\n",
|
||||
"..................................................................................................\n",
|
||||
"llama_new_context_with_model: n_ctx = 512\n",
|
||||
"llama_new_context_with_model: freq_base = 10000.0\n",
|
||||
"llama_new_context_with_model: freq_scale = 1\n",
|
||||
"llama_new_context_with_model: kv self size = 256.00 MB\n",
|
||||
"llama_build_graph: non-view tensors processed: 740/740\n",
|
||||
"llama_new_context_with_model: compute buffer total size = 2.66 MB\n",
|
||||
"AVX = 1 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 0 | AVX512_VNNI = 1 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from os.path import expanduser\n",
|
||||
"\n",
|
||||
"from langchain.llms import LlamaCpp\n",
|
||||
"\n",
|
||||
"model_path = expanduser(\"~/Models/llama-2-7b-chat.Q4_0.gguf\")\n",
|
||||
"\n",
|
||||
"llm = LlamaCpp(\n",
|
||||
" model_path=model_path,\n",
|
||||
" streaming=False,\n",
|
||||
")\n",
|
||||
"model = Llama2Chat(llm=llm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "50498d96",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"and used in the same way as in the previous example."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "90782b96",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)\n",
|
||||
"chain = LLMChain(llm=model, prompt=prompt_template, memory=memory)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "2160b26d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" Of course! Vienna is a beautiful city with a rich history and culture. Here are some of the top tourist attractions you might want to consider visiting:\n",
|
||||
"1. Schönbrunn Palace\n",
|
||||
"2. St. Stephen's Cathedral\n",
|
||||
"3. Hofburg Palace\n",
|
||||
"4. Belvedere Palace\n",
|
||||
"5. Prater Park\n",
|
||||
"6. MuseumsQuartier\n",
|
||||
"7. Ringstrasse\n",
|
||||
"8. Vienna State Opera\n",
|
||||
"9. Kunsthistorisches Museum\n",
|
||||
"10. Imperial Palace\n",
|
||||
"\n",
|
||||
"These are just a few of the many amazing places to see in Vienna. Each one has its own unique history and charm, so I hope you enjoy exploring this beautiful city!\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"llama_print_timings: load time = 250.46 ms\n",
|
||||
"llama_print_timings: sample time = 56.40 ms / 144 runs ( 0.39 ms per token, 2553.37 tokens per second)\n",
|
||||
"llama_print_timings: prompt eval time = 1444.25 ms / 47 tokens ( 30.73 ms per token, 32.54 tokens per second)\n",
|
||||
"llama_print_timings: eval time = 8832.02 ms / 143 runs ( 61.76 ms per token, 16.19 tokens per second)\n",
|
||||
"llama_print_timings: total time = 10645.94 ms\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\n",
|
||||
" chain.run(\n",
|
||||
" text=\"What can I see in Vienna? Propose a few locations. Names only, no details.\"\n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "d9ce06e3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Llama.generate: prefix-match hit\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" Of course! St. Stephen's Cathedral (also known as Stephansdom) is a stunning Gothic-style cathedral located in the heart of Vienna, Austria. It is one of the most recognizable landmarks in the city and is considered a symbol of Vienna.\n",
|
||||
"Here are some interesting facts about St. Stephen's Cathedral:\n",
|
||||
"1. History: The construction of St. Stephen's Cathedral began in the 12th century on the site of a former Romanesque church, and it took over 600 years to complete. The cathedral has been renovated and expanded several times throughout its history, with the most significant renovation taking place in the 19th century.\n",
|
||||
"2. Architecture: St. Stephen's Cathedral is built in the Gothic style, characterized by its tall spires, pointed arches, and intricate stone carvings. The cathedral features a mix of Romanesque, Gothic, and Baroque elements, making it a unique blend of styles.\n",
|
||||
"3. Design: The cathedral's design is based on the plan of a cross with a long nave and two shorter arms extending from it. The main altar is\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"llama_print_timings: load time = 250.46 ms\n",
|
||||
"llama_print_timings: sample time = 100.60 ms / 256 runs ( 0.39 ms per token, 2544.73 tokens per second)\n",
|
||||
"llama_print_timings: prompt eval time = 5128.71 ms / 160 tokens ( 32.05 ms per token, 31.20 tokens per second)\n",
|
||||
"llama_print_timings: eval time = 16193.02 ms / 255 runs ( 63.50 ms per token, 15.75 tokens per second)\n",
|
||||
"llama_print_timings: total time = 21988.57 ms\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(chain.run(text=\"Tell me more about #2.\"))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -99,7 +99,7 @@
|
||||
"\n",
|
||||
"Language param : It's a list of language codes in a descending priority, `en` by default.\n",
|
||||
"\n",
|
||||
"translation param : It's a translate preference when the youtube does'nt have your select language, `en` by default."
|
||||
"translation param : It's a translate preference, you can translate available transcript to your preferred language."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
"id": "91c6a7ef",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# MongodDB\n",
|
||||
"# MongoDB\n",
|
||||
"\n",
|
||||
">`MongoDB` is a source-available cross-platform document-oriented database program. Classified as a NoSQL database program, `MongoDB` uses `JSON`-like documents with optional schemas.\n",
|
||||
">\n",
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
# LangChain Decorators ✨
|
||||
|
||||
lanchchain decorators is a layer on the top of LangChain that provides syntactic sugar 🍭 for writing custom langchain prompts and chains
|
||||
|
||||
For Feedback, Issues, Contributions - please raise an issue here:
|
||||
[ju-bezdek/langchain-decorators](https://github.com/ju-bezdek/langchain-decorators)
|
||||
~~~
|
||||
Disclaimer: `LangChain decorators` is not created by the LangChain team and is not supported by it.
|
||||
~~~
|
||||
|
||||
>`LangChain decorators` is a layer on the top of LangChain that provides syntactic sugar 🍭 for writing custom langchain prompts and chains
|
||||
>
|
||||
>For Feedback, Issues, Contributions - please raise an issue here:
|
||||
>[ju-bezdek/langchain-decorators](https://github.com/ju-bezdek/langchain-decorators)
|
||||
|
||||
|
||||
Main principles and benefits:
|
||||
@@ -17,7 +20,6 @@ Main principles and benefits:
|
||||
- easily share parameters between the prompts by binding them to one class
|
||||
|
||||
|
||||
|
||||
Here is a simple example of a code written with **LangChain Decorators ✨**
|
||||
|
||||
``` python
|
||||
|
||||
255
docs/docs/integrations/retrievers/embedchain.ipynb
Normal file
255
docs/docs/integrations/retrievers/embedchain.ipynb
Normal file
@@ -0,0 +1,255 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2f0f85ac-9c49-4111-a320-e53bccc99b13",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Embedchain\n",
|
||||
"\n",
|
||||
"Embedchain is a RAG framework to create data pipelines. It loads, indexes, retrieves and syncs all the data.\n",
|
||||
"\n",
|
||||
"It is available as an [open source package](https://github.com/embedchain/embedchain) and as a [hosted platform solution](https://app.embedchain.ai/).\n",
|
||||
"\n",
|
||||
"This notebook shows how to use a retriever that uses Embedchain."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e48de822-307b-4284-96e7-c91f11ce005b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Installation\n",
|
||||
"\n",
|
||||
"First you will need to install the [`embedchain` package](https://pypi.org/project/embedchain/). \n",
|
||||
"\n",
|
||||
"You can install the package by running "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "c690a78c-5999-4072-b4e1-2712ff73f950",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#!pip install --upgrade embedchain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bc89ba12-6ebd-4cd6-8c85-7410531579ff",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Create New Retriever\n",
|
||||
"\n",
|
||||
"`EmbedchainRetriever` has a static `.create()` factory method that takes the following arguments:\n",
|
||||
"\n",
|
||||
"* `yaml_path: string` optional -- Path to the YAML configuration file. If not provided, a default configuration is used. You can browse the [docs](https://docs.embedchain.ai/) to explore various customization options."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "8e639bd4-2e60-487b-b7aa-f7e6b921b069",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdin",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" ········\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Setup API Key\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "223fbc76-91ad-4504-87e9-980fb0e027fc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.retrievers import EmbedchainRetriever\n",
|
||||
"\n",
|
||||
"# create retriever with default options\n",
|
||||
"retriever = EmbedchainRetriever.create()\n",
|
||||
"\n",
|
||||
"# or if you want to customize, pass the yaml config path\n",
|
||||
"# retriever = EmbedchainRetiever.create(yaml_path=\"config.yaml\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "536f3a1d-3491-45b5-9f25-869bd6fb6d6a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Add Data\n",
|
||||
"\n",
|
||||
"In embedchain, you can as many supported data types as possible. You can browse our [docs](https://docs.embedchain.ai/) to see the data types supported.\n",
|
||||
"\n",
|
||||
"Embedchain automatically deduces the types of the data. So you can add a string, URL or local file path."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "31262be3-7d0d-42e8-9253-052160576dc7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Inserting batches in chromadb: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00, 2.22s/it]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Successfully saved https://en.wikipedia.org/wiki/Elon_Musk (DataType.WEB_PAGE). New chunks count: 378\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Inserting batches in chromadb: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.17s/it]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Successfully saved https://www.forbes.com/profile/elon-musk (DataType.WEB_PAGE). New chunks count: 13\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Inserting batches in chromadb: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00, 2.25s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Successfully saved https://www.youtube.com/watch?v=RcYjXbSJBN8 (DataType.YOUTUBE_VIDEO). New chunks count: 53\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['1eab8dd1ffa92906f7fc839862871ca5',\n",
|
||||
" '8cf46026cabf9b05394a2658bd1fe890',\n",
|
||||
" 'da3227cdbcedb018e05c47b774d625f6']"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"retriever.add_texts(\n",
|
||||
" [\n",
|
||||
" \"https://en.wikipedia.org/wiki/Elon_Musk\",\n",
|
||||
" \"https://www.forbes.com/profile/elon-musk\",\n",
|
||||
" \"https://www.youtube.com/watch?v=RcYjXbSJBN8\",\n",
|
||||
" ]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e1f34a62-7f8e-4c03-8e10-c317ed3296aa",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Use Retriever\n",
|
||||
"\n",
|
||||
"You can now use the retrieve to find relevant documents given a query"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "6106baf9-652a-4a94-b2d7-d6a5d2917975",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result = retriever.get_relevant_documents(\n",
|
||||
" \"How many companies does Elon Musk run and name those?\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "1deae5d0-e0fa-431d-b164-e9680ef3e69b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='Views Filmography Companies Zip2 X.com PayPal SpaceX Starlink Tesla, Inc. Energycriticismlitigation OpenAI Neuralink The Boring Company Thud X Corp. Twitteracquisitiontenure as CEO xAI In popular culture Elon Musk (Isaacson) Elon Musk (Vance) Ludicrous Power Play \"Members Only\" \"The Platonic Permutation\" \"The Musk Who Fell to Earth\" \"One Crew over the Crewcoo\\'s Morty\" Elon Musk\\'s Crash Course Related Boring Test Tunnel Hyperloop Musk family Musk vs. Zuckerberg SolarCity Tesla Roadster in space', metadata={'source': 'https://en.wikipedia.org/wiki/Elon_Musk', 'document_id': 'c33c05d0-5028-498b-b5e3-c43a4f9e8bf8--3342161a0fbc19e91f6bf387204aa30fbb2cea05abc81882502476bde37b9392'}),\n",
|
||||
" Document(page_content='Elon Musk PROFILEElon MuskCEO, Tesla$241.2B$508M (0.21%)Real Time Net Worthas of 11/18/23Reflects change since 5 pm ET of prior trading day. 1 in the world todayPhoto by Martin Schoeller for ForbesAbout Elon MuskElon Musk cofounded six companies, including electric car maker Tesla, rocket producer SpaceX and tunneling startup Boring Company.He owns about 21% of Tesla between stock and options, but has pledged more than half his shares as collateral for personal loans of up to $3.5', metadata={'source': 'https://www.forbes.com/profile/elon-musk', 'document_id': 'c33c05d0-5028-498b-b5e3-c43a4f9e8bf8--3c8573134c575fafc025e9211413723e1f7a725b5936e8ee297fb7fb63bdd01a'}),\n",
|
||||
" Document(page_content='to form PayPal. In October 2002, eBay acquired PayPal for $1.5 billion, and that same year, with $100 million of the money he made, Musk founded SpaceX, a spaceflight services company. In 2004, he became an early investor in electric vehicle manufacturer Tesla Motors, Inc. (now Tesla, Inc.). He became its chairman and product architect, assuming the position of CEO in 2008. In 2006, Musk helped create SolarCity, a solar-energy company that was acquired by Tesla in 2016 and became Tesla Energy.', metadata={'source': 'https://en.wikipedia.org/wiki/Elon_Musk', 'document_id': 'c33c05d0-5028-498b-b5e3-c43a4f9e8bf8--3342161a0fbc19e91f6bf387204aa30fbb2cea05abc81882502476bde37b9392'})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b3f26c2b-048d-4588-90a0-50f5c9c35837",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -12,7 +12,8 @@
|
||||
"- AzureCogsImageAnalysisTool: used to extract caption, objects, tags, and text from images. (Note: this tool is not available on Mac OS yet, due to the dependency on `azure-ai-vision` package, which is only supported on Windows and Linux currently.)\n",
|
||||
"- AzureCogsFormRecognizerTool: used to extract text, tables, and key-value pairs from documents.\n",
|
||||
"- AzureCogsSpeech2TextTool: used to transcribe speech to text.\n",
|
||||
"- AzureCogsText2SpeechTool: used to synthesize text to speech."
|
||||
"- AzureCogsText2SpeechTool: used to synthesize text to speech.\n",
|
||||
"- AzureCogsTextAnalyticsHealthTool: used to extract healthcare entities."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -32,6 +33,7 @@
|
||||
"source": [
|
||||
"# !pip install --upgrade azure-ai-formrecognizer > /dev/null\n",
|
||||
"# !pip install --upgrade azure-cognitiveservices-speech > /dev/null\n",
|
||||
"# !pip install --upgrade azure-ai-textanalytics > /dev/null\n",
|
||||
"\n",
|
||||
"# For Windows/Linux\n",
|
||||
"# !pip install --upgrade azure-ai-vision > /dev/null"
|
||||
@@ -60,7 +62,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -101,7 +103,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -111,7 +113,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -240,6 +242,65 @@
|
||||
"display.display(audio)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mAction:\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"action\": \"azure_cognitive_services_text_analyics_health\",\n",
|
||||
" \"action_input\": \"The patient is a 54-year-old gentleman with a history of progressive angina over the past several months. The patient had a cardiac catheterization in July of this year revealing total occlusion of the RCA and 50% left main disease, with a strong family history of coronary artery disease with a brother dying at the age of 52 from a myocardial infarction and another brother who is status post coronary artery bypass grafting. The patient had a stress echocardiogram done on July, 2001, which showed no wall motion abnormalities, but this was a difficult study due to body habitus. The patient went for six minutes with minimal ST depressions in the anterior lateral leads, thought due to fatigue and wrist pain, his anginal equivalent. Due to the patient's increased symptoms and family history and history left main disease with total occasional of his RCA was referred for revascularization with open heart surgery.\"\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3mThe text conatins the following healthcare entities: 54-year-old is a healthcare entity of type Age, gentleman is a healthcare entity of type Gender, progressive angina is a healthcare entity of type Diagnosis, past several months is a healthcare entity of type Time, cardiac catheterization is a healthcare entity of type ExaminationName, July of this year is a healthcare entity of type Time, total is a healthcare entity of type ConditionQualifier, occlusion is a healthcare entity of type SymptomOrSign, RCA is a healthcare entity of type BodyStructure, 50 is a healthcare entity of type MeasurementValue, % is a healthcare entity of type MeasurementUnit, left main is a healthcare entity of type BodyStructure, disease is a healthcare entity of type Diagnosis, family is a healthcare entity of type FamilyRelation, coronary artery disease is a healthcare entity of type Diagnosis, brother is a healthcare entity of type FamilyRelation, dying is a healthcare entity of type Diagnosis, 52 is a healthcare entity of type Age, myocardial infarction is a healthcare entity of type Diagnosis, brother is a healthcare entity of type FamilyRelation, coronary artery bypass grafting is a healthcare entity of type TreatmentName, stress echocardiogram is a healthcare entity of type ExaminationName, July, 2001 is a healthcare entity of type Time, wall motion abnormalities is a healthcare entity of type SymptomOrSign, body habitus is a healthcare entity of type SymptomOrSign, six minutes is a healthcare entity of type Time, minimal is a healthcare entity of type ConditionQualifier, ST depressions in the anterior lateral leads is a healthcare entity of type SymptomOrSign, fatigue is a healthcare entity of type SymptomOrSign, wrist pain is a healthcare entity of type SymptomOrSign, anginal equivalent is a healthcare entity of type SymptomOrSign, increased is a healthcare entity of type Course, symptoms is a healthcare entity of type SymptomOrSign, family is a healthcare entity of type FamilyRelation, left is a healthcare entity of type Direction, main is a healthcare entity of type BodyStructure, disease is a healthcare entity of type Diagnosis, occasional is a healthcare entity of type Course, RCA is a healthcare entity of type BodyStructure, revascularization is a healthcare entity of type TreatmentName, open heart surgery is a healthcare entity of type TreatmentName\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I know what to respond\n",
|
||||
"Action:\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"action\": \"Final Answer\",\n",
|
||||
" \"action_input\": \"The text contains the following diagnoses: progressive angina, coronary artery disease, myocardial infarction, and coronary artery bypass grafting.\"\n",
|
||||
"}\n",
|
||||
"```\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'The text contains the following diagnoses: progressive angina, coronary artery disease, myocardial infarction, and coronary artery bypass grafting.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent.run(\n",
|
||||
" \"\"\"The patient is a 54-year-old gentleman with a history of progressive angina over the past several months.\n",
|
||||
"The patient had a cardiac catheterization in July of this year revealing total occlusion of the RCA and 50% left main disease ,\n",
|
||||
"with a strong family history of coronary artery disease with a brother dying at the age of 52 from a myocardial infarction and\n",
|
||||
"another brother who is status post coronary artery bypass grafting. The patient had a stress echocardiogram done on July , 2001 ,\n",
|
||||
"which showed no wall motion abnormalities , but this was a difficult study due to body habitus. The patient went for six minutes with\n",
|
||||
"minimal ST depressions in the anterior lateral leads , thought due to fatigue and wrist pain , his anginal equivalent. Due to the patient's\n",
|
||||
"increased symptoms and family history and history left main disease with total occasional of his RCA was referred for revascularization with open heart surgery.\n",
|
||||
"\n",
|
||||
"List all the diagnoses.\n",
|
||||
"\"\"\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -264,7 +325,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -57,11 +57,14 @@
|
||||
"1. **Load**: First we need to load our data. We'll use [DocumentLoaders](/docs/modules/data_connection/document_loaders/) for this.\n",
|
||||
"2. **Split**: [Text splitters](/docs/modules/data_connection/document_transformers/) break large `Documents` into smaller chunks. This is useful both for indexing data and for passing it in to a model, since large chunks are harder to search over and won't in a model's finite context window.\n",
|
||||
"3. **Store**: We need somewhere to store and index our splits, so that they can later be searched over. This is often done using a [VectorStore](/docs/modules/data_connection/vectorstores/) and [Embeddings](/docs/modules/data_connection/text_embedding/) model.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"#### Retrieval and generation\n",
|
||||
"4. **Retrieve**: Given a user input, relevant splits are retrieved from storage using a [Retriever](/docs/modules/data_connection/retrievers/).\n",
|
||||
"5. **Generate**: A [ChatModel](/docs/modules/model_io/chat_models) / [LLM](/docs/modules/model_io/llms/) produces an answer using a prompt that includes the question and the retrieved data\n",
|
||||
"\n",
|
||||
""
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
BIN
docs/static/img/qa_flow.jpeg
vendored
BIN
docs/static/img/qa_flow.jpeg
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 173 KiB |
BIN
docs/static/img/rag_indexing.png
vendored
Normal file
BIN
docs/static/img/rag_indexing.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 129 KiB |
BIN
docs/static/img/rag_retrieval_generation.png
vendored
Normal file
BIN
docs/static/img/rag_retrieval_generation.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 58 KiB |
16095
docs/static/svg/langchain_stack.svg
vendored
Normal file
16095
docs/static/svg/langchain_stack.svg
vendored
Normal file
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 956 KiB |
@@ -7,7 +7,7 @@ from langchain_cli.namespaces import app as app_namespace
|
||||
from langchain_cli.namespaces import template as template_namespace
|
||||
from langchain_cli.utils.packages import get_langserve_export, get_package_root
|
||||
|
||||
__version__ = "0.0.18"
|
||||
__version__ = "0.0.19"
|
||||
|
||||
app = typer.Typer(no_args_is_help=True, add_completion=False)
|
||||
app.add_typer(
|
||||
|
||||
@@ -6,7 +6,7 @@ RUN poetry config virtualenvs.create false
|
||||
|
||||
WORKDIR /code
|
||||
|
||||
COPY ./pyproject.toml ./poetry.lock* ./
|
||||
COPY ./pyproject.toml ./README.md ./poetry.lock* ./
|
||||
|
||||
COPY ./packages ./packages
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ packages = [
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
uvicorn = "^0.23.2"
|
||||
langserve = {extras = ["server"], version = ">=0.0.22"}
|
||||
langserve = {extras = ["server"], version = ">=0.0.30"}
|
||||
pydantic = "<2"
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-cli"
|
||||
version = "0.0.18"
|
||||
version = "0.0.19"
|
||||
description = "CLI for interacting with LangChain"
|
||||
authors = ["Erick Friis <erick@langchain.dev>"]
|
||||
readme = "README.md"
|
||||
|
||||
56
libs/core/Makefile
Normal file
56
libs/core/Makefile
Normal file
@@ -0,0 +1,56 @@
|
||||
.PHONY: all format lint test tests test_watch integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
test:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
test_watch:
|
||||
poetry run ptw --snapshot-update --now . -- -x tests/unit_tests
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/experimental --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
|
||||
lint lint_diff:
|
||||
./scripts/check_pydantic.sh .
|
||||
./scripts/check_imports.sh
|
||||
poetry run ruff .
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
||||
@echo 'test_watch - run unit tests in watch mode'
|
||||
1
libs/core/README.md
Normal file
1
libs/core/README.md
Normal file
@@ -0,0 +1 @@
|
||||
# langchain-core
|
||||
7
libs/core/langchain_core/__init__.py
Normal file
7
libs/core/langchain_core/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from importlib import metadata
|
||||
|
||||
try:
|
||||
__version__ = metadata.version(__package__)
|
||||
except metadata.PackageNotFoundError:
|
||||
# Case where package metadata is not available.
|
||||
__version__ = ""
|
||||
29
libs/core/langchain_core/_api/__init__.py
Normal file
29
libs/core/langchain_core/_api/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Helper functions for managing the LangChain API.
|
||||
|
||||
This module is only relevant for LangChain developers, not for users.
|
||||
|
||||
.. warning::
|
||||
|
||||
This module and its submodules are for internal use only. Do not use them
|
||||
in your own code. We may change the API at any time with no warning.
|
||||
|
||||
"""
|
||||
|
||||
from .deprecation import (
|
||||
LangChainDeprecationWarning,
|
||||
deprecated,
|
||||
suppress_langchain_deprecation_warning,
|
||||
surface_langchain_deprecation_warnings,
|
||||
warn_deprecated,
|
||||
)
|
||||
from .path import as_import_path, get_relative_path
|
||||
|
||||
__all__ = [
|
||||
"as_import_path",
|
||||
"deprecated",
|
||||
"get_relative_path",
|
||||
"LangChainDeprecationWarning",
|
||||
"suppress_langchain_deprecation_warning",
|
||||
"surface_langchain_deprecation_warnings",
|
||||
"warn_deprecated",
|
||||
]
|
||||
341
libs/core/langchain_core/_api/deprecation.py
Normal file
341
libs/core/langchain_core/_api/deprecation.py
Normal file
@@ -0,0 +1,341 @@
|
||||
"""Helper functions for deprecating parts of the LangChain API.
|
||||
|
||||
This module was adapted from matplotlibs _api/deprecation.py module:
|
||||
|
||||
https://github.com/matplotlib/matplotlib/blob/main/lib/matplotlib/_api/deprecation.py
|
||||
|
||||
.. warning::
|
||||
|
||||
This module is for internal use only. Do not use it in your own code.
|
||||
We may change the API at any time with no warning.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Generator, Type, TypeVar
|
||||
|
||||
|
||||
class LangChainDeprecationWarning(DeprecationWarning):
|
||||
"""A class for issuing deprecation warnings for LangChain users."""
|
||||
|
||||
|
||||
class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
|
||||
"""A class for issuing deprecation warnings for LangChain users."""
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
T = TypeVar("T", Type, Callable)
|
||||
|
||||
|
||||
def deprecated(
|
||||
since: str,
|
||||
*,
|
||||
message: str = "",
|
||||
name: str = "",
|
||||
alternative: str = "",
|
||||
pending: bool = False,
|
||||
obj_type: str = "",
|
||||
addendum: str = "",
|
||||
removal: str = "",
|
||||
) -> Callable[[T], T]:
|
||||
"""Decorator to mark a function, a class, or a property as deprecated.
|
||||
|
||||
When deprecating a classmethod, a staticmethod, or a property, the
|
||||
``@deprecated`` decorator should go *under* ``@classmethod`` and
|
||||
``@staticmethod`` (i.e., `deprecated` should directly decorate the
|
||||
underlying callable), but *over* ``@property``.
|
||||
|
||||
When deprecating a class ``C`` intended to be used as a base class in a
|
||||
multiple inheritance hierarchy, ``C`` *must* define an ``__init__`` method
|
||||
(if ``C`` instead inherited its ``__init__`` from its own base class, then
|
||||
``@deprecated`` would mess up ``__init__`` inheritance when installing its
|
||||
own (deprecation-emitting) ``C.__init__``).
|
||||
|
||||
Parameters are the same as for `warn_deprecated`, except that *obj_type*
|
||||
defaults to 'class' if decorating a class, 'attribute' if decorating a
|
||||
property, and 'function' otherwise.
|
||||
|
||||
Arguments:
|
||||
since : str
|
||||
The release at which this API became deprecated.
|
||||
message : str, optional
|
||||
Override the default deprecation message. The %(since)s,
|
||||
%(name)s, %(alternative)s, %(obj_type)s, %(addendum)s,
|
||||
and %(removal)s format specifiers will be replaced by the
|
||||
values of the respective arguments passed to this function.
|
||||
name : str, optional
|
||||
The name of the deprecated object.
|
||||
alternative : str, optional
|
||||
An alternative API that the user may use in place of the
|
||||
deprecated API. The deprecation warning will tell the user
|
||||
about this alternative if provided.
|
||||
pending : bool, optional
|
||||
If True, uses a PendingDeprecationWarning instead of a
|
||||
DeprecationWarning. Cannot be used together with removal.
|
||||
obj_type : str, optional
|
||||
The object type being deprecated.
|
||||
addendum : str, optional
|
||||
Additional text appended directly to the final message.
|
||||
removal : str, optional
|
||||
The expected removal version. With the default (an empty
|
||||
string), a removal version is automatically computed from
|
||||
since. Set to other Falsy values to not schedule a removal
|
||||
date. Cannot be used together with pending.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@deprecated('1.4.0')
|
||||
def the_function_to_deprecate():
|
||||
pass
|
||||
"""
|
||||
|
||||
def deprecate(
|
||||
obj: T,
|
||||
*,
|
||||
_obj_type: str = obj_type,
|
||||
_name: str = name,
|
||||
_message: str = message,
|
||||
_alternative: str = alternative,
|
||||
_pending: bool = pending,
|
||||
_addendum: str = addendum,
|
||||
) -> T:
|
||||
"""Implementation of the decorator returned by `deprecated`."""
|
||||
if isinstance(obj, type):
|
||||
if not _obj_type:
|
||||
_obj_type = "class"
|
||||
wrapped = obj.__init__ # type: ignore
|
||||
_name = _name or obj.__name__
|
||||
old_doc = obj.__doc__
|
||||
|
||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
|
||||
"""Finalize the deprecation of a class."""
|
||||
try:
|
||||
obj.__doc__ = new_doc
|
||||
except AttributeError: # Can't set on some extension objects.
|
||||
pass
|
||||
obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc]
|
||||
wrapper
|
||||
)
|
||||
return obj
|
||||
|
||||
elif isinstance(obj, property):
|
||||
if not _obj_type:
|
||||
_obj_type = "attribute"
|
||||
wrapped = None
|
||||
_name = _name or obj.fget.__name__
|
||||
old_doc = obj.__doc__
|
||||
|
||||
class _deprecated_property(type(obj)): # type: ignore
|
||||
"""A deprecated property."""
|
||||
|
||||
def __get__(self, instance, owner=None): # type: ignore
|
||||
if instance is not None or owner is not None:
|
||||
emit_warning()
|
||||
return super().__get__(instance, owner)
|
||||
|
||||
def __set__(self, instance, value): # type: ignore
|
||||
if instance is not None:
|
||||
emit_warning()
|
||||
return super().__set__(instance, value)
|
||||
|
||||
def __delete__(self, instance): # type: ignore
|
||||
if instance is not None:
|
||||
emit_warning()
|
||||
return super().__delete__(instance)
|
||||
|
||||
def __set_name__(self, owner, set_name): # type: ignore
|
||||
nonlocal _name
|
||||
if _name == "<lambda>":
|
||||
_name = set_name
|
||||
|
||||
def finalize(_: Any, new_doc: str) -> Any: # type: ignore
|
||||
"""Finalize the property."""
|
||||
return _deprecated_property(
|
||||
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
|
||||
)
|
||||
|
||||
else:
|
||||
if not _obj_type:
|
||||
_obj_type = "function"
|
||||
wrapped = obj
|
||||
_name = _name or obj.__name__ # type: ignore
|
||||
old_doc = wrapped.__doc__
|
||||
|
||||
def finalize( # type: ignore
|
||||
wrapper: Callable[..., Any], new_doc: str
|
||||
) -> T:
|
||||
"""Wrap the wrapped function using the wrapper and update the docstring.
|
||||
|
||||
Args:
|
||||
wrapper: The wrapper function.
|
||||
new_doc: The new docstring.
|
||||
|
||||
Returns:
|
||||
The wrapped function.
|
||||
"""
|
||||
wrapper = functools.wraps(wrapped)(wrapper)
|
||||
wrapper.__doc__ = new_doc
|
||||
return wrapper
|
||||
|
||||
def emit_warning() -> None:
|
||||
"""Emit the warning."""
|
||||
warn_deprecated(
|
||||
since,
|
||||
message=_message,
|
||||
name=_name,
|
||||
alternative=_alternative,
|
||||
pending=_pending,
|
||||
obj_type=_obj_type,
|
||||
addendum=_addendum,
|
||||
removal=removal,
|
||||
)
|
||||
|
||||
def warning_emitting_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Wrapper for the original wrapped callable that emits a warning.
|
||||
|
||||
Args:
|
||||
*args: The positional arguments to the function.
|
||||
**kwargs: The keyword arguments to the function.
|
||||
|
||||
Returns:
|
||||
The return value of the function being wrapped.
|
||||
"""
|
||||
emit_warning()
|
||||
return wrapped(*args, **kwargs)
|
||||
|
||||
old_doc = inspect.cleandoc(old_doc or "").strip("\n")
|
||||
|
||||
if not old_doc:
|
||||
new_doc = "[*Deprecated*]"
|
||||
else:
|
||||
new_doc = f"[*Deprecated*] {old_doc}"
|
||||
|
||||
# Modify the docstring to include a deprecation notice.
|
||||
notes_header = "\nNotes\n-----"
|
||||
components = [
|
||||
message,
|
||||
f"Use {alternative} instead." if alternative else "",
|
||||
addendum,
|
||||
]
|
||||
details = " ".join([component.strip() for component in components if component])
|
||||
new_doc += (
|
||||
f"[*Deprecated*] {old_doc}\n"
|
||||
f"{notes_header if notes_header not in old_doc else ''}\n"
|
||||
f".. deprecated:: {since}\n"
|
||||
f" {details}"
|
||||
)
|
||||
|
||||
return finalize(warning_emitting_wrapper, new_doc)
|
||||
|
||||
return deprecate
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_langchain_deprecation_warning() -> Generator[None, None, None]:
|
||||
"""Context manager to suppress LangChainDeprecationWarning."""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", LangChainDeprecationWarning)
|
||||
warnings.simplefilter("ignore", LangChainPendingDeprecationWarning)
|
||||
yield
|
||||
|
||||
|
||||
def warn_deprecated(
|
||||
since: str,
|
||||
*,
|
||||
message: str = "",
|
||||
name: str = "",
|
||||
alternative: str = "",
|
||||
pending: bool = False,
|
||||
obj_type: str = "",
|
||||
addendum: str = "",
|
||||
removal: str = "",
|
||||
) -> None:
|
||||
"""Display a standardized deprecation.
|
||||
|
||||
Arguments:
|
||||
since : str
|
||||
The release at which this API became deprecated.
|
||||
message : str, optional
|
||||
Override the default deprecation message. The %(since)s,
|
||||
%(name)s, %(alternative)s, %(obj_type)s, %(addendum)s,
|
||||
and %(removal)s format specifiers will be replaced by the
|
||||
values of the respective arguments passed to this function.
|
||||
name : str, optional
|
||||
The name of the deprecated object.
|
||||
alternative : str, optional
|
||||
An alternative API that the user may use in place of the
|
||||
deprecated API. The deprecation warning will tell the user
|
||||
about this alternative if provided.
|
||||
pending : bool, optional
|
||||
If True, uses a PendingDeprecationWarning instead of a
|
||||
DeprecationWarning. Cannot be used together with removal.
|
||||
obj_type : str, optional
|
||||
The object type being deprecated.
|
||||
addendum : str, optional
|
||||
Additional text appended directly to the final message.
|
||||
removal : str, optional
|
||||
The expected removal version. With the default (an empty
|
||||
string), a removal version is automatically computed from
|
||||
since. Set to other Falsy values to not schedule a removal
|
||||
date. Cannot be used together with pending.
|
||||
"""
|
||||
if pending and removal:
|
||||
raise ValueError("A pending deprecation cannot have a scheduled removal")
|
||||
|
||||
if not pending:
|
||||
if not removal:
|
||||
removal = f"in {removal}" if removal else "within ?? minor releases"
|
||||
raise NotImplementedError(
|
||||
f"Need to determine which default deprecation schedule to use. "
|
||||
f"{removal}"
|
||||
)
|
||||
else:
|
||||
removal = f"in {removal}"
|
||||
|
||||
if not message:
|
||||
message = ""
|
||||
|
||||
if obj_type:
|
||||
message += f"The {obj_type} `{name}`"
|
||||
else:
|
||||
message += f"`{name}`"
|
||||
|
||||
if pending:
|
||||
message += " will be deprecated in a future version"
|
||||
else:
|
||||
message += f" was deprecated in LangChain {since}"
|
||||
|
||||
if removal:
|
||||
message += f" and will be removed {removal}"
|
||||
|
||||
if alternative:
|
||||
message += f". Use {alternative} instead."
|
||||
|
||||
if addendum:
|
||||
message += f" {addendum}"
|
||||
|
||||
warning_cls = (
|
||||
LangChainPendingDeprecationWarning if pending else LangChainDeprecationWarning
|
||||
)
|
||||
warning = warning_cls(message)
|
||||
warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=2)
|
||||
|
||||
|
||||
def surface_langchain_deprecation_warnings() -> None:
|
||||
"""Unmute LangChain deprecation warnings."""
|
||||
warnings.filterwarnings(
|
||||
"default",
|
||||
category=LangChainPendingDeprecationWarning,
|
||||
)
|
||||
|
||||
warnings.filterwarnings(
|
||||
"default",
|
||||
category=LangChainDeprecationWarning,
|
||||
)
|
||||
36
libs/core/langchain_core/_api/path.py
Normal file
36
libs/core/langchain_core/_api/path.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
HERE = Path(__file__).parent
|
||||
|
||||
# Get directory of langchain package
|
||||
PACKAGE_DIR = HERE.parent
|
||||
SEPARATOR = os.sep
|
||||
|
||||
|
||||
def get_relative_path(
|
||||
file: Union[Path, str], *, relative_to: Path = PACKAGE_DIR
|
||||
) -> str:
|
||||
"""Get the path of the file as a relative path to the package directory."""
|
||||
if isinstance(file, str):
|
||||
file = Path(file)
|
||||
return str(file.relative_to(relative_to))
|
||||
|
||||
|
||||
def as_import_path(
|
||||
file: Union[Path, str],
|
||||
*,
|
||||
suffix: Optional[str] = None,
|
||||
relative_to: Path = PACKAGE_DIR,
|
||||
) -> str:
|
||||
"""Path of the file as a LangChain import exclude langchain top namespace."""
|
||||
if isinstance(file, str):
|
||||
file = Path(file)
|
||||
path = get_relative_path(file, relative_to=relative_to)
|
||||
if file.is_file():
|
||||
path = path[: -len(file.suffix)]
|
||||
import_path = path.replace(SEPARATOR, ".")
|
||||
if suffix:
|
||||
import_path += "." + suffix
|
||||
return import_path
|
||||
89
libs/core/langchain_core/agents.py
Normal file
89
libs/core/langchain_core/agents.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Literal, Sequence, Union
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
class AgentAction(Serializable):
|
||||
"""A full description of an action for an ActionAgent to execute."""
|
||||
|
||||
tool: str
|
||||
"""The name of the Tool to execute."""
|
||||
tool_input: Union[str, dict]
|
||||
"""The input to pass in to the Tool."""
|
||||
log: str
|
||||
"""Additional information to log about the action.
|
||||
This log can be used in a few ways. First, it can be used to audit
|
||||
what exactly the LLM predicted to lead to this (tool, tool_input).
|
||||
Second, it can be used in future iterations to show the LLMs prior
|
||||
thoughts. This is useful when (tool, tool_input) does not contain
|
||||
full information about the LLM prediction (for example, any `thought`
|
||||
before the tool/tool_input)."""
|
||||
type: Literal["AgentAction"] = "AgentAction"
|
||||
|
||||
def __init__(
|
||||
self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
|
||||
):
|
||||
"""Override init to support instantiation by position for backward compat."""
|
||||
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether or not the class is serializable."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
# For backwards compatibility.
|
||||
return ["langchain", "schema", "agent"]
|
||||
|
||||
|
||||
class AgentActionMessageLog(AgentAction):
|
||||
message_log: Sequence[BaseMessage]
|
||||
"""Similar to log, this can be used to pass along extra
|
||||
information about what exact messages were predicted by the LLM
|
||||
before parsing out the (tool, tool_input). This is again useful
|
||||
if (tool, tool_input) cannot be used to fully recreate the LLM
|
||||
prediction, and you need that LLM prediction (for future agent iteration).
|
||||
Compared to `log`, this is useful when the underlying LLM is a
|
||||
ChatModel (and therefore returns messages rather than a string)."""
|
||||
# Ignoring type because we're overriding the type from AgentAction.
|
||||
# And this is the correct thing to do in this case.
|
||||
# The type literal is used for serialization purposes.
|
||||
type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
# For backwards compatibility.
|
||||
return ["langchain", "schema", "agent"]
|
||||
|
||||
|
||||
class AgentFinish(Serializable):
|
||||
"""The final return value of an ActionAgent."""
|
||||
|
||||
return_values: dict
|
||||
"""Dictionary of return values."""
|
||||
log: str
|
||||
"""Additional information to log about the return value.
|
||||
This is used to pass along the full LLM prediction, not just the parsed out
|
||||
return value. For example, if the full LLM prediction was
|
||||
`Final Answer: 2` you may want to just return `2` as a return value, but pass
|
||||
along the full string as a `log` (for debugging or observability purposes).
|
||||
"""
|
||||
type: Literal["AgentFinish"] = "AgentFinish"
|
||||
|
||||
def __init__(self, return_values: dict, log: str, **kwargs: Any):
|
||||
"""Override init to support instantiation by position for backward compat."""
|
||||
super().__init__(return_values=return_values, log=log, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether or not the class is serializable."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
# For backwards compatibility.
|
||||
return ["langchain", "schema", "agent"]
|
||||
24
libs/core/langchain_core/caches.py
Normal file
24
libs/core/langchain_core/caches.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from langchain_core.outputs import Generation
|
||||
|
||||
RETURN_VAL_TYPE = Sequence[Generation]
|
||||
|
||||
|
||||
class BaseCache(ABC):
|
||||
"""Base interface for cache."""
|
||||
|
||||
@abstractmethod
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
|
||||
@abstractmethod
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache that can take additional keyword arguments."""
|
||||
65
libs/core/langchain_core/callbacks/__init__.py
Normal file
65
libs/core/langchain_core/callbacks/__init__.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from langchain_core.callbacks.base import (
|
||||
AsyncCallbackHandler,
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
CallbackManagerMixin,
|
||||
Callbacks,
|
||||
ChainManagerMixin,
|
||||
LLMManagerMixin,
|
||||
RetrieverManagerMixin,
|
||||
RunManagerMixin,
|
||||
ToolManagerMixin,
|
||||
)
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainGroup,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
AsyncCallbackManagerForToolRun,
|
||||
AsyncParentRunManager,
|
||||
AsyncRunManager,
|
||||
BaseRunManager,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainGroup,
|
||||
CallbackManagerForChainRun,
|
||||
CallbackManagerForLLMRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
CallbackManagerForToolRun,
|
||||
ParentRunManager,
|
||||
RunManager,
|
||||
)
|
||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
__all__ = [
|
||||
"RetrieverManagerMixin",
|
||||
"LLMManagerMixin",
|
||||
"ChainManagerMixin",
|
||||
"ToolManagerMixin",
|
||||
"Callbacks",
|
||||
"CallbackManagerMixin",
|
||||
"RunManagerMixin",
|
||||
"BaseCallbackHandler",
|
||||
"AsyncCallbackHandler",
|
||||
"BaseCallbackManager",
|
||||
"BaseRunManager",
|
||||
"RunManager",
|
||||
"ParentRunManager",
|
||||
"AsyncRunManager",
|
||||
"AsyncParentRunManager",
|
||||
"CallbackManagerForLLMRun",
|
||||
"AsyncCallbackManagerForLLMRun",
|
||||
"CallbackManagerForChainRun",
|
||||
"AsyncCallbackManagerForChainRun",
|
||||
"CallbackManagerForToolRun",
|
||||
"AsyncCallbackManagerForToolRun",
|
||||
"CallbackManagerForRetrieverRun",
|
||||
"AsyncCallbackManagerForRetrieverRun",
|
||||
"CallbackManager",
|
||||
"CallbackManagerForChainGroup",
|
||||
"AsyncCallbackManager",
|
||||
"AsyncCallbackManagerForChainGroup",
|
||||
"StdOutCallbackHandler",
|
||||
"StreamingStdOutCallbackHandler",
|
||||
]
|
||||
599
libs/core/langchain_core/callbacks/base.py
Normal file
599
libs/core/langchain_core/callbacks/base.py
Normal file
@@ -0,0 +1,599 @@
|
||||
"""Base callback handler that can be used to handle callbacks in langchain."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||
|
||||
|
||||
class RetrieverManagerMixin:
|
||||
"""Mixin for Retriever callbacks."""
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever errors."""
|
||||
|
||||
def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever ends running."""
|
||||
|
||||
|
||||
class LLMManagerMixin:
|
||||
"""Mixin for LLM callbacks."""
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on new LLM token. Only available when streaming is enabled.
|
||||
|
||||
Args:
|
||||
token (str): The new token.
|
||||
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
|
||||
containing content and other information.
|
||||
"""
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
|
||||
class ChainManagerMixin:
|
||||
"""Mixin for chain callbacks."""
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain errors."""
|
||||
|
||||
def on_agent_action(
|
||||
self,
|
||||
action: AgentAction,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
|
||||
def on_agent_finish(
|
||||
self,
|
||||
finish: AgentFinish,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on agent end."""
|
||||
|
||||
|
||||
class ToolManagerMixin:
|
||||
"""Mixin for tool callbacks."""
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool errors."""
|
||||
|
||||
|
||||
class CallbackManagerMixin:
|
||||
"""Mixin for callback manager."""
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chat model starts running."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
||||
)
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever starts running."""
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
|
||||
class RunManagerMixin:
|
||||
"""Mixin for run manager."""
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
def on_retry(
|
||||
self,
|
||||
retry_state: RetryCallState,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on a retry event."""
|
||||
|
||||
|
||||
class BaseCallbackHandler(
|
||||
LLMManagerMixin,
|
||||
ChainManagerMixin,
|
||||
ToolManagerMixin,
|
||||
RetrieverManagerMixin,
|
||||
CallbackManagerMixin,
|
||||
RunManagerMixin,
|
||||
):
|
||||
"""Base callback handler that handles callbacks from LangChain."""
|
||||
|
||||
raise_error: bool = False
|
||||
|
||||
run_inline: bool = False
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_retry(self) -> bool:
|
||||
"""Whether to ignore retry callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_retriever(self) -> bool:
|
||||
"""Whether to ignore retriever callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_chat_model(self) -> bool:
|
||||
"""Whether to ignore chat model callbacks."""
|
||||
return False
|
||||
|
||||
|
||||
class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Async callback handler that handles callbacks from LangChain."""
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chat model starts running."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
||||
)
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
|
||||
async def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
async def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
async def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
async def on_text(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
async def on_retry(
|
||||
self,
|
||||
retry_state: RetryCallState,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on a retry event."""
|
||||
|
||||
async def on_agent_action(
|
||||
self,
|
||||
action: AgentAction,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on agent action."""
|
||||
|
||||
async def on_agent_finish(
|
||||
self,
|
||||
finish: AgentFinish,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever start."""
|
||||
|
||||
async def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever end."""
|
||||
|
||||
async def on_retriever_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever error."""
|
||||
|
||||
|
||||
T = TypeVar("T", bound="BaseCallbackManager")
|
||||
|
||||
|
||||
class BaseCallbackManager(CallbackManagerMixin):
|
||||
"""Base callback manager that handles callbacks from LangChain."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize callback manager."""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
self.inheritable_handlers: List[BaseCallbackHandler] = (
|
||||
inheritable_handlers or []
|
||||
)
|
||||
self.parent_run_id: Optional[UUID] = parent_run_id
|
||||
self.tags = tags or []
|
||||
self.inheritable_tags = inheritable_tags or []
|
||||
self.metadata = metadata or {}
|
||||
self.inheritable_metadata = inheritable_metadata or {}
|
||||
|
||||
def copy(self: T) -> T:
|
||||
"""Copy the callback manager."""
|
||||
return self.__class__(
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Whether the callback manager is async."""
|
||||
return False
|
||||
|
||||
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
if handler not in self.handlers:
|
||||
self.handlers.append(handler)
|
||||
if inherit and handler not in self.inheritable_handlers:
|
||||
self.inheritable_handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
self.handlers.remove(handler)
|
||||
self.inheritable_handlers.remove(handler)
|
||||
|
||||
def set_handlers(
|
||||
self, handlers: List[BaseCallbackHandler], inherit: bool = True
|
||||
) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
self.handlers = []
|
||||
self.inheritable_handlers = []
|
||||
for handler in handlers:
|
||||
self.add_handler(handler, inherit=inherit)
|
||||
|
||||
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
||||
"""Set handler as the only handler on the callback manager."""
|
||||
self.set_handlers([handler], inherit=inherit)
|
||||
|
||||
def add_tags(self, tags: List[str], inherit: bool = True) -> None:
|
||||
for tag in tags:
|
||||
if tag in self.tags:
|
||||
self.remove_tags([tag])
|
||||
self.tags.extend(tags)
|
||||
if inherit:
|
||||
self.inheritable_tags.extend(tags)
|
||||
|
||||
def remove_tags(self, tags: List[str]) -> None:
|
||||
for tag in tags:
|
||||
self.tags.remove(tag)
|
||||
self.inheritable_tags.remove(tag)
|
||||
|
||||
def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None:
|
||||
self.metadata.update(metadata)
|
||||
if inherit:
|
||||
self.inheritable_metadata.update(metadata)
|
||||
|
||||
def remove_metadata(self, keys: List[str]) -> None:
|
||||
for key in keys:
|
||||
self.metadata.pop(key)
|
||||
self.inheritable_metadata.pop(key)
|
||||
|
||||
|
||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||
1865
libs/core/langchain_core/callbacks/manager.py
Normal file
1865
libs/core/langchain_core/callbacks/manager.py
Normal file
File diff suppressed because it is too large
Load Diff
102
libs/core/langchain_core/callbacks/stdout.py
Normal file
102
libs/core/langchain_core/callbacks/stdout.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Callback Handler that prints to std out."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
from langchain_core.utils import print_text
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
def __init__(self, color: Optional[str] = None) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.color = color
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
|
||||
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
print("\n\033[1m> Finished chain.\033[0m")
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
print_text(action.log, color=color or self.color)
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
if observation_prefix is not None:
|
||||
print_text(f"\n{observation_prefix}")
|
||||
print_text(output, color=color or self.color)
|
||||
if llm_prefix is not None:
|
||||
print_text(f"\n{llm_prefix}")
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when agent ends."""
|
||||
print_text(text, color=color or self.color, end=end)
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
print_text(finish.log, color=color or self.color, end="\n")
|
||||
72
libs/core/langchain_core/callbacks/streaming_stdout.py
Normal file
72
libs/core/langchain_core/callbacks/streaming_stdout.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Callback Handler streams to stdout on new llm token."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler for streaming. Only works with LLMs that support streaming."""
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.flush()
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when chain errors."""
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
pass
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
67
libs/core/langchain_core/chat_history.py
Normal file
67
libs/core/langchain_core/chat_history.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
|
||||
|
||||
class BaseChatMessageHistory(ABC):
|
||||
"""Abstract base class for storing chat message history.
|
||||
|
||||
See `ChatMessageHistory` for default implementation.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class FileChatMessageHistory(BaseChatMessageHistory):
|
||||
storage_path: str
|
||||
session_id: str
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
|
||||
messages = json.loads(f.read())
|
||||
return messages_from_dict(messages)
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
messages = self.messages.append(_message_to_dict(message))
|
||||
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||
json.dump(f, messages)
|
||||
|
||||
def clear(self):
|
||||
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||
f.write("[]")
|
||||
"""
|
||||
|
||||
messages: List[BaseMessage]
|
||||
"""A list of Messages stored in-memory."""
|
||||
|
||||
def add_user_message(self, message: str) -> None:
|
||||
"""Convenience method for adding a human message string to the store.
|
||||
|
||||
Args:
|
||||
message: The string contents of a human message.
|
||||
"""
|
||||
self.add_message(HumanMessage(content=message))
|
||||
|
||||
def add_ai_message(self, message: str) -> None:
|
||||
"""Convenience method for adding an AI message string to the store.
|
||||
|
||||
Args:
|
||||
message: The string contents of an AI message.
|
||||
"""
|
||||
self.add_message(AIMessage(content=message))
|
||||
|
||||
@abstractmethod
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a Message object to the store.
|
||||
|
||||
Args:
|
||||
message: A BaseMessage object to store.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Remove all messages from the store"""
|
||||
13
libs/core/langchain_core/chat_sessions.py
Normal file
13
libs/core/langchain_core/chat_sessions.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from typing import Sequence, TypedDict
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
class ChatSession(TypedDict, total=False):
|
||||
"""Chat Session represents a single
|
||||
conversation, channel, or other group of messages."""
|
||||
|
||||
messages: Sequence[BaseMessage]
|
||||
"""The LangChain chat messages loaded from the source."""
|
||||
functions: Sequence[dict]
|
||||
"""The function calling specs for the messages."""
|
||||
4
libs/core/langchain_core/documents/__init__.py
Normal file
4
libs/core/langchain_core/documents/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from langchain_core.documents.base import Document
|
||||
from langchain_core.documents.transformers import BaseDocumentTransformer
|
||||
|
||||
__all__ = ["Document", "BaseDocumentTransformer"]
|
||||
28
libs/core/langchain_core/documents/base.py
Normal file
28
libs/core/langchain_core/documents/base.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Literal
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
|
||||
class Document(Serializable):
|
||||
"""Class for storing a piece of text and associated metadata."""
|
||||
|
||||
page_content: str
|
||||
"""String text."""
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||
documents, etc.).
|
||||
"""
|
||||
type: Literal["Document"] = "Document"
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
# For backwards compatibility.
|
||||
return ["langchain", "schema", "document"]
|
||||
74
libs/core/langchain_core/documents/transformers.py
Normal file
74
libs/core/langchain_core/documents/transformers.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Sequence
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
class BaseDocumentTransformer(ABC):
|
||||
"""Abstract base class for document transformation systems.
|
||||
|
||||
A document transformation system takes a sequence of Documents and returns a
|
||||
sequence of transformed Documents.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
|
||||
embeddings: Embeddings
|
||||
similarity_fn: Callable = cosine_similarity
|
||||
similarity_threshold: float = 0.95
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
stateful_documents = get_stateful_documents(documents)
|
||||
embedded_documents = _get_embeddings_from_stateful_docs(
|
||||
self.embeddings, stateful_documents
|
||||
)
|
||||
included_idxs = _filter_similar_embeddings(
|
||||
embedded_documents, self.similarity_fn, self.similarity_threshold
|
||||
)
|
||||
return [stateful_documents[i] for i in sorted(included_idxs)]
|
||||
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
@abstractmethod
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Transform a list of documents.
|
||||
|
||||
Args:
|
||||
documents: A sequence of Documents to be transformed.
|
||||
|
||||
Returns:
|
||||
A list of transformed Documents.
|
||||
"""
|
||||
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Asynchronously transform a list of documents.
|
||||
|
||||
Args:
|
||||
documents: A sequence of Documents to be transformed.
|
||||
|
||||
Returns:
|
||||
A list of transformed Documents.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.transform_documents, **kwargs), documents
|
||||
)
|
||||
27
libs/core/langchain_core/embeddings.py
Normal file
27
libs/core/langchain_core/embeddings.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
"""Interface for embedding models."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous Embed search docs."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.embed_documents, texts
|
||||
)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.embed_query, text
|
||||
)
|
||||
17
libs/core/langchain_core/env.py
Normal file
17
libs/core/langchain_core/env.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import platform
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_runtime_environment() -> dict:
|
||||
"""Get information about the LangChain runtime environment."""
|
||||
# Lazy import to avoid circular imports
|
||||
from langchain_core import __version__
|
||||
|
||||
return {
|
||||
"library_version": __version__,
|
||||
"library": "langchain",
|
||||
"platform": platform.platform(),
|
||||
"runtime": "python",
|
||||
"runtime_version": platform.python_version(),
|
||||
}
|
||||
18
libs/core/langchain_core/example_selectors/__init__.py
Normal file
18
libs/core/langchain_core/example_selectors/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Logic for selecting examples to include in prompts."""
|
||||
from langchain_core.example_selectors.base import BaseExampleSelector
|
||||
from langchain_core.example_selectors.length_based import (
|
||||
LengthBasedExampleSelector,
|
||||
)
|
||||
from langchain_core.example_selectors.semantic_similarity import (
|
||||
MaxMarginalRelevanceExampleSelector,
|
||||
SemanticSimilarityExampleSelector,
|
||||
sorted_values,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseExampleSelector",
|
||||
"LengthBasedExampleSelector",
|
||||
"MaxMarginalRelevanceExampleSelector",
|
||||
"SemanticSimilarityExampleSelector",
|
||||
"sorted_values",
|
||||
]
|
||||
15
libs/core/langchain_core/example_selectors/base.py
Normal file
15
libs/core/langchain_core/example_selectors/base.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Interface for selecting examples to include in prompts."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class BaseExampleSelector(ABC):
|
||||
"""Interface for selecting examples to include in prompts."""
|
||||
|
||||
@abstractmethod
|
||||
def add_example(self, example: Dict[str, str]) -> Any:
|
||||
"""Add new example to store for a key."""
|
||||
|
||||
@abstractmethod
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Select which examples to use based on the inputs."""
|
||||
63
libs/core/langchain_core/example_selectors/length_based.py
Normal file
63
libs/core/langchain_core/example_selectors/length_based.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Select examples based on length."""
|
||||
import re
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from langchain_core.example_selectors.base import BaseExampleSelector
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, validator
|
||||
|
||||
|
||||
def _get_length_based(text: str) -> int:
|
||||
return len(re.split("\n| ", text))
|
||||
|
||||
|
||||
class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
"""Select examples based on length."""
|
||||
|
||||
examples: List[dict]
|
||||
"""A list of the examples that the prompt template expects."""
|
||||
|
||||
example_prompt: PromptTemplate
|
||||
"""Prompt template used to format the examples."""
|
||||
|
||||
get_text_length: Callable[[str], int] = _get_length_based
|
||||
"""Function to measure prompt length. Defaults to word count."""
|
||||
|
||||
max_length: int = 2048
|
||||
"""Max length for the prompt, beyond which examples are cut."""
|
||||
|
||||
example_text_lengths: List[int] = [] #: :meta private:
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> None:
|
||||
"""Add new example to list."""
|
||||
self.examples.append(example)
|
||||
string_example = self.example_prompt.format(**example)
|
||||
self.example_text_lengths.append(self.get_text_length(string_example))
|
||||
|
||||
@validator("example_text_lengths", always=True)
|
||||
def calculate_example_text_lengths(cls, v: List[int], values: Dict) -> List[int]:
|
||||
"""Calculate text lengths if they don't exist."""
|
||||
# Check if text lengths were passed in
|
||||
if v:
|
||||
return v
|
||||
# If they were not, calculate them
|
||||
example_prompt = values["example_prompt"]
|
||||
get_text_length = values["get_text_length"]
|
||||
string_examples = [example_prompt.format(**eg) for eg in values["examples"]]
|
||||
return [get_text_length(eg) for eg in string_examples]
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Select which examples to use based on the input lengths."""
|
||||
inputs = " ".join(input_variables.values())
|
||||
remaining_length = self.max_length - self.get_text_length(inputs)
|
||||
i = 0
|
||||
examples = []
|
||||
while remaining_length > 0 and i < len(self.examples):
|
||||
new_length = remaining_length - self.example_text_lengths[i]
|
||||
if new_length < 0:
|
||||
break
|
||||
else:
|
||||
examples.append(self.examples[i])
|
||||
remaining_length = new_length
|
||||
i += 1
|
||||
return examples
|
||||
@@ -0,0 +1,167 @@
|
||||
"""Example selector that selects examples based on SemanticSimilarity."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
|
||||
from langchain_core.example_selectors.base import BaseExampleSelector
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
|
||||
def sorted_values(values: Dict[str, str]) -> List[Any]:
|
||||
"""Return a list of values in dict sorted by key."""
|
||||
return [values[val] for val in sorted(values)]
|
||||
|
||||
|
||||
class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
"""Example selector that selects examples based on SemanticSimilarity."""
|
||||
|
||||
vectorstore: VectorStore
|
||||
"""VectorStore than contains information about examples."""
|
||||
k: int = 4
|
||||
"""Number of examples to select."""
|
||||
example_keys: Optional[List[str]] = None
|
||||
"""Optional keys to filter examples to."""
|
||||
input_keys: Optional[List[str]] = None
|
||||
"""Optional keys to filter input to. If provided, the search is based on
|
||||
the input variables instead of all variables."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> str:
|
||||
"""Add new example to vectorstore."""
|
||||
if self.input_keys:
|
||||
string_example = " ".join(
|
||||
sorted_values({key: example[key] for key in self.input_keys})
|
||||
)
|
||||
else:
|
||||
string_example = " ".join(sorted_values(example))
|
||||
ids = self.vectorstore.add_texts([string_example], metadatas=[example])
|
||||
return ids[0]
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Select which examples to use based on semantic similarity."""
|
||||
# Get the docs with the highest similarity.
|
||||
if self.input_keys:
|
||||
input_variables = {key: input_variables[key] for key in self.input_keys}
|
||||
query = " ".join(sorted_values(input_variables))
|
||||
example_docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||
# Get the examples from the metadata.
|
||||
# This assumes that examples are stored in metadata.
|
||||
examples = [dict(e.metadata) for e in example_docs]
|
||||
# If example keys are provided, filter examples to those keys.
|
||||
if self.example_keys:
|
||||
examples = [{k: eg[k] for k in self.example_keys} for eg in examples]
|
||||
return examples
|
||||
|
||||
@classmethod
|
||||
def from_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> SemanticSimilarityExampleSelector:
|
||||
"""Create k-shot example selector using example list and embeddings.
|
||||
|
||||
Reshuffles examples dynamically based on query similarity.
|
||||
|
||||
Args:
|
||||
examples: List of examples to use in the prompt.
|
||||
embeddings: An initialized embedding API interface, e.g. OpenAIEmbeddings().
|
||||
vectorstore_cls: A vector store DB interface class, e.g. FAISS.
|
||||
k: Number of examples to select
|
||||
input_keys: If provided, the search is based on the input variables
|
||||
instead of all variables.
|
||||
vectorstore_cls_kwargs: optional kwargs containing url for vector store
|
||||
|
||||
Returns:
|
||||
The ExampleSelector instantiated, backed by a vector store.
|
||||
"""
|
||||
if input_keys:
|
||||
string_examples = [
|
||||
" ".join(sorted_values({k: eg[k] for k in input_keys}))
|
||||
for eg in examples
|
||||
]
|
||||
else:
|
||||
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
|
||||
vectorstore = vectorstore_cls.from_texts(
|
||||
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
|
||||
)
|
||||
return cls(vectorstore=vectorstore, k=k, input_keys=input_keys)
|
||||
|
||||
|
||||
class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector):
|
||||
"""ExampleSelector that selects examples based on Max Marginal Relevance.
|
||||
|
||||
This was shown to improve performance in this paper:
|
||||
https://arxiv.org/pdf/2211.13892.pdf
|
||||
"""
|
||||
|
||||
fetch_k: int = 20
|
||||
"""Number of examples to fetch to rerank."""
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Select which examples to use based on semantic similarity."""
|
||||
# Get the docs with the highest similarity.
|
||||
if self.input_keys:
|
||||
input_variables = {key: input_variables[key] for key in self.input_keys}
|
||||
query = " ".join(sorted_values(input_variables))
|
||||
example_docs = self.vectorstore.max_marginal_relevance_search(
|
||||
query, k=self.k, fetch_k=self.fetch_k
|
||||
)
|
||||
# Get the examples from the metadata.
|
||||
# This assumes that examples are stored in metadata.
|
||||
examples = [dict(e.metadata) for e in example_docs]
|
||||
# If example keys are provided, filter examples to those keys.
|
||||
if self.example_keys:
|
||||
examples = [{k: eg[k] for k in self.example_keys} for eg in examples]
|
||||
return examples
|
||||
|
||||
@classmethod
|
||||
def from_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
fetch_k: int = 20,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> MaxMarginalRelevanceExampleSelector:
|
||||
"""Create k-shot example selector using example list and embeddings.
|
||||
|
||||
Reshuffles examples dynamically based on query similarity.
|
||||
|
||||
Args:
|
||||
examples: List of examples to use in the prompt.
|
||||
embeddings: An iniialized embedding API interface, e.g. OpenAIEmbeddings().
|
||||
vectorstore_cls: A vector store DB interface class, e.g. FAISS.
|
||||
k: Number of examples to select
|
||||
input_keys: If provided, the search is based on the input variables
|
||||
instead of all variables.
|
||||
vectorstore_cls_kwargs: optional kwargs containing url for vector store
|
||||
|
||||
Returns:
|
||||
The ExampleSelector instantiated, backed by a vector store.
|
||||
"""
|
||||
if input_keys:
|
||||
string_examples = [
|
||||
" ".join(sorted_values({k: eg[k] for k in input_keys}))
|
||||
for eg in examples
|
||||
]
|
||||
else:
|
||||
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
|
||||
vectorstore = vectorstore_cls.from_texts(
|
||||
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
|
||||
)
|
||||
return cls(vectorstore=vectorstore, k=k, fetch_k=fetch_k, input_keys=input_keys)
|
||||
48
libs/core/langchain_core/exceptions.py
Normal file
48
libs/core/langchain_core/exceptions.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class LangChainException(Exception):
|
||||
"""General LangChain exception."""
|
||||
|
||||
|
||||
class TracerException(LangChainException):
|
||||
"""Base class for exceptions in tracers module."""
|
||||
|
||||
|
||||
class OutputParserException(ValueError, LangChainException):
|
||||
"""Exception that output parsers should raise to signify a parsing error.
|
||||
|
||||
This exists to differentiate parsing errors from other code or execution errors
|
||||
that also may arise inside the output parser. OutputParserExceptions will be
|
||||
available to catch and handle in ways to fix the parsing error, while other
|
||||
errors will be raised.
|
||||
|
||||
Args:
|
||||
error: The error that's being re-raised or an error message.
|
||||
observation: String explanation of error which can be passed to a
|
||||
model to try and remediate the issue.
|
||||
llm_output: String model output which is error-ing.
|
||||
send_to_llm: Whether to send the observation and llm_output back to an Agent
|
||||
after an OutputParserException has been raised. This gives the underlying
|
||||
model driving the agent the context that the previous output was improperly
|
||||
structured, in the hopes that it will update the output to the correct
|
||||
format.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error: Any,
|
||||
observation: Optional[str] = None,
|
||||
llm_output: Optional[str] = None,
|
||||
send_to_llm: bool = False,
|
||||
):
|
||||
super(OutputParserException, self).__init__(error)
|
||||
if send_to_llm:
|
||||
if observation is None or llm_output is None:
|
||||
raise ValueError(
|
||||
"Arguments 'observation' & 'llm_output'"
|
||||
" are required if 'send_to_llm' is True"
|
||||
)
|
||||
self.observation = observation
|
||||
self.llm_output = llm_output
|
||||
self.send_to_llm = send_to_llm
|
||||
197
libs/core/langchain_core/globals/__init__.py
Normal file
197
libs/core/langchain_core/globals/__init__.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# flake8: noqa
|
||||
"""Global values and configuration that apply to all of LangChain."""
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.caches import BaseCache
|
||||
|
||||
|
||||
# DO NOT USE THESE VALUES DIRECTLY!
|
||||
# Use them only via `get_<X>()` and `set_<X>()` below,
|
||||
# or else your code may behave unexpectedly with other uses of these global settings:
|
||||
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
|
||||
_verbose: bool = False
|
||||
_debug: bool = False
|
||||
_llm_cache: Optional["BaseCache"] = None
|
||||
|
||||
|
||||
def set_verbose(value: bool) -> None:
|
||||
"""Set a new value for the `verbose` global setting."""
|
||||
try:
|
||||
import langchain # type: ignore[import]
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=(
|
||||
"Importing verbose from langchain_core root module is no longer supported"
|
||||
),
|
||||
)
|
||||
# N.B.: This is a workaround for an unfortunate quirk of Python's
|
||||
# module-level `__getattr__()` implementation:
|
||||
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
|
||||
#
|
||||
# Remove it once `langchain.verbose` is no longer supported, and once all users
|
||||
# have migrated to using `set_verbose()` here.
|
||||
langchain.verbose = value
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
global _verbose
|
||||
_verbose = value
|
||||
|
||||
|
||||
def get_verbose() -> bool:
|
||||
"""Get the value of the `verbose` global setting."""
|
||||
try:
|
||||
import langchain # type: ignore[import]
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=(
|
||||
"Importing verbose from langchain_core root module is no longer supported"
|
||||
),
|
||||
)
|
||||
# N.B.: This is a workaround for an unfortunate quirk of Python's
|
||||
# module-level `__getattr__()` implementation:
|
||||
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
|
||||
#
|
||||
# Remove it once `langchain.verbose` is no longer supported, and once all users
|
||||
# have migrated to using `set_verbose()` here.
|
||||
#
|
||||
# In the meantime, the `verbose` setting is considered True if either the old
|
||||
# or the new value are True. This accommodates users who haven't migrated
|
||||
# to using `set_verbose()` yet. Those users are getting deprecation warnings
|
||||
# directing them to use `set_verbose()` when they import `langhchain.verbose`.
|
||||
old_verbose = langchain.verbose
|
||||
except ImportError:
|
||||
old_verbose = False
|
||||
|
||||
global _verbose
|
||||
return _verbose or old_verbose
|
||||
|
||||
|
||||
def set_debug(value: bool) -> None:
|
||||
"""Set a new value for the `debug` global setting."""
|
||||
try:
|
||||
import langchain # type: ignore[import]
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="Importing debug from langchain_core root module is no longer supported",
|
||||
)
|
||||
# N.B.: This is a workaround for an unfortunate quirk of Python's
|
||||
# module-level `__getattr__()` implementation:
|
||||
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
|
||||
#
|
||||
# Remove it once `langchain.debug` is no longer supported, and once all users
|
||||
# have migrated to using `set_debug()` here.
|
||||
langchain.debug = value
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
global _debug
|
||||
_debug = value
|
||||
|
||||
|
||||
def get_debug() -> bool:
|
||||
"""Get the value of the `debug` global setting."""
|
||||
try:
|
||||
import langchain # type: ignore[import]
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="Importing debug from langchain_core root module is no longer supported",
|
||||
)
|
||||
# N.B.: This is a workaround for an unfortunate quirk of Python's
|
||||
# module-level `__getattr__()` implementation:
|
||||
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
|
||||
#
|
||||
# Remove it once `langchain.debug` is no longer supported, and once all users
|
||||
# have migrated to using `set_debug()` here.
|
||||
#
|
||||
# In the meantime, the `debug` setting is considered True if either the old
|
||||
# or the new value are True. This accommodates users who haven't migrated
|
||||
# to using `set_debug()` yet. Those users are getting deprecation warnings
|
||||
# directing them to use `set_debug()` when they import `langhchain.debug`.
|
||||
old_debug = langchain.debug
|
||||
except ImportError:
|
||||
old_debug = False
|
||||
|
||||
global _debug
|
||||
return _debug or old_debug
|
||||
|
||||
|
||||
def set_llm_cache(value: Optional["BaseCache"]) -> None:
|
||||
"""Set a new LLM cache, overwriting the previous value, if any."""
|
||||
try:
|
||||
import langchain # type: ignore[import]
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=(
|
||||
"Importing llm_cache from langchain_core root module is no longer supported"
|
||||
),
|
||||
)
|
||||
# N.B.: This is a workaround for an unfortunate quirk of Python's
|
||||
# module-level `__getattr__()` implementation:
|
||||
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
|
||||
#
|
||||
# Remove it once `langchain.llm_cache` is no longer supported, and
|
||||
# once all users have migrated to using `set_llm_cache()` here.
|
||||
langchain.llm_cache = value
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
global _llm_cache
|
||||
_llm_cache = value
|
||||
|
||||
|
||||
def get_llm_cache() -> "BaseCache":
|
||||
"""Get the value of the `llm_cache` global setting."""
|
||||
try:
|
||||
import langchain # type: ignore[import]
|
||||
|
||||
# We're about to run some deprecated code, don't report warnings from it.
|
||||
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=(
|
||||
"Importing llm_cache from langchain_core root module is no longer supported"
|
||||
),
|
||||
)
|
||||
# N.B.: This is a workaround for an unfortunate quirk of Python's
|
||||
# module-level `__getattr__()` implementation:
|
||||
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
|
||||
#
|
||||
# Remove it once `langchain.llm_cache` is no longer supported, and
|
||||
# once all users have migrated to using `set_llm_cache()` here.
|
||||
#
|
||||
# In the meantime, the `llm_cache` setting returns whichever of
|
||||
# its two backing sources is truthy (not `None` and non-empty),
|
||||
# or the old value if both are falsy. This accommodates users
|
||||
# who haven't migrated to using `set_llm_cache()` yet.
|
||||
# Those users are getting deprecation warnings directing them
|
||||
# to use `set_llm_cache()` when they import `langhchain.llm_cache`.
|
||||
old_llm_cache = langchain.llm_cache
|
||||
except ImportError:
|
||||
old_llm_cache = None
|
||||
|
||||
global _llm_cache
|
||||
return _llm_cache or old_llm_cache
|
||||
17
libs/core/langchain_core/language_models/__init__.py
Normal file
17
libs/core/langchain_core/language_models/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from langchain_core.language_models.base import (
|
||||
BaseLanguageModel,
|
||||
LanguageModelInput,
|
||||
get_tokenizer,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
|
||||
from langchain_core.language_models.llms import LLM, BaseLLM
|
||||
|
||||
__all__ = [
|
||||
"BaseLanguageModel",
|
||||
"BaseChatModel",
|
||||
"SimpleChatModel",
|
||||
"BaseLLM",
|
||||
"LLM",
|
||||
"LanguageModelInput",
|
||||
"get_tokenizer",
|
||||
]
|
||||
293
libs/core/langchain_core/language_models/base.py
Normal file
293
libs/core/langchain_core/language_models/base.py
Normal file
@@ -0,0 +1,293 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.runnables import RunnableSerializable
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
@lru_cache(maxsize=None) # Cache the tokenizer
|
||||
def get_tokenizer() -> Any:
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast # type: ignore[import]
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to calculate get_token_ids. "
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
# create a GPT-2 tokenizer instance
|
||||
return GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
|
||||
def _get_token_ids_default_method(text: str) -> List[int]:
|
||||
"""Encode the text into token IDs."""
|
||||
# get the cached tokenizer
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# tokenize the text using the GPT-2 tokenizer
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
LanguageModelInput = Union[PromptValue, str, List[BaseMessage]]
|
||||
LanguageModelOutput = TypeVar("LanguageModelOutput")
|
||||
|
||||
|
||||
class BaseLanguageModel(
|
||||
RunnableSerializable[LanguageModelInput, LanguageModelOutput], ABC
|
||||
):
|
||||
"""Abstract base class for interfacing with language models.
|
||||
|
||||
All language model wrappers inherit from BaseLanguageModel.
|
||||
|
||||
Exposes three main methods:
|
||||
- generate_prompt: generate language model outputs for a sequence of prompt
|
||||
values. A prompt value is a model input that can be converted to any language
|
||||
model input format (string or messages).
|
||||
- predict: pass in a single string to a language model and return a string
|
||||
prediction.
|
||||
- predict_messages: pass in a sequence of BaseMessages (corresponding to a single
|
||||
model call) to a language model and return a BaseMessage prediction.
|
||||
|
||||
Each of these has an equivalent asynchronous method.
|
||||
"""
|
||||
|
||||
@property
|
||||
def InputType(self) -> TypeAlias:
|
||||
"""Get the input type for this runnable."""
|
||||
from langchain_core.prompt_values import (
|
||||
ChatPromptValueConcrete,
|
||||
StringPromptValue,
|
||||
)
|
||||
|
||||
# This is a version of LanguageModelInput which replaces the abstract
|
||||
# base class BaseMessage with a union of its subclasses, which makes
|
||||
# for a much better schema.
|
||||
return Union[
|
||||
str,
|
||||
Union[StringPromptValue, ChatPromptValueConcrete],
|
||||
List[AnyMessage],
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def generate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Pass a sequence of prompts to the model and return model generations.
|
||||
|
||||
This method should make use of batched calls for models that expose a batched
|
||||
API.
|
||||
|
||||
Use this method when you want to:
|
||||
1. take advantage of batched calls,
|
||||
2. need more output from the model than just the top generated value,
|
||||
3. are building chains that are agnostic to the underlying language model
|
||||
type (e.g., pure text completion models vs chat models).
|
||||
|
||||
Args:
|
||||
prompts: List of PromptValues. A PromptValue is an object that can be
|
||||
converted to match the format of any language model (string for pure
|
||||
text generation models and BaseMessages for chat models).
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
callbacks: Callbacks to pass through. Used for executing additional
|
||||
functionality, such as logging or streaming, throughout generation.
|
||||
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
An LLMResult, which contains a list of candidate Generations for each input
|
||||
prompt and additional model provider-specific output.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Asynchronously pass a sequence of prompts and return model generations.
|
||||
|
||||
This method should make use of batched calls for models that expose a batched
|
||||
API.
|
||||
|
||||
Use this method when you want to:
|
||||
1. take advantage of batched calls,
|
||||
2. need more output from the model than just the top generated value,
|
||||
3. are building chains that are agnostic to the underlying language model
|
||||
type (e.g., pure text completion models vs chat models).
|
||||
|
||||
Args:
|
||||
prompts: List of PromptValues. A PromptValue is an object that can be
|
||||
converted to match the format of any language model (string for pure
|
||||
text generation models and BaseMessages for chat models).
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
callbacks: Callbacks to pass through. Used for executing additional
|
||||
functionality, such as logging or streaming, throughout generation.
|
||||
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
An LLMResult, which contains a list of candidate Generations for each input
|
||||
prompt and additional model provider-specific output.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
"""Pass a single string input to the model and return a string prediction.
|
||||
|
||||
Use this method when passing in raw text. If you want to pass in specific
|
||||
types of chat messages, use predict_messages.
|
||||
|
||||
Args:
|
||||
text: String input to pass to the model.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
Top model prediction as a string.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def predict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""Pass a message sequence to the model and return a message prediction.
|
||||
|
||||
Use this method when passing in chat messages. If you want to pass in raw text,
|
||||
use predict.
|
||||
|
||||
Args:
|
||||
messages: A sequence of chat messages corresponding to a single model input.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
Top model prediction as a message.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
"""Asynchronously pass a string to the model and return a string prediction.
|
||||
|
||||
Use this method when calling pure text generation models and only the top
|
||||
candidate generation is needed.
|
||||
|
||||
Args:
|
||||
text: String input to pass to the model.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
Top model prediction as a string.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def apredict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""Asynchronously pass messages to the model and return a message prediction.
|
||||
|
||||
Use this method when calling chat models and only the top
|
||||
candidate generation is needed.
|
||||
|
||||
Args:
|
||||
messages: A sequence of chat messages corresponding to a single model input.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
Top model prediction as a message.
|
||||
"""
|
||||
|
||||
def get_token_ids(self, text: str) -> List[int]:
|
||||
"""Return the ordered ids of the tokens in a text.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
Returns:
|
||||
A list of ids corresponding to the tokens in the text, in order they occur
|
||||
in the text.
|
||||
"""
|
||||
return _get_token_ids_default_method(text)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
Returns:
|
||||
The integer number of tokens in the text.
|
||||
"""
|
||||
return len(self.get_token_ids(text))
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
|
||||
|
||||
@classmethod
|
||||
def _all_required_field_names(cls) -> Set:
|
||||
"""DEPRECATED: Kept for backwards compatibility.
|
||||
|
||||
Use get_pydantic_field_names.
|
||||
"""
|
||||
return get_pydantic_field_names(cls)
|
||||
738
libs/core/langchain_core/language_models/chat_models.py
Normal file
738
libs/core/langchain_core/language_models/chat_models.py
Normal file
@@ -0,0 +1,738 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.globals import get_llm_cache
|
||||
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
|
||||
from langchain_core.load import dumpd, dumps
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
RunInfo,
|
||||
)
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
from langchain_core.globals import get_verbose
|
||||
|
||||
return get_verbose()
|
||||
|
||||
|
||||
def _generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
for chunk in stream:
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
|
||||
async def _agenerate_from_stream(
|
||||
stream: AsyncIterator[ChatGenerationChunk],
|
||||
) -> ChatResult:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
async for chunk in stream:
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
"""Base class for Chat models."""
|
||||
|
||||
cache: Optional[bool] = None
|
||||
"""Whether to cache the response."""
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
"""Callbacks to add to the run trace."""
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""Callback manager to add to the run trace."""
|
||||
tags: Optional[List[str]] = Field(default=None, exclude=True)
|
||||
"""Tags to add to the run trace."""
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
|
||||
"""Metadata to add to the run trace."""
|
||||
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
# --- Runnable methods ---
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
"""Get the output type for this runnable."""
|
||||
return AnyMessage
|
||||
|
||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||
if isinstance(input, PromptValue):
|
||||
return input
|
||||
elif isinstance(input, str):
|
||||
return StringPromptValue(text=input)
|
||||
elif isinstance(input, list):
|
||||
return ChatPromptValue(messages=input)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(input)}. "
|
||||
"Must be a PromptValue, str, or list of BaseMessages."
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
config = config or {}
|
||||
return cast(
|
||||
ChatGeneration,
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
).generations[0][0],
|
||||
).message
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
config = config or {}
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
return cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[BaseMessageChunk]:
|
||||
if type(self)._stream == BaseChatModel._stream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield cast(
|
||||
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
)
|
||||
else:
|
||||
config = config or {}
|
||||
messages = self._convert_input(input).to_messages()
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
config.get("metadata"),
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = callback_manager.on_chat_model_start(
|
||||
dumpd(self),
|
||||
[messages],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
for chunk in self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
except BaseException as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_llm_end(
|
||||
LLMResult(generations=[[generation]]),
|
||||
)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[BaseMessageChunk]:
|
||||
if type(self)._astream == BaseChatModel._astream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield cast(
|
||||
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
)
|
||||
else:
|
||||
config = config or {}
|
||||
messages = self._convert_input(input).to_messages()
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
config.get("metadata"),
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||
dumpd(self),
|
||||
[messages],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
async for chunk in self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
except BaseException as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_llm_end(
|
||||
LLMResult(generations=[[generation]]),
|
||||
)
|
||||
|
||||
# --- Custom methods ---
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
return {}
|
||||
|
||||
def _get_invocation_params(
|
||||
self,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
return {**params, **kwargs}
|
||||
|
||||
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
|
||||
if self.is_lc_serializable():
|
||||
params = {**kwargs, **{"stop": stop}}
|
||||
param_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
llm_string = dumps(self)
|
||||
return llm_string + "---" + param_string
|
||||
else:
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
params = {**params, **kwargs}
|
||||
return str(sorted([(k, v) for k, v in params.items()]))
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop}
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
run_managers = callback_manager.on_chat_model_start(
|
||||
dumpd(self),
|
||||
messages,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name,
|
||||
)
|
||||
results = []
|
||||
for i, m in enumerate(messages):
|
||||
try:
|
||||
results.append(
|
||||
self._generate_with_cache(
|
||||
m,
|
||||
stop=stop,
|
||||
run_manager=run_managers[i] if run_managers else None,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
except BaseException as e:
|
||||
if run_managers:
|
||||
run_managers[i].on_llm_error(e)
|
||||
raise e
|
||||
flattened_outputs = [
|
||||
LLMResult(generations=[res.generations], llm_output=res.llm_output)
|
||||
for res in results
|
||||
]
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||
generations = [res.generations for res in results]
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
if run_managers:
|
||||
run_infos = []
|
||||
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
||||
manager.on_llm_end(flattened_output)
|
||||
run_infos.append(RunInfo(run_id=manager.run_id))
|
||||
output.run = run_infos
|
||||
return output
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop}
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
|
||||
run_managers = await callback_manager.on_chat_model_start(
|
||||
dumpd(self),
|
||||
messages,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name,
|
||||
)
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
self._agenerate_with_cache(
|
||||
m,
|
||||
stop=stop,
|
||||
run_manager=run_managers[i] if run_managers else None,
|
||||
**kwargs,
|
||||
)
|
||||
for i, m in enumerate(messages)
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
exceptions = []
|
||||
for i, res in enumerate(results):
|
||||
if isinstance(res, BaseException):
|
||||
if run_managers:
|
||||
await run_managers[i].on_llm_error(res)
|
||||
exceptions.append(res)
|
||||
if exceptions:
|
||||
if run_managers:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(
|
||||
LLMResult(
|
||||
generations=[res.generations], llm_output=res.llm_output
|
||||
)
|
||||
)
|
||||
for run_manager, res in zip(run_managers, results)
|
||||
if not isinstance(res, Exception)
|
||||
]
|
||||
)
|
||||
raise exceptions[0]
|
||||
flattened_outputs = [
|
||||
LLMResult(generations=[res.generations], llm_output=res.llm_output)
|
||||
for res in results
|
||||
]
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||
generations = [res.generations for res in results]
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(flattened_output)
|
||||
for run_manager, flattened_output in zip(
|
||||
run_managers, flattened_outputs
|
||||
)
|
||||
]
|
||||
)
|
||||
if run_managers:
|
||||
output.run = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
return output
|
||||
|
||||
def generate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
|
||||
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
return await self.agenerate(
|
||||
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
def _generate_with_cache(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
llm_cache = get_llm_cache()
|
||||
if llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
if new_arg_supported:
|
||||
return self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
return self._generate(messages, stop=stop, **kwargs)
|
||||
else:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
else:
|
||||
if new_arg_supported:
|
||||
result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
result = self._generate(messages, stop=stop, **kwargs)
|
||||
llm_cache.update(prompt, llm_string, result.generations)
|
||||
return result
|
||||
|
||||
async def _agenerate_with_cache(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
llm_cache = get_llm_cache()
|
||||
if llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
if new_arg_supported:
|
||||
return await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
return await self._agenerate(messages, stop=stop, **kwargs)
|
||||
else:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
else:
|
||||
if new_arg_supported:
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
result = await self._agenerate(messages, stop=stop, **kwargs)
|
||||
llm_cache.update(prompt, llm_string, result.generations)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._generate, **kwargs), messages, stop, run_manager
|
||||
)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
generation = self.generate(
|
||||
[messages], stop=stop, callbacks=callbacks, **kwargs
|
||||
).generations[0][0]
|
||||
if isinstance(generation, ChatGeneration):
|
||||
return generation.message
|
||||
else:
|
||||
raise ValueError("Unexpected generation type")
|
||||
|
||||
async def _call_async(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
result = await self.agenerate(
|
||||
[messages], stop=stop, callbacks=callbacks, **kwargs
|
||||
)
|
||||
generation = result.generations[0][0]
|
||||
if isinstance(generation, ChatGeneration):
|
||||
return generation.message
|
||||
else:
|
||||
raise ValueError("Unexpected generation type")
|
||||
|
||||
def call_as_llm(
|
||||
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
return self.predict(message, stop=stop, **kwargs)
|
||||
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
|
||||
if isinstance(result.content, str):
|
||||
return result.content
|
||||
else:
|
||||
raise ValueError("Cannot use predict when output is not a string.")
|
||||
|
||||
def predict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
return self(messages, stop=_stop, **kwargs)
|
||||
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
result = await self._call_async(
|
||||
[HumanMessage(content=text)], stop=_stop, **kwargs
|
||||
)
|
||||
if isinstance(result.content, str):
|
||||
return result.content
|
||||
else:
|
||||
raise ValueError("Cannot use predict when output is not a string.")
|
||||
|
||||
async def apredict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
return await self._call_async(messages, stop=_stop, **kwargs)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return a dictionary of the LLM."""
|
||||
starter_dict = dict(self._identifying_params)
|
||||
starter_dict["_type"] = self._llm_type
|
||||
return starter_dict
|
||||
|
||||
|
||||
class SimpleChatModel(BaseChatModel):
|
||||
"""Simple Chat Model."""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
@abstractmethod
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Simpler interface."""
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
func = partial(
|
||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
1074
libs/core/langchain_core/language_models/llms.py
Normal file
1074
libs/core/langchain_core/language_models/llms.py
Normal file
File diff suppressed because it is too large
Load Diff
6
libs/core/langchain_core/load/__init__.py
Normal file
6
libs/core/langchain_core/load/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Serialization and deserialization."""
|
||||
from langchain_core.load.dump import dumpd, dumps
|
||||
from langchain_core.load.load import load, loads
|
||||
from langchain_core.load.serializable import Serializable
|
||||
|
||||
__all__ = ["dumpd", "dumps", "load", "loads", "Serializable"]
|
||||
26
libs/core/langchain_core/load/dump.py
Normal file
26
libs/core/langchain_core/load/dump.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
|
||||
from langchain_core.load.serializable import Serializable, to_json_not_implemented
|
||||
|
||||
|
||||
def default(obj: Any) -> Any:
|
||||
"""Return a default value for a Serializable object or
|
||||
a SerializedNotImplemented object."""
|
||||
if isinstance(obj, Serializable):
|
||||
return obj.to_json()
|
||||
else:
|
||||
return to_json_not_implemented(obj)
|
||||
|
||||
|
||||
def dumps(obj: Any, *, pretty: bool = False) -> str:
|
||||
"""Return a json string representation of an object."""
|
||||
if pretty:
|
||||
return json.dumps(obj, default=default, indent=2)
|
||||
else:
|
||||
return json.dumps(obj, default=default)
|
||||
|
||||
|
||||
def dumpd(obj: Any) -> Dict[str, Any]:
|
||||
"""Return a json dict representation of an object."""
|
||||
return json.loads(dumps(obj))
|
||||
130
libs/core/langchain_core/load/load.py
Normal file
130
libs/core/langchain_core/load/load.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
|
||||
DEFAULT_NAMESPACES = ["langchain", "langchain_core"]
|
||||
|
||||
|
||||
class Reviver:
|
||||
"""Reviver for JSON objects."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
secrets_map: Optional[Dict[str, str]] = None,
|
||||
valid_namespaces: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
self.secrets_map = secrets_map or dict()
|
||||
# By default only support langchain, but user can pass in additional namespaces
|
||||
self.valid_namespaces = (
|
||||
[*DEFAULT_NAMESPACES, *valid_namespaces]
|
||||
if valid_namespaces
|
||||
else DEFAULT_NAMESPACES
|
||||
)
|
||||
|
||||
def __call__(self, value: Dict[str, Any]) -> Any:
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "secret"
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
[key] = value["id"]
|
||||
if key in self.secrets_map:
|
||||
return self.secrets_map[key]
|
||||
else:
|
||||
if key in os.environ and os.environ[key]:
|
||||
return os.environ[key]
|
||||
raise KeyError(f'Missing key "{key}" in load(secrets_map)')
|
||||
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "not_implemented"
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Trying to load an object that doesn't implement "
|
||||
f"serialization: {value}"
|
||||
)
|
||||
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "constructor"
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
[*namespace, name] = value["id"]
|
||||
|
||||
if namespace[0] not in self.valid_namespaces:
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
# The root namespace "langchain" is not a valid identifier.
|
||||
if len(namespace) == 1 and namespace[0] == "langchain":
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
mod = importlib.import_module(".".join(namespace))
|
||||
cls = getattr(mod, name)
|
||||
|
||||
# The class must be a subclass of Serializable.
|
||||
if not issubclass(cls, Serializable):
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
# We don't need to recurse on kwargs
|
||||
# as json.loads will do that for us.
|
||||
kwargs = value.get("kwargs", dict())
|
||||
return cls(**kwargs)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def loads(
|
||||
text: str,
|
||||
*,
|
||||
secrets_map: Optional[Dict[str, str]] = None,
|
||||
valid_namespaces: Optional[List[str]] = None,
|
||||
) -> Any:
|
||||
"""Revive a LangChain class from a JSON string.
|
||||
Equivalent to `load(json.loads(text))`.
|
||||
|
||||
Args:
|
||||
text: The string to load.
|
||||
secrets_map: A map of secrets to load.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
|
||||
Returns:
|
||||
Revived LangChain objects.
|
||||
"""
|
||||
return json.loads(text, object_hook=Reviver(secrets_map, valid_namespaces))
|
||||
|
||||
|
||||
def load(
|
||||
obj: Any,
|
||||
*,
|
||||
secrets_map: Optional[Dict[str, str]] = None,
|
||||
valid_namespaces: Optional[List[str]] = None,
|
||||
) -> Any:
|
||||
"""Revive a LangChain class from a JSON object. Use this if you already
|
||||
have a parsed JSON object, eg. from `json.load` or `orjson.loads`.
|
||||
|
||||
Args:
|
||||
obj: The object to load.
|
||||
secrets_map: A map of secrets to load.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
|
||||
Returns:
|
||||
Revived LangChain objects.
|
||||
"""
|
||||
reviver = Reviver(secrets_map, valid_namespaces)
|
||||
|
||||
def _load(obj: Any) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
# Need to revive leaf nodes before reviving this node
|
||||
loaded_obj = {k: _load(v) for k, v in obj.items()}
|
||||
return reviver(loaded_obj)
|
||||
if isinstance(obj, list):
|
||||
return [_load(o) for o in obj]
|
||||
return obj
|
||||
|
||||
return _load(obj)
|
||||
207
libs/core/langchain_core/load/serializable.py
Normal file
207
libs/core/langchain_core/load/serializable.py
Normal file
@@ -0,0 +1,207 @@
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
|
||||
|
||||
|
||||
class BaseSerialized(TypedDict):
|
||||
"""Base class for serialized objects."""
|
||||
|
||||
lc: int
|
||||
id: List[str]
|
||||
|
||||
|
||||
class SerializedConstructor(BaseSerialized):
|
||||
"""Serialized constructor."""
|
||||
|
||||
type: Literal["constructor"]
|
||||
kwargs: Dict[str, Any]
|
||||
|
||||
|
||||
class SerializedSecret(BaseSerialized):
|
||||
"""Serialized secret."""
|
||||
|
||||
type: Literal["secret"]
|
||||
|
||||
|
||||
class SerializedNotImplemented(BaseSerialized):
|
||||
"""Serialized not implemented."""
|
||||
|
||||
type: Literal["not_implemented"]
|
||||
repr: Optional[str]
|
||||
|
||||
|
||||
def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
|
||||
try:
|
||||
return model.__fields__[key].get_default() != value
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
class Serializable(BaseModel, ABC):
|
||||
"""Serializable base class."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Is this class serializable?"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
|
||||
For example, if the class is `langchain.llms.openai.OpenAI`, then the
|
||||
namespace is ["langchain", "llms", "openai"]
|
||||
"""
|
||||
return cls.__module__.split(".")
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
"""A map of constructor argument names to secret ids.
|
||||
|
||||
For example,
|
||||
{"openai_api_key": "OPENAI_API_KEY"}
|
||||
"""
|
||||
return dict()
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict:
|
||||
"""List of attribute names that should be included in the serialized kwargs.
|
||||
|
||||
These attributes must be accepted by the constructor.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def lc_id(cls) -> List[str]:
|
||||
"""A unique identifier for this class for serialization purposes.
|
||||
|
||||
The unique identifier is a list of strings that describes the path
|
||||
to the object.
|
||||
"""
|
||||
return [*cls.get_lc_namespace(), cls.__name__]
|
||||
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
|
||||
def __repr_args__(self) -> Any:
|
||||
return [
|
||||
(k, v)
|
||||
for k, v in super().__repr_args__()
|
||||
if (k not in self.__fields__ or try_neq_default(v, k, self))
|
||||
]
|
||||
|
||||
_lc_kwargs = PrivateAttr(default_factory=dict)
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._lc_kwargs = kwargs
|
||||
|
||||
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
|
||||
if not self.is_lc_serializable():
|
||||
return self.to_json_not_implemented()
|
||||
|
||||
secrets = dict()
|
||||
# Get latest values for kwargs if there is an attribute with same name
|
||||
lc_kwargs = {
|
||||
k: getattr(self, k, v)
|
||||
for k, v in self._lc_kwargs.items()
|
||||
if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
|
||||
}
|
||||
|
||||
# Merge the lc_secrets and lc_attributes from every class in the MRO
|
||||
for cls in [None, *self.__class__.mro()]:
|
||||
# Once we get to Serializable, we're done
|
||||
if cls is Serializable:
|
||||
break
|
||||
|
||||
if cls:
|
||||
deprecated_attributes = [
|
||||
"lc_namespace",
|
||||
"lc_serializable",
|
||||
]
|
||||
|
||||
for attr in deprecated_attributes:
|
||||
if hasattr(cls, attr):
|
||||
raise ValueError(
|
||||
f"Class {self.__class__} has a deprecated "
|
||||
f"attribute {attr}. Please use the corresponding "
|
||||
f"classmethod instead."
|
||||
)
|
||||
|
||||
# Get a reference to self bound to each class in the MRO
|
||||
this = cast(Serializable, self if cls is None else super(cls, self))
|
||||
|
||||
secrets.update(this.lc_secrets)
|
||||
lc_kwargs.update(this.lc_attributes)
|
||||
|
||||
# include all secrets, even if not specified in kwargs
|
||||
# as these secrets may be passed as an environment variable instead
|
||||
for key in secrets.keys():
|
||||
secret_value = getattr(self, key, None) or lc_kwargs.get(key)
|
||||
if secret_value is not None:
|
||||
lc_kwargs.update({key: secret_value})
|
||||
|
||||
return {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": self.lc_id(),
|
||||
"kwargs": lc_kwargs
|
||||
if not secrets
|
||||
else _replace_secrets(lc_kwargs, secrets),
|
||||
}
|
||||
|
||||
def to_json_not_implemented(self) -> SerializedNotImplemented:
|
||||
return to_json_not_implemented(self)
|
||||
|
||||
|
||||
def _replace_secrets(
|
||||
root: Dict[Any, Any], secrets_map: Dict[str, str]
|
||||
) -> Dict[Any, Any]:
|
||||
result = root.copy()
|
||||
for path, secret_id in secrets_map.items():
|
||||
[*parts, last] = path.split(".")
|
||||
current = result
|
||||
for part in parts:
|
||||
if part not in current:
|
||||
break
|
||||
current[part] = current[part].copy()
|
||||
current = current[part]
|
||||
if last in current:
|
||||
current[last] = {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [secret_id],
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
|
||||
"""Serialize a "not implemented" object.
|
||||
|
||||
Args:
|
||||
obj: object to serialize
|
||||
|
||||
Returns:
|
||||
SerializedNotImplemented
|
||||
"""
|
||||
_id: List[str] = []
|
||||
try:
|
||||
if hasattr(obj, "__name__"):
|
||||
_id = [*obj.__module__.split("."), obj.__name__]
|
||||
elif hasattr(obj, "__class__"):
|
||||
_id = [*obj.__class__.__module__.split("."), obj.__class__.__name__]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
result: SerializedNotImplemented = {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": _id,
|
||||
"repr": None,
|
||||
}
|
||||
try:
|
||||
result["repr"] = repr(obj)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
59
libs/core/langchain_core/memory.py
Normal file
59
libs/core/langchain_core/memory.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
|
||||
|
||||
class BaseMemory(Serializable, ABC):
|
||||
"""Abstract base class for memory in Chains.
|
||||
|
||||
Memory refers to state in Chains. Memory can be used to store information about
|
||||
past executions of a Chain and inject that information into the inputs of
|
||||
future executions of the Chain. For example, for conversational Chains Memory
|
||||
can be used to store conversations and automatically add them to future model
|
||||
prompts so that the model has the necessary context to respond coherently to
|
||||
the latest input.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class SimpleMemory(BaseMemory):
|
||||
memories: Dict[str, Any] = dict()
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
return list(self.memories.keys())
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return self.memories
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
pass
|
||||
""" # noqa: E501
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""The string keys this memory class will add to chain inputs."""
|
||||
|
||||
@abstractmethod
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
|
||||
@abstractmethod
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save the context of this chain run to memory."""
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
122
libs/core/langchain_core/messages/__init__.py
Normal file
122
libs/core/langchain_core/messages/__init__.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from typing import List, Sequence, Union
|
||||
|
||||
from langchain_core.messages.ai import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
merge_content,
|
||||
message_to_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
|
||||
from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk
|
||||
from langchain_core.messages.human import HumanMessage, HumanMessageChunk
|
||||
from langchain_core.messages.system import SystemMessage, SystemMessageChunk
|
||||
from langchain_core.messages.tool import ToolMessage, ToolMessageChunk
|
||||
|
||||
AnyMessage = Union[
|
||||
AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage, ToolMessage
|
||||
]
|
||||
|
||||
|
||||
def get_buffer_string(
|
||||
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
||||
) -> str:
|
||||
"""Convert sequence of Messages to strings and concatenate them into one string.
|
||||
|
||||
Args:
|
||||
messages: Messages to be converted to strings.
|
||||
human_prefix: The prefix to prepend to contents of HumanMessages.
|
||||
ai_prefix: THe prefix to prepend to contents of AIMessages.
|
||||
|
||||
Returns:
|
||||
A single string concatenation of all input messages.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core import AIMessage, HumanMessage
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="Hi, how are you?"),
|
||||
AIMessage(content="Good, how are you?"),
|
||||
]
|
||||
get_buffer_string(messages)
|
||||
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
|
||||
"""
|
||||
string_messages = []
|
||||
for m in messages:
|
||||
if isinstance(m, HumanMessage):
|
||||
role = human_prefix
|
||||
elif isinstance(m, AIMessage):
|
||||
role = ai_prefix
|
||||
elif isinstance(m, SystemMessage):
|
||||
role = "System"
|
||||
elif isinstance(m, FunctionMessage):
|
||||
role = "Function"
|
||||
elif isinstance(m, ToolMessage):
|
||||
role = "Tool"
|
||||
elif isinstance(m, ChatMessage):
|
||||
role = m.role
|
||||
else:
|
||||
raise ValueError(f"Got unsupported message type: {m}")
|
||||
message = f"{role}: {m.content}"
|
||||
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
|
||||
message += f"{m.additional_kwargs['function_call']}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
|
||||
|
||||
def _message_from_dict(message: dict) -> BaseMessage:
|
||||
_type = message["type"]
|
||||
if _type == "human":
|
||||
return HumanMessage(**message["data"])
|
||||
elif _type == "ai":
|
||||
return AIMessage(**message["data"])
|
||||
elif _type == "system":
|
||||
return SystemMessage(**message["data"])
|
||||
elif _type == "chat":
|
||||
return ChatMessage(**message["data"])
|
||||
elif _type == "function":
|
||||
return FunctionMessage(**message["data"])
|
||||
elif _type == "tool":
|
||||
return ToolMessage(**message["data"])
|
||||
else:
|
||||
raise ValueError(f"Got unexpected message type: {_type}")
|
||||
|
||||
|
||||
def messages_from_dict(messages: Sequence[dict]) -> List[BaseMessage]:
|
||||
"""Convert a sequence of messages from dicts to Message objects.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages (as dicts) to convert.
|
||||
|
||||
Returns:
|
||||
List of messages (BaseMessages).
|
||||
"""
|
||||
return [_message_from_dict(m) for m in messages]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AIMessage",
|
||||
"AIMessageChunk",
|
||||
"AnyMessage",
|
||||
"BaseMessage",
|
||||
"BaseMessageChunk",
|
||||
"ChatMessage",
|
||||
"ChatMessageChunk",
|
||||
"FunctionMessage",
|
||||
"FunctionMessageChunk",
|
||||
"HumanMessage",
|
||||
"HumanMessageChunk",
|
||||
"SystemMessage",
|
||||
"SystemMessageChunk",
|
||||
"ToolMessage",
|
||||
"ToolMessageChunk",
|
||||
"get_buffer_string",
|
||||
"messages_from_dict",
|
||||
"messages_to_dict",
|
||||
"message_to_dict",
|
||||
"merge_content",
|
||||
]
|
||||
47
libs/core/langchain_core/messages/ai.py
Normal file
47
libs/core/langchain_core/messages/ai.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
merge_content,
|
||||
)
|
||||
|
||||
|
||||
class AIMessage(BaseMessage):
|
||||
"""A Message from an AI."""
|
||||
|
||||
example: bool = False
|
||||
"""Whether this Message is being passed in to the model as part of an example
|
||||
conversation.
|
||||
"""
|
||||
|
||||
type: Literal["ai"] = "ai"
|
||||
|
||||
|
||||
AIMessage.update_forward_refs()
|
||||
|
||||
|
||||
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
"""A Message chunk from an AI."""
|
||||
|
||||
# Ignoring mypy re-assignment here since we're overriding the value
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, AIMessageChunk):
|
||||
if self.example != other.example:
|
||||
raise ValueError(
|
||||
"Cannot concatenate AIMessageChunks with different example values."
|
||||
)
|
||||
|
||||
return self.__class__(
|
||||
example=self.example,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
131
libs/core/langchain_core/messages/base.py
Normal file
131
libs/core/langchain_core/messages/base.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import Extra, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
class BaseMessage(Serializable):
|
||||
"""The base abstract Message class.
|
||||
|
||||
Messages are the inputs and outputs of ChatModels.
|
||||
"""
|
||||
|
||||
content: Union[str, List[Union[str, Dict]]]
|
||||
"""The string contents of the message."""
|
||||
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
"""Any additional information."""
|
||||
|
||||
type: str
|
||||
|
||||
class Config:
|
||||
extra = Extra.allow
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
|
||||
prompt = ChatPromptTemplate(messages=[self])
|
||||
return prompt + other
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
# For backwards compatibility.
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
|
||||
def merge_content(
|
||||
first_content: Union[str, List[Union[str, Dict]]],
|
||||
second_content: Union[str, List[Union[str, Dict]]],
|
||||
) -> Union[str, List[Union[str, Dict]]]:
|
||||
# If first chunk is a string
|
||||
if isinstance(first_content, str):
|
||||
# If the second chunk is also a string, then merge them naively
|
||||
if isinstance(second_content, str):
|
||||
return first_content + second_content
|
||||
# If the second chunk is a list, add the first chunk to the start of the list
|
||||
else:
|
||||
return_list: List[Union[str, Dict]] = [first_content]
|
||||
return return_list + second_content
|
||||
# If both are lists, merge them naively
|
||||
elif isinstance(second_content, List):
|
||||
return first_content + second_content
|
||||
# If the first content is a list, and the second content is a string
|
||||
else:
|
||||
# If the last element of the first content is a string
|
||||
# Add the second content to the last element
|
||||
if isinstance(first_content[-1], str):
|
||||
return first_content[:-1] + [first_content[-1] + second_content]
|
||||
else:
|
||||
# Otherwise, add the second content as a new element of the list
|
||||
return first_content + [second_content]
|
||||
|
||||
|
||||
class BaseMessageChunk(BaseMessage):
|
||||
"""A Message chunk, which can be concatenated with other Message chunks."""
|
||||
|
||||
def _merge_kwargs_dict(
|
||||
self, left: Dict[str, Any], right: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Merge additional_kwargs from another BaseMessageChunk into this one."""
|
||||
merged = left.copy()
|
||||
for k, v in right.items():
|
||||
if k not in merged:
|
||||
merged[k] = v
|
||||
elif type(merged[k]) != type(v):
|
||||
raise ValueError(
|
||||
f'additional_kwargs["{k}"] already exists in this message,'
|
||||
" but with a different type."
|
||||
)
|
||||
elif isinstance(merged[k], str):
|
||||
merged[k] += v
|
||||
elif isinstance(merged[k], dict):
|
||||
merged[k] = self._merge_kwargs_dict(merged[k], v)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Additional kwargs key {k} already exists in this message."
|
||||
)
|
||||
return merged
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, BaseMessageChunk):
|
||||
# If both are (subclasses of) BaseMessageChunk,
|
||||
# concat into a single BaseMessageChunk
|
||||
|
||||
return self.__class__(
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
'unsupported operand type(s) for +: "'
|
||||
f"{self.__class__.__name__}"
|
||||
f'" and "{other.__class__.__name__}"'
|
||||
)
|
||||
|
||||
|
||||
def message_to_dict(message: BaseMessage) -> dict:
|
||||
return {"type": message.type, "data": message.dict()}
|
||||
|
||||
|
||||
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
|
||||
"""Convert a sequence of Messages to a list of dictionaries.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages (as BaseMessages) to convert.
|
||||
|
||||
Returns:
|
||||
List of messages as dicts.
|
||||
"""
|
||||
return [message_to_dict(m) for m in messages]
|
||||
53
libs/core/langchain_core/messages/chat.py
Normal file
53
libs/core/langchain_core/messages/chat.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
merge_content,
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(BaseMessage):
|
||||
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
|
||||
|
||||
role: str
|
||||
"""The speaker / role of the Message."""
|
||||
|
||||
type: Literal["chat"] = "chat"
|
||||
|
||||
|
||||
ChatMessage.update_forward_refs()
|
||||
|
||||
|
||||
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||
"""A Chat Message chunk."""
|
||||
|
||||
# Ignoring mypy re-assignment here since we're overriding the value
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
type: Literal["ChatMessageChunk"] = "ChatMessageChunk" # type: ignore
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, ChatMessageChunk):
|
||||
if self.role != other.role:
|
||||
raise ValueError(
|
||||
"Cannot concatenate ChatMessageChunks with different roles."
|
||||
)
|
||||
|
||||
return self.__class__(
|
||||
role=self.role,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
)
|
||||
elif isinstance(other, BaseMessageChunk):
|
||||
return self.__class__(
|
||||
role=self.role,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
)
|
||||
else:
|
||||
return super().__add__(other)
|
||||
45
libs/core/langchain_core/messages/function.py
Normal file
45
libs/core/langchain_core/messages/function.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
merge_content,
|
||||
)
|
||||
|
||||
|
||||
class FunctionMessage(BaseMessage):
|
||||
"""A Message for passing the result of executing a function back to a model."""
|
||||
|
||||
name: str
|
||||
"""The name of the function that was executed."""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
|
||||
|
||||
FunctionMessage.update_forward_refs()
|
||||
|
||||
|
||||
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||
"""A Function Message chunk."""
|
||||
|
||||
# Ignoring mypy re-assignment here since we're overriding the value
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
type: Literal["FunctionMessageChunk"] = "FunctionMessageChunk" # type: ignore[assignment]
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, FunctionMessageChunk):
|
||||
if self.name != other.name:
|
||||
raise ValueError(
|
||||
"Cannot concatenate FunctionMessageChunks with different names."
|
||||
)
|
||||
|
||||
return self.__class__(
|
||||
name=self.name,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
26
libs/core/langchain_core/messages/human.py
Normal file
26
libs/core/langchain_core/messages/human.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||
|
||||
|
||||
class HumanMessage(BaseMessage):
|
||||
"""A Message from a human."""
|
||||
|
||||
example: bool = False
|
||||
"""Whether this Message is being passed in to the model as part of an example
|
||||
conversation.
|
||||
"""
|
||||
|
||||
type: Literal["human"] = "human"
|
||||
|
||||
|
||||
HumanMessage.update_forward_refs()
|
||||
|
||||
|
||||
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
||||
"""A Human Message chunk."""
|
||||
|
||||
# Ignoring mypy re-assignment here since we're overriding the value
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
type: Literal["HumanMessageChunk"] = "HumanMessageChunk" # type: ignore[assignment] # noqa: E501
|
||||
23
libs/core/langchain_core/messages/system.py
Normal file
23
libs/core/langchain_core/messages/system.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||
|
||||
|
||||
class SystemMessage(BaseMessage):
|
||||
"""A Message for priming AI behavior, usually passed in as the first of a sequence
|
||||
of input messages.
|
||||
"""
|
||||
|
||||
type: Literal["system"] = "system"
|
||||
|
||||
|
||||
SystemMessage.update_forward_refs()
|
||||
|
||||
|
||||
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
||||
"""A System Message chunk."""
|
||||
|
||||
# Ignoring mypy re-assignment here since we're overriding the value
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
type: Literal["SystemMessageChunk"] = "SystemMessageChunk" # type: ignore[assignment] # noqa: E501
|
||||
45
libs/core/langchain_core/messages/tool.py
Normal file
45
libs/core/langchain_core/messages/tool.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
merge_content,
|
||||
)
|
||||
|
||||
|
||||
class ToolMessage(BaseMessage):
|
||||
"""A Message for passing the result of executing a tool back to a model."""
|
||||
|
||||
tool_call_id: str
|
||||
"""Tool call that this message is responding to."""
|
||||
|
||||
type: Literal["tool"] = "tool"
|
||||
|
||||
|
||||
ToolMessage.update_forward_refs()
|
||||
|
||||
|
||||
class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||
"""A Tool Message chunk."""
|
||||
|
||||
# Ignoring mypy re-assignment here since we're overriding the value
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment]
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, ToolMessageChunk):
|
||||
if self.tool_call_id != other.tool_call_id:
|
||||
raise ValueError(
|
||||
"Cannot concatenate ToolMessageChunks with different names."
|
||||
)
|
||||
|
||||
return self.__class__(
|
||||
tool_call_id=self.tool_call_id,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
29
libs/core/langchain_core/output_parsers/__init__.py
Normal file
29
libs/core/langchain_core/output_parsers/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from langchain_core.output_parsers.base import (
|
||||
BaseGenerationOutputParser,
|
||||
BaseLLMOutputParser,
|
||||
BaseOutputParser,
|
||||
)
|
||||
from langchain_core.output_parsers.list import (
|
||||
CommaSeparatedListOutputParser,
|
||||
ListOutputParser,
|
||||
MarkdownListOutputParser,
|
||||
NumberedListOutputParser,
|
||||
)
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from langchain_core.output_parsers.transform import (
|
||||
BaseCumulativeTransformOutputParser,
|
||||
BaseTransformOutputParser,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseLLMOutputParser",
|
||||
"BaseGenerationOutputParser",
|
||||
"BaseOutputParser",
|
||||
"ListOutputParser",
|
||||
"CommaSeparatedListOutputParser",
|
||||
"NumberedListOutputParser",
|
||||
"MarkdownListOutputParser",
|
||||
"StrOutputParser",
|
||||
"BaseTransformOutputParser",
|
||||
"BaseCumulativeTransformOutputParser",
|
||||
]
|
||||
301
libs/core/langchain_core/output_parsers/base.py
Normal file
301
libs/core/langchain_core/output_parsers/base.py
Normal file
@@ -0,0 +1,301 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import get_args
|
||||
|
||||
from langchain_core.messages import AnyMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseLLMOutputParser(Generic[T], ABC):
|
||||
"""Abstract base class for parsing the outputs of a model."""
|
||||
|
||||
@abstractmethod
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
|
||||
async def aparse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.parse_result, result
|
||||
)
|
||||
|
||||
|
||||
class BaseGenerationOutputParser(
|
||||
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
|
||||
):
|
||||
"""Base class to parse the output of an LLM call."""
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
return Union[str, AnyMessage]
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[T]:
|
||||
# even though mypy complains this isn't valid,
|
||||
# it is good enough for pydantic to build the schema from
|
||||
return T # type: ignore[misc]
|
||||
|
||||
def invoke(
|
||||
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result(
|
||||
[ChatGeneration(message=inner_input)]
|
||||
),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
else:
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: str | BaseMessage,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result(
|
||||
[ChatGeneration(message=inner_input)]
|
||||
),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
else:
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
|
||||
class BaseOutputParser(
|
||||
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
|
||||
):
|
||||
"""Base class to parse the output of an LLM call.
|
||||
|
||||
Output parsers help structure language model responses.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class BooleanOutputParser(BaseOutputParser[bool]):
|
||||
true_val: str = "YES"
|
||||
false_val: str = "NO"
|
||||
|
||||
def parse(self, text: str) -> bool:
|
||||
cleaned_text = text.strip().upper()
|
||||
if cleaned_text not in (self.true_val.upper(), self.false_val.upper()):
|
||||
raise OutputParserException(
|
||||
f"BooleanOutputParser expected output value to either be "
|
||||
f"{self.true_val} or {self.false_val} (case-insensitive). "
|
||||
f"Received {cleaned_text}."
|
||||
)
|
||||
return cleaned_text == self.true_val.upper()
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "boolean_output_parser"
|
||||
""" # noqa: E501
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
return Union[str, AnyMessage]
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[T]:
|
||||
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||
type_args = get_args(cls)
|
||||
if type_args and len(type_args) == 1:
|
||||
return type_args[0]
|
||||
|
||||
raise TypeError(
|
||||
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
|
||||
"Override the OutputType property to specify the output type."
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result(
|
||||
[ChatGeneration(message=inner_input)]
|
||||
),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
else:
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: str | BaseMessage,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result(
|
||||
[ChatGeneration(message=inner_input)]
|
||||
),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
else:
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
The return value is parsed from only the first Generation in the result, which
|
||||
is assumed to be the highest-likelihood Generation.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return self.parse(result[0].text)
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> T:
|
||||
"""Parse a single string model output into some structure.
|
||||
|
||||
Args:
|
||||
text: String output of a language model.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
|
||||
async def aparse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
The return value is parsed from only the first Generation in the result, which
|
||||
is assumed to be the highest-likelihood Generation.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, functools.partial(self.parse_result, partial=partial), result
|
||||
)
|
||||
|
||||
async def aparse(self, text: str) -> T:
|
||||
"""Parse a single string model output into some structure.
|
||||
|
||||
Args:
|
||||
text: String output of a language model.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(None, self.parse, text)
|
||||
|
||||
# TODO: rename 'completion' -> 'text'.
|
||||
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
|
||||
"""Parse the output of an LLM call with the input prompt for context.
|
||||
|
||||
The prompt is largely provided in the event the OutputParser wants
|
||||
to retry or fix the output in some way, and needs information from
|
||||
the prompt to do so.
|
||||
|
||||
Args:
|
||||
completion: String output of a language model.
|
||||
prompt: Input PromptValue.
|
||||
|
||||
Returns:
|
||||
Structured output
|
||||
"""
|
||||
return self.parse(completion)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
"""Instructions on how the LLM output should be formatted."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Return the output parser type for serialization."""
|
||||
raise NotImplementedError(
|
||||
f"_type property is not implemented in class {self.__class__.__name__}."
|
||||
" This is required for serialization."
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of output parser."""
|
||||
output_parser_dict = super().dict(**kwargs)
|
||||
try:
|
||||
output_parser_dict["_type"] = self._type
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return output_parser_dict
|
||||
84
libs/core/langchain_core/output_parsers/list.py
Normal file
84
libs/core/langchain_core/output_parsers/list.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
|
||||
from langchain_core.output_parsers.base import BaseOutputParser
|
||||
|
||||
|
||||
class ListOutputParser(BaseOutputParser[List[str]]):
|
||||
"""Parse the output of an LLM call to a list."""
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "list"
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
# For backwards compatibility.
|
||||
return ["langchain", "output_parsers", "list"]
|
||||
|
||||
|
||||
class CommaSeparatedListOutputParser(ListOutputParser):
|
||||
"""Parse the output of an LLM call to a comma-separated list."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return (
|
||||
"Your response should be a list of comma separated values, "
|
||||
"eg: `foo, bar, baz`"
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
return text.strip().split(", ")
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "comma-separated-list"
|
||||
|
||||
|
||||
class NumberedListOutputParser(ListOutputParser):
|
||||
"""Parse a numbered list."""
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return (
|
||||
"Your response should be a numbered list with each item on a new line. "
|
||||
"For example: \n\n1. foo\n\n2. bar\n\n3. baz"
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
pattern = r"\d+\.\s([^\n]+)"
|
||||
|
||||
# Extract the text of each item
|
||||
matches = re.findall(pattern, text)
|
||||
return matches
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "numbered-list"
|
||||
|
||||
|
||||
class MarkdownListOutputParser(ListOutputParser):
|
||||
"""Parse a markdown list."""
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`"
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
pattern = r"-\s([^\n]+)"
|
||||
return re.findall(pattern, text)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "markdown-list"
|
||||
26
libs/core/langchain_core/output_parsers/string.py
Normal file
26
libs/core/langchain_core/output_parsers/string.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
|
||||
|
||||
class StrOutputParser(BaseTransformOutputParser[str]):
|
||||
"""OutputParser that parses LLMResult into the top likely string."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Return the output parser type for serialization."""
|
||||
return "default"
|
||||
|
||||
def parse(self, text: str) -> str:
|
||||
"""Returns the input text with no changes."""
|
||||
return text
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
# For backwards compatibility.
|
||||
return ["langchain", "schema", "output_parser"]
|
||||
131
libs/core/langchain_core/output_parsers/transform.py
Normal file
131
libs/core/langchain_core/output_parsers/transform.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Iterator,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||
from langchain_core.output_parsers.base import BaseOutputParser, T
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
Generation,
|
||||
GenerationChunk,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
||||
class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||
"""Base class for an output parser that can handle streaming input."""
|
||||
|
||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]:
|
||||
for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
yield self.parse_result([ChatGeneration(message=chunk)])
|
||||
else:
|
||||
yield self.parse_result([Generation(text=chunk)])
|
||||
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
) -> AsyncIterator[T]:
|
||||
async for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
yield self.parse_result([ChatGeneration(message=chunk)])
|
||||
else:
|
||||
yield self.parse_result([Generation(text=chunk)])
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Union[str, BaseMessage]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[T]:
|
||||
yield from self._transform_stream_with_config(
|
||||
input, self._transform, config, run_type="parser"
|
||||
)
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Union[str, BaseMessage]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[T]:
|
||||
async for chunk in self._atransform_stream_with_config(
|
||||
input, self._atransform, config, run_type="parser"
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
"""Base class for an output parser that can handle streaming input."""
|
||||
|
||||
diff: bool = False
|
||||
"""In streaming mode, whether to yield diffs between the previous and current
|
||||
parsed output, or just the current parsed output.
|
||||
"""
|
||||
|
||||
def _diff(self, prev: Optional[T], next: T) -> T:
|
||||
"""Convert parsed outputs into a diff format. The semantics of this are
|
||||
up to the output parser."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
||||
prev_parsed = None
|
||||
acc_gen = None
|
||||
for chunk in input:
|
||||
if isinstance(chunk, BaseMessageChunk):
|
||||
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.dict())
|
||||
)
|
||||
else:
|
||||
chunk_gen = GenerationChunk(text=chunk)
|
||||
|
||||
if acc_gen is None:
|
||||
acc_gen = chunk_gen
|
||||
else:
|
||||
acc_gen += chunk_gen
|
||||
|
||||
parsed = self.parse_result([acc_gen], partial=True)
|
||||
if parsed is not None and parsed != prev_parsed:
|
||||
if self.diff:
|
||||
yield self._diff(prev_parsed, parsed)
|
||||
else:
|
||||
yield parsed
|
||||
prev_parsed = parsed
|
||||
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
) -> AsyncIterator[T]:
|
||||
prev_parsed = None
|
||||
acc_gen = None
|
||||
async for chunk in input:
|
||||
if isinstance(chunk, BaseMessageChunk):
|
||||
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.dict())
|
||||
)
|
||||
else:
|
||||
chunk_gen = GenerationChunk(text=chunk)
|
||||
|
||||
if acc_gen is None:
|
||||
acc_gen = chunk_gen
|
||||
else:
|
||||
acc_gen += chunk_gen
|
||||
|
||||
parsed = self.parse_result([acc_gen], partial=True)
|
||||
if parsed is not None and parsed != prev_parsed:
|
||||
if self.diff:
|
||||
yield self._diff(prev_parsed, parsed)
|
||||
else:
|
||||
yield parsed
|
||||
prev_parsed = parsed
|
||||
15
libs/core/langchain_core/outputs/__init__.py
Normal file
15
libs/core/langchain_core/outputs/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from langchain_core.outputs.chat_generation import ChatGeneration, ChatGenerationChunk
|
||||
from langchain_core.outputs.chat_result import ChatResult
|
||||
from langchain_core.outputs.generation import Generation, GenerationChunk
|
||||
from langchain_core.outputs.llm_result import LLMResult
|
||||
from langchain_core.outputs.run_info import RunInfo
|
||||
|
||||
__all__ = [
|
||||
"ChatGeneration",
|
||||
"ChatGenerationChunk",
|
||||
"ChatResult",
|
||||
"Generation",
|
||||
"GenerationChunk",
|
||||
"LLMResult",
|
||||
"RunInfo",
|
||||
]
|
||||
58
libs/core/langchain_core/outputs/chat_generation.py
Normal file
58
libs/core/langchain_core/outputs/chat_generation.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Literal
|
||||
|
||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||
from langchain_core.outputs.generation import Generation
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
|
||||
class ChatGeneration(Generation):
|
||||
"""A single chat generation output."""
|
||||
|
||||
text: str = ""
|
||||
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
|
||||
message: BaseMessage
|
||||
"""The message output by the chat model."""
|
||||
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
|
||||
@root_validator
|
||||
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Set the text attribute to be the contents of the message."""
|
||||
try:
|
||||
values["text"] = values["message"].content
|
||||
except (KeyError, AttributeError) as e:
|
||||
raise ValueError("Error while initializing ChatGeneration") from e
|
||||
return values
|
||||
|
||||
|
||||
class ChatGenerationChunk(ChatGeneration):
|
||||
"""A ChatGeneration chunk, which can be concatenated with other
|
||||
ChatGeneration chunks.
|
||||
|
||||
Attributes:
|
||||
message: The message chunk output by the chat model.
|
||||
"""
|
||||
|
||||
message: BaseMessageChunk
|
||||
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
|
||||
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
|
||||
if isinstance(other, ChatGenerationChunk):
|
||||
generation_info = (
|
||||
{**(self.generation_info or {}), **(other.generation_info or {})}
|
||||
if self.generation_info is not None or other.generation_info is not None
|
||||
else None
|
||||
)
|
||||
return ChatGenerationChunk(
|
||||
message=self.message + other.message,
|
||||
generation_info=generation_info,
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||
)
|
||||
15
libs/core/langchain_core/outputs/chat_result.py
Normal file
15
libs/core/langchain_core/outputs/chat_result.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain_core.outputs.chat_generation import ChatGeneration
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class ChatResult(BaseModel):
|
||||
"""Class that contains all results for a single chat model call."""
|
||||
|
||||
generations: List[ChatGeneration]
|
||||
"""List of the chat generations. This is a List because an input can have multiple
|
||||
candidate generations.
|
||||
"""
|
||||
llm_output: Optional[dict] = None
|
||||
"""For arbitrary LLM provider specific output."""
|
||||
50
libs/core/langchain_core/outputs/generation.py
Normal file
50
libs/core/langchain_core/outputs/generation.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from langchain_core.load import Serializable
|
||||
|
||||
|
||||
class Generation(Serializable):
|
||||
"""A single text generation output."""
|
||||
|
||||
text: str
|
||||
"""Generated text output."""
|
||||
|
||||
generation_info: Optional[Dict[str, Any]] = None
|
||||
"""Raw response from the provider. May include things like the
|
||||
reason for finishing or token log probabilities.
|
||||
"""
|
||||
type: Literal["Generation"] = "Generation"
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
# TODO: add log probs as separate attribute
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
# For backwards compatibility.
|
||||
return ["langchain", "schema", "output"]
|
||||
|
||||
|
||||
class GenerationChunk(Generation):
|
||||
"""A Generation chunk, which can be concatenated with other Generation chunks."""
|
||||
|
||||
def __add__(self, other: GenerationChunk) -> GenerationChunk:
|
||||
if isinstance(other, GenerationChunk):
|
||||
generation_info = (
|
||||
{**(self.generation_info or {}), **(other.generation_info or {})}
|
||||
if self.generation_info is not None or other.generation_info is not None
|
||||
else None
|
||||
)
|
||||
return GenerationChunk(
|
||||
text=self.text + other.text,
|
||||
generation_info=generation_info,
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||
)
|
||||
65
libs/core/langchain_core/outputs/llm_result.py
Normal file
65
libs/core/langchain_core/outputs/llm_result.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain_core.outputs.generation import Generation
|
||||
from langchain_core.outputs.run_info import RunInfo
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class LLMResult(BaseModel):
|
||||
"""Class that contains all results for a batched LLM call."""
|
||||
|
||||
generations: List[List[Generation]]
|
||||
"""List of generated outputs. This is a List[List[]] because
|
||||
each input could have multiple candidate generations."""
|
||||
llm_output: Optional[dict] = None
|
||||
"""Arbitrary LLM provider-specific output."""
|
||||
run: Optional[List[RunInfo]] = None
|
||||
"""List of metadata info for model call for each input."""
|
||||
|
||||
def flatten(self) -> List[LLMResult]:
|
||||
"""Flatten generations into a single list.
|
||||
|
||||
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult
|
||||
contains only a single Generation. If token usage information is available,
|
||||
it is kept only for the LLMResult corresponding to the top-choice
|
||||
Generation, to avoid over-counting of token usage downstream.
|
||||
|
||||
Returns:
|
||||
List of LLMResults where each returned LLMResult contains a single
|
||||
Generation.
|
||||
"""
|
||||
llm_results = []
|
||||
for i, gen_list in enumerate(self.generations):
|
||||
# Avoid double counting tokens in OpenAICallback
|
||||
if i == 0:
|
||||
llm_results.append(
|
||||
LLMResult(
|
||||
generations=[gen_list],
|
||||
llm_output=self.llm_output,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if self.llm_output is not None:
|
||||
llm_output = deepcopy(self.llm_output)
|
||||
llm_output["token_usage"] = dict()
|
||||
else:
|
||||
llm_output = None
|
||||
llm_results.append(
|
||||
LLMResult(
|
||||
generations=[gen_list],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
)
|
||||
return llm_results
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Check for LLMResult equality by ignoring any metadata related to runs."""
|
||||
if not isinstance(other, LLMResult):
|
||||
return NotImplemented
|
||||
return (
|
||||
self.generations == other.generations
|
||||
and self.llm_output == other.llm_output
|
||||
)
|
||||
12
libs/core/langchain_core/outputs/run_info.py
Normal file
12
libs/core/langchain_core/outputs/run_info.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class RunInfo(BaseModel):
|
||||
"""Class that contains metadata for a single execution of a Chain or model."""
|
||||
|
||||
run_id: UUID
|
||||
"""A unique identifier for the model or chain run."""
|
||||
86
libs/core/langchain_core/prompt_values.py
Normal file
86
libs/core/langchain_core/prompt_values.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Literal, Sequence
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import (
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
|
||||
|
||||
class PromptValue(Serializable, ABC):
|
||||
"""Base abstract class for inputs to any language model.
|
||||
|
||||
PromptValues can be converted to both LLM (pure text-generation) inputs and
|
||||
ChatModel inputs.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt value as string."""
|
||||
|
||||
@abstractmethod
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of Messages."""
|
||||
|
||||
|
||||
class StringPromptValue(PromptValue):
|
||||
"""String prompt value."""
|
||||
|
||||
text: str
|
||||
"""Prompt text."""
|
||||
type: Literal["StringPromptValue"] = "StringPromptValue"
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt as string."""
|
||||
return self.text
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as messages."""
|
||||
return [HumanMessage(content=self.text)]
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
# For backwards compatibility.
|
||||
return ["langchain", "prompts", "base"]
|
||||
|
||||
|
||||
class ChatPromptValue(PromptValue):
|
||||
"""Chat prompt value.
|
||||
|
||||
A type of a prompt value that is built from messages.
|
||||
"""
|
||||
|
||||
messages: Sequence[BaseMessage]
|
||||
"""List of messages."""
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt as string."""
|
||||
return get_buffer_string(self.messages)
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of messages."""
|
||||
return list(self.messages)
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
# For backwards compatibility.
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
|
||||
class ChatPromptValueConcrete(ChatPromptValue):
|
||||
"""Chat prompt value which explicitly lists out the message types it accepts.
|
||||
For use in external schemas."""
|
||||
|
||||
messages: Sequence[AnyMessage]
|
||||
|
||||
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"
|
||||
74
libs/core/langchain_core/prompts/__init__.py
Normal file
74
libs/core/langchain_core/prompts/__init__.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""**Prompt** is the input to the model.
|
||||
|
||||
Prompt is often constructed
|
||||
from multiple components. Prompt classes and functions make constructing
|
||||
and working with prompts easy.
|
||||
|
||||
**Class hierarchy:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
BasePromptTemplate --> PipelinePromptTemplate
|
||||
StringPromptTemplate --> PromptTemplate
|
||||
FewShotPromptTemplate
|
||||
FewShotPromptWithTemplates
|
||||
BaseChatPromptTemplate --> AutoGPTPrompt
|
||||
ChatPromptTemplate --> AgentScratchPadChatPromptTemplate
|
||||
|
||||
|
||||
|
||||
BaseMessagePromptTemplate --> MessagesPlaceholder
|
||||
BaseStringMessagePromptTemplate --> ChatMessagePromptTemplate
|
||||
HumanMessagePromptTemplate
|
||||
AIMessagePromptTemplate
|
||||
SystemMessagePromptTemplate
|
||||
|
||||
""" # noqa: E501
|
||||
from langchain_core.prompts.base import BasePromptTemplate, format_document
|
||||
from langchain_core.prompts.chat import (
|
||||
AIMessagePromptTemplate,
|
||||
BaseChatPromptTemplate,
|
||||
ChatMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain_core.prompts.few_shot import (
|
||||
FewShotChatMessagePromptTemplate,
|
||||
FewShotPromptTemplate,
|
||||
)
|
||||
from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates
|
||||
from langchain_core.prompts.loading import load_prompt
|
||||
from langchain_core.prompts.pipeline import PipelinePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import (
|
||||
StringPromptTemplate,
|
||||
check_valid_template,
|
||||
get_template_variables,
|
||||
jinja2_formatter,
|
||||
validate_jinja2,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AIMessagePromptTemplate",
|
||||
"BaseChatPromptTemplate",
|
||||
"BasePromptTemplate",
|
||||
"ChatMessagePromptTemplate",
|
||||
"ChatPromptTemplate",
|
||||
"FewShotPromptTemplate",
|
||||
"FewShotPromptWithTemplates",
|
||||
"FewShotChatMessagePromptTemplate",
|
||||
"HumanMessagePromptTemplate",
|
||||
"MessagesPlaceholder",
|
||||
"PipelinePromptTemplate",
|
||||
"PromptTemplate",
|
||||
"StringPromptTemplate",
|
||||
"SystemMessagePromptTemplate",
|
||||
"load_prompt",
|
||||
"format_document",
|
||||
"check_valid_template",
|
||||
"get_template_variables",
|
||||
"jinja2_formatter",
|
||||
"validate_jinja2",
|
||||
]
|
||||
246
libs/core/langchain_core/prompts/base.py
Normal file
246
libs/core/langchain_core/prompts/base.py
Normal file
@@ -0,0 +1,246 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import yaml
|
||||
|
||||
from langchain_core.output_parsers.base import BaseOutputParser
|
||||
from langchain_core.prompt_values import (
|
||||
ChatPromptValueConcrete,
|
||||
PromptValue,
|
||||
StringPromptValue,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||
"""Base class for all prompt templates, returning a prompt."""
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
input_types: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""A dictionary of the types of the variables the prompt template expects.
|
||||
If not provided, all variables are assumed to be strings."""
|
||||
output_parser: Optional[BaseOutputParser] = None
|
||||
"""How to parse the output of calling an LLM on this formatted prompt."""
|
||||
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
return Union[StringPromptValue, ChatPromptValueConcrete]
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"PromptInput",
|
||||
**{k: (self.input_types.get(k, str), None) for k in self.input_variables},
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self, input: Dict, config: Optional[RunnableConfig] = None
|
||||
) -> PromptValue:
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.format_prompt(
|
||||
**{key: inner_input[key] for key in self.input_variables}
|
||||
),
|
||||
input,
|
||||
config,
|
||||
run_type="prompt",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""Create Chat Messages."""
|
||||
|
||||
@root_validator()
|
||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||
"""Validate variable names do not include restricted names."""
|
||||
if "stop" in values["input_variables"]:
|
||||
raise ValueError(
|
||||
"Cannot have an input variable named 'stop', as it is used internally,"
|
||||
" please rename."
|
||||
)
|
||||
if "stop" in values["partial_variables"]:
|
||||
raise ValueError(
|
||||
"Cannot have an partial variable named 'stop', as it is used "
|
||||
"internally, please rename."
|
||||
)
|
||||
|
||||
overall = set(values["input_variables"]).intersection(
|
||||
values["partial_variables"]
|
||||
)
|
||||
if overall:
|
||||
raise ValueError(
|
||||
f"Found overlapping input and partial variables: {overall}"
|
||||
)
|
||||
return values
|
||||
|
||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
|
||||
"""Return a partial of the prompt template."""
|
||||
prompt_dict = self.__dict__.copy()
|
||||
prompt_dict["input_variables"] = list(
|
||||
set(self.input_variables).difference(kwargs)
|
||||
)
|
||||
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
|
||||
return type(self)(**prompt_dict)
|
||||
|
||||
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
# Get partial params:
|
||||
partial_kwargs = {
|
||||
k: v if isinstance(v, str) else v()
|
||||
for k, v in self.partial_variables.items()
|
||||
}
|
||||
return {**partial_kwargs, **kwargs}
|
||||
|
||||
@abstractmethod
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt.format(variable1="foo")
|
||||
"""
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
raise NotImplementedError
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of prompt."""
|
||||
prompt_dict = super().dict(**kwargs)
|
||||
try:
|
||||
prompt_dict["_type"] = self._prompt_type
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return prompt_dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the prompt.
|
||||
|
||||
Args:
|
||||
file_path: Path to directory to save prompt to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
prompt.save(file_path="path/prompt.yaml")
|
||||
"""
|
||||
if self.partial_variables:
|
||||
raise ValueError("Cannot save prompt with partial variables.")
|
||||
|
||||
# Fetch dictionary to save
|
||||
prompt_dict = self.dict()
|
||||
if "_type" not in prompt_dict:
|
||||
raise NotImplementedError(f"Prompt {self} does not support saving.")
|
||||
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(prompt_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
# For backwards compatibility.
|
||||
return ["langchain"] + cls.__module__.split(".")[1:]
|
||||
|
||||
|
||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
"""Format a document into a string based on a prompt template.
|
||||
|
||||
First, this pulls information from the document from two sources:
|
||||
|
||||
1. `page_content`:
|
||||
This takes the information from the `document.page_content`
|
||||
and assigns it to a variable named `page_content`.
|
||||
2. metadata:
|
||||
This takes information from `document.metadata` and assigns
|
||||
it to variables of the same name.
|
||||
|
||||
Those variables are then passed into the `prompt` to produce a formatted string.
|
||||
|
||||
Args:
|
||||
doc: Document, the page_content and metadata will be used to create
|
||||
the final string.
|
||||
prompt: BasePromptTemplate, will be used to format the page_content
|
||||
and metadata into the final string.
|
||||
|
||||
Returns:
|
||||
string of the document formatted.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core import Document
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
doc = Document(page_content="This is a joke", metadata={"page": "1"})
|
||||
prompt = PromptTemplate.from_template("Page {page}: {page_content}")
|
||||
format_document(doc, prompt)
|
||||
>>> "Page 1: This is a joke"
|
||||
"""
|
||||
base_info = {"page_content": doc.page_content, **doc.metadata}
|
||||
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||
if len(missing_metadata) > 0:
|
||||
required_metadata = [
|
||||
iv for iv in prompt.input_variables if iv != "page_content"
|
||||
]
|
||||
raise ValueError(
|
||||
f"Document prompt requires documents to have metadata variables: "
|
||||
f"{required_metadata}. Received document with missing metadata: "
|
||||
f"{list(missing_metadata)}."
|
||||
)
|
||||
document_info = {k: base_info[k] for k in prompt.input_variables}
|
||||
return prompt.format(**document_info)
|
||||
733
libs/core/langchain_core/prompts/chat.py
Normal file
733
libs/core/langchain_core/prompts/chat.py
Normal file
@@ -0,0 +1,733 @@
|
||||
"""Chat prompt template."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import StringPromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
|
||||
|
||||
class BaseMessagePromptTemplate(Serializable, ABC):
|
||||
"""Base class for message prompt templates."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether or not the class is serializable."""
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs. Should return a list of BaseMessages.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
List of BaseMessages.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_variables(self) -> List[str]:
|
||||
"""Input variables for this prompt template.
|
||||
|
||||
Returns:
|
||||
List of input variables.
|
||||
"""
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
"""Combine two prompt templates.
|
||||
|
||||
Args:
|
||||
other: Another prompt template.
|
||||
|
||||
Returns:
|
||||
Combined prompt template.
|
||||
"""
|
||||
prompt = ChatPromptTemplate(messages=[self])
|
||||
return prompt + other
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
# For backwards compatibility.
|
||||
return ["langchain"] + cls.__module__.split(".")[1:]
|
||||
|
||||
|
||||
class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
"""Prompt template that assumes variable is already list of messages."""
|
||||
|
||||
variable_name: str
|
||||
"""Name of variable to use as messages."""
|
||||
|
||||
def __init__(self, variable_name: str, **kwargs: Any):
|
||||
return super().__init__(variable_name=variable_name, **kwargs)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
List of BaseMessage.
|
||||
"""
|
||||
value = kwargs[self.variable_name]
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(
|
||||
f"variable {self.variable_name} should be a list of base messages, "
|
||||
f"got {value}"
|
||||
)
|
||||
for v in value:
|
||||
if not isinstance(v, BaseMessage):
|
||||
raise ValueError(
|
||||
f"variable {self.variable_name} should be a list of base messages,"
|
||||
f" got {value}"
|
||||
)
|
||||
return value
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""Input variables for this prompt template.
|
||||
|
||||
Returns:
|
||||
List of input variable names.
|
||||
"""
|
||||
return [self.variable_name]
|
||||
|
||||
|
||||
MessagePromptTemplateT = TypeVar(
|
||||
"MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
|
||||
)
|
||||
"""Type variable for message prompt templates."""
|
||||
|
||||
|
||||
class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
"""Base class for message prompt templates that use a string prompt template."""
|
||||
|
||||
prompt: StringPromptTemplate
|
||||
"""String prompt template."""
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
"""Additional keyword arguments to pass to the prompt template."""
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls: Type[MessagePromptTemplateT],
|
||||
template: str,
|
||||
template_format: str = "f-string",
|
||||
partial_variables: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> MessagePromptTemplateT:
|
||||
"""Create a class from a string template.
|
||||
|
||||
Args:
|
||||
template: a template.
|
||||
template_format: format of the template.
|
||||
partial_variables: A dictionary of variables that can be used to partially
|
||||
fill in the template. For example, if the template is
|
||||
`"{variable1} {variable2}"`, and `partial_variables` is
|
||||
`{"variable1": "foo"}`, then the final prompt will be
|
||||
`"foo {variable2}"`.
|
||||
**kwargs: keyword arguments to pass to the constructor.
|
||||
|
||||
Returns:
|
||||
A new instance of this class.
|
||||
"""
|
||||
prompt = PromptTemplate.from_template(
|
||||
template,
|
||||
template_format=template_format,
|
||||
partial_variables=partial_variables,
|
||||
)
|
||||
return cls(prompt=prompt, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_template_file(
|
||||
cls: Type[MessagePromptTemplateT],
|
||||
template_file: Union[str, Path],
|
||||
input_variables: List[str],
|
||||
**kwargs: Any,
|
||||
) -> MessagePromptTemplateT:
|
||||
"""Create a class from a template file.
|
||||
|
||||
Args:
|
||||
template_file: path to a template file. String or Path.
|
||||
input_variables: list of input variables.
|
||||
**kwargs: keyword arguments to pass to the constructor.
|
||||
|
||||
Returns:
|
||||
A new instance of this class.
|
||||
"""
|
||||
prompt = PromptTemplate.from_file(template_file, input_variables)
|
||||
return cls(prompt=prompt, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
List of BaseMessages.
|
||||
"""
|
||||
return [self.format(**kwargs)]
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""
|
||||
Input variables for this prompt template.
|
||||
|
||||
Returns:
|
||||
List of input variable names.
|
||||
"""
|
||||
return self.prompt.input_variables
|
||||
|
||||
|
||||
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""Chat message prompt template."""
|
||||
|
||||
role: str
|
||||
"""Role of the message."""
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
text = self.prompt.format(**kwargs)
|
||||
return ChatMessage(
|
||||
content=text, role=self.role, additional_kwargs=self.additional_kwargs
|
||||
)
|
||||
|
||||
|
||||
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""Human message prompt template. This is a message sent from the user."""
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
text = self.prompt.format(**kwargs)
|
||||
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
|
||||
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""AI message prompt template. This is a message sent from the AI."""
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
text = self.prompt.format(**kwargs)
|
||||
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
|
||||
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""System message prompt template.
|
||||
This is a message that is not sent to the user.
|
||||
"""
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
text = self.prompt.format(**kwargs)
|
||||
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
|
||||
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""Base class for chat prompt templates."""
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict:
|
||||
"""
|
||||
Return a list of attribute names that should be included in the
|
||||
serialized kwargs. These attributes must be accepted by the
|
||||
constructor.
|
||||
"""
|
||||
return {"input_variables": self.input_variables}
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the chat template into a string.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in template variables
|
||||
in all the template messages in this chat template.
|
||||
|
||||
Returns:
|
||||
formatted string
|
||||
"""
|
||||
return self.format_prompt(**kwargs).to_string()
|
||||
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""
|
||||
Format prompt. Should return a PromptValue.
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
PromptValue.
|
||||
"""
|
||||
messages = self.format_messages(**kwargs)
|
||||
return ChatPromptValue(messages=messages)
|
||||
|
||||
@abstractmethod
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format kwargs into a list of messages."""
|
||||
|
||||
|
||||
MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]
|
||||
|
||||
MessageLikeRepresentation = Union[
|
||||
MessageLike,
|
||||
Tuple[str, str],
|
||||
Tuple[Type, str],
|
||||
str,
|
||||
]
|
||||
|
||||
|
||||
class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
"""A prompt template for chat models.
|
||||
|
||||
Use to create flexible templated prompts for chat models.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
template = ChatPromptTemplate.from_messages([
|
||||
("system", "You are a helpful AI bot. Your name is {name}."),
|
||||
("human", "Hello, how are you doing?"),
|
||||
("ai", "I'm doing well, thanks!"),
|
||||
("human", "{user_input}"),
|
||||
])
|
||||
|
||||
messages = template.format_messages(
|
||||
name="Bob",
|
||||
user_input="What is your name?"
|
||||
)
|
||||
"""
|
||||
|
||||
input_variables: List[str]
|
||||
"""List of input variables in template messages. Used for validation."""
|
||||
messages: List[MessageLike]
|
||||
"""List of messages consisting of either message prompt templates or messages."""
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
"""Combine two prompt templates.
|
||||
|
||||
Args:
|
||||
other: Another prompt template.
|
||||
|
||||
Returns:
|
||||
Combined prompt template.
|
||||
"""
|
||||
# Allow for easy combining
|
||||
if isinstance(other, ChatPromptTemplate):
|
||||
return ChatPromptTemplate(messages=self.messages + other.messages)
|
||||
elif isinstance(
|
||||
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
|
||||
):
|
||||
return ChatPromptTemplate(messages=self.messages + [other])
|
||||
elif isinstance(other, (list, tuple)):
|
||||
_other = ChatPromptTemplate.from_messages(other)
|
||||
return ChatPromptTemplate(messages=self.messages + _other.messages)
|
||||
elif isinstance(other, str):
|
||||
prompt = HumanMessagePromptTemplate.from_template(other)
|
||||
return ChatPromptTemplate(messages=self.messages + [prompt])
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_input_variables(cls, values: dict) -> dict:
|
||||
"""Validate input variables.
|
||||
|
||||
If input_variables is not set, it will be set to the union of
|
||||
all input variables in the messages.
|
||||
|
||||
Args:
|
||||
values: values to validate.
|
||||
|
||||
Returns:
|
||||
Validated values.
|
||||
"""
|
||||
messages = values["messages"]
|
||||
input_vars = set()
|
||||
input_types: Dict[str, Any] = values.get("input_types", {})
|
||||
for message in messages:
|
||||
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
|
||||
input_vars.update(message.input_variables)
|
||||
if isinstance(message, MessagesPlaceholder):
|
||||
if message.variable_name not in input_types:
|
||||
input_types[message.variable_name] = List[AnyMessage]
|
||||
if "partial_variables" in values:
|
||||
input_vars = input_vars - set(values["partial_variables"])
|
||||
if "input_variables" in values and values.get("validate_template"):
|
||||
if input_vars != set(values["input_variables"]):
|
||||
raise ValueError(
|
||||
"Got mismatched input_variables. "
|
||||
f"Expected: {input_vars}. "
|
||||
f"Got: {values['input_variables']}"
|
||||
)
|
||||
else:
|
||||
values["input_variables"] = sorted(input_vars)
|
||||
values["input_types"] = input_types
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a template string.
|
||||
|
||||
Creates a chat template consisting of a single message assumed to be from
|
||||
the human.
|
||||
|
||||
Args:
|
||||
template: template string
|
||||
**kwargs: keyword arguments to pass to the constructor.
|
||||
|
||||
Returns:
|
||||
A new instance of this class.
|
||||
"""
|
||||
prompt_template = PromptTemplate.from_template(template, **kwargs)
|
||||
message = HumanMessagePromptTemplate(prompt=prompt_template)
|
||||
return cls.from_messages([message])
|
||||
|
||||
@classmethod
|
||||
@deprecated("0.0.260", alternative="from_messages classmethod", pending=True)
|
||||
def from_role_strings(
|
||||
cls, string_messages: List[Tuple[str, str]]
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a list of (role, template) tuples.
|
||||
|
||||
Args:
|
||||
string_messages: list of (role, template) tuples.
|
||||
|
||||
Returns:
|
||||
a chat prompt template
|
||||
"""
|
||||
return cls(
|
||||
messages=[
|
||||
ChatMessagePromptTemplate.from_template(template, role=role)
|
||||
for role, template in string_messages
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@deprecated("0.0.260", alternative="from_messages classmethod", pending=True)
|
||||
def from_strings(
|
||||
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]]
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a list of (role class, template) tuples.
|
||||
|
||||
Args:
|
||||
string_messages: list of (role class, template) tuples.
|
||||
|
||||
Returns:
|
||||
a chat prompt template
|
||||
"""
|
||||
return cls.from_messages(string_messages)
|
||||
|
||||
@classmethod
|
||||
def from_messages(
|
||||
cls,
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a variety of message formats.
|
||||
|
||||
Examples:
|
||||
|
||||
Instantiation from a list of message templates:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
template = ChatPromptTemplate.from_messages([
|
||||
("human", "Hello, how are you?"),
|
||||
("ai", "I'm doing well, thanks!"),
|
||||
("human", "That's good to hear."),
|
||||
])
|
||||
|
||||
Instantiation from mixed message formats:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
template = ChatPromptTemplate.from_messages([
|
||||
SystemMessage(content="hello"),
|
||||
("human", "Hello, how are you?"),
|
||||
])
|
||||
|
||||
Args:
|
||||
messages: sequence of message representations.
|
||||
A message can be represented using the following formats:
|
||||
(1) BaseMessagePromptTemplate, (2) BaseMessage, (3) 2-tuple of
|
||||
(message type, template); e.g., ("human", "{user_input}"),
|
||||
(4) 2-tuple of (message class, template), (4) a string which is
|
||||
shorthand for ("human", template); e.g., "{user_input}"
|
||||
|
||||
Returns:
|
||||
a chat prompt template
|
||||
"""
|
||||
_messages = [_convert_to_message(message) for message in messages]
|
||||
|
||||
# Automatically infer input variables from messages
|
||||
input_vars: Set[str] = set()
|
||||
for _message in _messages:
|
||||
if isinstance(
|
||||
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
|
||||
):
|
||||
input_vars.update(_message.input_variables)
|
||||
|
||||
return cls(input_variables=sorted(input_vars), messages=_messages)
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the chat template into a string.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in template variables
|
||||
in all the template messages in this chat template.
|
||||
|
||||
Returns:
|
||||
formatted string
|
||||
"""
|
||||
return self.format_prompt(**kwargs).to_string()
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format the chat template into a list of finalized messages.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in template variables
|
||||
in all the template messages in this chat template.
|
||||
|
||||
Returns:
|
||||
list of formatted messages
|
||||
"""
|
||||
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||
result = []
|
||||
for message_template in self.messages:
|
||||
if isinstance(message_template, BaseMessage):
|
||||
result.extend([message_template])
|
||||
elif isinstance(
|
||||
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
|
||||
):
|
||||
rel_params = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k in message_template.input_variables
|
||||
}
|
||||
message = message_template.format_messages(**rel_params)
|
||||
result.extend(message)
|
||||
else:
|
||||
raise ValueError(f"Unexpected input: {message_template}")
|
||||
return result
|
||||
|
||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate:
|
||||
"""Get a new ChatPromptTemplate with some input variables already filled in.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in template variables. Ought
|
||||
to be a subset of the input variables.
|
||||
|
||||
Returns:
|
||||
A new ChatPromptTemplate.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are an AI assistant named {name}."),
|
||||
("human", "Hi I'm {user}"),
|
||||
("ai", "Hi there, {user}, I'm {name}."),
|
||||
("human", "{input}"),
|
||||
]
|
||||
)
|
||||
template2 = template.partial(user="Lucy", name="R2D2")
|
||||
|
||||
template2.format_messages(input="hello")
|
||||
"""
|
||||
prompt_dict = self.__dict__.copy()
|
||||
prompt_dict["input_variables"] = list(
|
||||
set(self.input_variables).difference(kwargs)
|
||||
)
|
||||
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
|
||||
return type(self)(**prompt_dict)
|
||||
|
||||
def append(self, message: MessageLikeRepresentation) -> None:
|
||||
"""Append message to the end of the chat template.
|
||||
|
||||
Args:
|
||||
message: representation of a message to append.
|
||||
"""
|
||||
self.messages.append(_convert_to_message(message))
|
||||
|
||||
def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None:
|
||||
"""Extend the chat template with a sequence of messages."""
|
||||
self.messages.extend([_convert_to_message(message) for message in messages])
|
||||
|
||||
@overload
|
||||
def __getitem__(self, index: int) -> MessageLike:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, index: slice) -> ChatPromptTemplate:
|
||||
...
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[int, slice]
|
||||
) -> Union[MessageLike, ChatPromptTemplate]:
|
||||
"""Use to index into the chat template."""
|
||||
if isinstance(index, slice):
|
||||
start, stop, step = index.indices(len(self.messages))
|
||||
messages = self.messages[start:stop:step]
|
||||
return ChatPromptTemplate.from_messages(messages)
|
||||
else:
|
||||
return self.messages[index]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get the length of the chat template."""
|
||||
return len(self.messages)
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Name of prompt type."""
|
||||
return "chat"
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save prompt to file.
|
||||
|
||||
Args:
|
||||
file_path: path to file.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def _create_template_from_message_type(
|
||||
message_type: str, template: str
|
||||
) -> BaseMessagePromptTemplate:
|
||||
"""Create a message prompt template from a message type and template string.
|
||||
|
||||
Args:
|
||||
message_type: str the type of the message template (e.g., "human", "ai", etc.)
|
||||
template: str the template string.
|
||||
|
||||
Returns:
|
||||
a message prompt template of the appropriate type.
|
||||
"""
|
||||
if message_type in ("human", "user"):
|
||||
message: BaseMessagePromptTemplate = HumanMessagePromptTemplate.from_template(
|
||||
template
|
||||
)
|
||||
elif message_type in ("ai", "assistant"):
|
||||
message = AIMessagePromptTemplate.from_template(template)
|
||||
elif message_type == "system":
|
||||
message = SystemMessagePromptTemplate.from_template(template)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected message type: {message_type}. Use one of 'human',"
|
||||
f" 'user', 'ai', 'assistant', or 'system'."
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
def _convert_to_message(
|
||||
message: MessageLikeRepresentation,
|
||||
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
|
||||
"""Instantiate a message from a variety of message formats.
|
||||
|
||||
The message format can be one of the following:
|
||||
|
||||
- BaseMessagePromptTemplate
|
||||
- BaseMessage
|
||||
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
|
||||
- 2-tuple of (message class, template)
|
||||
- string: shorthand for ("human", template); e.g., "{user_input}"
|
||||
|
||||
Args:
|
||||
message: a representation of a message in one of the supported formats
|
||||
|
||||
Returns:
|
||||
an instance of a message or a message template
|
||||
"""
|
||||
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
|
||||
_message: Union[
|
||||
BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate
|
||||
] = message
|
||||
elif isinstance(message, BaseMessage):
|
||||
_message = message
|
||||
elif isinstance(message, str):
|
||||
_message = _create_template_from_message_type("human", message)
|
||||
elif isinstance(message, tuple):
|
||||
if len(message) != 2:
|
||||
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
|
||||
message_type_str, template = message
|
||||
if isinstance(message_type_str, str):
|
||||
_message = _create_template_from_message_type(message_type_str, template)
|
||||
else:
|
||||
_message = message_type_str(prompt=PromptTemplate.from_template(template))
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
return _message
|
||||
342
libs/core/langchain_core/prompts/few_shot.py
Normal file
342
libs/core/langchain_core/prompts/few_shot.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""Prompt template that contains few shot examples."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.prompts.chat import (
|
||||
BaseChatPromptTemplate,
|
||||
BaseMessagePromptTemplate,
|
||||
)
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
StringPromptTemplate,
|
||||
check_valid_template,
|
||||
get_template_variables,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
|
||||
|
||||
class _FewShotPromptTemplateMixin(BaseModel):
|
||||
"""Prompt template that contains few shot examples."""
|
||||
|
||||
examples: Optional[List[dict]] = None
|
||||
"""Examples to format into the prompt.
|
||||
Either this or example_selector should be provided."""
|
||||
|
||||
example_selector: Any = None
|
||||
"""ExampleSelector to choose the examples to format into the prompt.
|
||||
Either this or examples should be provided."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_examples_and_selector(cls, values: Dict) -> Dict:
|
||||
"""Check that one and only one of examples/example_selector are provided."""
|
||||
examples = values.get("examples", None)
|
||||
example_selector = values.get("example_selector", None)
|
||||
if examples and example_selector:
|
||||
raise ValueError(
|
||||
"Only one of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
if examples is None and example_selector is None:
|
||||
raise ValueError(
|
||||
"One of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def _get_examples(self, **kwargs: Any) -> List[dict]:
|
||||
"""Get the examples to use for formatting the prompt.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to be passed to the example selector.
|
||||
|
||||
Returns:
|
||||
List of examples.
|
||||
"""
|
||||
if self.examples is not None:
|
||||
return self.examples
|
||||
elif self.example_selector is not None:
|
||||
return self.example_selector.select_examples(kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"One of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
|
||||
class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
||||
"""Prompt template that contains few shot examples."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether or not the class is serializable."""
|
||||
return False
|
||||
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
|
||||
example_prompt: PromptTemplate
|
||||
"""PromptTemplate used to format an individual example."""
|
||||
|
||||
suffix: str
|
||||
"""A prompt template string to put after the examples."""
|
||||
|
||||
example_separator: str = "\n\n"
|
||||
"""String separator used to join the prefix, the examples, and suffix."""
|
||||
|
||||
prefix: str = ""
|
||||
"""A prompt template string to put before the examples."""
|
||||
|
||||
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
|
||||
@root_validator()
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Check that prefix, suffix, and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
check_valid_template(
|
||||
values["prefix"] + values["suffix"],
|
||||
values["template_format"],
|
||||
values["input_variables"] + list(values["partial_variables"]),
|
||||
)
|
||||
elif values.get("template_format"):
|
||||
values["input_variables"] = [
|
||||
var
|
||||
for var in get_template_variables(
|
||||
values["prefix"] + values["suffix"], values["template_format"]
|
||||
)
|
||||
if var not in values["partial_variables"]
|
||||
]
|
||||
return values
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
**kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt.format(variable1="foo")
|
||||
"""
|
||||
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||
# Get the examples to use.
|
||||
examples = self._get_examples(**kwargs)
|
||||
examples = [
|
||||
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
|
||||
]
|
||||
# Format the examples.
|
||||
example_strings = [
|
||||
self.example_prompt.format(**example) for example in examples
|
||||
]
|
||||
# Create the overall template.
|
||||
pieces = [self.prefix, *example_strings, self.suffix]
|
||||
template = self.example_separator.join([piece for piece in pieces if piece])
|
||||
|
||||
# Format the template with the input variables.
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
return "few_shot"
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
if self.example_selector:
|
||||
raise ValueError("Saving an example selector is not currently supported")
|
||||
return super().save(file_path)
|
||||
|
||||
|
||||
class FewShotChatMessagePromptTemplate(
|
||||
BaseChatPromptTemplate, _FewShotPromptTemplateMixin
|
||||
):
|
||||
"""Chat prompt template that supports few-shot examples.
|
||||
|
||||
The high level structure of produced by this prompt template is a list of messages
|
||||
consisting of prefix message(s), example message(s), and suffix message(s).
|
||||
|
||||
This structure enables creating a conversation with intermediate examples like:
|
||||
|
||||
System: You are a helpful AI Assistant
|
||||
Human: What is 2+2?
|
||||
AI: 4
|
||||
Human: What is 2+3?
|
||||
AI: 5
|
||||
Human: What is 4+4?
|
||||
|
||||
This prompt template can be used to generate a fixed list of examples or else
|
||||
to dynamically select examples based on the input.
|
||||
|
||||
Examples:
|
||||
|
||||
Prompt template with a fixed list of examples (matching the sample
|
||||
conversation above):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.prompts import (
|
||||
FewShotChatMessagePromptTemplate,
|
||||
ChatPromptTemplate
|
||||
)
|
||||
|
||||
examples = [
|
||||
{"input": "2+2", "output": "4"},
|
||||
{"input": "2+3", "output": "5"},
|
||||
]
|
||||
|
||||
example_prompt = ChatPromptTemplate.from_messages(
|
||||
[('human', '{input}'), ('ai', '{output}')]
|
||||
)
|
||||
|
||||
few_shot_prompt = FewShotChatMessagePromptTemplate(
|
||||
examples=examples,
|
||||
# This is a prompt template used to format each individual example.
|
||||
example_prompt=example_prompt,
|
||||
)
|
||||
|
||||
final_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
('system', 'You are a helpful AI Assistant'),
|
||||
few_shot_prompt,
|
||||
('human', '{input}'),
|
||||
]
|
||||
)
|
||||
final_prompt.format(input="What is 4+4?")
|
||||
|
||||
Prompt template with dynamically selected examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.prompts import SemanticSimilarityExampleSelector
|
||||
from langchain_core.embeddings import OpenAIEmbeddings
|
||||
from langchain_core.vectorstores import Chroma
|
||||
|
||||
examples = [
|
||||
{"input": "2+2", "output": "4"},
|
||||
{"input": "2+3", "output": "5"},
|
||||
{"input": "2+4", "output": "6"},
|
||||
# ...
|
||||
]
|
||||
|
||||
to_vectorize = [
|
||||
" ".join(example.values())
|
||||
for example in examples
|
||||
]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = Chroma.from_texts(
|
||||
to_vectorize, embeddings, metadatas=examples
|
||||
)
|
||||
example_selector = SemanticSimilarityExampleSelector(
|
||||
vectorstore=vectorstore
|
||||
)
|
||||
|
||||
from langchain_core import SystemMessage
|
||||
from langchain_core.prompts import HumanMessagePromptTemplate
|
||||
from langchain_core.prompts.few_shot import FewShotChatMessagePromptTemplate
|
||||
|
||||
few_shot_prompt = FewShotChatMessagePromptTemplate(
|
||||
# Which variable(s) will be passed to the example selector.
|
||||
input_variables=["input"],
|
||||
example_selector=example_selector,
|
||||
# Define how each example will be formatted.
|
||||
# In this case, each example will become 2 messages:
|
||||
# 1 human, and 1 AI
|
||||
example_prompt=(
|
||||
HumanMessagePromptTemplate.from_template("{input}")
|
||||
+ AIMessagePromptTemplate.from_template("{output}")
|
||||
),
|
||||
)
|
||||
# Define the overall prompt.
|
||||
final_prompt = (
|
||||
SystemMessagePromptTemplate.from_template(
|
||||
"You are a helpful AI Assistant"
|
||||
)
|
||||
+ few_shot_prompt
|
||||
+ HumanMessagePromptTemplate.from_template("{input}")
|
||||
)
|
||||
# Show the prompt
|
||||
print(final_prompt.format_messages(input="What's 3+3?"))
|
||||
|
||||
# Use within an LLM
|
||||
from langchain_core.chat_models import ChatAnthropic
|
||||
chain = final_prompt | ChatAnthropic()
|
||||
chain.invoke({"input": "What's 3+3?"})
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether or not the class is serializable."""
|
||||
return False
|
||||
|
||||
input_variables: List[str] = Field(default_factory=list)
|
||||
"""A list of the names of the variables the prompt template will use
|
||||
to pass to the example_selector, if provided."""
|
||||
example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate]
|
||||
"""The class to format each example."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format kwargs into a list of messages.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in templates in messages.
|
||||
|
||||
Returns:
|
||||
A list of formatted messages with all template variables filled in.
|
||||
"""
|
||||
# Get the examples to use.
|
||||
examples = self._get_examples(**kwargs)
|
||||
examples = [
|
||||
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
|
||||
]
|
||||
# Format the examples.
|
||||
messages = [
|
||||
message
|
||||
for example in examples
|
||||
for message in self.example_prompt.format_messages(**example)
|
||||
]
|
||||
return messages
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with inputs generating a string.
|
||||
|
||||
Use this method to generate a string representation of a prompt consisting
|
||||
of chat messages.
|
||||
|
||||
Useful for feeding into a string based completion language model or debugging.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
A string representation of the prompt
|
||||
"""
|
||||
messages = self.format_messages(**kwargs)
|
||||
return get_buffer_string(messages)
|
||||
155
libs/core/langchain_core/prompts/few_shot_with_templates.py
Normal file
155
libs/core/langchain_core/prompts/few_shot_with_templates.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""Prompt template that contains few shot examples."""
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
StringPromptTemplate,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||
|
||||
|
||||
class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
"""Prompt template that contains few shot examples."""
|
||||
|
||||
examples: Optional[List[dict]] = None
|
||||
"""Examples to format into the prompt.
|
||||
Either this or example_selector should be provided."""
|
||||
|
||||
example_selector: Any = None
|
||||
"""ExampleSelector to choose the examples to format into the prompt.
|
||||
Either this or examples should be provided."""
|
||||
|
||||
example_prompt: PromptTemplate
|
||||
"""PromptTemplate used to format an individual example."""
|
||||
|
||||
suffix: StringPromptTemplate
|
||||
"""A PromptTemplate to put after the examples."""
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
|
||||
example_separator: str = "\n\n"
|
||||
"""String separator used to join the prefix, the examples, and suffix."""
|
||||
|
||||
prefix: Optional[StringPromptTemplate] = None
|
||||
"""A PromptTemplate to put before the examples."""
|
||||
|
||||
template_format: str = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_examples_and_selector(cls, values: Dict) -> Dict:
|
||||
"""Check that one and only one of examples/example_selector are provided."""
|
||||
examples = values.get("examples", None)
|
||||
example_selector = values.get("example_selector", None)
|
||||
if examples and example_selector:
|
||||
raise ValueError(
|
||||
"Only one of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
if examples is None and example_selector is None:
|
||||
raise ValueError(
|
||||
"One of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Check that prefix, suffix, and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
input_variables = values["input_variables"]
|
||||
expected_input_variables = set(values["suffix"].input_variables)
|
||||
expected_input_variables |= set(values["partial_variables"])
|
||||
if values["prefix"] is not None:
|
||||
expected_input_variables |= set(values["prefix"].input_variables)
|
||||
missing_vars = expected_input_variables.difference(input_variables)
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Got input_variables={input_variables}, but based on "
|
||||
f"prefix/suffix expected {expected_input_variables}"
|
||||
)
|
||||
else:
|
||||
values["input_variables"] = sorted(
|
||||
set(values["suffix"].input_variables)
|
||||
| set(values["prefix"].input_variables if values["prefix"] else [])
|
||||
- set(values["partial_variables"])
|
||||
)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _get_examples(self, **kwargs: Any) -> List[dict]:
|
||||
if self.examples is not None:
|
||||
return self.examples
|
||||
elif self.example_selector is not None:
|
||||
return self.example_selector.select_examples(kwargs)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt.format(variable1="foo")
|
||||
"""
|
||||
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||
# Get the examples to use.
|
||||
examples = self._get_examples(**kwargs)
|
||||
# Format the examples.
|
||||
example_strings = [
|
||||
self.example_prompt.format(**example) for example in examples
|
||||
]
|
||||
# Create the overall prefix.
|
||||
if self.prefix is None:
|
||||
prefix = ""
|
||||
else:
|
||||
prefix_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in self.prefix.input_variables
|
||||
}
|
||||
for k in prefix_kwargs.keys():
|
||||
kwargs.pop(k)
|
||||
prefix = self.prefix.format(**prefix_kwargs)
|
||||
|
||||
# Create the overall suffix
|
||||
suffix_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in self.suffix.input_variables
|
||||
}
|
||||
for k in suffix_kwargs.keys():
|
||||
kwargs.pop(k)
|
||||
suffix = self.suffix.format(
|
||||
**suffix_kwargs,
|
||||
)
|
||||
|
||||
pieces = [prefix, *example_strings, suffix]
|
||||
template = self.example_separator.join([piece for piece in pieces if piece])
|
||||
# Format the template with the input variables.
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
return "few_shot_with_templates"
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
if self.example_selector:
|
||||
raise ValueError("Saving an example selector is not currently supported")
|
||||
return super().save(file_path)
|
||||
160
libs/core/langchain_core/prompts/loading.py
Normal file
160
libs/core/langchain_core/prompts/loading.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Load prompts."""
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.utils import try_load_from_hub
|
||||
|
||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_prompt_from_config(config: dict) -> BasePromptTemplate:
|
||||
"""Load prompt from Config Dict."""
|
||||
if "_type" not in config:
|
||||
logger.warning("No `_type` key found, defaulting to `prompt`.")
|
||||
config_type = config.pop("_type", "prompt")
|
||||
|
||||
if config_type not in type_to_loader_dict:
|
||||
raise ValueError(f"Loading {config_type} prompt not supported")
|
||||
|
||||
prompt_loader = type_to_loader_dict[config_type]
|
||||
return prompt_loader(config)
|
||||
|
||||
|
||||
def _load_template(var_name: str, config: dict) -> dict:
|
||||
"""Load template from the path if applicable."""
|
||||
# Check if template_path exists in config.
|
||||
if f"{var_name}_path" in config:
|
||||
# If it does, make sure template variable doesn't also exist.
|
||||
if var_name in config:
|
||||
raise ValueError(
|
||||
f"Both `{var_name}_path` and `{var_name}` cannot be provided."
|
||||
)
|
||||
# Pop the template path from the config.
|
||||
template_path = Path(config.pop(f"{var_name}_path"))
|
||||
# Load the template.
|
||||
if template_path.suffix == ".txt":
|
||||
with open(template_path) as f:
|
||||
template = f.read()
|
||||
else:
|
||||
raise ValueError
|
||||
# Set the template variable to the extracted variable.
|
||||
config[var_name] = template
|
||||
return config
|
||||
|
||||
|
||||
def _load_examples(config: dict) -> dict:
|
||||
"""Load examples if necessary."""
|
||||
if isinstance(config["examples"], list):
|
||||
pass
|
||||
elif isinstance(config["examples"], str):
|
||||
with open(config["examples"]) as f:
|
||||
if config["examples"].endswith(".json"):
|
||||
examples = json.load(f)
|
||||
elif config["examples"].endswith((".yaml", ".yml")):
|
||||
examples = yaml.safe_load(f)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid file format. Only json or yaml formats are supported."
|
||||
)
|
||||
config["examples"] = examples
|
||||
else:
|
||||
raise ValueError("Invalid examples format. Only list or string are supported.")
|
||||
return config
|
||||
|
||||
|
||||
def _load_output_parser(config: dict) -> dict:
|
||||
"""Load output parser."""
|
||||
if "output_parser" in config and config["output_parser"]:
|
||||
_config = config.pop("output_parser")
|
||||
output_parser_type = _config.pop("_type")
|
||||
if output_parser_type == "default":
|
||||
output_parser = StrOutputParser(**_config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported output parser {output_parser_type}")
|
||||
config["output_parser"] = output_parser
|
||||
return config
|
||||
|
||||
|
||||
def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate:
|
||||
"""Load the "few shot" prompt from the config."""
|
||||
# Load the suffix and prefix templates.
|
||||
config = _load_template("suffix", config)
|
||||
config = _load_template("prefix", config)
|
||||
# Load the example prompt.
|
||||
if "example_prompt_path" in config:
|
||||
if "example_prompt" in config:
|
||||
raise ValueError(
|
||||
"Only one of example_prompt and example_prompt_path should "
|
||||
"be specified."
|
||||
)
|
||||
config["example_prompt"] = load_prompt(config.pop("example_prompt_path"))
|
||||
else:
|
||||
config["example_prompt"] = load_prompt_from_config(config["example_prompt"])
|
||||
# Load the examples.
|
||||
config = _load_examples(config)
|
||||
config = _load_output_parser(config)
|
||||
return FewShotPromptTemplate(**config)
|
||||
|
||||
|
||||
def _load_prompt(config: dict) -> PromptTemplate:
|
||||
"""Load the prompt template from config."""
|
||||
# Load the template from disk if necessary.
|
||||
config = _load_template("template", config)
|
||||
config = _load_output_parser(config)
|
||||
|
||||
template_format = config.get("template_format", "f-string")
|
||||
if template_format == "jinja2":
|
||||
# Disabled due to:
|
||||
# https://github.com/langchain-ai/langchain/issues/4394
|
||||
raise ValueError(
|
||||
f"Loading templates with '{template_format}' format is no longer supported "
|
||||
f"since it can lead to arbitrary code execution. Please migrate to using "
|
||||
f"the 'f-string' template format, which does not suffer from this issue."
|
||||
)
|
||||
|
||||
return PromptTemplate(**config)
|
||||
|
||||
|
||||
def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
|
||||
"""Unified method for loading a prompt from LangChainHub or local fs."""
|
||||
if hub_result := try_load_from_hub(
|
||||
path, _load_prompt_from_file, "prompts", {"py", "json", "yaml"}
|
||||
):
|
||||
return hub_result
|
||||
else:
|
||||
return _load_prompt_from_file(path)
|
||||
|
||||
|
||||
def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
|
||||
"""Load prompt from file."""
|
||||
# Convert file to a Path object.
|
||||
if isinstance(file, str):
|
||||
file_path = Path(file)
|
||||
else:
|
||||
file_path = file
|
||||
# Load from either json or yaml.
|
||||
if file_path.suffix == ".json":
|
||||
with open(file_path) as f:
|
||||
config = json.load(f)
|
||||
elif file_path.suffix == ".yaml":
|
||||
with open(file_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
else:
|
||||
raise ValueError(f"Got unsupported file type {file_path.suffix}")
|
||||
# Load the prompt from the config now.
|
||||
return load_prompt_from_config(config)
|
||||
|
||||
|
||||
type_to_loader_dict: Dict[str, Callable[[dict], BasePromptTemplate]] = {
|
||||
"prompt": _load_prompt,
|
||||
"few_shot": _load_few_shot_prompt,
|
||||
}
|
||||
57
libs/core/langchain_core/prompts/pipeline.py
Normal file
57
libs/core/langchain_core/prompts/pipeline.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.chat import BaseChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
|
||||
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
|
||||
return {k: inputs[k] for k in input_variables}
|
||||
|
||||
|
||||
class PipelinePromptTemplate(BasePromptTemplate):
|
||||
"""A prompt template for composing multiple prompt templates together.
|
||||
|
||||
This can be useful when you want to reuse parts of prompts.
|
||||
A PipelinePrompt consists of two main parts:
|
||||
- final_prompt: This is the final prompt that is returned
|
||||
- pipeline_prompts: This is a list of tuples, consisting
|
||||
of a string (`name`) and a Prompt Template.
|
||||
Each PromptTemplate will be formatted and then passed
|
||||
to future prompt templates as a variable with
|
||||
the same name as `name`
|
||||
"""
|
||||
|
||||
final_prompt: BasePromptTemplate
|
||||
"""The final prompt that is returned."""
|
||||
pipeline_prompts: List[Tuple[str, BasePromptTemplate]]
|
||||
"""A list of tuples, consisting of a string (`name`) and a Prompt Template."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def get_input_variables(cls, values: Dict) -> Dict:
|
||||
"""Get input variables."""
|
||||
created_variables = set()
|
||||
all_variables = set()
|
||||
for k, prompt in values["pipeline_prompts"]:
|
||||
created_variables.add(k)
|
||||
all_variables.update(prompt.input_variables)
|
||||
values["input_variables"] = list(all_variables.difference(created_variables))
|
||||
return values
|
||||
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
for k, prompt in self.pipeline_prompts:
|
||||
_inputs = _get_inputs(kwargs, prompt.input_variables)
|
||||
if isinstance(prompt, BaseChatPromptTemplate):
|
||||
kwargs[k] = prompt.format_messages(**_inputs)
|
||||
else:
|
||||
kwargs[k] = prompt.format(**_inputs)
|
||||
_inputs = _get_inputs(kwargs, self.final_prompt.input_variables)
|
||||
return self.final_prompt.format_prompt(**_inputs)
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
return self.format_prompt(**kwargs).to_string()
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
raise ValueError
|
||||
246
libs/core/langchain_core/prompts/prompt.py
Normal file
246
libs/core/langchain_core/prompts/prompt.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""Prompt schema definition."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from langchain_core.prompts.string import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
StringPromptTemplate,
|
||||
check_valid_template,
|
||||
get_template_variables,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
|
||||
class PromptTemplate(StringPromptTemplate):
|
||||
"""A prompt template for a language model.
|
||||
|
||||
A prompt template consists of a string template. It accepts a set of parameters
|
||||
from the user that can be used to generate a prompt for a language model.
|
||||
|
||||
The template can be formatted using either f-strings (default) or jinja2 syntax.
|
||||
|
||||
*Security warning*: Prefer using `template_format="f-string"` instead of
|
||||
`template_format="jinja2"`, or make sure to NEVER accept jinja2 templates
|
||||
from untrusted sources as they may lead to arbitrary Python code execution.
|
||||
|
||||
As of LangChain 0.0.329, Jinja2 templates will be rendered using
|
||||
Jinja2's SandboxedEnvironment by default. This sand-boxing should
|
||||
be treated as a best-effort approach rather than a guarantee of security,
|
||||
as it is an opt-out rather than opt-in approach.
|
||||
|
||||
Despite the sand-boxing, we recommend to never use jinja2 templates
|
||||
from untrusted sources.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
# Instantiation using from_template (recommended)
|
||||
prompt = PromptTemplate.from_template("Say {foo}")
|
||||
prompt.format(foo="bar")
|
||||
|
||||
# Instantiation using initializer
|
||||
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"template_format": self.template_format,
|
||||
}
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
|
||||
template: str
|
||||
"""The prompt template."""
|
||||
|
||||
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
def __add__(self, other: Any) -> PromptTemplate:
|
||||
"""Override the + operator to allow for combining prompt templates."""
|
||||
# Allow for easy combining
|
||||
if isinstance(other, PromptTemplate):
|
||||
if self.template_format != "f-string":
|
||||
raise ValueError(
|
||||
"Adding prompt templates only supported for f-strings."
|
||||
)
|
||||
if other.template_format != "f-string":
|
||||
raise ValueError(
|
||||
"Adding prompt templates only supported for f-strings."
|
||||
)
|
||||
input_variables = list(
|
||||
set(self.input_variables) | set(other.input_variables)
|
||||
)
|
||||
template = self.template + other.template
|
||||
# If any do not want to validate, then don't
|
||||
validate_template = self.validate_template and other.validate_template
|
||||
partial_variables = {k: v for k, v in self.partial_variables.items()}
|
||||
for k, v in other.partial_variables.items():
|
||||
if k in partial_variables:
|
||||
raise ValueError("Cannot have same variable partialed twice.")
|
||||
else:
|
||||
partial_variables[k] = v
|
||||
return PromptTemplate(
|
||||
template=template,
|
||||
input_variables=input_variables,
|
||||
partial_variables=partial_variables,
|
||||
template_format="f-string",
|
||||
validate_template=validate_template,
|
||||
)
|
||||
elif isinstance(other, str):
|
||||
prompt = PromptTemplate.from_template(other)
|
||||
return self + prompt
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
return "prompt"
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt.format(variable1="foo")
|
||||
"""
|
||||
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
||||
|
||||
@root_validator()
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Check that template and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
all_inputs = values["input_variables"] + list(values["partial_variables"])
|
||||
check_valid_template(
|
||||
values["template"], values["template_format"], all_inputs
|
||||
)
|
||||
elif values.get("template_format"):
|
||||
values["input_variables"] = [
|
||||
var
|
||||
for var in get_template_variables(
|
||||
values["template"], values["template_format"]
|
||||
)
|
||||
if var not in values["partial_variables"]
|
||||
]
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_examples(
|
||||
cls,
|
||||
examples: List[str],
|
||||
suffix: str,
|
||||
input_variables: List[str],
|
||||
example_separator: str = "\n\n",
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> PromptTemplate:
|
||||
"""Take examples in list format with prefix and suffix to create a prompt.
|
||||
|
||||
Intended to be used as a way to dynamically create a prompt from examples.
|
||||
|
||||
Args:
|
||||
examples: List of examples to use in the prompt.
|
||||
suffix: String to go after the list of examples. Should generally
|
||||
set up the user's input.
|
||||
input_variables: A list of variable names the final prompt template
|
||||
will expect.
|
||||
example_separator: The separator to use in between examples. Defaults
|
||||
to two new line characters.
|
||||
prefix: String that should go before any examples. Generally includes
|
||||
examples. Default to an empty string.
|
||||
|
||||
Returns:
|
||||
The final prompt generated.
|
||||
"""
|
||||
template = example_separator.join([prefix, *examples, suffix])
|
||||
return cls(input_variables=input_variables, template=template, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_file(
|
||||
cls, template_file: Union[str, Path], input_variables: List[str], **kwargs: Any
|
||||
) -> PromptTemplate:
|
||||
"""Load a prompt from a file.
|
||||
|
||||
Args:
|
||||
template_file: The path to the file containing the prompt template.
|
||||
input_variables: A list of variable names the final prompt template
|
||||
will expect.
|
||||
|
||||
Returns:
|
||||
The prompt loaded from the file.
|
||||
"""
|
||||
with open(str(template_file), "r") as f:
|
||||
template = f.read()
|
||||
return cls(input_variables=input_variables, template=template, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls,
|
||||
template: str,
|
||||
*,
|
||||
template_format: str = "f-string",
|
||||
partial_variables: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> PromptTemplate:
|
||||
"""Load a prompt template from a template.
|
||||
|
||||
*Security warning*: Prefer using `template_format="f-string"` instead of
|
||||
`template_format="jinja2"`, or make sure to NEVER accept jinja2 templates
|
||||
from untrusted sources as they may lead to arbitrary Python code execution.
|
||||
|
||||
As of LangChain 0.0.329, Jinja2 templates will be rendered using
|
||||
Jinja2's SandboxedEnvironment by default. This sand-boxing should
|
||||
be treated as a best-effort approach rather than a guarantee of security,
|
||||
as it is an opt-out rather than opt-in approach.
|
||||
|
||||
Despite the sand-boxing, we recommend to never use jinja2 templates
|
||||
from untrusted sources.
|
||||
|
||||
Args:
|
||||
template: The template to load.
|
||||
template_format: The format of the template. Use `jinja2` for jinja2,
|
||||
and `f-string` or None for f-strings.
|
||||
partial_variables: A dictionary of variables that can be used to partially
|
||||
fill in the template. For example, if the template is
|
||||
`"{variable1} {variable2}"`, and `partial_variables` is
|
||||
`{"variable1": "foo"}`, then the final prompt will be
|
||||
`"foo {variable2}"`.
|
||||
|
||||
Returns:
|
||||
The prompt template loaded from the template.
|
||||
"""
|
||||
|
||||
input_variables = get_template_variables(template, template_format)
|
||||
_partial_variables = partial_variables or {}
|
||||
|
||||
if _partial_variables:
|
||||
input_variables = [
|
||||
var for var in input_variables if var not in _partial_variables
|
||||
]
|
||||
|
||||
return cls(
|
||||
input_variables=input_variables,
|
||||
template=template,
|
||||
template_format=template_format,
|
||||
partial_variables=_partial_variables,
|
||||
**kwargs,
|
||||
)
|
||||
156
libs/core/langchain_core/prompts/string.py
Normal file
156
libs/core/langchain_core/prompts/string.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""BasePrompt schema definition."""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from string import Formatter
|
||||
from typing import Any, Callable, Dict, List, Set
|
||||
|
||||
from langchain_core.prompt_values import PromptValue, StringPromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.utils.formatting import formatter
|
||||
|
||||
|
||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
"""Format a template using jinja2.
|
||||
|
||||
*Security warning*: As of LangChain 0.0.329, this method uses Jinja2's
|
||||
SandboxedEnvironment by default. However, this sand-boxing should
|
||||
be treated as a best-effort approach rather than a guarantee of security.
|
||||
Do not accept jinja2 templates from untrusted sources as they may lead
|
||||
to arbitrary Python code execution.
|
||||
|
||||
https://jinja.palletsprojects.com/en/3.1.x/sandbox/
|
||||
"""
|
||||
try:
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
||||
"Please install it with `pip install jinja2`."
|
||||
"Please be cautious when using jinja2 templates. "
|
||||
"Do not expand jinja2 templates using unverified or user-controlled "
|
||||
"inputs as that can result in arbitrary Python code execution."
|
||||
)
|
||||
|
||||
# This uses a sandboxed environment to prevent arbitrary code execution.
|
||||
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
|
||||
# Please treat this sand-boxing as a best-effort approach rather than
|
||||
# a guarantee of security.
|
||||
# We recommend to never use jinja2 templates with untrusted inputs.
|
||||
# https://jinja.palletsprojects.com/en/3.1.x/sandbox/
|
||||
# approach not a guarantee of security.
|
||||
return SandboxedEnvironment().from_string(template).render(**kwargs)
|
||||
|
||||
|
||||
def validate_jinja2(template: str, input_variables: List[str]) -> None:
|
||||
"""
|
||||
Validate that the input variables are valid for the template.
|
||||
Issues a warning if missing or extra variables are found.
|
||||
|
||||
Args:
|
||||
template: The template string.
|
||||
input_variables: The input variables.
|
||||
"""
|
||||
input_variables_set = set(input_variables)
|
||||
valid_variables = _get_jinja2_variables_from_template(template)
|
||||
missing_variables = valid_variables - input_variables_set
|
||||
extra_variables = input_variables_set - valid_variables
|
||||
|
||||
warning_message = ""
|
||||
if missing_variables:
|
||||
warning_message += f"Missing variables: {missing_variables} "
|
||||
|
||||
if extra_variables:
|
||||
warning_message += f"Extra variables: {extra_variables}"
|
||||
|
||||
if warning_message:
|
||||
warnings.warn(warning_message.strip())
|
||||
|
||||
|
||||
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
||||
try:
|
||||
from jinja2 import Environment, meta
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
||||
"Please install it with `pip install jinja2`."
|
||||
)
|
||||
env = Environment()
|
||||
ast = env.parse(template)
|
||||
variables = meta.find_undeclared_variables(ast)
|
||||
return variables
|
||||
|
||||
|
||||
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
||||
"f-string": formatter.format,
|
||||
"jinja2": jinja2_formatter,
|
||||
}
|
||||
|
||||
DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
|
||||
"f-string": formatter.validate_input_variables,
|
||||
"jinja2": validate_jinja2,
|
||||
}
|
||||
|
||||
|
||||
def check_valid_template(
|
||||
template: str, template_format: str, input_variables: List[str]
|
||||
) -> None:
|
||||
"""Check that template string is valid.
|
||||
|
||||
Args:
|
||||
template: The template string.
|
||||
template_format: The template format. Should be one of "f-string" or "jinja2".
|
||||
input_variables: The input variables.
|
||||
|
||||
Raises:
|
||||
ValueError: If the template format is not supported.
|
||||
"""
|
||||
if template_format not in DEFAULT_FORMATTER_MAPPING:
|
||||
valid_formats = list(DEFAULT_FORMATTER_MAPPING)
|
||||
raise ValueError(
|
||||
f"Invalid template format. Got `{template_format}`;"
|
||||
f" should be one of {valid_formats}"
|
||||
)
|
||||
try:
|
||||
validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
|
||||
validator_func(template, input_variables)
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"Invalid prompt schema; check for mismatched or missing input parameters. "
|
||||
+ str(e)
|
||||
)
|
||||
|
||||
|
||||
def get_template_variables(template: str, template_format: str) -> List[str]:
|
||||
"""Get the variables from the template.
|
||||
|
||||
Args:
|
||||
template: The template string.
|
||||
template_format: The template format. Should be one of "f-string" or "jinja2".
|
||||
|
||||
Returns:
|
||||
The variables from the template.
|
||||
|
||||
Raises:
|
||||
ValueError: If the template format is not supported.
|
||||
"""
|
||||
if template_format == "jinja2":
|
||||
# Get the variables for the template
|
||||
input_variables = _get_jinja2_variables_from_template(template)
|
||||
elif template_format == "f-string":
|
||||
input_variables = {
|
||||
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported template format: {template_format}")
|
||||
|
||||
return sorted(input_variables)
|
||||
|
||||
|
||||
class StringPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""String prompt that exposes the format method, returning a prompt."""
|
||||
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""Create Chat Messages."""
|
||||
return StringPromptValue(text=self.format(**kwargs))
|
||||
23
libs/core/langchain_core/pydantic_v1/__init__.py
Normal file
23
libs/core/langchain_core/pydantic_v1/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from importlib import metadata
|
||||
|
||||
## Create namespaces for pydantic v1 and v2.
|
||||
# This code must stay at the top of the file before other modules may
|
||||
# attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules.
|
||||
#
|
||||
# This hack is done for the following reasons:
|
||||
# * Langchain will attempt to remain compatible with both pydantic v1 and v2 since
|
||||
# both dependencies and dependents may be stuck on either version of v1 or v2.
|
||||
# * Creating namespaces for pydantic v1 and v2 should allow us to write code that
|
||||
# unambiguously uses either v1 or v2 API.
|
||||
# * This change is easier to roll out and roll back.
|
||||
|
||||
try:
|
||||
from pydantic.v1 import * # noqa: F403 # type: ignore
|
||||
except ImportError:
|
||||
from pydantic import * # noqa: F403 # type: ignore
|
||||
|
||||
|
||||
try:
|
||||
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
|
||||
except metadata.PackageNotFoundError:
|
||||
_PYDANTIC_MAJOR_VERSION = 0
|
||||
4
libs/core/langchain_core/pydantic_v1/dataclasses.py
Normal file
4
libs/core/langchain_core/pydantic_v1/dataclasses.py
Normal file
@@ -0,0 +1,4 @@
|
||||
try:
|
||||
from pydantic.v1.dataclasses import * # noqa: F403
|
||||
except ImportError:
|
||||
from pydantic.dataclasses import * # noqa: F403
|
||||
4
libs/core/langchain_core/pydantic_v1/main.py
Normal file
4
libs/core/langchain_core/pydantic_v1/main.py
Normal file
@@ -0,0 +1,4 @@
|
||||
try:
|
||||
from pydantic.v1.main import * # noqa: F403
|
||||
except ImportError:
|
||||
from pydantic.main import * # noqa: F403
|
||||
275
libs/core/langchain_core/retrievers.py
Normal file
275
libs/core/langchain_core/retrievers.py
Normal file
@@ -0,0 +1,275 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
Callbacks,
|
||||
)
|
||||
|
||||
|
||||
class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
|
||||
"""Abstract base class for a Document retrieval system.
|
||||
|
||||
A retrieval system is defined as something that can take string queries and return
|
||||
the most 'relevant' Documents from some source.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class TFIDFRetriever(BaseRetriever, BaseModel):
|
||||
vectorizer: Any
|
||||
docs: List[Document]
|
||||
tfidf_array: Any
|
||||
k: int = 4
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
# Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
|
||||
query_vec = self.vectorizer.transform([query])
|
||||
# Op -- (n_docs,1) -- Cosine Sim with each doc
|
||||
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
|
||||
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
||||
""" # noqa: E501
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
_new_arg_supported: bool = False
|
||||
_expects_other_args: bool = False
|
||||
tags: Optional[List[str]] = None
|
||||
"""Optional list of tags associated with the retriever. Defaults to None
|
||||
These tags will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a retriever with its
|
||||
use case.
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
"""Optional metadata associated with the retriever. Defaults to None
|
||||
This metadata will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a retriever with its
|
||||
use case.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Version upgrade for old retrievers that implemented the public
|
||||
# methods directly.
|
||||
if cls.get_relevant_documents != BaseRetriever.get_relevant_documents:
|
||||
warnings.warn(
|
||||
"Retrievers must implement abstract `_get_relevant_documents` method"
|
||||
" instead of `get_relevant_documents`",
|
||||
DeprecationWarning,
|
||||
)
|
||||
swap = cls.get_relevant_documents
|
||||
cls.get_relevant_documents = ( # type: ignore[assignment]
|
||||
BaseRetriever.get_relevant_documents
|
||||
)
|
||||
cls._get_relevant_documents = swap # type: ignore[assignment]
|
||||
if (
|
||||
hasattr(cls, "aget_relevant_documents")
|
||||
and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents
|
||||
):
|
||||
warnings.warn(
|
||||
"Retrievers must implement abstract `_aget_relevant_documents` method"
|
||||
" instead of `aget_relevant_documents`",
|
||||
DeprecationWarning,
|
||||
)
|
||||
aswap = cls.aget_relevant_documents
|
||||
cls.aget_relevant_documents = ( # type: ignore[assignment]
|
||||
BaseRetriever.aget_relevant_documents
|
||||
)
|
||||
cls._aget_relevant_documents = aswap # type: ignore[assignment]
|
||||
parameters = signature(cls._get_relevant_documents).parameters
|
||||
cls._new_arg_supported = parameters.get("run_manager") is not None
|
||||
# If a V1 retriever broke the interface and expects additional arguments
|
||||
cls._expects_other_args = (
|
||||
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None
|
||||
) -> List[Document]:
|
||||
config = config or {}
|
||||
return self.get_relevant_documents(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: str,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Document]:
|
||||
config = config or {}
|
||||
return await self.aget_relevant_documents(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
Args:
|
||||
query: String to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: String to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._get_relevant_documents, run_manager=run_manager), query
|
||||
)
|
||||
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Retrieve documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
callbacks: Callback manager or list of callbacks
|
||||
tags: Optional list of tags associated with the retriever. Defaults to None
|
||||
These tags will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
metadata: Optional metadata associated with the retriever. Defaults to None
|
||||
This metadata will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
None,
|
||||
verbose=kwargs.get("verbose", False),
|
||||
inheritable_tags=tags,
|
||||
local_tags=self.tags,
|
||||
inheritable_metadata=metadata,
|
||||
local_metadata=self.metadata,
|
||||
)
|
||||
run_manager = callback_manager.on_retriever_start(
|
||||
dumpd(self),
|
||||
query,
|
||||
name=run_name,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
_kwargs = kwargs if self._expects_other_args else {}
|
||||
if self._new_arg_supported:
|
||||
result = self._get_relevant_documents(
|
||||
query, run_manager=run_manager, **_kwargs
|
||||
)
|
||||
else:
|
||||
result = self._get_relevant_documents(query, **_kwargs)
|
||||
except Exception as e:
|
||||
run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
callbacks: Callback manager or list of callbacks
|
||||
tags: Optional list of tags associated with the retriever. Defaults to None
|
||||
These tags will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
metadata: Optional metadata associated with the retriever. Defaults to None
|
||||
This metadata will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks,
|
||||
None,
|
||||
verbose=kwargs.get("verbose", False),
|
||||
inheritable_tags=tags,
|
||||
local_tags=self.tags,
|
||||
inheritable_metadata=metadata,
|
||||
local_metadata=self.metadata,
|
||||
)
|
||||
run_manager = await callback_manager.on_retriever_start(
|
||||
dumpd(self),
|
||||
query,
|
||||
name=run_name,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
_kwargs = kwargs if self._expects_other_args else {}
|
||||
if self._new_arg_supported:
|
||||
result = await self._aget_relevant_documents(
|
||||
query, run_manager=run_manager, **_kwargs
|
||||
)
|
||||
else:
|
||||
result = await self._aget_relevant_documents(query, **_kwargs)
|
||||
except Exception as e:
|
||||
await run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user