mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-10 03:00:59 +00:00
Compare commits
49 Commits
langchain-
...
langchain-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c453b76579 | ||
|
|
f087ab43fd | ||
|
|
409f35363b | ||
|
|
e8236e58f2 | ||
|
|
eef18dec44 | ||
|
|
311f861547 | ||
|
|
c77c28e631 | ||
|
|
7d49ee9741 | ||
|
|
28dd6564db | ||
|
|
f91bdd12d2 | ||
|
|
4d3d62c249 | ||
|
|
60dc19da30 | ||
|
|
55b641b761 | ||
|
|
37b72023fe | ||
|
|
3fc0ea510e | ||
|
|
a8561bc303 | ||
|
|
4e0a6ebe7d | ||
|
|
fd21ffe293 | ||
|
|
7835c0651f | ||
|
|
85caaa773f | ||
|
|
8fb643a6e8 | ||
|
|
03b9aca55d | ||
|
|
acbb4e4701 | ||
|
|
e0c36afc3e | ||
|
|
9909354cd0 | ||
|
|
84b831356c | ||
|
|
a47b332841 | ||
|
|
0f07cf61da | ||
|
|
d158401e73 | ||
|
|
de58942618 | ||
|
|
df38d5250f | ||
|
|
b246052184 | ||
|
|
52729ac0be | ||
|
|
f62d454f36 | ||
|
|
6fe2536c5a | ||
|
|
418b170f94 | ||
|
|
c3b3f46cb8 | ||
|
|
e2245fac82 | ||
|
|
1a8e9023de | ||
|
|
1a62f9850f | ||
|
|
6ed50e78c9 | ||
|
|
5ced41bf50 | ||
|
|
c6bdd6f482 | ||
|
|
3a99467ccb | ||
|
|
2ef4c9466f | ||
|
|
194adc485c | ||
|
|
97b05d70e6 | ||
|
|
e1d113ea84 | ||
|
|
7c05f71e0f |
27
.github/DISCUSSION_TEMPLATE/q-a.yml
vendored
27
.github/DISCUSSION_TEMPLATE/q-a.yml
vendored
@@ -96,22 +96,27 @@ body:
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Please share your system info with us. Do NOT skip this step and please don't trim
|
||||
the output. Most users don't include enough information here and it makes it harder
|
||||
for us to help you.
|
||||
Please share your system info with us.
|
||||
|
||||
Run the following command in your terminal and paste the output here:
|
||||
"pip freeze | grep langchain"
|
||||
platform (windows / linux / mac)
|
||||
python version
|
||||
|
||||
OR if you're on a recent version of langchain-core you can paste the output of:
|
||||
|
||||
python -m langchain_core.sys_info
|
||||
|
||||
or if you have an existing python interpreter running:
|
||||
|
||||
from langchain_core import sys_info
|
||||
sys_info.print_sys_info()
|
||||
|
||||
alternatively, put the entire output of `pip freeze` here.
|
||||
placeholder: |
|
||||
"pip freeze | grep langchain"
|
||||
platform
|
||||
python version
|
||||
|
||||
Alternatively, if you're on a recent version of langchain-core you can paste the output of:
|
||||
|
||||
python -m langchain_core.sys_info
|
||||
|
||||
These will only surface LangChain packages, don't forget to include any other relevant
|
||||
packages you're using (if you're not sure what's relevant, you can paste the entire output of `pip freeze`).
|
||||
validations:
|
||||
required: true
|
||||
|
||||
2
.github/workflows/_release.yml
vendored
2
.github/workflows/_release.yml
vendored
@@ -85,7 +85,7 @@ jobs:
|
||||
path: langchain
|
||||
sparse-checkout: | # this only grabs files for relevant dir
|
||||
${{ inputs.working-directory }}
|
||||
ref: master # this scopes to just master branch
|
||||
ref: ${{ github.ref }} # this scopes to just ref'd branch
|
||||
fetch-depth: 0 # this fetches entire commit history
|
||||
- name: Check Tags
|
||||
id: check-tags
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
:member-order: groupwise
|
||||
:show-inheritance: True
|
||||
:special-members: __call__
|
||||
:exclude-members: construct, copy, dict, from_orm, parse_file, parse_obj, parse_raw, schema, schema_json, update_forward_refs, validate, json, is_lc_serializable, to_json_not_implemented, lc_secrets, lc_attributes, lc_id, get_lc_namespace, astream_log, transform, atransform, get_output_schema, get_prompts, config_schema, map, pick, pipe, with_listeners, with_alisteners, with_config, with_fallbacks, with_types, with_retry, InputType, OutputType, config_specs, output_schema, get_input_schema, get_graph, get_name, input_schema, name, bind, assign, as_tool, get_config_jsonschema, get_input_jsonschema, get_output_jsonschema, model_construct, model_copy, model_dump, model_dump_json, model_parametrized_name, model_post_init, model_rebuild, model_validate, model_validate_json, model_validate_strings, to_json, model_extra, model_fields_set, model_json_schema
|
||||
:exclude-members: construct, copy, dict, from_orm, parse_file, parse_obj, parse_raw, schema, schema_json, update_forward_refs, validate, json, is_lc_serializable, to_json_not_implemented, lc_secrets, lc_attributes, lc_id, get_lc_namespace, astream_log, transform, atransform, get_output_schema, get_prompts, config_schema, map, pick, pipe, InputType, OutputType, config_specs, output_schema, get_input_schema, get_graph, get_name, input_schema, name, assign, as_tool, get_config_jsonschema, get_input_jsonschema, get_output_jsonschema, model_construct, model_copy, model_dump, model_dump_json, model_parametrized_name, model_post_init, model_rebuild, model_validate, model_validate_json, model_validate_strings, to_json, model_extra, model_fields_set, model_json_schema, predict, apredict, predict_messages, apredict_messages, generate, generate_prompt, agenerate, agenerate_prompt, call_as_llm
|
||||
|
||||
.. NOTE:: {{objname}} implements the standard :py:class:`Runnable Interface <langchain_core.runnables.base.Runnable>`. 🏃
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
"\n",
|
||||
"This sample demonstrates the use of `Amazon Textract` in combination with LangChain as a DocumentLoader.\n",
|
||||
"\n",
|
||||
"`Textract` supports`PDF`, `TIF`F, `PNG` and `JPEG` format.\n",
|
||||
"`Textract` supports`PDF`, `TIFF`, `PNG` and `JPEG` format.\n",
|
||||
"\n",
|
||||
"`Textract` supports these [document sizes, languages and characters](https://docs.aws.amazon.com/textract/latest/dg/limits-document.html)."
|
||||
]
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
"source": [
|
||||
"# Google Speech-to-Text Audio Transcripts\n",
|
||||
"\n",
|
||||
"The `GoogleSpeechToTextLoader` allows to transcribe audio files with the [Google Cloud Speech-to-Text API](https://cloud.google.com/speech-to-text) and loads the transcribed text into documents.\n",
|
||||
"The `SpeechToTextLoader` allows to transcribe audio files with the [Google Cloud Speech-to-Text API](https://cloud.google.com/speech-to-text) and loads the transcribed text into documents.\n",
|
||||
"\n",
|
||||
"To use it, you should have the `google-cloud-speech` python package installed, and a Google Cloud project with the [Speech-to-Text API enabled](https://cloud.google.com/speech-to-text/v2/docs/transcribe-client-libraries#before_you_begin).\n",
|
||||
"\n",
|
||||
@@ -41,7 +41,7 @@
|
||||
"source": [
|
||||
"## Example\n",
|
||||
"\n",
|
||||
"The `GoogleSpeechToTextLoader` must include the `project_id` and `file_path` arguments. Audio files can be specified as a Google Cloud Storage URI (`gs://...`) or a local file path.\n",
|
||||
"The `SpeechToTextLoader` must include the `project_id` and `file_path` arguments. Audio files can be specified as a Google Cloud Storage URI (`gs://...`) or a local file path.\n",
|
||||
"\n",
|
||||
"Only synchronous requests are supported by the loader, which has a [limit of 60 seconds or 10MB](https://cloud.google.com/speech-to-text/v2/docs/sync-recognize#:~:text=60%20seconds%20and/or%2010%20MB) per audio file."
|
||||
]
|
||||
@@ -52,13 +52,13 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_google_community import GoogleSpeechToTextLoader\n",
|
||||
"from langchain_google_community import SpeechToTextLoader\n",
|
||||
"\n",
|
||||
"project_id = \"<PROJECT_ID>\"\n",
|
||||
"file_path = \"gs://cloud-samples-data/speech/audio.flac\"\n",
|
||||
"# or a local file path: file_path = \"./audio.wav\"\n",
|
||||
"\n",
|
||||
"loader = GoogleSpeechToTextLoader(project_id=project_id, file_path=file_path)\n",
|
||||
"loader = SpeechToTextLoader(project_id=project_id, file_path=file_path)\n",
|
||||
"\n",
|
||||
"docs = loader.load()"
|
||||
]
|
||||
@@ -152,7 +152,7 @@
|
||||
" RecognitionConfig,\n",
|
||||
" RecognitionFeatures,\n",
|
||||
")\n",
|
||||
"from langchain_google_community import GoogleSpeechToTextLoader\n",
|
||||
"from langchain_google_community import SpeechToTextLoader\n",
|
||||
"\n",
|
||||
"project_id = \"<PROJECT_ID>\"\n",
|
||||
"location = \"global\"\n",
|
||||
@@ -171,7 +171,7 @@
|
||||
" ),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"loader = GoogleSpeechToTextLoader(\n",
|
||||
"loader = SpeechToTextLoader(\n",
|
||||
" project_id=project_id,\n",
|
||||
" location=location,\n",
|
||||
" recognizer_id=recognizer_id,\n",
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
"\n",
|
||||
"| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/document_loaders/file_loaders/unstructured/)|\n",
|
||||
"| :--- | :--- | :---: | :---: | :---: |\n",
|
||||
"| [UnstructuredLoader](https://python.langchain.com/api_reference/unstructured/document_loaders/langchain_unstructured.document_loaders.UnstructuredLoader.html) | [langchain_community](https://python.langchain.com/api_reference/unstructured/index.html) | ✅ | ❌ | ✅ | \n",
|
||||
"| [UnstructuredLoader](https://python.langchain.com/api_reference/unstructured/document_loaders/langchain_unstructured.document_loaders.UnstructuredLoader.html) | [langchain_unstructured](https://python.langchain.com/api_reference/unstructured/index.html) | ✅ | ❌ | ✅ | \n",
|
||||
"### Loader features\n",
|
||||
"| Source | Document Lazy Loading | Native Async Support\n",
|
||||
"| :---: | :---: | :---: | \n",
|
||||
@@ -519,6 +519,47 @@
|
||||
"print(\"Length of text in the document:\", len(docs[0].page_content))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3ec3c22d-02cd-498b-921f-b839d1404f32",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Loading web pages\n",
|
||||
"\n",
|
||||
"`UnstructuredLoader` accepts a `web_url` kwarg when run locally that populates the `url` parameter of the underlying Unstructured [partition](https://docs.unstructured.io/open-source/core-functionality/partitioning). This allows for the parsing of remotely hosted documents, such as HTML web pages.\n",
|
||||
"\n",
|
||||
"Example usage:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "bf9a8546-659d-4861-bff2-fdf1ad93ac65",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"page_content='Example Domain' metadata={'category_depth': 0, 'languages': ['eng'], 'filetype': 'text/html', 'url': 'https://www.example.com', 'category': 'Title', 'element_id': 'fdaa78d856f9d143aeeed85bf23f58f8'}\n",
|
||||
"\n",
|
||||
"page_content='This domain is for use in illustrative examples in documents. You may use this domain in literature without prior coordination or asking for permission.' metadata={'languages': ['eng'], 'parent_id': 'fdaa78d856f9d143aeeed85bf23f58f8', 'filetype': 'text/html', 'url': 'https://www.example.com', 'category': 'NarrativeText', 'element_id': '3652b8458b0688639f973fe36253c992'}\n",
|
||||
"\n",
|
||||
"page_content='More information...' metadata={'category_depth': 0, 'link_texts': ['More information...'], 'link_urls': ['https://www.iana.org/domains/example'], 'languages': ['eng'], 'filetype': 'text/html', 'url': 'https://www.example.com', 'category': 'Title', 'element_id': '793ab98565d6f6d6f3a6d614e3ace2a9'}\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_unstructured import UnstructuredLoader\n",
|
||||
"\n",
|
||||
"loader = UnstructuredLoader(web_url=\"https://www.example.com\")\n",
|
||||
"docs = loader.load()\n",
|
||||
"\n",
|
||||
"for doc in docs:\n",
|
||||
" print(f\"{doc}\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ce01aa40",
|
||||
@@ -546,7 +587,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
"version": "3.10.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -6,129 +6,11 @@
|
||||
"source": [
|
||||
"# SambaNova\n",
|
||||
"\n",
|
||||
"**[SambaNova](https://sambanova.ai/)'s** [Sambaverse](https://sambaverse.sambanova.ai/) and [Sambastudio](https://sambanova.ai/technology/full-stack-ai-platform) are platforms for running your own open-source models\n",
|
||||
"**[SambaNova](https://sambanova.ai/)'s** [Sambastudio](https://sambanova.ai/technology/full-stack-ai-platform) is a platform for running your own open-source models\n",
|
||||
"\n",
|
||||
"This example goes over how to use LangChain to interact with SambaNova models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Sambaverse"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Sambaverse** allows you to interact with multiple open-source models. You can view the list of available models and interact with them in the [playground](https://sambaverse.sambanova.ai/playground).\n",
|
||||
" **Please note that Sambaverse's free offering is performance-limited.** Companies that are ready to evaluate the production tokens-per-second performance, volume throughput, and 10x lower total cost of ownership (TCO) of SambaNova should [contact us](https://sambaverse.sambanova.ai/contact-us) for a non-limited evaluation instance."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"An API key is required to access Sambaverse models. To get a key, create an account at [sambaverse.sambanova.ai](https://sambaverse.sambanova.ai/)\n",
|
||||
"\n",
|
||||
"The [sseclient-py](https://pypi.org/project/sseclient-py/) package is required to run streaming predictions "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install --quiet sseclient-py==1.8.0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Register your API key as an environment variable:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"sambaverse_api_key = \"<Your sambaverse API key>\"\n",
|
||||
"\n",
|
||||
"# Set the environment variables\n",
|
||||
"os.environ[\"SAMBAVERSE_API_KEY\"] = sambaverse_api_key"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Call Sambaverse models directly from LangChain!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms.sambanova import Sambaverse\n",
|
||||
"\n",
|
||||
"llm = Sambaverse(\n",
|
||||
" sambaverse_model_name=\"Meta/llama-2-7b-chat-hf\",\n",
|
||||
" streaming=False,\n",
|
||||
" model_kwargs={\n",
|
||||
" \"do_sample\": True,\n",
|
||||
" \"max_tokens_to_generate\": 1000,\n",
|
||||
" \"temperature\": 0.01,\n",
|
||||
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
|
||||
" \"process_prompt\": False,\n",
|
||||
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
|
||||
" # \"repetition_penalty\": 1.0,\n",
|
||||
" # \"top_k\": 50,\n",
|
||||
" # \"top_p\": 1.0\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(llm.invoke(\"Why should I use open source models?\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Streaming response\n",
|
||||
"\n",
|
||||
"from langchain_community.llms.sambanova import Sambaverse\n",
|
||||
"\n",
|
||||
"llm = Sambaverse(\n",
|
||||
" sambaverse_model_name=\"Meta/llama-2-7b-chat-hf\",\n",
|
||||
" streaming=True,\n",
|
||||
" model_kwargs={\n",
|
||||
" \"do_sample\": True,\n",
|
||||
" \"max_tokens_to_generate\": 1000,\n",
|
||||
" \"temperature\": 0.01,\n",
|
||||
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
|
||||
" \"process_prompt\": False,\n",
|
||||
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
|
||||
" # \"repetition_penalty\": 1.0,\n",
|
||||
" # \"top_k\": 50,\n",
|
||||
" # \"top_p\": 1.0\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"for chunk in llm.stream(\"Why should I use open source models?\"):\n",
|
||||
" print(chunk, end=\"\", flush=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
# MLflow Deployments for LLMs
|
||||
# MLflow AI Gateway for LLMs
|
||||
|
||||
>[The MLflow Deployments for LLMs](https://www.mlflow.org/docs/latest/llms/deployments/index.html) is a powerful tool designed to streamline the usage and management of various large
|
||||
>[The MLflow AI Gateway for LLMs](https://www.mlflow.org/docs/latest/llms/deployments/index.html) is a powerful tool designed to streamline the usage and management of various large
|
||||
> language model (LLM) providers, such as OpenAI and Anthropic, within an organization. It offers a high-level interface
|
||||
> that simplifies the interaction with these services by providing a unified endpoint to handle specific LLM related requests.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
Install `mlflow` with MLflow Deployments dependencies:
|
||||
Install `mlflow` with MLflow GenAI dependencies:
|
||||
|
||||
```sh
|
||||
pip install 'mlflow[genai]'
|
||||
@@ -39,10 +39,10 @@ endpoints:
|
||||
openai_api_key: $OPENAI_API_KEY
|
||||
```
|
||||
|
||||
Start the deployments server:
|
||||
Start the gateway server:
|
||||
|
||||
```sh
|
||||
mlflow deployments start-server --config-path /path/to/config.yaml
|
||||
mlflow gateway start --config-path /path/to/config.yaml
|
||||
```
|
||||
|
||||
## Example provided by `MLflow`
|
||||
|
||||
@@ -1,160 +0,0 @@
|
||||
# MLflow AI Gateway
|
||||
|
||||
:::warning
|
||||
|
||||
MLflow AI Gateway has been deprecated. Please use [MLflow Deployments for LLMs](/docs/integrations/providers/mlflow/) instead.
|
||||
|
||||
:::
|
||||
|
||||
>[The MLflow AI Gateway](https://www.mlflow.org/docs/latest/index.html) service is a powerful tool designed to streamline the usage and management of various large
|
||||
> language model (LLM) providers, such as OpenAI and Anthropic, within an organization. It offers a high-level interface
|
||||
> that simplifies the interaction with these services by providing a unified endpoint to handle specific LLM related requests.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
Install `mlflow` with MLflow AI Gateway dependencies:
|
||||
|
||||
```sh
|
||||
pip install 'mlflow[gateway]'
|
||||
```
|
||||
|
||||
Set the OpenAI API key as an environment variable:
|
||||
|
||||
```sh
|
||||
export OPENAI_API_KEY=...
|
||||
```
|
||||
|
||||
Create a configuration file:
|
||||
|
||||
```yaml
|
||||
routes:
|
||||
- name: completions
|
||||
route_type: llm/v1/completions
|
||||
model:
|
||||
provider: openai
|
||||
name: text-davinci-003
|
||||
config:
|
||||
openai_api_key: $OPENAI_API_KEY
|
||||
|
||||
- name: embeddings
|
||||
route_type: llm/v1/embeddings
|
||||
model:
|
||||
provider: openai
|
||||
name: text-embedding-ada-002
|
||||
config:
|
||||
openai_api_key: $OPENAI_API_KEY
|
||||
```
|
||||
|
||||
Start the Gateway server:
|
||||
|
||||
```sh
|
||||
mlflow gateway start --config-path /path/to/config.yaml
|
||||
```
|
||||
|
||||
## Example provided by `MLflow`
|
||||
|
||||
>The `mlflow.langchain` module provides an API for logging and loading `LangChain` models.
|
||||
> This module exports multivariate LangChain models in the langchain flavor and univariate LangChain
|
||||
> models in the pyfunc flavor.
|
||||
|
||||
See the [API documentation and examples](https://www.mlflow.org/docs/latest/python_api/mlflow.langchain.html?highlight=langchain#module-mlflow.langchain).
|
||||
|
||||
|
||||
|
||||
## Completions Example
|
||||
|
||||
```python
|
||||
import mlflow
|
||||
from langchain.chains import LLMChain, PromptTemplate
|
||||
from langchain_community.llms import MlflowAIGateway
|
||||
|
||||
gateway = MlflowAIGateway(
|
||||
gateway_uri="http://127.0.0.1:5000",
|
||||
route="completions",
|
||||
params={
|
||||
"temperature": 0.0,
|
||||
"top_p": 0.1,
|
||||
},
|
||||
)
|
||||
|
||||
llm_chain = LLMChain(
|
||||
llm=gateway,
|
||||
prompt=PromptTemplate(
|
||||
input_variables=["adjective"],
|
||||
template="Tell me a {adjective} joke",
|
||||
),
|
||||
)
|
||||
result = llm_chain.run(adjective="funny")
|
||||
print(result)
|
||||
|
||||
with mlflow.start_run():
|
||||
model_info = mlflow.langchain.log_model(chain, "model")
|
||||
|
||||
model = mlflow.pyfunc.load_model(model_info.model_uri)
|
||||
print(model.predict([{"adjective": "funny"}]))
|
||||
```
|
||||
|
||||
## Embeddings Example
|
||||
|
||||
```python
|
||||
from langchain_community.embeddings import MlflowAIGatewayEmbeddings
|
||||
|
||||
embeddings = MlflowAIGatewayEmbeddings(
|
||||
gateway_uri="http://127.0.0.1:5000",
|
||||
route="embeddings",
|
||||
)
|
||||
|
||||
print(embeddings.embed_query("hello"))
|
||||
print(embeddings.embed_documents(["hello"]))
|
||||
```
|
||||
|
||||
## Chat Example
|
||||
|
||||
```python
|
||||
from langchain_community.chat_models import ChatMLflowAIGateway
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
chat = ChatMLflowAIGateway(
|
||||
gateway_uri="http://127.0.0.1:5000",
|
||||
route="chat",
|
||||
params={
|
||||
"temperature": 0.1
|
||||
}
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content="You are a helpful assistant that translates English to French."
|
||||
),
|
||||
HumanMessage(
|
||||
content="Translate this sentence from English to French: I love programming."
|
||||
),
|
||||
]
|
||||
print(chat(messages))
|
||||
```
|
||||
|
||||
## Databricks MLflow AI Gateway
|
||||
|
||||
Databricks MLflow AI Gateway is in private preview.
|
||||
Please contact a Databricks representative to enroll in the preview.
|
||||
|
||||
```python
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_community.llms import MlflowAIGateway
|
||||
|
||||
gateway = MlflowAIGateway(
|
||||
gateway_uri="databricks",
|
||||
route="completions",
|
||||
)
|
||||
|
||||
llm_chain = LLMChain(
|
||||
llm=gateway,
|
||||
prompt=PromptTemplate(
|
||||
input_variables=["adjective"],
|
||||
template="Tell me a {adjective} joke",
|
||||
),
|
||||
)
|
||||
result = llm_chain.run(adjective="funny")
|
||||
print(result)
|
||||
```
|
||||
File diff suppressed because one or more lines are too long
@@ -400,18 +400,29 @@
|
||||
"def hybrid_query(search_query: str) -> Dict:\n",
|
||||
" vector = embeddings.embed_query(search_query) # same embeddings as for indexing\n",
|
||||
" return {\n",
|
||||
" \"query\": {\n",
|
||||
" \"match\": {\n",
|
||||
" text_field: search_query,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"knn\": {\n",
|
||||
" \"field\": dense_vector_field,\n",
|
||||
" \"query_vector\": vector,\n",
|
||||
" \"k\": 5,\n",
|
||||
" \"num_candidates\": 10,\n",
|
||||
" },\n",
|
||||
" \"rank\": {\"rrf\": {}},\n",
|
||||
" \"retriever\": {\n",
|
||||
" \"rrf\": {\n",
|
||||
" \"retrievers\": [\n",
|
||||
" {\n",
|
||||
" \"standard\": {\n",
|
||||
" \"query\": {\n",
|
||||
" \"match\": {\n",
|
||||
" text_field: search_query,\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"knn\": {\n",
|
||||
" \"field\": dense_vector_field,\n",
|
||||
" \"query_vector\": vector,\n",
|
||||
" \"k\": 5,\n",
|
||||
" \"num_candidates\": 10,\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
||||
@@ -380,7 +380,7 @@
|
||||
"source": [
|
||||
"## API reference\n",
|
||||
"\n",
|
||||
"For detailed documentation of all `AstraDBVectorStore` features and configurations head to the API reference:https://python.langchain.com/api_reference/community/vectorstores/langchain_community.vectorstores.clickhouse.Clickhouse.html"
|
||||
"For detailed documentation of all `Clickhouse` features and configurations head to the API reference:https://python.langchain.com/api_reference/community/vectorstores/langchain_community.vectorstores.clickhouse.Clickhouse.html"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@@ -90,7 +90,7 @@
|
||||
"source": [
|
||||
" </TabItem>\n",
|
||||
" <TabItem value=\"conda\" label=\"Conda\">\n",
|
||||
" <CodeBlock language=\"bash\">conda install langchain langchain_community langchain_chroma -c conda-forge</CodeBlock>\n",
|
||||
" <CodeBlock language=\"bash\">conda install langchain langchain-community langchain-chroma -c conda-forge</CodeBlock>\n",
|
||||
" </TabItem>\n",
|
||||
"</Tabs>\n",
|
||||
"\n",
|
||||
|
||||
@@ -168,52 +168,43 @@ const config = {
|
||||
label: "Integrations",
|
||||
},
|
||||
{
|
||||
type: "dropdown",
|
||||
label: "API reference",
|
||||
position: "left",
|
||||
items: [
|
||||
{
|
||||
label: "Latest",
|
||||
to: "https://python.langchain.com/api_reference/reference.html",
|
||||
},
|
||||
{
|
||||
label: "Legacy",
|
||||
href: "https://api.python.langchain.com/"
|
||||
}
|
||||
]
|
||||
label: "API Reference",
|
||||
to: "https://python.langchain.com/api_reference/",
|
||||
},
|
||||
{
|
||||
type: "dropdown",
|
||||
label: "More",
|
||||
position: "left",
|
||||
items: [
|
||||
{
|
||||
type: "doc",
|
||||
docId: "people",
|
||||
label: "People",
|
||||
},
|
||||
{
|
||||
type: "doc",
|
||||
docId: "contributing/index",
|
||||
label: "Contributing",
|
||||
},
|
||||
{
|
||||
label: "Cookbooks",
|
||||
href: "https://github.com/langchain-ai/langchain/blob/master/cookbook/README.md"
|
||||
},
|
||||
{
|
||||
type: "doc",
|
||||
docId: "additional_resources/tutorials",
|
||||
label: "3rd party tutorials"
|
||||
docId: "people",
|
||||
label: "People",
|
||||
},
|
||||
{
|
||||
type: "doc",
|
||||
docId: "additional_resources/youtube",
|
||||
label: "YouTube"
|
||||
type: 'html',
|
||||
value: '<hr class="dropdown-separator" style="margin-top: 0.5rem; margin-bottom: 0.5rem">',
|
||||
},
|
||||
{
|
||||
to: "/docs/additional_resources/arxiv_references",
|
||||
label: "arXiv"
|
||||
href: "https://docs.smith.langchain.com",
|
||||
label: "LangSmith",
|
||||
},
|
||||
{
|
||||
href: "https://langchain-ai.github.io/langgraph/",
|
||||
label: "LangGraph",
|
||||
},
|
||||
{
|
||||
href: "https://smith.langchain.com/hub",
|
||||
label: "LangChain Hub",
|
||||
},
|
||||
{
|
||||
href: "https://js.langchain.com",
|
||||
label: "LangChain JS/TS",
|
||||
},
|
||||
]
|
||||
},
|
||||
@@ -237,30 +228,7 @@ const config = {
|
||||
]
|
||||
},
|
||||
{
|
||||
type: "dropdown",
|
||||
label: "🦜️🔗",
|
||||
position: "right",
|
||||
items: [
|
||||
{
|
||||
href: "https://smith.langchain.com",
|
||||
label: "LangSmith",
|
||||
},
|
||||
{
|
||||
href: "https://docs.smith.langchain.com/",
|
||||
label: "LangSmith Docs",
|
||||
},
|
||||
{
|
||||
href: "https://smith.langchain.com/hub",
|
||||
label: "LangChain Hub",
|
||||
},
|
||||
{
|
||||
href: "https://js.langchain.com",
|
||||
label: "JS/TS Docs",
|
||||
},
|
||||
]
|
||||
},
|
||||
{
|
||||
href: "https://chat.langchain.com",
|
||||
to: "https://chat.langchain.com",
|
||||
label: "💬",
|
||||
position: "right",
|
||||
},
|
||||
|
||||
@@ -38,6 +38,9 @@
|
||||
--ifm-menu-link-padding-horizontal: 0.5rem;
|
||||
--ifm-menu-link-padding-vertical: 0.5rem;
|
||||
--doc-sidebar-width: 275px !important;
|
||||
|
||||
/* Code block syntax highlighting */
|
||||
--docusaurus-highlighted-code-line-bg: rgb(176, 227, 199);
|
||||
}
|
||||
|
||||
/* For readability concerns, you should choose a lighter palette in dark mode. */
|
||||
@@ -49,6 +52,9 @@
|
||||
--ifm-color-primary-light: #29d5b0;
|
||||
--ifm-color-primary-lighter: #32d8b4;
|
||||
--ifm-color-primary-lightest: #4fddbf;
|
||||
|
||||
/* Code block syntax highlighting */
|
||||
--docusaurus-highlighted-code-line-bg: rgb(14, 73, 60);
|
||||
}
|
||||
|
||||
nav, h1, h2, h3, h4 {
|
||||
|
||||
@@ -354,7 +354,7 @@ const FEATURE_TABLES = {
|
||||
},
|
||||
{
|
||||
name: "Nomic",
|
||||
link: "cohere",
|
||||
link: "nomic",
|
||||
package: "langchain-nomic",
|
||||
apiLink: "https://python.langchain.com/api_reference/nomic/embeddings/langchain_nomic.embeddings.NomicEmbeddings.html"
|
||||
},
|
||||
@@ -886,7 +886,7 @@ const FEATURE_TABLES = {
|
||||
apiLink: "https://python.langchain.com/api_reference/community/document_loaders/langchain_community.document_loaders.html_bs.BSHTMLLoader.html"
|
||||
},
|
||||
{
|
||||
name: "UnstrucutredXMLLoader",
|
||||
name: "UnstructuredXMLLoader",
|
||||
link: "xml",
|
||||
source: "XML files",
|
||||
apiLink: "https://python.langchain.com/api_reference/community/document_loaders/langchain_community.document_loaders.xml.UnstructuredXMLLoader.html"
|
||||
|
||||
@@ -26,6 +26,10 @@
|
||||
}
|
||||
],
|
||||
"redirects": [
|
||||
{
|
||||
"source": "/v0.3/docs/:path(.*/?)*",
|
||||
"destination": "/docs/:path*"
|
||||
},
|
||||
{
|
||||
"source": "/docs/modules/agents/tools/custom_tools(/?)",
|
||||
"destination": "/docs/how_to/custom_tools/"
|
||||
@@ -73,6 +77,10 @@
|
||||
{
|
||||
"source": "/v0.2/docs/templates/:path(.*/?)*",
|
||||
"destination": "https://github.com/langchain-ai/langchain/tree/master/templates/:path*"
|
||||
},
|
||||
{
|
||||
"source": "/docs/integrations/providers/mlflow_ai_gateway(/?)",
|
||||
"destination": "/docs/integrations/providers/mlflow/"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ LangChain Community contains third-party integrations that implement the base in
|
||||
|
||||
For full documentation see the [API reference](https://api.python.langchain.com/en/stable/community_api_reference.html).
|
||||
|
||||

|
||||

|
||||
|
||||
## 📕 Releases & Versioning
|
||||
|
||||
|
||||
@@ -8,6 +8,18 @@ from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
|
||||
MODEL_COST_PER_1K_TOKENS = {
|
||||
# OpenAI o1-preview input
|
||||
"o1-preview": 0.015,
|
||||
"o1-preview-2024-09-12": 0.015,
|
||||
# OpenAI o1-preview output
|
||||
"o1-preview-completion": 0.06,
|
||||
"o1-preview-2024-09-12-completion": 0.06,
|
||||
# OpenAI o1-mini input
|
||||
"o1-mini": 0.003,
|
||||
"o1-mini-2024-09-12": 0.003,
|
||||
# OpenAI o1-mini output
|
||||
"o1-mini-completion": 0.012,
|
||||
"o1-mini-2024-09-12-completion": 0.012,
|
||||
# GPT-4o-mini input
|
||||
"gpt-4o-mini": 0.00015,
|
||||
"gpt-4o-mini-2024-07-18": 0.00015,
|
||||
@@ -153,6 +165,7 @@ def standardize_model_name(
|
||||
model_name.startswith("gpt-4")
|
||||
or model_name.startswith("gpt-3.5")
|
||||
or model_name.startswith("gpt-35")
|
||||
or model_name.startswith("o1-")
|
||||
or ("finetuned" in model_name and "legacy" not in model_name)
|
||||
):
|
||||
return model_name + "-completion"
|
||||
|
||||
@@ -53,13 +53,15 @@ class LLMThoughtLabeler:
|
||||
labeling logic.
|
||||
"""
|
||||
|
||||
def get_initial_label(self) -> str:
|
||||
@staticmethod
|
||||
def get_initial_label() -> str:
|
||||
"""Return the markdown label for a new LLMThought that doesn't have
|
||||
an associated tool yet.
|
||||
"""
|
||||
return f"{THINKING_EMOJI} **Thinking...**"
|
||||
|
||||
def get_tool_label(self, tool: ToolRecord, is_complete: bool) -> str:
|
||||
@staticmethod
|
||||
def get_tool_label(tool: ToolRecord, is_complete: bool) -> str:
|
||||
"""Return the label for an LLMThought that has an associated
|
||||
tool.
|
||||
|
||||
@@ -91,13 +93,15 @@ class LLMThoughtLabeler:
|
||||
label = f"{emoji} **{name}:** {input}"
|
||||
return label
|
||||
|
||||
def get_history_label(self) -> str:
|
||||
@staticmethod
|
||||
def get_history_label() -> str:
|
||||
"""Return a markdown label for the special 'history' container
|
||||
that contains overflow thoughts.
|
||||
"""
|
||||
return f"{HISTORY_EMOJI} **History**"
|
||||
|
||||
def get_final_agent_thought_label(self) -> str:
|
||||
@staticmethod
|
||||
def get_final_agent_thought_label() -> str:
|
||||
"""Return the markdown label for the agent's final thought -
|
||||
the "Now I have the answer" thought, that doesn't involve
|
||||
a tool.
|
||||
|
||||
@@ -359,6 +359,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from langchain_community.document_loaders.pebblo import (
|
||||
PebbloSafeLoader,
|
||||
PebbloTextLoader,
|
||||
)
|
||||
from langchain_community.document_loaders.polars_dataframe import (
|
||||
PolarsDataFrameLoader,
|
||||
@@ -650,6 +651,7 @@ _module_lookup = {
|
||||
"PDFPlumberLoader": "langchain_community.document_loaders.pdf",
|
||||
"PagedPDFSplitter": "langchain_community.document_loaders.pdf",
|
||||
"PebbloSafeLoader": "langchain_community.document_loaders.pebblo",
|
||||
"PebbloTextLoader": "langchain_community.document_loaders.pebblo",
|
||||
"PlaywrightURLLoader": "langchain_community.document_loaders.url_playwright",
|
||||
"PolarsDataFrameLoader": "langchain_community.document_loaders.polars_dataframe",
|
||||
"PsychicLoader": "langchain_community.document_loaders.psychic",
|
||||
@@ -855,6 +857,7 @@ __all__ = [
|
||||
"PDFPlumberLoader",
|
||||
"PagedPDFSplitter",
|
||||
"PebbloSafeLoader",
|
||||
"PebbloTextLoader",
|
||||
"PlaywrightURLLoader",
|
||||
"PolarsDataFrameLoader",
|
||||
"PsychicLoader",
|
||||
|
||||
@@ -267,6 +267,7 @@ class PyMuPDFParser(BaseBlobParser):
|
||||
|
||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]: # type: ignore[valid-type]
|
||||
"""Lazily parse the blob."""
|
||||
|
||||
import fitz
|
||||
|
||||
with blob.as_bytes_io() as file_path: # type: ignore[attr-defined]
|
||||
@@ -277,25 +278,49 @@ class PyMuPDFParser(BaseBlobParser):
|
||||
|
||||
yield from [
|
||||
Document(
|
||||
page_content=page.get_text(**self.text_kwargs)
|
||||
+ self._extract_images_from_page(doc, page),
|
||||
metadata=dict(
|
||||
{
|
||||
"source": blob.source, # type: ignore[attr-defined]
|
||||
"file_path": blob.source, # type: ignore[attr-defined]
|
||||
"page": page.number,
|
||||
"total_pages": len(doc),
|
||||
},
|
||||
**{
|
||||
k: doc.metadata[k]
|
||||
for k in doc.metadata
|
||||
if type(doc.metadata[k]) in [str, int]
|
||||
},
|
||||
),
|
||||
page_content=self._get_page_content(doc, page, blob),
|
||||
metadata=self._extract_metadata(doc, page, blob),
|
||||
)
|
||||
for page in doc
|
||||
]
|
||||
|
||||
def _get_page_content(
|
||||
self, doc: fitz.fitz.Document, page: fitz.fitz.Page, blob: Blob
|
||||
) -> str:
|
||||
"""
|
||||
Get the text of the page using PyMuPDF and RapidOCR and issue a warning
|
||||
if it is empty.
|
||||
"""
|
||||
content = page.get_text(**self.text_kwargs) + self._extract_images_from_page(
|
||||
doc, page
|
||||
)
|
||||
|
||||
if not content:
|
||||
warnings.warn(
|
||||
f"Warning: Empty content on page "
|
||||
f"{page.number} of document {blob.source}"
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
def _extract_metadata(
|
||||
self, doc: fitz.fitz.Document, page: fitz.fitz.Page, blob: Blob
|
||||
) -> dict:
|
||||
"""Extract metadata from the document and page."""
|
||||
return dict(
|
||||
{
|
||||
"source": blob.source, # type: ignore[attr-defined]
|
||||
"file_path": blob.source, # type: ignore[attr-defined]
|
||||
"page": page.number,
|
||||
"total_pages": len(doc),
|
||||
},
|
||||
**{
|
||||
k: doc.metadata[k]
|
||||
for k in doc.metadata
|
||||
if isinstance(doc.metadata[k], (str, int))
|
||||
},
|
||||
)
|
||||
|
||||
def _extract_images_from_page(
|
||||
self, doc: fitz.fitz.Document, page: fitz.fitz.Page
|
||||
) -> str:
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from importlib.metadata import version
|
||||
from typing import Dict, Iterator, List, Optional
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@@ -271,3 +271,67 @@ class PebbloSafeLoader(BaseLoader):
|
||||
doc_metadata["pb_checksum"] = classified_docs.get(doc.pb_id, {}).get(
|
||||
"pb_checksum", None
|
||||
)
|
||||
|
||||
|
||||
class PebbloTextLoader(BaseLoader):
|
||||
"""
|
||||
Loader for text data.
|
||||
|
||||
Since PebbloSafeLoader is a wrapper around document loaders, this loader is
|
||||
used to load text data directly into Documents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
*,
|
||||
source: Optional[str] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
texts: Iterable of text data.
|
||||
source: Source of the text data.
|
||||
Optional. Defaults to None.
|
||||
ids: List of unique identifiers for each text.
|
||||
Optional. Defaults to None.
|
||||
metadata: Metadata for all texts.
|
||||
Optional. Defaults to None.
|
||||
metadatas: List of metadata for each text.
|
||||
Optional. Defaults to None.
|
||||
"""
|
||||
self.texts = texts
|
||||
self.source = source
|
||||
self.ids = ids
|
||||
self.metadata = metadata
|
||||
self.metadatas = metadatas
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""
|
||||
Lazy load text data into Documents.
|
||||
|
||||
Returns:
|
||||
Iterator of Documents
|
||||
"""
|
||||
for i, text in enumerate(self.texts):
|
||||
_id = None
|
||||
metadata = self.metadata or {}
|
||||
if self.metadatas and i < len(self.metadatas) and self.metadatas[i]:
|
||||
metadata.update(self.metadatas[i])
|
||||
if self.ids and i < len(self.ids):
|
||||
_id = self.ids[i]
|
||||
yield Document(id=_id, page_content=text, metadata=metadata)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""
|
||||
Load text data into Documents.
|
||||
|
||||
Returns:
|
||||
List of Documents
|
||||
"""
|
||||
documents = []
|
||||
for doc in self.lazy_load():
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
@@ -227,7 +227,7 @@ class RecursiveUrlLoader(BaseLoader):
|
||||
"https://docs.python.org/3.9/",
|
||||
prevent_outside=True,
|
||||
base_url="https://docs.python.org",
|
||||
link_regex=r'<a\s+(?:[^>]*?\s+)?href="([^"]*(?=index)[^"]*)"',
|
||||
link_regex=r'<a\\s+(?:[^>]*?\\s+)?href="([^"]*(?=index)[^"]*)"',
|
||||
exclude_dirs=['https://docs.python.org/3.9/faq']
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
@@ -132,6 +132,7 @@ class BeautifulSoupTransformer(BaseDocumentTransformer):
|
||||
Args:
|
||||
html_content: The original HTML content string.
|
||||
tags: A list of tags to be extracted from the HTML.
|
||||
remove_comments: If set to True, the comments will be removed.
|
||||
|
||||
Returns:
|
||||
A string combining the content of the extracted tags.
|
||||
@@ -184,6 +185,7 @@ def get_navigable_strings(
|
||||
|
||||
Args:
|
||||
element: A BeautifulSoup element.
|
||||
remove_comments: If set to True, the comments will be removed.
|
||||
|
||||
Returns:
|
||||
A generator of strings.
|
||||
|
||||
@@ -213,7 +213,7 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
try:
|
||||
if params.get("select_expert"):
|
||||
embedding = response.json()["predictions"][0]
|
||||
embedding = response.json()["predictions"]
|
||||
else:
|
||||
embedding = response.json()["predictions"]
|
||||
embeddings.extend(embedding)
|
||||
@@ -299,7 +299,7 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
try:
|
||||
if params.get("select_expert"):
|
||||
embedding = response.json()["predictions"][0][0]
|
||||
embedding = response.json()["predictions"][0]
|
||||
else:
|
||||
embedding = response.json()["predictions"][0]
|
||||
except KeyError:
|
||||
|
||||
@@ -1,7 +1,708 @@
|
||||
from langchain_core.graph_vectorstores.base import (
|
||||
GraphVectorStore,
|
||||
GraphVectorStoreRetriever,
|
||||
Node,
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterable, Collection, Iterable, Iterator
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Optional,
|
||||
)
|
||||
|
||||
__all__ = ["GraphVectorStore", "GraphVectorStoreRetriever", "Node"]
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.runnables import run_in_executor
|
||||
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
|
||||
|
||||
|
||||
def _has_next(iterator: Iterator) -> bool:
|
||||
"""Checks if the iterator has more elements.
|
||||
Warning: consumes an element from the iterator"""
|
||||
sentinel = object()
|
||||
return next(iterator, sentinel) is not sentinel
|
||||
|
||||
|
||||
@beta()
|
||||
class Node(Serializable):
|
||||
"""Node in the GraphVectorStore.
|
||||
|
||||
Edges exist from nodes with an outgoing link to nodes with a matching incoming link.
|
||||
|
||||
For instance two nodes `a` and `b` connected over a hyperlink ``https://some-url``
|
||||
would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[
|
||||
Node(
|
||||
id="a",
|
||||
text="some text a",
|
||||
links= [
|
||||
Link(kind="hyperlink", tag="https://some-url", direction="incoming")
|
||||
],
|
||||
),
|
||||
Node(
|
||||
id="b",
|
||||
text="some text b",
|
||||
links= [
|
||||
Link(kind="hyperlink", tag="https://some-url", direction="outgoing")
|
||||
],
|
||||
)
|
||||
]
|
||||
"""
|
||||
|
||||
id: Optional[str] = None
|
||||
"""Unique ID for the node. Will be generated by the GraphVectorStore if not set."""
|
||||
text: str
|
||||
"""Text contained by the node."""
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
"""Metadata for the node."""
|
||||
links: list[Link] = Field(default_factory=list)
|
||||
"""Links associated with the node."""
|
||||
|
||||
|
||||
def _texts_to_nodes(
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]],
|
||||
ids: Optional[Iterable[str]],
|
||||
) -> Iterator[Node]:
|
||||
metadatas_it = iter(metadatas) if metadatas else None
|
||||
ids_it = iter(ids) if ids else None
|
||||
for text in texts:
|
||||
try:
|
||||
_metadata = next(metadatas_it).copy() if metadatas_it else {}
|
||||
except StopIteration as e:
|
||||
raise ValueError("texts iterable longer than metadatas") from e
|
||||
try:
|
||||
_id = next(ids_it) if ids_it else None
|
||||
except StopIteration as e:
|
||||
raise ValueError("texts iterable longer than ids") from e
|
||||
|
||||
links = _metadata.pop(METADATA_LINKS_KEY, [])
|
||||
if not isinstance(links, list):
|
||||
links = list(links)
|
||||
yield Node(
|
||||
id=_id,
|
||||
metadata=_metadata,
|
||||
text=text,
|
||||
links=links,
|
||||
)
|
||||
if ids_it and _has_next(ids_it):
|
||||
raise ValueError("ids iterable longer than texts")
|
||||
if metadatas_it and _has_next(metadatas_it):
|
||||
raise ValueError("metadatas iterable longer than texts")
|
||||
|
||||
|
||||
def _documents_to_nodes(documents: Iterable[Document]) -> Iterator[Node]:
|
||||
for doc in documents:
|
||||
metadata = doc.metadata.copy()
|
||||
links = metadata.pop(METADATA_LINKS_KEY, [])
|
||||
if not isinstance(links, list):
|
||||
links = list(links)
|
||||
yield Node(
|
||||
id=doc.id,
|
||||
metadata=metadata,
|
||||
text=doc.page_content,
|
||||
links=links,
|
||||
)
|
||||
|
||||
|
||||
@beta()
|
||||
def nodes_to_documents(nodes: Iterable[Node]) -> Iterator[Document]:
|
||||
"""Convert nodes to documents.
|
||||
|
||||
Args:
|
||||
nodes: The nodes to convert to documents.
|
||||
Returns:
|
||||
The documents generated from the nodes.
|
||||
"""
|
||||
for node in nodes:
|
||||
metadata = node.metadata.copy()
|
||||
metadata[METADATA_LINKS_KEY] = [
|
||||
# Convert the core `Link` (from the node) back to the local `Link`.
|
||||
Link(kind=link.kind, direction=link.direction, tag=link.tag)
|
||||
for link in node.links
|
||||
]
|
||||
|
||||
yield Document(
|
||||
id=node.id,
|
||||
page_content=node.text,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
@beta(message="Added in version 0.2.14 of langchain_core. API subject to change.")
|
||||
class GraphVectorStore(VectorStore):
|
||||
"""A hybrid vector-and-graph graph store.
|
||||
|
||||
Document chunks support vector-similarity search as well as edges linking
|
||||
chunks based on structural and semantic properties.
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_nodes(
|
||||
self,
|
||||
nodes: Iterable[Node],
|
||||
**kwargs: Any,
|
||||
) -> Iterable[str]:
|
||||
"""Add nodes to the graph store.
|
||||
|
||||
Args:
|
||||
nodes: the nodes to add.
|
||||
"""
|
||||
|
||||
async def aadd_nodes(
|
||||
self,
|
||||
nodes: Iterable[Node],
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[str]:
|
||||
"""Add nodes to the graph store.
|
||||
|
||||
Args:
|
||||
nodes: the nodes to add.
|
||||
"""
|
||||
iterator = iter(await run_in_executor(None, self.add_nodes, nodes, **kwargs))
|
||||
done = object()
|
||||
while True:
|
||||
doc = await run_in_executor(None, next, iterator, done)
|
||||
if doc is done:
|
||||
break
|
||||
yield doc # type: ignore[misc]
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
*,
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the metadata field `links` will be extracted to create
|
||||
the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
store.add_texts(
|
||||
ids=["a", "b"],
|
||||
texts=["some text a", "some text b"],
|
||||
metadatas=[
|
||||
{
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
{
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
The metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
|
||||
**kwargs: vectorstore specific parameters.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
nodes = _texts_to_nodes(texts, metadatas, ids)
|
||||
return list(self.add_nodes(nodes, **kwargs))
|
||||
|
||||
async def aadd_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
*,
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the metadata field `links` will be extracted to create
|
||||
the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
await store.aadd_texts(
|
||||
ids=["a", "b"],
|
||||
texts=["some text a", "some text b"],
|
||||
metadatas=[
|
||||
{
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
{
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
The metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
|
||||
**kwargs: vectorstore specific parameters.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
nodes = _texts_to_nodes(texts, metadatas, ids)
|
||||
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
documents: Iterable[Document],
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the document metadata field `links` will be extracted to
|
||||
create the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
store.add_documents(
|
||||
[
|
||||
Document(
|
||||
id="a",
|
||||
page_content="some text a",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
Document(
|
||||
id="b",
|
||||
page_content="some text b",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
)
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
The document's metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
|
||||
|
||||
Returns:
|
||||
List of IDs of the added texts.
|
||||
"""
|
||||
nodes = _documents_to_nodes(documents)
|
||||
return list(self.add_nodes(nodes, **kwargs))
|
||||
|
||||
async def aadd_documents(
|
||||
self,
|
||||
documents: Iterable[Document],
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the document metadata field `links` will be extracted to
|
||||
create the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
store.add_documents(
|
||||
[
|
||||
Document(
|
||||
id="a",
|
||||
page_content="some text a",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
Document(
|
||||
id="b",
|
||||
page_content="some text b",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
)
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
The document's metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
|
||||
|
||||
Returns:
|
||||
List of IDs of the added texts.
|
||||
"""
|
||||
nodes = _documents_to_nodes(documents)
|
||||
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
|
||||
|
||||
@abstractmethod
|
||||
def traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Document]:
|
||||
"""Retrieve documents from traversing this graph store.
|
||||
|
||||
First, `k` nodes are retrieved using a search for each `query` string.
|
||||
Then, additional nodes are discovered up to the given `depth` from those
|
||||
starting nodes.
|
||||
|
||||
Args:
|
||||
query: The query string.
|
||||
k: The number of Documents to return from the initial search.
|
||||
Defaults to 4. Applies to each of the query strings.
|
||||
depth: The maximum depth of edges to traverse. Defaults to 1.
|
||||
Returns:
|
||||
Retrieved documents.
|
||||
"""
|
||||
|
||||
async def atraversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[Document]:
|
||||
"""Retrieve documents from traversing this graph store.
|
||||
|
||||
First, `k` nodes are retrieved using a search for each `query` string.
|
||||
Then, additional nodes are discovered up to the given `depth` from those
|
||||
starting nodes.
|
||||
|
||||
Args:
|
||||
query: The query string.
|
||||
k: The number of Documents to return from the initial search.
|
||||
Defaults to 4. Applies to each of the query strings.
|
||||
depth: The maximum depth of edges to traverse. Defaults to 1.
|
||||
Returns:
|
||||
Retrieved documents.
|
||||
"""
|
||||
iterator = iter(
|
||||
await run_in_executor(
|
||||
None, self.traversal_search, query, k=k, depth=depth, **kwargs
|
||||
)
|
||||
)
|
||||
done = object()
|
||||
while True:
|
||||
doc = await run_in_executor(None, next, iterator, done)
|
||||
if doc is done:
|
||||
break
|
||||
yield doc # type: ignore[misc]
|
||||
|
||||
@abstractmethod
|
||||
def mmr_traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 2,
|
||||
fetch_k: int = 100,
|
||||
adjacent_k: int = 10,
|
||||
lambda_mult: float = 0.5,
|
||||
score_threshold: float = float("-inf"),
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Document]:
|
||||
"""Retrieve documents from this graph store using MMR-traversal.
|
||||
|
||||
This strategy first retrieves the top `fetch_k` results by similarity to
|
||||
the question. It then selects the top `k` results based on
|
||||
maximum-marginal relevance using the given `lambda_mult`.
|
||||
|
||||
At each step, it considers the (remaining) documents from `fetch_k` as
|
||||
well as any documents connected by edges to a selected document
|
||||
retrieved based on similarity (a "root").
|
||||
|
||||
Args:
|
||||
query: The query string to search for.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch via similarity.
|
||||
Defaults to 100.
|
||||
adjacent_k: Number of adjacent Documents to fetch.
|
||||
Defaults to 10.
|
||||
depth: Maximum depth of a node (number of edges) from a node
|
||||
retrieved via similarity. Defaults to 2.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding to maximum
|
||||
diversity and 1 to minimum diversity. Defaults to 0.5.
|
||||
score_threshold: Only documents with a score greater than or equal
|
||||
this threshold will be chosen. Defaults to negative infinity.
|
||||
"""
|
||||
|
||||
async def ammr_traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 2,
|
||||
fetch_k: int = 100,
|
||||
adjacent_k: int = 10,
|
||||
lambda_mult: float = 0.5,
|
||||
score_threshold: float = float("-inf"),
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[Document]:
|
||||
"""Retrieve documents from this graph store using MMR-traversal.
|
||||
|
||||
This strategy first retrieves the top `fetch_k` results by similarity to
|
||||
the question. It then selects the top `k` results based on
|
||||
maximum-marginal relevance using the given `lambda_mult`.
|
||||
|
||||
At each step, it considers the (remaining) documents from `fetch_k` as
|
||||
well as any documents connected by edges to a selected document
|
||||
retrieved based on similarity (a "root").
|
||||
|
||||
Args:
|
||||
query: The query string to search for.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch via similarity.
|
||||
Defaults to 100.
|
||||
adjacent_k: Number of adjacent Documents to fetch.
|
||||
Defaults to 10.
|
||||
depth: Maximum depth of a node (number of edges) from a node
|
||||
retrieved via similarity. Defaults to 2.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding to maximum
|
||||
diversity and 1 to minimum diversity. Defaults to 0.5.
|
||||
score_threshold: Only documents with a score greater than or equal
|
||||
this threshold will be chosen. Defaults to negative infinity.
|
||||
"""
|
||||
iterator = iter(
|
||||
await run_in_executor(
|
||||
None,
|
||||
self.mmr_traversal_search,
|
||||
query,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
adjacent_k=adjacent_k,
|
||||
depth=depth,
|
||||
lambda_mult=lambda_mult,
|
||||
score_threshold=score_threshold,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
done = object()
|
||||
while True:
|
||||
doc = await run_in_executor(None, next, iterator, done)
|
||||
if doc is done:
|
||||
break
|
||||
yield doc # type: ignore[misc]
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
return list(self.traversal_search(query, k=k, depth=0))
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
return list(
|
||||
self.mmr_traversal_search(
|
||||
query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0
|
||||
)
|
||||
)
|
||||
|
||||
async def asimilarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
return [doc async for doc in self.atraversal_search(query, k=k, depth=0)]
|
||||
|
||||
def search(self, query: str, search_type: str, **kwargs: Any) -> list[Document]:
|
||||
if search_type == "similarity":
|
||||
return self.similarity_search(query, **kwargs)
|
||||
elif search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = self.similarity_search_with_relevance_scores(
|
||||
query, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_similarities]
|
||||
elif search_type == "mmr":
|
||||
return self.max_marginal_relevance_search(query, **kwargs)
|
||||
elif search_type == "traversal":
|
||||
return list(self.traversal_search(query, **kwargs))
|
||||
elif search_type == "mmr_traversal":
|
||||
return list(self.mmr_traversal_search(query, **kwargs))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity', 'similarity_score_threshold', "
|
||||
"'mmr' or 'traversal'."
|
||||
)
|
||||
|
||||
async def asearch(
|
||||
self, query: str, search_type: str, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
if search_type == "similarity":
|
||||
return await self.asimilarity_search(query, **kwargs)
|
||||
elif search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = await self.asimilarity_search_with_relevance_scores(
|
||||
query, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_similarities]
|
||||
elif search_type == "mmr":
|
||||
return await self.amax_marginal_relevance_search(query, **kwargs)
|
||||
elif search_type == "traversal":
|
||||
return [doc async for doc in self.atraversal_search(query, **kwargs)]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity', 'similarity_score_threshold', "
|
||||
"'mmr' or 'traversal'."
|
||||
)
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever:
|
||||
"""Return GraphVectorStoreRetriever initialized from this GraphVectorStore.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to pass to the search function.
|
||||
Can include:
|
||||
|
||||
- search_type (Optional[str]): Defines the type of search that
|
||||
the Retriever should perform.
|
||||
Can be ``traversal`` (default), ``similarity``, ``mmr``, or
|
||||
``similarity_score_threshold``.
|
||||
- search_kwargs (Optional[Dict]): Keyword arguments to pass to the
|
||||
search function. Can include things like:
|
||||
|
||||
- k(int): Amount of documents to return (Default: 4).
|
||||
- depth(int): The maximum depth of edges to traverse (Default: 1).
|
||||
- score_threshold(float): Minimum relevance threshold
|
||||
for similarity_score_threshold.
|
||||
- fetch_k(int): Amount of documents to pass to MMR algorithm
|
||||
(Default: 20).
|
||||
- lambda_mult(float): Diversity of results returned by MMR;
|
||||
1 for minimum diversity and 0 for maximum. (Default: 0.5).
|
||||
Returns:
|
||||
Retriever for this GraphVectorStore.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Retrieve documents traversing edges
|
||||
docsearch.as_retriever(
|
||||
search_type="traversal",
|
||||
search_kwargs={'k': 6, 'depth': 3}
|
||||
)
|
||||
|
||||
# Retrieve more documents with higher diversity
|
||||
# Useful if your dataset has many similar documents
|
||||
docsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={'k': 6, 'lambda_mult': 0.25}
|
||||
)
|
||||
|
||||
# Fetch more documents for the MMR algorithm to consider
|
||||
# But only return the top 5
|
||||
docsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={'k': 5, 'fetch_k': 50}
|
||||
)
|
||||
|
||||
# Only retrieve documents that have a relevance score
|
||||
# Above a certain threshold
|
||||
docsearch.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={'score_threshold': 0.8}
|
||||
)
|
||||
|
||||
# Only get the single most similar document from the dataset
|
||||
docsearch.as_retriever(search_kwargs={'k': 1})
|
||||
|
||||
"""
|
||||
return GraphVectorStoreRetriever(vectorstore=self, **kwargs)
|
||||
|
||||
|
||||
class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
"""Retriever class for GraphVectorStore."""
|
||||
|
||||
vectorstore: GraphVectorStore
|
||||
"""GraphVectorStore to use for retrieval."""
|
||||
search_type: str = "traversal"
|
||||
"""Type of search to perform. Defaults to "traversal"."""
|
||||
allowed_search_types: ClassVar[Collection[str]] = (
|
||||
"similarity",
|
||||
"similarity_score_threshold",
|
||||
"mmr",
|
||||
"traversal",
|
||||
"mmr_traversal",
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> list[Document]:
|
||||
if self.search_type == "traversal":
|
||||
return list(self.vectorstore.traversal_search(query, **self.search_kwargs))
|
||||
elif self.search_type == "mmr_traversal":
|
||||
return list(
|
||||
self.vectorstore.mmr_traversal_search(query, **self.search_kwargs)
|
||||
)
|
||||
else:
|
||||
return super()._get_relevant_documents(query, run_manager=run_manager)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> list[Document]:
|
||||
if self.search_type == "traversal":
|
||||
return [
|
||||
doc
|
||||
async for doc in self.vectorstore.atraversal_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
]
|
||||
elif self.search_type == "mmr_traversal":
|
||||
return [
|
||||
doc
|
||||
async for doc in self.vectorstore.ammr_traversal_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
]
|
||||
else:
|
||||
return await super()._aget_relevant_documents(
|
||||
query, run_manager=run_manager
|
||||
)
|
||||
|
||||
@@ -12,12 +12,12 @@ from typing import (
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.graph_vectorstores.base import (
|
||||
|
||||
from langchain_community.graph_vectorstores.base import (
|
||||
GraphVectorStore,
|
||||
Node,
|
||||
nodes_to_documents,
|
||||
)
|
||||
|
||||
from langchain_community.utilities.cassandra import SetupMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -2,11 +2,11 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Union
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores.links import Link
|
||||
|
||||
from langchain_community.graph_vectorstores.extractors.link_extractor import (
|
||||
LinkExtractor,
|
||||
)
|
||||
from langchain_community.graph_vectorstores.links import Link
|
||||
|
||||
# TypeAlias is not available in Python 3.9, we can't use that or the newer `type`.
|
||||
GLiNERInput = Union[str, Document]
|
||||
@@ -34,7 +34,7 @@ class GLiNERLinkExtractor(LinkExtractor[GLiNERInput]):
|
||||
.. seealso::
|
||||
|
||||
- :mod:`How to use a graph vector store <langchain_community.graph_vectorstores>`
|
||||
- :class:`How to create links between documents <langchain_core.graph_vectorstores.links.Link>`
|
||||
- :class:`How to create links between documents <langchain_community.graph_vectorstores.links.Link>`
|
||||
|
||||
How to link Documents on common named entities
|
||||
==============================================
|
||||
@@ -59,12 +59,12 @@ class GLiNERLinkExtractor(LinkExtractor[GLiNERInput]):
|
||||
|
||||
We can use :meth:`extract_one` on a document to get the links and add the links
|
||||
to the document metadata with
|
||||
:meth:`~langchain_core.graph_vectorstores.links.add_links`::
|
||||
:meth:`~langchain_community.graph_vectorstores.links.add_links`::
|
||||
|
||||
from langchain_community.document_loaders import TextLoader
|
||||
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
|
||||
from langchain_community.graph_vectorstores.extractors import GLiNERLinkExtractor
|
||||
from langchain_core.graph_vectorstores.links import add_links
|
||||
from langchain_community.graph_vectorstores.links import add_links
|
||||
from langchain_text_splitters import CharacterTextSplitter
|
||||
|
||||
loader = TextLoader("state_of_the_union.txt")
|
||||
@@ -87,7 +87,7 @@ class GLiNERLinkExtractor(LinkExtractor[GLiNERInput]):
|
||||
Using LinkExtractorTransformer
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Using the :class:`~langchain_community.graph_vectorstores.extractors.keybert_link_extractor.LinkExtractorTransformer`,
|
||||
Using the :class:`~langchain_community.graph_vectorstores.extractors.link_extractor_transformer.LinkExtractorTransformer`,
|
||||
we can simplify the link extraction::
|
||||
|
||||
from langchain_community.document_loaders import TextLoader
|
||||
@@ -113,7 +113,7 @@ class GLiNERLinkExtractor(LinkExtractor[GLiNERInput]):
|
||||
|
||||
{'source': 'state_of_the_union.txt', 'links': [Link(kind='entity:Person', direction='bidir', tag='President Zelenskyy'), Link(kind='entity:Person', direction='bidir', tag='Vladimir Putin')]}
|
||||
|
||||
The documents with named entity links can then be added to a :class:`~langchain_core.graph_vectorstores.base.GraphVectorStore`::
|
||||
The documents with named entity links can then be added to a :class:`~langchain_community.graph_vectorstores.base.GraphVectorStore`::
|
||||
|
||||
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Callable, List, Set
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores.links import Link
|
||||
|
||||
from langchain_community.graph_vectorstores.extractors.link_extractor import (
|
||||
LinkExtractor,
|
||||
@@ -10,6 +9,7 @@ from langchain_community.graph_vectorstores.extractors.link_extractor import (
|
||||
from langchain_community.graph_vectorstores.extractors.link_extractor_adapter import (
|
||||
LinkExtractorAdapter,
|
||||
)
|
||||
from langchain_community.graph_vectorstores.links import Link
|
||||
|
||||
# TypeAlias is not available in Python 3.9, we can't use that or the newer `type`.
|
||||
HierarchyInput = List[str]
|
||||
|
||||
@@ -6,8 +6,8 @@ from urllib.parse import urldefrag, urljoin, urlparse
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores import Link
|
||||
|
||||
from langchain_community.graph_vectorstores import Link
|
||||
from langchain_community.graph_vectorstores.extractors.link_extractor import (
|
||||
LinkExtractor,
|
||||
)
|
||||
@@ -77,7 +77,7 @@ class HtmlLinkExtractor(LinkExtractor[HtmlInput]):
|
||||
.. seealso::
|
||||
|
||||
- :mod:`How to use a graph vector store <langchain_community.graph_vectorstores>`
|
||||
- :class:`How to create links between documents <langchain_core.graph_vectorstores.links.Link>`
|
||||
- :class:`How to create links between documents <langchain_community.graph_vectorstores.links.Link>`
|
||||
|
||||
How to link Documents on hyperlinks in HTML
|
||||
===========================================
|
||||
@@ -103,7 +103,7 @@ class HtmlLinkExtractor(LinkExtractor[HtmlInput]):
|
||||
|
||||
We can use :meth:`extract_one` on a document to get the links and add the links
|
||||
to the document metadata with
|
||||
:meth:`~langchain_core.graph_vectorstores.links.add_links`::
|
||||
:meth:`~langchain_community.graph_vectorstores.links.add_links`::
|
||||
|
||||
from langchain_community.document_loaders import AsyncHtmlLoader
|
||||
from langchain_community.graph_vectorstores.extractors import (
|
||||
@@ -148,7 +148,7 @@ class HtmlLinkExtractor(LinkExtractor[HtmlInput]):
|
||||
|
||||
from langchain_community.document_loaders import AsyncHtmlLoader
|
||||
from langchain_community.graph_vectorstores.extractors import HtmlLinkExtractor
|
||||
from langchain_core.graph_vectorstores.links import add_links
|
||||
from langchain_community.graph_vectorstores.links import add_links
|
||||
|
||||
loader = AsyncHtmlLoader(
|
||||
[
|
||||
@@ -176,7 +176,7 @@ class HtmlLinkExtractor(LinkExtractor[HtmlInput]):
|
||||
Using LinkExtractorTransformer
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Using the :class:`~langchain_community.graph_vectorstores.extractors.keybert_link_extractor.LinkExtractorTransformer`,
|
||||
Using the :class:`~langchain_community.graph_vectorstores.extractors.link_extractor_transformer.LinkExtractorTransformer`,
|
||||
we can simplify the link extraction::
|
||||
|
||||
from langchain_community.document_loaders import AsyncHtmlLoader
|
||||
@@ -227,7 +227,7 @@ class HtmlLinkExtractor(LinkExtractor[HtmlInput]):
|
||||
|
||||
Found link from https://python.langchain.com/v0.2/docs/integrations/providers/astradb/ to https://docs.datastax.com/en/astra/home/astra.html.
|
||||
|
||||
The documents with URL links can then be added to a :class:`~langchain_core.graph_vectorstores.base.GraphVectorStore`::
|
||||
The documents with URL links can then be added to a :class:`~langchain_community.graph_vectorstores.base.GraphVectorStore`::
|
||||
|
||||
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
|
||||
|
||||
|
||||
@@ -2,11 +2,11 @@ from typing import Any, Dict, Iterable, Optional, Set, Union
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores.links import Link
|
||||
|
||||
from langchain_community.graph_vectorstores.extractors.link_extractor import (
|
||||
LinkExtractor,
|
||||
)
|
||||
from langchain_community.graph_vectorstores.links import Link
|
||||
|
||||
KeybertInput = Union[str, Document]
|
||||
|
||||
@@ -37,7 +37,7 @@ class KeybertLinkExtractor(LinkExtractor[KeybertInput]):
|
||||
.. seealso::
|
||||
|
||||
- :mod:`How to use a graph vector store <langchain_community.graph_vectorstores>`
|
||||
- :class:`How to create links between documents <langchain_core.graph_vectorstores.links.Link>`
|
||||
- :class:`How to create links between documents <langchain_community.graph_vectorstores.links.Link>`
|
||||
|
||||
How to link Documents on common keywords using Keybert
|
||||
======================================================
|
||||
@@ -62,12 +62,12 @@ class KeybertLinkExtractor(LinkExtractor[KeybertInput]):
|
||||
|
||||
We can use :meth:`extract_one` on a document to get the links and add the links
|
||||
to the document metadata with
|
||||
:meth:`~langchain_core.graph_vectorstores.links.add_links`::
|
||||
:meth:`~langchain_community.graph_vectorstores.links.add_links`::
|
||||
|
||||
from langchain_community.document_loaders import TextLoader
|
||||
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
|
||||
from langchain_community.graph_vectorstores.extractors import KeybertLinkExtractor
|
||||
from langchain_core.graph_vectorstores.links import add_links
|
||||
from langchain_community.graph_vectorstores.links import add_links
|
||||
from langchain_text_splitters import CharacterTextSplitter
|
||||
|
||||
loader = TextLoader("state_of_the_union.txt")
|
||||
@@ -91,7 +91,7 @@ class KeybertLinkExtractor(LinkExtractor[KeybertInput]):
|
||||
Using LinkExtractorTransformer
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Using the :class:`~langchain_community.graph_vectorstores.extractors.keybert_link_extractor.LinkExtractorTransformer`,
|
||||
Using the :class:`~langchain_community.graph_vectorstores.extractors.link_extractor_transformer.LinkExtractorTransformer`,
|
||||
we can simplify the link extraction::
|
||||
|
||||
from langchain_community.document_loaders import TextLoader
|
||||
@@ -116,7 +116,7 @@ class KeybertLinkExtractor(LinkExtractor[KeybertInput]):
|
||||
|
||||
{'source': 'state_of_the_union.txt', 'links': [Link(kind='kw', direction='bidir', tag='ukraine'), Link(kind='kw', direction='bidir', tag='ukrainian'), Link(kind='kw', direction='bidir', tag='putin'), Link(kind='kw', direction='bidir', tag='vladimir'), Link(kind='kw', direction='bidir', tag='russia')]}
|
||||
|
||||
The documents with keyword links can then be added to a :class:`~langchain_core.graph_vectorstores.base.GraphVectorStore`::
|
||||
The documents with keyword links can then be added to a :class:`~langchain_community.graph_vectorstores.base.GraphVectorStore`::
|
||||
|
||||
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
|
||||
|
||||
|
||||
@@ -4,7 +4,8 @@ from abc import ABC, abstractmethod
|
||||
from typing import Generic, Iterable, Set, TypeVar
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.graph_vectorstores import Link
|
||||
|
||||
from langchain_community.graph_vectorstores import Link
|
||||
|
||||
InputT = TypeVar("InputT")
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Callable, Iterable, Set, TypeVar
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.graph_vectorstores import Link
|
||||
|
||||
from langchain_community.graph_vectorstores import Link
|
||||
from langchain_community.graph_vectorstores.extractors.link_extractor import (
|
||||
LinkExtractor,
|
||||
)
|
||||
|
||||
@@ -3,11 +3,11 @@ from typing import Any, Sequence
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.documents.transformers import BaseDocumentTransformer
|
||||
from langchain_core.graph_vectorstores.links import copy_with_links
|
||||
|
||||
from langchain_community.graph_vectorstores.extractors.link_extractor import (
|
||||
LinkExtractor,
|
||||
)
|
||||
from langchain_community.graph_vectorstores.links import copy_with_links
|
||||
|
||||
|
||||
@beta()
|
||||
|
||||
@@ -1,8 +1,102 @@
|
||||
from langchain_core.graph_vectorstores.links import (
|
||||
Link,
|
||||
add_links,
|
||||
copy_with_links,
|
||||
get_links,
|
||||
)
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Union
|
||||
|
||||
__all__ = ["Link", "add_links", "get_links", "copy_with_links"]
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
@beta()
|
||||
@dataclass(frozen=True)
|
||||
class Link:
|
||||
"""A link to/from a tag of a given tag.
|
||||
|
||||
Edges exist from nodes with an outgoing link to nodes with a matching incoming link.
|
||||
"""
|
||||
|
||||
kind: str
|
||||
"""The kind of link. Allows different extractors to use the same tag name without
|
||||
creating collisions between extractors. For example “keyword” vs “url”."""
|
||||
direction: Literal["in", "out", "bidir"]
|
||||
"""The direction of the link."""
|
||||
tag: str
|
||||
"""The tag of the link."""
|
||||
|
||||
@staticmethod
|
||||
def incoming(kind: str, tag: str) -> "Link":
|
||||
"""Create an incoming link."""
|
||||
return Link(kind=kind, direction="in", tag=tag)
|
||||
|
||||
@staticmethod
|
||||
def outgoing(kind: str, tag: str) -> "Link":
|
||||
"""Create an outgoing link."""
|
||||
return Link(kind=kind, direction="out", tag=tag)
|
||||
|
||||
@staticmethod
|
||||
def bidir(kind: str, tag: str) -> "Link":
|
||||
"""Create a bidirectional link."""
|
||||
return Link(kind=kind, direction="bidir", tag=tag)
|
||||
|
||||
|
||||
METADATA_LINKS_KEY = "links"
|
||||
|
||||
|
||||
@beta()
|
||||
def get_links(doc: Document) -> list[Link]:
|
||||
"""Get the links from a document.
|
||||
|
||||
Args:
|
||||
doc: The document to get the link tags from.
|
||||
Returns:
|
||||
The set of link tags from the document.
|
||||
"""
|
||||
|
||||
links = doc.metadata.setdefault(METADATA_LINKS_KEY, [])
|
||||
if not isinstance(links, list):
|
||||
# Convert to a list and remember that.
|
||||
links = list(links)
|
||||
doc.metadata[METADATA_LINKS_KEY] = links
|
||||
return links
|
||||
|
||||
|
||||
@beta()
|
||||
def add_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> None:
|
||||
"""Add links to the given metadata.
|
||||
|
||||
Args:
|
||||
doc: The document to add the links to.
|
||||
*links: The links to add to the document.
|
||||
"""
|
||||
links_in_metadata = get_links(doc)
|
||||
for link in links:
|
||||
if isinstance(link, Iterable):
|
||||
links_in_metadata.extend(link)
|
||||
else:
|
||||
links_in_metadata.append(link)
|
||||
|
||||
|
||||
@beta()
|
||||
def copy_with_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> Document:
|
||||
"""Return a document with the given links added.
|
||||
|
||||
Args:
|
||||
doc: The document to add the links to.
|
||||
*links: The links to add to the document.
|
||||
|
||||
Returns:
|
||||
A document with a shallow-copy of the metadata with the links added.
|
||||
"""
|
||||
new_links = set(get_links(doc))
|
||||
for link in links:
|
||||
if isinstance(link, Iterable):
|
||||
new_links.update(link)
|
||||
else:
|
||||
new_links.add(link)
|
||||
|
||||
return Document(
|
||||
page_content=doc.page_content,
|
||||
metadata={
|
||||
**doc.metadata,
|
||||
METADATA_LINKS_KEY: list(new_links),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -411,7 +411,9 @@ class Neo4jGraph(GraphStore):
|
||||
return self.structured_schema
|
||||
|
||||
def query(
|
||||
self, query: str, params: dict = {}, retry_on_session_expired: bool = True
|
||||
self,
|
||||
query: str,
|
||||
params: dict = {},
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query Neo4j database.
|
||||
|
||||
@@ -423,26 +425,44 @@ class Neo4jGraph(GraphStore):
|
||||
List[Dict[str, Any]]: The list of dictionaries containing the query results.
|
||||
"""
|
||||
from neo4j import Query
|
||||
from neo4j.exceptions import CypherSyntaxError, SessionExpired
|
||||
from neo4j.exceptions import Neo4jError
|
||||
|
||||
with self._driver.session(database=self._database) as session:
|
||||
try:
|
||||
data = session.run(Query(text=query, timeout=self.timeout), params)
|
||||
json_data = [r.data() for r in data]
|
||||
if self.sanitize:
|
||||
json_data = [value_sanitize(el) for el in json_data]
|
||||
return json_data
|
||||
except CypherSyntaxError as e:
|
||||
raise ValueError(f"Generated Cypher Statement is not valid\n{e}")
|
||||
except (
|
||||
SessionExpired
|
||||
) as e: # Session expired is a transient error that can be retried
|
||||
if retry_on_session_expired:
|
||||
return self.query(
|
||||
query, params=params, retry_on_session_expired=False
|
||||
try:
|
||||
data, _, _ = self._driver.execute_query(
|
||||
Query(text=query, timeout=self.timeout),
|
||||
database=self._database,
|
||||
parameters_=params,
|
||||
)
|
||||
json_data = [r.data() for r in data]
|
||||
if self.sanitize:
|
||||
json_data = [value_sanitize(el) for el in json_data]
|
||||
return json_data
|
||||
except Neo4jError as e:
|
||||
if not (
|
||||
(
|
||||
( # isCallInTransactionError
|
||||
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
|
||||
or e.code
|
||||
== "Neo.DatabaseError.Transaction.TransactionStartFailed"
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
and "in an implicit transaction" in e.message
|
||||
)
|
||||
or ( # isPeriodicCommitError
|
||||
e.code == "Neo.ClientError.Statement.SemanticError"
|
||||
and (
|
||||
"in an open transaction is not possible" in e.message
|
||||
or "tried to execute in an explicit transaction" in e.message
|
||||
)
|
||||
)
|
||||
):
|
||||
raise
|
||||
# fallback to allow implicit transactions
|
||||
with self._driver.session() as session:
|
||||
data = session.run(Query(text=query, timeout=self.timeout), params)
|
||||
json_data = [r.data() for r in data]
|
||||
if self.sanitize:
|
||||
json_data = [value_sanitize(el) for el in json_data]
|
||||
return json_data
|
||||
|
||||
def refresh_schema(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -510,12 +510,6 @@ def _import_sagemaker_endpoint() -> Type[BaseLLM]:
|
||||
return SagemakerEndpoint
|
||||
|
||||
|
||||
def _import_sambaverse() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.sambanova import Sambaverse
|
||||
|
||||
return Sambaverse
|
||||
|
||||
|
||||
def _import_sambastudio() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.sambanova import SambaStudio
|
||||
|
||||
@@ -817,8 +811,6 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_rwkv()
|
||||
elif name == "SagemakerEndpoint":
|
||||
return _import_sagemaker_endpoint()
|
||||
elif name == "Sambaverse":
|
||||
return _import_sambaverse()
|
||||
elif name == "SambaStudio":
|
||||
return _import_sambastudio()
|
||||
elif name == "SelfHostedPipeline":
|
||||
@@ -954,7 +946,6 @@ __all__ = [
|
||||
"RWKV",
|
||||
"Replicate",
|
||||
"SagemakerEndpoint",
|
||||
"Sambaverse",
|
||||
"SambaStudio",
|
||||
"SelfHostedHuggingFaceLLM",
|
||||
"SelfHostedPipeline",
|
||||
@@ -1051,7 +1042,6 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"replicate": _import_replicate,
|
||||
"rwkv": _import_rwkv,
|
||||
"sagemaker_endpoint": _import_sagemaker_endpoint,
|
||||
"sambaverse": _import_sambaverse,
|
||||
"sambastudio": _import_sambastudio,
|
||||
"self_hosted": _import_self_hosted,
|
||||
"self_hosted_hugging_face": _import_self_hosted_hugging_face,
|
||||
|
||||
@@ -9,464 +9,6 @@ from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class SVEndpointHandler:
|
||||
"""
|
||||
SambaNova Systems Interface for Sambaverse endpoint.
|
||||
|
||||
:param str host_url: Base URL of the DaaS API service
|
||||
"""
|
||||
|
||||
API_BASE_PATH: str = "/api/predict"
|
||||
|
||||
def __init__(self, host_url: str):
|
||||
"""
|
||||
Initialize the SVEndpointHandler.
|
||||
|
||||
:param str host_url: Base URL of the DaaS API service
|
||||
"""
|
||||
self.host_url = host_url
|
||||
self.http_session = requests.Session()
|
||||
|
||||
@staticmethod
|
||||
def _process_response(response: requests.Response) -> Dict:
|
||||
"""
|
||||
Processes the API response and returns the resulting dict.
|
||||
|
||||
All resulting dicts, regardless of success or failure, will contain the
|
||||
`status_code` key with the API response status code.
|
||||
|
||||
If the API returned an error, the resulting dict will contain the key
|
||||
`detail` with the error message.
|
||||
|
||||
If the API call was successful, the resulting dict will contain the key
|
||||
`data` with the response data.
|
||||
|
||||
:param requests.Response response: the response object to process
|
||||
:return: the response dict
|
||||
:type: dict
|
||||
"""
|
||||
result: Dict[str, Any] = {}
|
||||
try:
|
||||
lines_result = response.text.strip().split("\n")
|
||||
text_result = lines_result[-1]
|
||||
if response.status_code == 200 and json.loads(text_result).get("error"):
|
||||
completion = ""
|
||||
for line in lines_result[:-1]:
|
||||
completion += json.loads(line)["result"]["responses"][0][
|
||||
"stream_token"
|
||||
]
|
||||
text_result = lines_result[-2]
|
||||
result = json.loads(text_result)
|
||||
result["result"]["responses"][0]["completion"] = completion
|
||||
else:
|
||||
result = json.loads(text_result)
|
||||
except Exception as e:
|
||||
result["detail"] = str(e)
|
||||
if "status_code" not in result:
|
||||
result["status_code"] = response.status_code
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _process_streaming_response(
|
||||
response: requests.Response,
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Process the streaming response"""
|
||||
try:
|
||||
for line in response.iter_lines():
|
||||
chunk = json.loads(line)
|
||||
if "status_code" not in chunk:
|
||||
chunk["status_code"] = response.status_code
|
||||
if chunk["status_code"] == 200 and chunk.get("error"):
|
||||
chunk["result"] = {"responses": [{"stream_token": ""}]}
|
||||
return chunk
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error processing streaming response: {e}")
|
||||
|
||||
def _get_full_url(self) -> str:
|
||||
"""
|
||||
Return the full API URL for a given path.
|
||||
:returns: the full API URL for the sub-path
|
||||
:type: str
|
||||
"""
|
||||
return f"{self.host_url}{self.API_BASE_PATH}"
|
||||
|
||||
def nlp_predict(
|
||||
self,
|
||||
key: str,
|
||||
sambaverse_model_name: Optional[str],
|
||||
input: Union[List[str], str],
|
||||
params: Optional[str] = "",
|
||||
stream: bool = False,
|
||||
) -> Dict:
|
||||
"""
|
||||
NLP predict using inline input string.
|
||||
|
||||
:param str project: Project ID in which the endpoint exists
|
||||
:param str endpoint: Endpoint ID
|
||||
:param str key: API Key
|
||||
:param str input_str: Input string
|
||||
:param str params: Input params string
|
||||
:returns: Prediction results
|
||||
:type: dict
|
||||
"""
|
||||
if params:
|
||||
data = {"instance": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"instance": input}
|
||||
response = self.http_session.post(
|
||||
self._get_full_url(),
|
||||
headers={
|
||||
"key": key,
|
||||
"Content-Type": "application/json",
|
||||
"modelName": sambaverse_model_name,
|
||||
},
|
||||
json=data,
|
||||
)
|
||||
return SVEndpointHandler._process_response(response)
|
||||
|
||||
def nlp_predict_stream(
|
||||
self,
|
||||
key: str,
|
||||
sambaverse_model_name: Optional[str],
|
||||
input: Union[List[str], str],
|
||||
params: Optional[str] = "",
|
||||
) -> Iterator[Dict]:
|
||||
"""
|
||||
NLP predict using inline input string.
|
||||
|
||||
:param str project: Project ID in which the endpoint exists
|
||||
:param str endpoint: Endpoint ID
|
||||
:param str key: API Key
|
||||
:param str input_str: Input string
|
||||
:param str params: Input params string
|
||||
:returns: Prediction results
|
||||
:type: dict
|
||||
"""
|
||||
if params:
|
||||
data = {"instance": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"instance": input}
|
||||
# Streaming output
|
||||
response = self.http_session.post(
|
||||
self._get_full_url(),
|
||||
headers={
|
||||
"key": key,
|
||||
"Content-Type": "application/json",
|
||||
"modelName": sambaverse_model_name,
|
||||
},
|
||||
json=data,
|
||||
stream=True,
|
||||
)
|
||||
for chunk in SVEndpointHandler._process_streaming_response(response):
|
||||
yield chunk
|
||||
|
||||
|
||||
class Sambaverse(LLM):
|
||||
"""
|
||||
Sambaverse large language models.
|
||||
|
||||
To use, you should have the environment variable ``SAMBAVERSE_API_KEY``
|
||||
set with your API key.
|
||||
|
||||
get one in https://sambaverse.sambanova.ai
|
||||
read extra documentation in https://docs.sambanova.ai/sambaverse/latest/index.html
|
||||
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms.sambanova import Sambaverse
|
||||
Sambaverse(
|
||||
sambaverse_url="https://sambaverse.sambanova.ai",
|
||||
sambaverse_api_key="your-sambaverse-api-key",
|
||||
sambaverse_model_name="Meta/llama-2-7b-chat-hf",
|
||||
streaming: = False
|
||||
model_kwargs={
|
||||
"select_expert": "llama-2-7b-chat-hf",
|
||||
"do_sample": False,
|
||||
"max_tokens_to_generate": 100,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"repetition_penalty": 1.0,
|
||||
"top_k": 50,
|
||||
"process_prompt": False
|
||||
},
|
||||
)
|
||||
"""
|
||||
|
||||
sambaverse_url: str = ""
|
||||
"""Sambaverse url to use"""
|
||||
|
||||
sambaverse_api_key: str = ""
|
||||
"""sambaverse api key"""
|
||||
|
||||
sambaverse_model_name: Optional[str] = None
|
||||
"""sambaverse expert model to use"""
|
||||
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
streaming: Optional[bool] = False
|
||||
"""Streaming flag to get streamed response."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["sambaverse_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"sambaverse_url",
|
||||
"SAMBAVERSE_URL",
|
||||
default="https://sambaverse.sambanova.ai",
|
||||
)
|
||||
values["sambaverse_api_key"] = get_from_dict_or_env(
|
||||
values, "sambaverse_api_key", "SAMBAVERSE_API_KEY"
|
||||
)
|
||||
values["sambaverse_model_name"] = get_from_dict_or_env(
|
||||
values, "sambaverse_model_name", "SAMBAVERSE_MODEL_NAME"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model_kwargs": self.model_kwargs}}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "Sambaverse LLM"
|
||||
|
||||
def _get_tuning_params(self, stop: Optional[List[str]]) -> str:
|
||||
"""
|
||||
Get the tuning parameters to use when calling the LLM.
|
||||
|
||||
Args:
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of the stop substrings.
|
||||
|
||||
Returns:
|
||||
The tuning parameters as a JSON string.
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_kwarg_stop_sequences = _model_kwargs.get("stop_sequences", [])
|
||||
_stop_sequences = stop or _kwarg_stop_sequences
|
||||
if not _kwarg_stop_sequences:
|
||||
_model_kwargs["stop_sequences"] = ",".join(
|
||||
f'"{x}"' for x in _stop_sequences
|
||||
)
|
||||
tuning_params_dict = {
|
||||
k: {"type": type(v).__name__, "value": str(v)}
|
||||
for k, v in (_model_kwargs.items())
|
||||
}
|
||||
_model_kwargs["stop_sequences"] = _kwarg_stop_sequences
|
||||
tuning_params = json.dumps(tuning_params_dict)
|
||||
return tuning_params
|
||||
|
||||
def _handle_nlp_predict(
|
||||
self,
|
||||
sdk: SVEndpointHandler,
|
||||
prompt: Union[List[str], str],
|
||||
tuning_params: str,
|
||||
) -> str:
|
||||
"""
|
||||
Perform an NLP prediction using the Sambaverse endpoint handler.
|
||||
|
||||
Args:
|
||||
sdk: The SVEndpointHandler to use for the prediction.
|
||||
prompt: The prompt to use for the prediction.
|
||||
tuning_params: The tuning parameters to use for the prediction.
|
||||
|
||||
Returns:
|
||||
The prediction result.
|
||||
|
||||
Raises:
|
||||
ValueError: If the prediction fails.
|
||||
"""
|
||||
response = sdk.nlp_predict(
|
||||
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
|
||||
)
|
||||
if response["status_code"] != 200:
|
||||
error = response.get("error")
|
||||
if error:
|
||||
optional_code = error.get("code")
|
||||
optional_details = error.get("details")
|
||||
optional_message = error.get("message")
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response['status_code']}.\n"
|
||||
f"Message: {optional_message}\n"
|
||||
f"Details: {optional_details}\n"
|
||||
f"Code: {optional_code}\n"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response['status_code']}."
|
||||
f"{response}."
|
||||
)
|
||||
return response["result"]["responses"][0]["completion"]
|
||||
|
||||
def _handle_completion_requests(
|
||||
self, prompt: Union[List[str], str], stop: Optional[List[str]]
|
||||
) -> str:
|
||||
"""
|
||||
Perform a prediction using the Sambaverse endpoint handler.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to use for the prediction.
|
||||
stop: stop sequences.
|
||||
|
||||
Returns:
|
||||
The prediction result.
|
||||
|
||||
Raises:
|
||||
ValueError: If the prediction fails.
|
||||
"""
|
||||
ss_endpoint = SVEndpointHandler(self.sambaverse_url)
|
||||
tuning_params = self._get_tuning_params(stop)
|
||||
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
|
||||
|
||||
def _handle_nlp_predict_stream(
|
||||
self, sdk: SVEndpointHandler, prompt: Union[List[str], str], tuning_params: str
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""
|
||||
Perform a streaming request to the LLM.
|
||||
|
||||
Args:
|
||||
sdk: The SVEndpointHandler to use for the prediction.
|
||||
prompt: The prompt to use for the prediction.
|
||||
tuning_params: The tuning parameters to use for the prediction.
|
||||
|
||||
Returns:
|
||||
An iterator of GenerationChunks.
|
||||
"""
|
||||
for chunk in sdk.nlp_predict_stream(
|
||||
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
|
||||
):
|
||||
if chunk["status_code"] != 200:
|
||||
error = chunk.get("error")
|
||||
if error:
|
||||
optional_code = error.get("code")
|
||||
optional_details = error.get("details")
|
||||
optional_message = error.get("message")
|
||||
raise ValueError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{chunk['status_code']}.\n"
|
||||
f"Message: {optional_message}\n"
|
||||
f"Details: {optional_details}\n"
|
||||
f"Code: {optional_code}\n"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{chunk['status_code']}."
|
||||
f"{chunk}."
|
||||
)
|
||||
text = chunk["result"]["responses"][0]["stream_token"]
|
||||
generated_chunk = GenerationChunk(text=text)
|
||||
yield generated_chunk
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: Union[List[str], str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""Stream the Sambaverse's LLM on the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
run_manager: Callback manager for the run.
|
||||
kwargs: Additional keyword arguments. directly passed
|
||||
to the sambaverse model in API call.
|
||||
|
||||
Returns:
|
||||
An iterator of GenerationChunks.
|
||||
"""
|
||||
ss_endpoint = SVEndpointHandler(self.sambaverse_url)
|
||||
tuning_params = self._get_tuning_params(stop)
|
||||
try:
|
||||
if self.streaming:
|
||||
for chunk in self._handle_nlp_predict_stream(
|
||||
ss_endpoint, prompt, tuning_params
|
||||
):
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text)
|
||||
yield chunk
|
||||
else:
|
||||
return
|
||||
except Exception as e:
|
||||
# Handle any errors raised by the inference endpoint
|
||||
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
|
||||
|
||||
def _handle_stream_request(
|
||||
self,
|
||||
prompt: Union[List[str], str],
|
||||
stop: Optional[List[str]],
|
||||
run_manager: Optional[CallbackManagerForLLMRun],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""
|
||||
Perform a streaming request to the LLM.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to generate from.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of the stop substrings.
|
||||
run_manager: Callback manager for the run.
|
||||
kwargs: Additional keyword arguments. directly passed
|
||||
to the sambaverse model in API call.
|
||||
|
||||
Returns:
|
||||
The model output as a string.
|
||||
"""
|
||||
completion = ""
|
||||
for chunk in self._stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: Union[List[str], str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the LLM on the given input.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to generate from.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of the stop substrings.
|
||||
run_manager: Callback manager for the run.
|
||||
kwargs: Additional keyword arguments. directly passed
|
||||
to the sambaverse model in API call.
|
||||
|
||||
Returns:
|
||||
The model output as a string.
|
||||
"""
|
||||
try:
|
||||
if self.streaming:
|
||||
return self._handle_stream_request(prompt, stop, run_manager, kwargs)
|
||||
return self._handle_completion_requests(prompt, stop)
|
||||
except Exception as e:
|
||||
# Handle any errors raised by the inference endpoint
|
||||
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
|
||||
|
||||
|
||||
class SSEndpointHandler:
|
||||
"""
|
||||
SambaNova Systems Interface for SambaStudio model endpoints.
|
||||
@@ -975,7 +517,7 @@ class SambaStudio(LLM):
|
||||
first occurrence of any of the stop substrings.
|
||||
run_manager: Callback manager for the run.
|
||||
kwargs: Additional keyword arguments. directly passed
|
||||
to the sambaverse model in API call.
|
||||
to the sambastudio model in API call.
|
||||
|
||||
Returns:
|
||||
The model output as a string.
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import BaseModel, Field, create_model
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from databricks.sdk import WorkspaceClient
|
||||
from databricks.sdk.service.catalog import FunctionInfo
|
||||
|
||||
from pydantic import ConfigDict
|
||||
@@ -121,7 +120,7 @@ def _get_tool_name(function: "FunctionInfo") -> str:
|
||||
return tool_name
|
||||
|
||||
|
||||
def _get_default_workspace_client() -> "WorkspaceClient":
|
||||
def _get_default_workspace_client() -> Any:
|
||||
try:
|
||||
from databricks.sdk import WorkspaceClient
|
||||
except ImportError as e:
|
||||
@@ -137,7 +136,7 @@ class UCFunctionToolkit(BaseToolkit):
|
||||
description="The ID of a Databricks SQL Warehouse to execute functions."
|
||||
)
|
||||
|
||||
workspace_client: "WorkspaceClient" = Field(
|
||||
workspace_client: Any = Field(
|
||||
default_factory=_get_default_workspace_client,
|
||||
description="Databricks workspace client.",
|
||||
)
|
||||
|
||||
@@ -443,6 +443,12 @@ class AzureSearch(VectorStore):
|
||||
logger.debug("Nothing to insert, skipping.")
|
||||
return []
|
||||
|
||||
# when `keys` are not passed in and there is `ids` in kwargs, use those instead
|
||||
# base class expects `ids` passed in rather than `keys`
|
||||
# https://github.com/langchain-ai/langchain/blob/4cdaca67dc51dba887289f56c6fead3c1a52f97d/libs/core/langchain_core/vectorstores/base.py#L65
|
||||
if (not keys) and ("ids" in kwargs) and (len(kwargs["ids"]) == len(embeddings)):
|
||||
keys = kwargs["ids"]
|
||||
|
||||
return self.add_embeddings(zip(texts, embeddings), metadatas, keys=keys)
|
||||
|
||||
async def aadd_texts(
|
||||
@@ -467,6 +473,12 @@ class AzureSearch(VectorStore):
|
||||
logger.debug("Nothing to insert, skipping.")
|
||||
return []
|
||||
|
||||
# when `keys` are not passed in and there is `ids` in kwargs, use those instead
|
||||
# base class expects `ids` passed in rather than `keys`
|
||||
# https://github.com/langchain-ai/langchain/blob/4cdaca67dc51dba887289f56c6fead3c1a52f97d/libs/core/langchain_core/vectorstores/base.py#L65
|
||||
if (not keys) and ("ids" in kwargs) and (len(kwargs["ids"]) == len(embeddings)):
|
||||
keys = kwargs["ids"]
|
||||
|
||||
return await self.aadd_embeddings(zip(texts, embeddings), metadatas, keys=keys)
|
||||
|
||||
def add_embeddings(
|
||||
@@ -483,9 +495,13 @@ class AzureSearch(VectorStore):
|
||||
data = []
|
||||
for i, (text, embedding) in enumerate(text_embeddings):
|
||||
# Use provided key otherwise use default key
|
||||
key = keys[i] if keys else str(uuid.uuid4())
|
||||
# Encoding key for Azure Search valid characters
|
||||
key = base64.urlsafe_b64encode(bytes(key, "utf-8")).decode("ascii")
|
||||
if keys:
|
||||
key = keys[i]
|
||||
else:
|
||||
key = str(uuid.uuid4())
|
||||
# Encoding key for Azure Search valid characters
|
||||
key = base64.urlsafe_b64encode(bytes(key, "utf-8")).decode("ascii")
|
||||
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
# Add data to index
|
||||
# Additional metadata to fields mapping
|
||||
|
||||
@@ -595,11 +595,8 @@ class Neo4jVector(VectorStore):
|
||||
query: str,
|
||||
*,
|
||||
params: Optional[dict] = None,
|
||||
retry_on_session_expired: bool = True,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
This method sends a Cypher query to the connected Neo4j database
|
||||
and returns the results as a list of dictionaries.
|
||||
"""Query Neo4j database with retries and exponential backoff.
|
||||
|
||||
Args:
|
||||
query (str): The Cypher query to execute.
|
||||
@@ -608,24 +605,38 @@ class Neo4jVector(VectorStore):
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of dictionaries containing the query results.
|
||||
"""
|
||||
from neo4j.exceptions import CypherSyntaxError, SessionExpired
|
||||
from neo4j import Query
|
||||
from neo4j.exceptions import Neo4jError
|
||||
|
||||
params = params or {}
|
||||
with self._driver.session(database=self._database) as session:
|
||||
try:
|
||||
data = session.run(query, params)
|
||||
return [r.data() for r in data]
|
||||
except CypherSyntaxError as e:
|
||||
raise ValueError(f"Cypher Statement is not valid\n{e}")
|
||||
except (
|
||||
SessionExpired
|
||||
) as e: # Session expired is a transient error that can be retried
|
||||
if retry_on_session_expired:
|
||||
return self.query(
|
||||
query, params=params, retry_on_session_expired=False
|
||||
try:
|
||||
data, _, _ = self._driver.execute_query(
|
||||
query, database=self._database, parameters_=params
|
||||
)
|
||||
return [r.data() for r in data]
|
||||
except Neo4jError as e:
|
||||
if not (
|
||||
(
|
||||
( # isCallInTransactionError
|
||||
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
|
||||
or e.code
|
||||
== "Neo.DatabaseError.Transaction.TransactionStartFailed"
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
and "in an implicit transaction" in e.message
|
||||
)
|
||||
or ( # isPeriodicCommitError
|
||||
e.code == "Neo.ClientError.Statement.SemanticError"
|
||||
and (
|
||||
"in an open transaction is not possible" in e.message
|
||||
or "tried to execute in an explicit transaction" in e.message
|
||||
)
|
||||
)
|
||||
):
|
||||
raise
|
||||
# Fallback to allow implicit transactions
|
||||
with self._driver.session() as session:
|
||||
data = session.run(Query(text=query), params)
|
||||
return [r.data() for r in data]
|
||||
|
||||
def verify_version(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -144,7 +144,7 @@ class TencentVectorDB(VectorStore):
|
||||
|
||||
In order to use this you need to have a database instance.
|
||||
See the following documentation for details:
|
||||
https://cloud.tencent.com/document/product/1709/94951
|
||||
https://cloud.tencent.com/document/product/1709/104489
|
||||
"""
|
||||
|
||||
field_id: str = "id"
|
||||
|
||||
@@ -20,7 +20,7 @@ count=$(git grep -E '(@root_validator)|(@validator)|(@field_validator)|(@pre_ini
|
||||
# PRs that increase the current count will not be accepted.
|
||||
# PRs that decrease update the code in the repository
|
||||
# and allow decreasing the count of are welcome!
|
||||
current_count=129
|
||||
current_count=128
|
||||
|
||||
if [ "$count" -gt "$current_count" ]; then
|
||||
echo "The PR seems to be introducing new usage of @root_validator and/or @field_validator."
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Sequence, Union
|
||||
|
||||
@@ -11,7 +10,6 @@ from langchain_community.document_loaders import (
|
||||
PDFMinerPDFasHTMLLoader,
|
||||
PyMuPDFLoader,
|
||||
PyPDFium2Loader,
|
||||
PyPDFLoader,
|
||||
UnstructuredPDFLoader,
|
||||
)
|
||||
|
||||
@@ -86,37 +84,6 @@ def test_pdfminer_pdf_as_html_loader() -> None:
|
||||
assert len(docs) == 1
|
||||
|
||||
|
||||
def test_pypdf_loader() -> None:
|
||||
"""Test PyPDFLoader."""
|
||||
file_path = Path(__file__).parent.parent / "examples/hello.pdf"
|
||||
loader = PyPDFLoader(str(file_path))
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 1
|
||||
|
||||
file_path = Path(__file__).parent.parent / "examples/layout-parser-paper.pdf"
|
||||
loader = PyPDFLoader(str(file_path))
|
||||
|
||||
docs = loader.load()
|
||||
assert len(docs) == 16
|
||||
|
||||
|
||||
def test_pypdf_loader_with_layout() -> None:
|
||||
"""Test PyPDFLoader with layout mode."""
|
||||
file_path = Path(__file__).parent.parent / "examples/layout-parser-paper.pdf"
|
||||
loader = PyPDFLoader(str(file_path), extraction_mode="layout")
|
||||
|
||||
docs = loader.load()
|
||||
first_page = docs[0].page_content
|
||||
|
||||
expected = (
|
||||
Path(__file__).parent.parent / "examples/layout-parser-paper-page-1.txt"
|
||||
).read_text(encoding="utf-8")
|
||||
cleaned_first_page = re.sub(r"\x00", "", first_page)
|
||||
cleaned_expected = re.sub(r"\x00", "", expected)
|
||||
assert cleaned_first_page == cleaned_expected
|
||||
|
||||
|
||||
def test_pypdfium2_loader() -> None:
|
||||
"""Test PyPDFium2Loader."""
|
||||
file_path = Path(__file__).parent.parent / "examples/hello.pdf"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from langchain_core.graph_vectorstores.links import Link
|
||||
|
||||
from langchain_community.graph_vectorstores.extractors import GLiNERLinkExtractor
|
||||
from langchain_community.graph_vectorstores.links import Link
|
||||
|
||||
PAGE_1 = """
|
||||
Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃ'tjɐnu
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from langchain_core.graph_vectorstores.links import Link
|
||||
|
||||
from langchain_community.graph_vectorstores.extractors import KeybertLinkExtractor
|
||||
from langchain_community.graph_vectorstores.links import Link
|
||||
|
||||
PAGE_1 = """
|
||||
Supervised learning is the machine learning task of learning a function that
|
||||
|
||||
@@ -4,9 +4,9 @@ from typing import Iterable, List, Optional, Type
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.graph_vectorstores.links import METADATA_LINKS_KEY, Link
|
||||
|
||||
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
|
||||
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
|
||||
|
||||
CASSANDRA_DEFAULT_KEYSPACE = "graph_test_keyspace"
|
||||
|
||||
|
||||
@@ -1,28 +1,17 @@
|
||||
"""Test sambanova API wrapper.
|
||||
|
||||
In order to run this test, you need to have an sambaverse api key,
|
||||
and a sambaverse base url, project id, endpoint id, and api key.
|
||||
You'll then need to set SAMBAVERSE_API_KEY, SAMBASTUDIO_BASE_URL,
|
||||
In order to run this test, you need to have a sambastudio base url,
|
||||
project id, endpoint id, and api key.
|
||||
You'll then need to set SAMBASTUDIO_BASE_URL, SAMBASTUDIO_BASE_URI
|
||||
SAMBASTUDIO_PROJECT_ID, SAMBASTUDIO_ENDPOINT_ID, and SAMBASTUDIO_API_KEY
|
||||
environment variables.
|
||||
"""
|
||||
|
||||
from langchain_community.llms.sambanova import SambaStudio, Sambaverse
|
||||
|
||||
|
||||
def test_sambaverse_call() -> None:
|
||||
"""Test simple non-streaming call to sambaverse."""
|
||||
llm = Sambaverse(
|
||||
sambaverse_model_name="Meta/llama-2-7b-chat-hf",
|
||||
model_kwargs={"select_expert": "llama-2-7b-chat-hf"},
|
||||
)
|
||||
output = llm.invoke("What is LangChain")
|
||||
assert output
|
||||
assert isinstance(output, str)
|
||||
from langchain_community.llms.sambanova import SambaStudio
|
||||
|
||||
|
||||
def test_sambastudio_call() -> None:
|
||||
"""Test simple non-streaming call to sambaverse."""
|
||||
"""Test simple non-streaming call to sambastudio."""
|
||||
llm = SambaStudio()
|
||||
output = llm.invoke("What is LangChain")
|
||||
assert output
|
||||
|
||||
@@ -55,6 +55,7 @@ EXPECTED_ALL = [
|
||||
"DedocFileLoader",
|
||||
"DedocPDFLoader",
|
||||
"PebbloSafeLoader",
|
||||
"PebbloTextLoader",
|
||||
"DiffbotLoader",
|
||||
"DirectoryLoader",
|
||||
"DiscordChatLoader",
|
||||
|
||||
62
libs/community/tests/unit_tests/document_loaders/test_pdf.py
Normal file
62
libs/community/tests/unit_tests/document_loaders/test_pdf.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
|
||||
path_to_simple_pdf = (
|
||||
Path(__file__).parent.parent.parent / "integration_tests/examples/hello.pdf"
|
||||
)
|
||||
path_to_layout_pdf = (
|
||||
Path(__file__).parent.parent
|
||||
/ "document_loaders/sample_documents/layout-parser-paper.pdf"
|
||||
)
|
||||
path_to_layout_pdf_txt = (
|
||||
Path(__file__).parent.parent.parent
|
||||
/ "integration_tests/examples/layout-parser-paper-page-1.txt"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("pypdf")
|
||||
def test_pypdf_loader() -> None:
|
||||
"""Test PyPDFLoader."""
|
||||
loader = PyPDFLoader(str(path_to_simple_pdf))
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 1
|
||||
|
||||
loader = PyPDFLoader(str(path_to_layout_pdf))
|
||||
|
||||
docs = loader.load()
|
||||
assert len(docs) == 16
|
||||
for page, doc in enumerate(docs):
|
||||
assert doc.metadata["page"] == page
|
||||
assert doc.metadata["source"].endswith("layout-parser-paper.pdf")
|
||||
assert len(doc.page_content) > 10
|
||||
|
||||
first_page = docs[0].page_content
|
||||
for expected in ["LayoutParser", "A Unified Toolkit"]:
|
||||
assert expected in first_page
|
||||
|
||||
|
||||
@pytest.mark.requires("pypdf")
|
||||
def test_pypdf_loader_with_layout() -> None:
|
||||
"""Test PyPDFLoader with layout mode."""
|
||||
loader = PyPDFLoader(str(path_to_layout_pdf), extraction_mode="layout")
|
||||
|
||||
docs = loader.load()
|
||||
assert len(docs) == 16
|
||||
for page, doc in enumerate(docs):
|
||||
assert doc.metadata["page"] == page
|
||||
assert doc.metadata["source"].endswith("layout-parser-paper.pdf")
|
||||
assert len(doc.page_content) > 10
|
||||
|
||||
first_page = docs[0].page_content
|
||||
for expected in ["LayoutParser", "A Unified Toolkit"]:
|
||||
assert expected in first_page
|
||||
|
||||
expected = path_to_layout_pdf_txt.read_text(encoding="utf-8")
|
||||
cleaned_first_page = re.sub(r"\x00", "", first_page)
|
||||
cleaned_expected = re.sub(r"\x00", "", expected)
|
||||
assert cleaned_first_page == cleaned_expected
|
||||
@@ -25,6 +25,11 @@ def test_pebblo_import() -> None:
|
||||
from langchain_community.document_loaders import PebbloSafeLoader # noqa: F401
|
||||
|
||||
|
||||
def test_pebblo_text_loader_import() -> None:
|
||||
"""Test that the Pebblo text loader can be imported."""
|
||||
from langchain_community.document_loaders import PebbloTextLoader # noqa: F401
|
||||
|
||||
|
||||
def test_empty_filebased_loader(mocker: MockerFixture) -> None:
|
||||
"""Test basic file based csv loader."""
|
||||
# Setup
|
||||
@@ -146,3 +151,42 @@ def test_pebblo_safe_loader_api_key() -> None:
|
||||
# Assert
|
||||
assert loader.pb_client.api_key == api_key
|
||||
assert loader.pb_client.classifier_location == "local"
|
||||
|
||||
|
||||
def test_pebblo_text_loader(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test loading in-memory text with PebbloTextLoader and PebbloSafeLoader.
|
||||
"""
|
||||
# Setup
|
||||
from langchain_community.document_loaders import PebbloSafeLoader, PebbloTextLoader
|
||||
|
||||
mocker.patch.multiple(
|
||||
"requests",
|
||||
get=MockResponse(json_data={"data": ""}, status_code=200),
|
||||
post=MockResponse(json_data={"data": ""}, status_code=200),
|
||||
)
|
||||
|
||||
text = "This is a test text."
|
||||
source = "fake_source"
|
||||
expected_docs = [
|
||||
Document(
|
||||
metadata={
|
||||
"full_path": source,
|
||||
"pb_checksum": None,
|
||||
},
|
||||
page_content=text,
|
||||
),
|
||||
]
|
||||
|
||||
# Exercise
|
||||
texts = [text]
|
||||
loader = PebbloSafeLoader(
|
||||
PebbloTextLoader(texts, source=source),
|
||||
"dummy_app_name",
|
||||
"dummy_owner",
|
||||
"dummy_description",
|
||||
)
|
||||
result = loader.load()
|
||||
|
||||
# Assert
|
||||
assert result == expected_docs
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from langchain_core.graph_vectorstores.links import Link
|
||||
|
||||
from langchain_community.graph_vectorstores.extractors import HierarchyLinkExtractor
|
||||
from langchain_community.graph_vectorstores.links import Link
|
||||
|
||||
PATH_1 = ["Root", "H1", "h2"]
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from langchain_core.graph_vectorstores import Link
|
||||
|
||||
from langchain_community.graph_vectorstores import Link
|
||||
from langchain_community.graph_vectorstores.extractors import (
|
||||
HtmlInput,
|
||||
HtmlLinkExtractor,
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import Set
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores.links import Link, get_links
|
||||
|
||||
from langchain_community.graph_vectorstores.extractors import (
|
||||
LinkExtractor,
|
||||
LinkExtractorTransformer,
|
||||
)
|
||||
from langchain_community.graph_vectorstores.links import Link, get_links
|
||||
|
||||
TEXT1 = "Text1"
|
||||
TEXT2 = "Text2"
|
||||
|
||||
@@ -77,7 +77,6 @@ EXPECT_ALL = [
|
||||
"RWKV",
|
||||
"Replicate",
|
||||
"SagemakerEndpoint",
|
||||
"Sambaverse",
|
||||
"SambaStudio",
|
||||
"SelfHostedHuggingFaceLLM",
|
||||
"SelfHostedPipeline",
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import pytest
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores.base import (
|
||||
|
||||
from langchain_community.graph_vectorstores.base import (
|
||||
Node,
|
||||
_documents_to_nodes,
|
||||
_texts_to_nodes,
|
||||
)
|
||||
from langchain_core.graph_vectorstores.links import Link
|
||||
from langchain_community.graph_vectorstores.links import Link
|
||||
|
||||
|
||||
def test_texts_to_nodes() -> None:
|
||||
@@ -190,3 +190,40 @@ def test_additional_search_options() -> None:
|
||||
)
|
||||
assert vector_store.client is not None
|
||||
assert vector_store.client._api_version == "test"
|
||||
|
||||
|
||||
@pytest.mark.requires("azure.search.documents")
|
||||
def test_ids_used_correctly() -> None:
|
||||
"""Check whether vector store uses the document ids when provided with them."""
|
||||
from azure.search.documents import SearchClient
|
||||
from azure.search.documents.indexes import SearchIndexClient
|
||||
from langchain_core.documents import Document
|
||||
|
||||
class Response:
|
||||
def __init__(self) -> None:
|
||||
self.succeeded: bool = True
|
||||
|
||||
def mock_upload_documents(self, documents: List[object]) -> List[Response]: # type: ignore[no-untyped-def]
|
||||
# assume all documents uploaded successfuly
|
||||
response = [Response() for _ in documents]
|
||||
return response
|
||||
|
||||
documents = [
|
||||
Document(
|
||||
page_content="page zero Lorem Ipsum",
|
||||
metadata={"source": "document.pdf", "page": 0, "id": "ID-document-1"},
|
||||
),
|
||||
Document(
|
||||
page_content="page one Lorem Ipsum",
|
||||
metadata={"source": "document.pdf", "page": 1, "id": "ID-document-2"},
|
||||
),
|
||||
]
|
||||
ids_provided = [i.metadata.get("id") for i in documents]
|
||||
|
||||
with patch.object(
|
||||
SearchClient, "upload_documents", mock_upload_documents
|
||||
), patch.object(SearchIndexClient, "get_index", mock_default_index):
|
||||
vector_store = create_vector_store()
|
||||
ids_used_at_upload = vector_store.add_documents(documents, ids=ids_provided)
|
||||
assert len(ids_provided) == len(ids_used_at_upload)
|
||||
assert ids_provided == ids_used_at_upload
|
||||
|
||||
@@ -53,7 +53,7 @@ LangChain Core compiles LCEL sequences to an _optimized execution plan_, with au
|
||||
|
||||
For more check out the [LCEL docs](https://python.langchain.com/docs/expression_language/).
|
||||
|
||||

|
||||

|
||||
|
||||
For more advanced use cases, also check out [LangGraph](https://github.com/langchain-ai/langgraph), which is a graph-based runner for cyclic and recursive LLM workflows.
|
||||
|
||||
|
||||
@@ -14,7 +14,8 @@ import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Generator, Type, TypeVar, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Callable, TypeVar, Union, cast
|
||||
|
||||
from langchain_core._api.internal import is_caller_internal
|
||||
|
||||
@@ -26,7 +27,7 @@ class LangChainBetaWarning(DeprecationWarning):
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Union[Callable[..., Any], Type])
|
||||
T = TypeVar("T", bound=Union[Callable[..., Any], type])
|
||||
|
||||
|
||||
def beta(
|
||||
@@ -154,7 +155,7 @@ def beta(
|
||||
_name = _name or obj.fget.__qualname__
|
||||
old_doc = obj.__doc__
|
||||
|
||||
class _beta_property(property):
|
||||
class _BetaProperty(property):
|
||||
"""A beta property."""
|
||||
|
||||
def __init__(self, fget=None, fset=None, fdel=None, doc=None):
|
||||
@@ -185,7 +186,7 @@ def beta(
|
||||
|
||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any:
|
||||
"""Finalize the property."""
|
||||
return _beta_property(
|
||||
return _BetaProperty(
|
||||
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
|
||||
)
|
||||
|
||||
|
||||
@@ -14,11 +14,10 @@ import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from collections.abc import Generator
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generator,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
@@ -41,7 +40,7 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
|
||||
|
||||
|
||||
# Last Any should be FieldInfoV1 but this leads to circular imports
|
||||
T = TypeVar("T", bound=Union[Type, Callable[..., Any], Any])
|
||||
T = TypeVar("T", bound=Union[type, Callable[..., Any], Any])
|
||||
|
||||
|
||||
def _validate_deprecation_params(
|
||||
@@ -262,10 +261,10 @@ def deprecated(
|
||||
if not _obj_type:
|
||||
_obj_type = "attribute"
|
||||
wrapped = None
|
||||
_name = _name or cast(Union[Type, Callable], obj.fget).__qualname__
|
||||
_name = _name or cast(Union[type, Callable], obj.fget).__qualname__
|
||||
old_doc = obj.__doc__
|
||||
|
||||
class _deprecated_property(property):
|
||||
class _DeprecatedProperty(property):
|
||||
"""A deprecated property."""
|
||||
|
||||
def __init__(self, fget=None, fset=None, fdel=None, doc=None): # type: ignore[no-untyped-def]
|
||||
@@ -298,13 +297,13 @@ def deprecated(
|
||||
"""Finalize the property."""
|
||||
return cast(
|
||||
T,
|
||||
_deprecated_property(
|
||||
_DeprecatedProperty(
|
||||
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
_name = _name or cast(Union[Type, Callable], obj).__qualname__
|
||||
_name = _name or cast(Union[type, Callable], obj).__qualname__
|
||||
if not _obj_type:
|
||||
# edge case: when a function is within another function
|
||||
# within a test, this will call it a "method" not a "function"
|
||||
@@ -333,9 +332,26 @@ def deprecated(
|
||||
old_doc = ""
|
||||
|
||||
# Modify the docstring to include a deprecation notice.
|
||||
if (
|
||||
_alternative
|
||||
and _alternative.split(".")[-1].lower() == _alternative.split(".")[-1]
|
||||
):
|
||||
_alternative = f":meth:`~{_alternative}`"
|
||||
elif _alternative:
|
||||
_alternative = f":class:`~{_alternative}`"
|
||||
|
||||
if (
|
||||
_alternative_import
|
||||
and _alternative_import.split(".")[-1].lower()
|
||||
== _alternative_import.split(".")[-1]
|
||||
):
|
||||
_alternative_import = f":meth:`~{_alternative_import}`"
|
||||
elif _alternative_import:
|
||||
_alternative_import = f":class:`~{_alternative_import}`"
|
||||
|
||||
components = [
|
||||
_message,
|
||||
f"Use ``{_alternative}`` instead." if _alternative else "",
|
||||
f"Use {_alternative} instead." if _alternative else "",
|
||||
f"Use ``{_alternative_import}`` instead." if _alternative_import else "",
|
||||
_addendum,
|
||||
]
|
||||
|
||||
@@ -25,7 +25,8 @@ The schemas for the agents themselves are defined in langchain.agents.agent.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, List, Literal, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import (
|
||||
@@ -71,7 +72,7 @@ class AgentAction(Serializable):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
Default is ["langchain", "schema", "agent"]."""
|
||||
return ["langchain", "schema", "agent"]
|
||||
@@ -145,7 +146,7 @@ class AgentFinish(Serializable):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "agent"]
|
||||
|
||||
|
||||
@@ -1,19 +1,13 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from collections.abc import Awaitable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from itertools import groupby
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
@@ -30,7 +24,7 @@ from langchain_core.runnables.config import RunnableConfig, ensure_config, patch
|
||||
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
|
||||
|
||||
T = TypeVar("T")
|
||||
Values = Dict[Union[asyncio.Event, threading.Event], Any]
|
||||
Values = dict[Union[asyncio.Event, threading.Event], Any]
|
||||
CONTEXT_CONFIG_PREFIX = "__context__/"
|
||||
CONTEXT_CONFIG_SUFFIX_GET = "/get"
|
||||
CONTEXT_CONFIG_SUFFIX_SET = "/set"
|
||||
@@ -70,10 +64,10 @@ def _key_from_id(id_: str) -> str:
|
||||
|
||||
def _config_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: List[Runnable],
|
||||
steps: list[Runnable],
|
||||
setter: Callable,
|
||||
getter: Callable,
|
||||
event_cls: Union[Type[threading.Event], Type[asyncio.Event]],
|
||||
event_cls: Union[type[threading.Event], type[asyncio.Event]],
|
||||
) -> RunnableConfig:
|
||||
if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})):
|
||||
return config
|
||||
@@ -99,10 +93,10 @@ def _config_with_context(
|
||||
}
|
||||
|
||||
values: Values = {}
|
||||
events: DefaultDict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
|
||||
events: defaultdict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
|
||||
event_cls
|
||||
)
|
||||
context_funcs: Dict[str, Callable[[], Any]] = {}
|
||||
context_funcs: dict[str, Callable[[], Any]] = {}
|
||||
for key, group in grouped_by_key.items():
|
||||
getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)]
|
||||
setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)]
|
||||
@@ -129,7 +123,7 @@ def _config_with_context(
|
||||
|
||||
def aconfig_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: List[Runnable],
|
||||
steps: list[Runnable],
|
||||
) -> RunnableConfig:
|
||||
"""Asynchronously patch a runnable config with context getters and setters.
|
||||
|
||||
@@ -145,7 +139,7 @@ def aconfig_with_context(
|
||||
|
||||
def config_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: List[Runnable],
|
||||
steps: list[Runnable],
|
||||
) -> RunnableConfig:
|
||||
"""Patch a runnable config with context getters and setters.
|
||||
|
||||
@@ -165,13 +159,13 @@ class ContextGet(RunnableSerializable):
|
||||
|
||||
prefix: str = ""
|
||||
|
||||
key: Union[str, List[str]]
|
||||
key: Union[str, list[str]]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ContextGet({_print_keys(self.key)})"
|
||||
|
||||
@property
|
||||
def ids(self) -> List[str]:
|
||||
def ids(self) -> list[str]:
|
||||
prefix = self.prefix + "/" if self.prefix else ""
|
||||
keys = self.key if isinstance(self.key, list) else [self.key]
|
||||
return [
|
||||
@@ -180,7 +174,7 @@ class ContextGet(RunnableSerializable):
|
||||
]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
return super().config_specs + [
|
||||
ConfigurableFieldSpec(
|
||||
id=id_,
|
||||
@@ -256,7 +250,7 @@ class ContextSet(RunnableSerializable):
|
||||
return f"ContextSet({_print_keys(list(self.keys.keys()))})"
|
||||
|
||||
@property
|
||||
def ids(self) -> List[str]:
|
||||
def ids(self) -> list[str]:
|
||||
prefix = self.prefix + "/" if self.prefix else ""
|
||||
return [
|
||||
f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}"
|
||||
@@ -264,7 +258,7 @@ class ContextSet(RunnableSerializable):
|
||||
]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
mapper_config_specs = [
|
||||
s
|
||||
for mapper in self.keys.values()
|
||||
@@ -364,7 +358,7 @@ class Context:
|
||||
return PrefixContext(prefix=scope)
|
||||
|
||||
@staticmethod
|
||||
def getter(key: Union[str, List[str]], /) -> ContextGet:
|
||||
def getter(key: Union[str, list[str]], /) -> ContextGet:
|
||||
return ContextGet(key=key)
|
||||
|
||||
@staticmethod
|
||||
@@ -385,7 +379,7 @@ class PrefixContext:
|
||||
def __init__(self, prefix: str = ""):
|
||||
self.prefix = prefix
|
||||
|
||||
def getter(self, key: Union[str, List[str]], /) -> ContextGet:
|
||||
def getter(self, key: Union[str, list[str]], /) -> ContextGet:
|
||||
return ContextGet(key=key, prefix=self.prefix)
|
||||
|
||||
def setter(
|
||||
|
||||
@@ -23,7 +23,8 @@ Cache directly competes with Memory. See documentation for Pros and Cons.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.outputs import Generation
|
||||
from langchain_core.runnables import run_in_executor
|
||||
@@ -157,7 +158,7 @@ class InMemoryCache(BaseCache):
|
||||
Raises:
|
||||
ValueError: If maxsize is less than or equal to 0.
|
||||
"""
|
||||
self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {}
|
||||
self._cache: dict[tuple[str, str], RETURN_VAL_TYPE] = {}
|
||||
if maxsize is not None and maxsize <= 0:
|
||||
raise ValueError("maxsize must be greater than 0")
|
||||
self._maxsize = maxsize
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
@@ -118,7 +119,7 @@ class ChainManagerMixin:
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
outputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -222,13 +223,13 @@ class CallbackManagerMixin:
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM starts running.
|
||||
@@ -249,13 +250,13 @@ class CallbackManagerMixin:
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chat model starts running.
|
||||
@@ -280,13 +281,13 @@ class CallbackManagerMixin:
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when the Retriever starts running.
|
||||
@@ -303,13 +304,13 @@ class CallbackManagerMixin:
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chain starts running.
|
||||
@@ -326,14 +327,14 @@ class CallbackManagerMixin:
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when the tool starts running.
|
||||
@@ -393,8 +394,8 @@ class RunManagerMixin:
|
||||
data: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Override to define a handler for a custom event.
|
||||
@@ -470,13 +471,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running.
|
||||
@@ -497,13 +498,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chat model starts running.
|
||||
@@ -533,7 +534,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled.
|
||||
@@ -554,7 +555,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM ends running.
|
||||
@@ -573,7 +574,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors.
|
||||
@@ -590,13 +591,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when a chain starts running.
|
||||
@@ -613,11 +614,11 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
outputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when a chain ends running.
|
||||
@@ -636,7 +637,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors.
|
||||
@@ -651,14 +652,14 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when the tool starts running.
|
||||
@@ -680,7 +681,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when the tool ends running.
|
||||
@@ -699,7 +700,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors.
|
||||
@@ -718,7 +719,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on an arbitrary text.
|
||||
@@ -754,7 +755,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on agent action.
|
||||
@@ -773,7 +774,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on the agent end.
|
||||
@@ -788,13 +789,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on the retriever start.
|
||||
@@ -815,7 +816,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on the retriever end.
|
||||
@@ -833,7 +834,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever error.
|
||||
@@ -852,8 +853,8 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
data: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Override to define a handler for a custom event.
|
||||
@@ -880,14 +881,14 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
||||
handlers: list[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[list[BaseCallbackHandler]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
inheritable_tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize callback manager.
|
||||
|
||||
@@ -901,8 +902,8 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
Default is None.
|
||||
metadata (Optional[Dict[str, Any]]): The metadata. Default is None.
|
||||
"""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
self.inheritable_handlers: List[BaseCallbackHandler] = (
|
||||
self.handlers: list[BaseCallbackHandler] = handlers
|
||||
self.inheritable_handlers: list[BaseCallbackHandler] = (
|
||||
inheritable_handlers or []
|
||||
)
|
||||
self.parent_run_id: Optional[UUID] = parent_run_id
|
||||
@@ -1002,7 +1003,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
self.inheritable_handlers.remove(handler)
|
||||
|
||||
def set_handlers(
|
||||
self, handlers: List[BaseCallbackHandler], inherit: bool = True
|
||||
self, handlers: list[BaseCallbackHandler], inherit: bool = True
|
||||
) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager.
|
||||
|
||||
@@ -1024,7 +1025,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
"""
|
||||
self.set_handlers([handler], inherit=inherit)
|
||||
|
||||
def add_tags(self, tags: List[str], inherit: bool = True) -> None:
|
||||
def add_tags(self, tags: list[str], inherit: bool = True) -> None:
|
||||
"""Add tags to the callback manager.
|
||||
|
||||
Args:
|
||||
@@ -1038,7 +1039,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
if inherit:
|
||||
self.inheritable_tags.extend(tags)
|
||||
|
||||
def remove_tags(self, tags: List[str]) -> None:
|
||||
def remove_tags(self, tags: list[str]) -> None:
|
||||
"""Remove tags from the callback manager.
|
||||
|
||||
Args:
|
||||
@@ -1048,7 +1049,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
self.tags.remove(tag)
|
||||
self.inheritable_tags.remove(tag)
|
||||
|
||||
def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None:
|
||||
def add_metadata(self, metadata: dict[str, Any], inherit: bool = True) -> None:
|
||||
"""Add metadata to the callback manager.
|
||||
|
||||
Args:
|
||||
@@ -1059,7 +1060,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
if inherit:
|
||||
self.inheritable_metadata.update(metadata)
|
||||
|
||||
def remove_metadata(self, keys: List[str]) -> None:
|
||||
def remove_metadata(self, keys: list[str]) -> None:
|
||||
"""Remove metadata from the callback manager.
|
||||
|
||||
Args:
|
||||
@@ -1070,4 +1071,4 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
self.inheritable_metadata.pop(key)
|
||||
|
||||
|
||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||
Callbacks = Optional[Union[list[BaseCallbackHandler], BaseCallbackManager]]
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional, TextIO, cast
|
||||
from typing import Any, Optional, TextIO, cast
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
@@ -35,7 +35,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
self.file.close()
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain.
|
||||
|
||||
@@ -51,7 +51,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
file=self.file,
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -5,21 +5,15 @@ import functools
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Coroutine, Generator, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextvars import copy_context
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
@@ -64,12 +58,12 @@ def trace_as_chain_group(
|
||||
group_name: str,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
project_name: Optional[str] = None,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> Generator[CallbackManagerForChainGroup, None, None]:
|
||||
"""Get a callback manager for a chain group in a context manager.
|
||||
Useful for grouping different calls together as a single run even if
|
||||
@@ -144,12 +138,12 @@ async def atrace_as_chain_group(
|
||||
group_name: str,
|
||||
callback_manager: Optional[AsyncCallbackManager] = None,
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
project_name: Optional[str] = None,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]:
|
||||
"""Get an async callback manager for a chain group in a context manager.
|
||||
Useful for grouping different async calls together as a single run even if
|
||||
@@ -240,7 +234,7 @@ def shielded(func: Func) -> Func:
|
||||
|
||||
|
||||
def handle_event(
|
||||
handlers: List[BaseCallbackHandler],
|
||||
handlers: list[BaseCallbackHandler],
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
*args: Any,
|
||||
@@ -258,10 +252,10 @@ def handle_event(
|
||||
*args: The arguments to pass to the event handler.
|
||||
**kwargs: The keyword arguments to pass to the event handler
|
||||
"""
|
||||
coros: List[Coroutine[Any, Any, Any]] = []
|
||||
coros: list[Coroutine[Any, Any, Any]] = []
|
||||
|
||||
try:
|
||||
message_strings: Optional[List[str]] = None
|
||||
message_strings: Optional[list[str]] = None
|
||||
for handler in handlers:
|
||||
try:
|
||||
if ignore_condition_name is None or not getattr(
|
||||
@@ -318,7 +312,7 @@ def handle_event(
|
||||
_run_coros(coros)
|
||||
|
||||
|
||||
def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
|
||||
def _run_coros(coros: list[Coroutine[Any, Any, Any]]) -> None:
|
||||
if hasattr(asyncio, "Runner"):
|
||||
# Python 3.11+
|
||||
# Run the coroutines in a new event loop, taking care to
|
||||
@@ -399,7 +393,7 @@ async def _ahandle_event_for_handler(
|
||||
|
||||
|
||||
async def ahandle_event(
|
||||
handlers: List[BaseCallbackHandler],
|
||||
handlers: list[BaseCallbackHandler],
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
*args: Any,
|
||||
@@ -446,13 +440,13 @@ class BaseRunManager(RunManagerMixin):
|
||||
self,
|
||||
*,
|
||||
run_id: UUID,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: List[BaseCallbackHandler],
|
||||
handlers: list[BaseCallbackHandler],
|
||||
inheritable_handlers: list[BaseCallbackHandler],
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
inheritable_tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize the run manager.
|
||||
|
||||
@@ -481,7 +475,7 @@ class BaseRunManager(RunManagerMixin):
|
||||
self.inheritable_metadata = inheritable_metadata or {}
|
||||
|
||||
@classmethod
|
||||
def get_noop_manager(cls: Type[BRM]) -> BRM:
|
||||
def get_noop_manager(cls: type[BRM]) -> BRM:
|
||||
"""Return a manager that doesn't perform any operations.
|
||||
|
||||
Returns:
|
||||
@@ -824,7 +818,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||
"""Callback manager for chain run."""
|
||||
|
||||
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: Union[dict[str, Any], Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running.
|
||||
|
||||
Args:
|
||||
@@ -929,7 +923,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
|
||||
@shielded
|
||||
async def on_chain_end(
|
||||
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
||||
self, outputs: Union[dict[str, Any], Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when a chain ends running.
|
||||
|
||||
@@ -1248,11 +1242,11 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[CallbackManagerForLLMRun]:
|
||||
) -> list[CallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running.
|
||||
|
||||
Args:
|
||||
@@ -1299,11 +1293,11 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[CallbackManagerForLLMRun]:
|
||||
) -> list[CallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running.
|
||||
|
||||
Args:
|
||||
@@ -1354,8 +1348,8 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
inputs: Union[dict[str, Any], Any],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForChainRun:
|
||||
@@ -1398,11 +1392,11 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
input_str: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForToolRun:
|
||||
"""Run when tool starts running.
|
||||
@@ -1453,7 +1447,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
query: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -1541,10 +1535,10 @@ class CallbackManager(BaseCallbackManager):
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
local_callbacks: Callbacks = None,
|
||||
verbose: bool = False,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
local_tags: Optional[List[str]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
local_metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_tags: Optional[list[str]] = None,
|
||||
local_tags: Optional[list[str]] = None,
|
||||
inheritable_metadata: Optional[dict[str, Any]] = None,
|
||||
local_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> CallbackManager:
|
||||
"""Configure the callback manager.
|
||||
|
||||
@@ -1583,8 +1577,8 @@ class CallbackManagerForChainGroup(CallbackManager):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
||||
handlers: list[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[list[BaseCallbackHandler]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
*,
|
||||
parent_run_manager: CallbackManagerForChainRun,
|
||||
@@ -1681,7 +1675,7 @@ class CallbackManagerForChainGroup(CallbackManager):
|
||||
manager.add_handler(handler, inherit=True)
|
||||
return manager
|
||||
|
||||
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: Union[dict[str, Any], Any], **kwargs: Any) -> None:
|
||||
"""Run when traced chain group ends.
|
||||
|
||||
Args:
|
||||
@@ -1716,11 +1710,11 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[AsyncCallbackManagerForLLMRun]:
|
||||
) -> list[AsyncCallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running.
|
||||
|
||||
Args:
|
||||
@@ -1779,11 +1773,11 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[AsyncCallbackManagerForLLMRun]:
|
||||
) -> list[AsyncCallbackManagerForLLMRun]:
|
||||
"""Async run when LLM starts running.
|
||||
|
||||
Args:
|
||||
@@ -1840,8 +1834,8 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
inputs: Union[dict[str, Any], Any],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCallbackManagerForChainRun:
|
||||
@@ -1886,7 +1880,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
input_str: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -1975,7 +1969,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
query: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -2027,10 +2021,10 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
local_callbacks: Callbacks = None,
|
||||
verbose: bool = False,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
local_tags: Optional[List[str]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
local_metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_tags: Optional[list[str]] = None,
|
||||
local_tags: Optional[list[str]] = None,
|
||||
inheritable_metadata: Optional[dict[str, Any]] = None,
|
||||
local_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncCallbackManager:
|
||||
"""Configure the async callback manager.
|
||||
|
||||
@@ -2069,8 +2063,8 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
||||
handlers: list[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[list[BaseCallbackHandler]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
*,
|
||||
parent_run_manager: AsyncCallbackManagerForChainRun,
|
||||
@@ -2169,7 +2163,7 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
|
||||
return manager
|
||||
|
||||
async def on_chain_end(
|
||||
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
||||
self, outputs: Union[dict[str, Any], Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when traced chain group ends.
|
||||
|
||||
@@ -2202,14 +2196,14 @@ H = TypeVar("H", bound=BaseCallbackHandler, covariant=True)
|
||||
|
||||
|
||||
def _configure(
|
||||
callback_manager_cls: Type[T],
|
||||
callback_manager_cls: type[T],
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
local_callbacks: Callbacks = None,
|
||||
verbose: bool = False,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
local_tags: Optional[List[str]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
local_metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_tags: Optional[list[str]] = None,
|
||||
local_tags: Optional[list[str]] = None,
|
||||
inheritable_metadata: Optional[dict[str, Any]] = None,
|
||||
local_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> T:
|
||||
"""Configure the callback manager.
|
||||
|
||||
@@ -2354,7 +2348,7 @@ def _configure(
|
||||
and handler_class is not None
|
||||
)
|
||||
if var.get() is not None or create_one:
|
||||
var_handler = var.get() or cast(Type[BaseCallbackHandler], handler_class)()
|
||||
var_handler = var.get() or cast(type[BaseCallbackHandler], handler_class)()
|
||||
if handler_class is None:
|
||||
if not any(
|
||||
handler is var_handler # direct pointer comparison
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
from langchain_core.utils import print_text
|
||||
@@ -23,7 +23,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
self.color = color
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain.
|
||||
|
||||
@@ -35,7 +35,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
|
||||
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") # noqa: T201
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
|
||||
@@ -17,7 +17,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler for streaming. Only works with LLMs that support streaming."""
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running.
|
||||
|
||||
@@ -29,8 +29,8 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running.
|
||||
@@ -68,7 +68,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when a chain starts running.
|
||||
|
||||
@@ -78,7 +78,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when a chain ends running.
|
||||
|
||||
Args:
|
||||
@@ -95,7 +95,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
self, serialized: dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when the tool starts running.
|
||||
|
||||
|
||||
@@ -18,7 +18,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -87,7 +88,7 @@ class BaseChatMessageHistory(ABC):
|
||||
f.write("[]")
|
||||
"""
|
||||
|
||||
messages: List[BaseMessage]
|
||||
messages: list[BaseMessage]
|
||||
"""A property or attribute that returns a list of messages.
|
||||
|
||||
In general, getting the messages may involve IO to the underlying
|
||||
@@ -95,7 +96,7 @@ class BaseChatMessageHistory(ABC):
|
||||
latency.
|
||||
"""
|
||||
|
||||
async def aget_messages(self) -> List[BaseMessage]:
|
||||
async def aget_messages(self) -> list[BaseMessage]:
|
||||
"""Async version of getting messages.
|
||||
|
||||
Can over-ride this method to provide an efficient async implementation.
|
||||
@@ -204,10 +205,10 @@ class InMemoryChatMessageHistory(BaseChatMessageHistory, BaseModel):
|
||||
Stores messages in a memory list.
|
||||
"""
|
||||
|
||||
messages: List[BaseMessage] = Field(default_factory=list)
|
||||
messages: list[BaseMessage] = Field(default_factory=list)
|
||||
"""A list of messages stored in memory."""
|
||||
|
||||
async def aget_messages(self) -> List[BaseMessage]:
|
||||
async def aget_messages(self) -> list[BaseMessage]:
|
||||
"""Async version of getting messages.
|
||||
|
||||
Can over-ride this method to provide an efficient async implementation.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterator, List
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langchain_core.chat_sessions import ChatSession
|
||||
|
||||
@@ -15,7 +15,7 @@ class BaseChatLoader(ABC):
|
||||
An iterator of chat sessions.
|
||||
"""
|
||||
|
||||
def load(self) -> List[ChatSession]:
|
||||
def load(self) -> list[ChatSession]:
|
||||
"""Eagerly load the chat sessions into memory.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""**Chat Sessions** are a collection of messages and function calls."""
|
||||
|
||||
from typing import Sequence, TypedDict
|
||||
from collections.abc import Sequence
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Optional
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.runnables import run_in_executor
|
||||
@@ -25,17 +26,17 @@ class BaseLoader(ABC): # noqa: B024
|
||||
|
||||
# Sub-classes should not implement this method directly. Instead, they
|
||||
# should implement the lazy load method.
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
"""Load data into Document objects."""
|
||||
return list(self.lazy_load())
|
||||
|
||||
async def aload(self) -> List[Document]:
|
||||
async def aload(self) -> list[Document]:
|
||||
"""Load data into Document objects."""
|
||||
return [document async for document in self.alazy_load()]
|
||||
|
||||
def load_and_split(
|
||||
self, text_splitter: Optional[TextSplitter] = None
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Load Documents and split into chunks. Chunks are returned as Documents.
|
||||
|
||||
Do not override this method. It should be considered to be deprecated!
|
||||
@@ -108,7 +109,7 @@ class BaseBlobParser(ABC):
|
||||
Generator of documents
|
||||
"""
|
||||
|
||||
def parse(self, blob: Blob) -> List[Document]:
|
||||
def parse(self, blob: Blob) -> list[Document]:
|
||||
"""Eagerly parse the blob into a document or documents.
|
||||
|
||||
This is a convenience method for interactive development environment.
|
||||
|
||||
@@ -8,7 +8,7 @@ In addition, content loading code should provide a lazy loading interface by def
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable
|
||||
from collections.abc import Iterable
|
||||
|
||||
# Re-export Blob and PathLike for backwards compatibility
|
||||
from langchain_core.documents.base import Blob as Blob
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Callable, Iterator, Optional, Sequence, Union
|
||||
from collections.abc import Iterator, Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from langsmith import Client as LangSmithClient
|
||||
|
||||
|
||||
@@ -2,9 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from io import BufferedReader, BytesIO
|
||||
from pathlib import PurePath
|
||||
from typing import Any, Dict, Generator, List, Literal, Optional, Union, cast
|
||||
from typing import Any, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
@@ -138,7 +139,7 @@ class Blob(BaseMedia):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_blob_is_valid(cls, values: Dict[str, Any]) -> Any:
|
||||
def check_blob_is_valid(cls, values: dict[str, Any]) -> Any:
|
||||
"""Verify that either data or path is provided."""
|
||||
if "data" not in values and "path" not in values:
|
||||
raise ValueError("Either data or path must be provided")
|
||||
@@ -285,7 +286,7 @@ class Document(BaseMedia):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "document"]
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""**Embeddings** interface."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
@@ -35,7 +34,7 @@ class Embeddings(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs.
|
||||
|
||||
Args:
|
||||
@@ -46,7 +45,7 @@ class Embeddings(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text.
|
||||
|
||||
Args:
|
||||
@@ -56,7 +55,7 @@ class Embeddings(ABC):
|
||||
Embedding.
|
||||
"""
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Asynchronous Embed search docs.
|
||||
|
||||
Args:
|
||||
@@ -67,7 +66,7 @@ class Embeddings(ABC):
|
||||
"""
|
||||
return await run_in_executor(None, self.embed_documents, texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Asynchronous Embed query text.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
# Please do not add additional fake embedding model implementations here.
|
||||
import hashlib
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -51,15 +50,15 @@ class FakeEmbeddings(Embeddings, BaseModel):
|
||||
size: int
|
||||
"""The size of the embedding vector."""
|
||||
|
||||
def _get_embedding(self) -> List[float]:
|
||||
def _get_embedding(self) -> list[float]:
|
||||
import numpy as np # type: ignore[import-not-found, import-untyped]
|
||||
|
||||
return list(np.random.normal(size=self.size))
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return [self._get_embedding() for _ in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
return self._get_embedding()
|
||||
|
||||
|
||||
@@ -106,7 +105,7 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
|
||||
size: int
|
||||
"""The size of the embedding vector."""
|
||||
|
||||
def _get_embedding(self, seed: int) -> List[float]:
|
||||
def _get_embedding(self, seed: int) -> list[float]:
|
||||
import numpy as np # type: ignore[import-not-found, import-untyped]
|
||||
|
||||
# set the seed for the random generator
|
||||
@@ -117,8 +116,8 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
|
||||
"""Get a seed for the random generator, using the hash of the text."""
|
||||
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return [self._get_embedding(seed=self._get_seed(_)) for _ in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
return self._get_embedding(seed=self._get_seed(text))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Interface for selecting examples to include in prompts."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
@@ -10,14 +10,14 @@ class BaseExampleSelector(ABC):
|
||||
"""Interface for selecting examples to include in prompts."""
|
||||
|
||||
@abstractmethod
|
||||
def add_example(self, example: Dict[str, str]) -> Any:
|
||||
def add_example(self, example: dict[str, str]) -> Any:
|
||||
"""Add new example to store.
|
||||
|
||||
Args:
|
||||
example: A dictionary with keys as input variables
|
||||
and values as their values."""
|
||||
|
||||
async def aadd_example(self, example: Dict[str, str]) -> Any:
|
||||
async def aadd_example(self, example: dict[str, str]) -> Any:
|
||||
"""Async add new example to store.
|
||||
|
||||
Args:
|
||||
@@ -27,14 +27,14 @@ class BaseExampleSelector(ABC):
|
||||
return await run_in_executor(None, self.add_example, example)
|
||||
|
||||
@abstractmethod
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Select which examples to use based on the inputs.
|
||||
|
||||
Args:
|
||||
input_variables: A dictionary with keys as input variables
|
||||
and values as their values."""
|
||||
|
||||
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Async select which examples to use based on the inputs.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Select examples based on length."""
|
||||
|
||||
import re
|
||||
from typing import Callable, Dict, List
|
||||
from typing import Callable
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
@@ -17,7 +17,7 @@ def _get_length_based(text: str) -> int:
|
||||
class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
"""Select examples based on length."""
|
||||
|
||||
examples: List[dict]
|
||||
examples: list[dict]
|
||||
"""A list of the examples that the prompt template expects."""
|
||||
|
||||
example_prompt: PromptTemplate
|
||||
@@ -29,10 +29,10 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
max_length: int = 2048
|
||||
"""Max length for the prompt, beyond which examples are cut."""
|
||||
|
||||
example_text_lengths: List[int] = Field(default_factory=list) # :meta private:
|
||||
example_text_lengths: list[int] = Field(default_factory=list) # :meta private:
|
||||
"""Length of each example."""
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> None:
|
||||
def add_example(self, example: dict[str, str]) -> None:
|
||||
"""Add new example to list.
|
||||
|
||||
Args:
|
||||
@@ -43,7 +43,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
string_example = self.example_prompt.format(**example)
|
||||
self.example_text_lengths.append(self.get_text_length(string_example))
|
||||
|
||||
async def aadd_example(self, example: Dict[str, str]) -> None:
|
||||
async def aadd_example(self, example: dict[str, str]) -> None:
|
||||
"""Async add new example to list.
|
||||
|
||||
Args:
|
||||
@@ -62,7 +62,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
self.example_text_lengths = [self.get_text_length(eg) for eg in string_examples]
|
||||
return self
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Select which examples to use based on the input lengths.
|
||||
|
||||
Args:
|
||||
@@ -86,7 +86,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
i += 1
|
||||
return examples
|
||||
|
||||
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Async select which examples to use based on the input lengths.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
|
||||
def sorted_values(values: Dict[str, str]) -> List[Any]:
|
||||
def sorted_values(values: dict[str, str]) -> list[Any]:
|
||||
"""Return a list of values in dict sorted by key.
|
||||
|
||||
Args:
|
||||
@@ -35,12 +35,12 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
"""VectorStore that contains information about examples."""
|
||||
k: int = 4
|
||||
"""Number of examples to select."""
|
||||
example_keys: Optional[List[str]] = None
|
||||
example_keys: Optional[list[str]] = None
|
||||
"""Optional keys to filter examples to."""
|
||||
input_keys: Optional[List[str]] = None
|
||||
input_keys: Optional[list[str]] = None
|
||||
"""Optional keys to filter input to. If provided, the search is based on
|
||||
the input variables instead of all variables."""
|
||||
vectorstore_kwargs: Optional[Dict[str, Any]] = None
|
||||
vectorstore_kwargs: Optional[dict[str, Any]] = None
|
||||
"""Extra arguments passed to similarity_search function of the vectorstore."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -50,14 +50,14 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
|
||||
@staticmethod
|
||||
def _example_to_text(
|
||||
example: Dict[str, str], input_keys: Optional[List[str]]
|
||||
example: dict[str, str], input_keys: Optional[list[str]]
|
||||
) -> str:
|
||||
if input_keys:
|
||||
return " ".join(sorted_values({key: example[key] for key in input_keys}))
|
||||
else:
|
||||
return " ".join(sorted_values(example))
|
||||
|
||||
def _documents_to_examples(self, documents: List[Document]) -> List[dict]:
|
||||
def _documents_to_examples(self, documents: list[Document]) -> list[dict]:
|
||||
# Get the examples from the metadata.
|
||||
# This assumes that examples are stored in metadata.
|
||||
examples = [dict(e.metadata) for e in documents]
|
||||
@@ -66,7 +66,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
examples = [{k: eg[k] for k in self.example_keys} for eg in examples]
|
||||
return examples
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> str:
|
||||
def add_example(self, example: dict[str, str]) -> str:
|
||||
"""Add a new example to vectorstore.
|
||||
|
||||
Args:
|
||||
@@ -81,7 +81,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
)
|
||||
return ids[0]
|
||||
|
||||
async def aadd_example(self, example: Dict[str, str]) -> str:
|
||||
async def aadd_example(self, example: dict[str, str]) -> str:
|
||||
"""Async add new example to vectorstore.
|
||||
|
||||
Args:
|
||||
@@ -100,7 +100,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
|
||||
"""Select examples based on semantic similarity."""
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Select examples based on semantic similarity.
|
||||
|
||||
Args:
|
||||
@@ -118,7 +118,7 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
|
||||
)
|
||||
return self._documents_to_examples(example_docs)
|
||||
|
||||
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Asynchronously select examples based on semantic similarity.
|
||||
|
||||
Args:
|
||||
@@ -139,13 +139,13 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
|
||||
@classmethod
|
||||
def from_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
examples: list[dict],
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
vectorstore_cls: type[VectorStore],
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
input_keys: Optional[list[str]] = None,
|
||||
*,
|
||||
example_keys: Optional[List[str]] = None,
|
||||
example_keys: Optional[list[str]] = None,
|
||||
vectorstore_kwargs: Optional[dict] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> SemanticSimilarityExampleSelector:
|
||||
@@ -183,13 +183,13 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
|
||||
@classmethod
|
||||
async def afrom_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
examples: list[dict],
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
vectorstore_cls: type[VectorStore],
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
input_keys: Optional[list[str]] = None,
|
||||
*,
|
||||
example_keys: Optional[List[str]] = None,
|
||||
example_keys: Optional[list[str]] = None,
|
||||
vectorstore_kwargs: Optional[dict] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> SemanticSimilarityExampleSelector:
|
||||
@@ -235,7 +235,7 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
|
||||
fetch_k: int = 20
|
||||
"""Number of examples to fetch to rerank."""
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Select examples based on Max Marginal Relevance.
|
||||
|
||||
Args:
|
||||
@@ -251,7 +251,7 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
|
||||
)
|
||||
return self._documents_to_examples(example_docs)
|
||||
|
||||
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Asynchronously select examples based on Max Marginal Relevance.
|
||||
|
||||
Args:
|
||||
@@ -270,13 +270,13 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
|
||||
@classmethod
|
||||
def from_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
examples: list[dict],
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
vectorstore_cls: type[VectorStore],
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
input_keys: Optional[list[str]] = None,
|
||||
fetch_k: int = 20,
|
||||
example_keys: Optional[List[str]] = None,
|
||||
example_keys: Optional[list[str]] = None,
|
||||
vectorstore_kwargs: Optional[dict] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> MaxMarginalRelevanceExampleSelector:
|
||||
@@ -317,14 +317,14 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
|
||||
@classmethod
|
||||
async def afrom_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
examples: list[dict],
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
vectorstore_cls: type[VectorStore],
|
||||
*,
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
input_keys: Optional[list[str]] = None,
|
||||
fetch_k: int = 20,
|
||||
example_keys: Optional[List[str]] = None,
|
||||
example_keys: Optional[list[str]] = None,
|
||||
vectorstore_kwargs: Optional[dict] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> MaxMarginalRelevanceExampleSelector:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class LangChainException(Exception):
|
||||
class LangChainException(Exception): # noqa: N818
|
||||
"""General LangChain exception."""
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ class TracerException(LangChainException):
|
||||
"""Base class for exceptions in tracers module."""
|
||||
|
||||
|
||||
class OutputParserException(ValueError, LangChainException):
|
||||
class OutputParserException(ValueError, LangChainException): # noqa: N818
|
||||
"""Exception that output parsers should raise to signify a parsing error.
|
||||
|
||||
This exists to differentiate parsing errors from other code or execution errors
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
from langchain_core.graph_vectorstores.base import (
|
||||
GraphVectorStore,
|
||||
GraphVectorStoreRetriever,
|
||||
Node,
|
||||
)
|
||||
from langchain_core.graph_vectorstores.links import (
|
||||
Link,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GraphVectorStore",
|
||||
"GraphVectorStoreRetriever",
|
||||
"Node",
|
||||
"Link",
|
||||
]
|
||||
@@ -1,712 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
ClassVar,
|
||||
Collection,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores.links import METADATA_LINKS_KEY, Link
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.runnables import run_in_executor
|
||||
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
|
||||
|
||||
|
||||
def _has_next(iterator: Iterator) -> bool:
|
||||
"""Checks if the iterator has more elements.
|
||||
Warning: consumes an element from the iterator"""
|
||||
sentinel = object()
|
||||
return next(iterator, sentinel) is not sentinel
|
||||
|
||||
|
||||
@beta()
|
||||
class Node(Serializable):
|
||||
"""Node in the GraphVectorStore.
|
||||
|
||||
Edges exist from nodes with an outgoing link to nodes with a matching incoming link.
|
||||
|
||||
For instance two nodes `a` and `b` connected over a hyperlink ``https://some-url``
|
||||
would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[
|
||||
Node(
|
||||
id="a",
|
||||
text="some text a",
|
||||
links= [
|
||||
Link(kind="hyperlink", tag="https://some-url", direction="incoming")
|
||||
],
|
||||
),
|
||||
Node(
|
||||
id="b",
|
||||
text="some text b",
|
||||
links= [
|
||||
Link(kind="hyperlink", tag="https://some-url", direction="outgoing")
|
||||
],
|
||||
)
|
||||
]
|
||||
"""
|
||||
|
||||
id: Optional[str] = None
|
||||
"""Unique ID for the node. Will be generated by the GraphVectorStore if not set."""
|
||||
text: str
|
||||
"""Text contained by the node."""
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
"""Metadata for the node."""
|
||||
links: List[Link] = Field(default_factory=list)
|
||||
"""Links associated with the node."""
|
||||
|
||||
|
||||
def _texts_to_nodes(
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]],
|
||||
ids: Optional[Iterable[str]],
|
||||
) -> Iterator[Node]:
|
||||
metadatas_it = iter(metadatas) if metadatas else None
|
||||
ids_it = iter(ids) if ids else None
|
||||
for text in texts:
|
||||
try:
|
||||
_metadata = next(metadatas_it).copy() if metadatas_it else {}
|
||||
except StopIteration as e:
|
||||
raise ValueError("texts iterable longer than metadatas") from e
|
||||
try:
|
||||
_id = next(ids_it) if ids_it else None
|
||||
except StopIteration as e:
|
||||
raise ValueError("texts iterable longer than ids") from e
|
||||
|
||||
links = _metadata.pop(METADATA_LINKS_KEY, [])
|
||||
if not isinstance(links, list):
|
||||
links = list(links)
|
||||
yield Node(
|
||||
id=_id,
|
||||
metadata=_metadata,
|
||||
text=text,
|
||||
links=links,
|
||||
)
|
||||
if ids_it and _has_next(ids_it):
|
||||
raise ValueError("ids iterable longer than texts")
|
||||
if metadatas_it and _has_next(metadatas_it):
|
||||
raise ValueError("metadatas iterable longer than texts")
|
||||
|
||||
|
||||
def _documents_to_nodes(documents: Iterable[Document]) -> Iterator[Node]:
|
||||
for doc in documents:
|
||||
metadata = doc.metadata.copy()
|
||||
links = metadata.pop(METADATA_LINKS_KEY, [])
|
||||
if not isinstance(links, list):
|
||||
links = list(links)
|
||||
yield Node(
|
||||
id=doc.id,
|
||||
metadata=metadata,
|
||||
text=doc.page_content,
|
||||
links=links,
|
||||
)
|
||||
|
||||
|
||||
@beta()
|
||||
def nodes_to_documents(nodes: Iterable[Node]) -> Iterator[Document]:
|
||||
"""Convert nodes to documents.
|
||||
|
||||
Args:
|
||||
nodes: The nodes to convert to documents.
|
||||
Returns:
|
||||
The documents generated from the nodes.
|
||||
"""
|
||||
for node in nodes:
|
||||
metadata = node.metadata.copy()
|
||||
metadata[METADATA_LINKS_KEY] = [
|
||||
# Convert the core `Link` (from the node) back to the local `Link`.
|
||||
Link(kind=link.kind, direction=link.direction, tag=link.tag)
|
||||
for link in node.links
|
||||
]
|
||||
|
||||
yield Document(
|
||||
id=node.id,
|
||||
page_content=node.text,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
@beta(message="Added in version 0.2.14 of langchain_core. API subject to change.")
|
||||
class GraphVectorStore(VectorStore):
|
||||
"""A hybrid vector-and-graph graph store.
|
||||
|
||||
Document chunks support vector-similarity search as well as edges linking
|
||||
chunks based on structural and semantic properties.
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_nodes(
|
||||
self,
|
||||
nodes: Iterable[Node],
|
||||
**kwargs: Any,
|
||||
) -> Iterable[str]:
|
||||
"""Add nodes to the graph store.
|
||||
|
||||
Args:
|
||||
nodes: the nodes to add.
|
||||
"""
|
||||
|
||||
async def aadd_nodes(
|
||||
self,
|
||||
nodes: Iterable[Node],
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[str]:
|
||||
"""Add nodes to the graph store.
|
||||
|
||||
Args:
|
||||
nodes: the nodes to add.
|
||||
"""
|
||||
iterator = iter(await run_in_executor(None, self.add_nodes, nodes, **kwargs))
|
||||
done = object()
|
||||
while True:
|
||||
doc = await run_in_executor(None, next, iterator, done)
|
||||
if doc is done:
|
||||
break
|
||||
yield doc # type: ignore[misc]
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
*,
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the metadata field `links` will be extracted to create
|
||||
the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
store.add_texts(
|
||||
ids=["a", "b"],
|
||||
texts=["some text a", "some text b"],
|
||||
metadatas=[
|
||||
{
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
{
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
The metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_core.graph_vectorstores.links.Link`.
|
||||
**kwargs: vectorstore specific parameters.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
nodes = _texts_to_nodes(texts, metadatas, ids)
|
||||
return list(self.add_nodes(nodes, **kwargs))
|
||||
|
||||
async def aadd_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
*,
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the metadata field `links` will be extracted to create
|
||||
the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
await store.aadd_texts(
|
||||
ids=["a", "b"],
|
||||
texts=["some text a", "some text b"],
|
||||
metadatas=[
|
||||
{
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
{
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
The metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_core.graph_vectorstores.links.Link`.
|
||||
**kwargs: vectorstore specific parameters.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
nodes = _texts_to_nodes(texts, metadatas, ids)
|
||||
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
documents: Iterable[Document],
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the document metadata field `links` will be extracted to
|
||||
create the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
store.add_documents(
|
||||
[
|
||||
Document(
|
||||
id="a",
|
||||
page_content="some text a",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
Document(
|
||||
id="b",
|
||||
page_content="some text b",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
)
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
The document's metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_core.graph_vectorstores.links.Link`.
|
||||
|
||||
Returns:
|
||||
List of IDs of the added texts.
|
||||
"""
|
||||
nodes = _documents_to_nodes(documents)
|
||||
return list(self.add_nodes(nodes, **kwargs))
|
||||
|
||||
async def aadd_documents(
|
||||
self,
|
||||
documents: Iterable[Document],
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the document metadata field `links` will be extracted to
|
||||
create the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
store.add_documents(
|
||||
[
|
||||
Document(
|
||||
id="a",
|
||||
page_content="some text a",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
Document(
|
||||
id="b",
|
||||
page_content="some text b",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
)
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
The document's metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_core.graph_vectorstores.links.Link`.
|
||||
|
||||
Returns:
|
||||
List of IDs of the added texts.
|
||||
"""
|
||||
nodes = _documents_to_nodes(documents)
|
||||
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
|
||||
|
||||
@abstractmethod
|
||||
def traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Document]:
|
||||
"""Retrieve documents from traversing this graph store.
|
||||
|
||||
First, `k` nodes are retrieved using a search for each `query` string.
|
||||
Then, additional nodes are discovered up to the given `depth` from those
|
||||
starting nodes.
|
||||
|
||||
Args:
|
||||
query: The query string.
|
||||
k: The number of Documents to return from the initial search.
|
||||
Defaults to 4. Applies to each of the query strings.
|
||||
depth: The maximum depth of edges to traverse. Defaults to 1.
|
||||
Returns:
|
||||
Retrieved documents.
|
||||
"""
|
||||
|
||||
async def atraversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[Document]:
|
||||
"""Retrieve documents from traversing this graph store.
|
||||
|
||||
First, `k` nodes are retrieved using a search for each `query` string.
|
||||
Then, additional nodes are discovered up to the given `depth` from those
|
||||
starting nodes.
|
||||
|
||||
Args:
|
||||
query: The query string.
|
||||
k: The number of Documents to return from the initial search.
|
||||
Defaults to 4. Applies to each of the query strings.
|
||||
depth: The maximum depth of edges to traverse. Defaults to 1.
|
||||
Returns:
|
||||
Retrieved documents.
|
||||
"""
|
||||
iterator = iter(
|
||||
await run_in_executor(
|
||||
None, self.traversal_search, query, k=k, depth=depth, **kwargs
|
||||
)
|
||||
)
|
||||
done = object()
|
||||
while True:
|
||||
doc = await run_in_executor(None, next, iterator, done)
|
||||
if doc is done:
|
||||
break
|
||||
yield doc # type: ignore[misc]
|
||||
|
||||
@abstractmethod
|
||||
def mmr_traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 2,
|
||||
fetch_k: int = 100,
|
||||
adjacent_k: int = 10,
|
||||
lambda_mult: float = 0.5,
|
||||
score_threshold: float = float("-inf"),
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Document]:
|
||||
"""Retrieve documents from this graph store using MMR-traversal.
|
||||
|
||||
This strategy first retrieves the top `fetch_k` results by similarity to
|
||||
the question. It then selects the top `k` results based on
|
||||
maximum-marginal relevance using the given `lambda_mult`.
|
||||
|
||||
At each step, it considers the (remaining) documents from `fetch_k` as
|
||||
well as any documents connected by edges to a selected document
|
||||
retrieved based on similarity (a "root").
|
||||
|
||||
Args:
|
||||
query: The query string to search for.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch via similarity.
|
||||
Defaults to 100.
|
||||
adjacent_k: Number of adjacent Documents to fetch.
|
||||
Defaults to 10.
|
||||
depth: Maximum depth of a node (number of edges) from a node
|
||||
retrieved via similarity. Defaults to 2.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding to maximum
|
||||
diversity and 1 to minimum diversity. Defaults to 0.5.
|
||||
score_threshold: Only documents with a score greater than or equal
|
||||
this threshold will be chosen. Defaults to negative infinity.
|
||||
"""
|
||||
|
||||
async def ammr_traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 2,
|
||||
fetch_k: int = 100,
|
||||
adjacent_k: int = 10,
|
||||
lambda_mult: float = 0.5,
|
||||
score_threshold: float = float("-inf"),
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[Document]:
|
||||
"""Retrieve documents from this graph store using MMR-traversal.
|
||||
|
||||
This strategy first retrieves the top `fetch_k` results by similarity to
|
||||
the question. It then selects the top `k` results based on
|
||||
maximum-marginal relevance using the given `lambda_mult`.
|
||||
|
||||
At each step, it considers the (remaining) documents from `fetch_k` as
|
||||
well as any documents connected by edges to a selected document
|
||||
retrieved based on similarity (a "root").
|
||||
|
||||
Args:
|
||||
query: The query string to search for.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch via similarity.
|
||||
Defaults to 100.
|
||||
adjacent_k: Number of adjacent Documents to fetch.
|
||||
Defaults to 10.
|
||||
depth: Maximum depth of a node (number of edges) from a node
|
||||
retrieved via similarity. Defaults to 2.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding to maximum
|
||||
diversity and 1 to minimum diversity. Defaults to 0.5.
|
||||
score_threshold: Only documents with a score greater than or equal
|
||||
this threshold will be chosen. Defaults to negative infinity.
|
||||
"""
|
||||
iterator = iter(
|
||||
await run_in_executor(
|
||||
None,
|
||||
self.mmr_traversal_search,
|
||||
query,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
adjacent_k=adjacent_k,
|
||||
depth=depth,
|
||||
lambda_mult=lambda_mult,
|
||||
score_threshold=score_threshold,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
done = object()
|
||||
while True:
|
||||
doc = await run_in_executor(None, next, iterator, done)
|
||||
if doc is done:
|
||||
break
|
||||
yield doc # type: ignore[misc]
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
return list(self.traversal_search(query, k=k, depth=0))
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return list(
|
||||
self.mmr_traversal_search(
|
||||
query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0
|
||||
)
|
||||
)
|
||||
|
||||
async def asimilarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
return [doc async for doc in self.atraversal_search(query, k=k, depth=0)]
|
||||
|
||||
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
|
||||
if search_type == "similarity":
|
||||
return self.similarity_search(query, **kwargs)
|
||||
elif search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = self.similarity_search_with_relevance_scores(
|
||||
query, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_similarities]
|
||||
elif search_type == "mmr":
|
||||
return self.max_marginal_relevance_search(query, **kwargs)
|
||||
elif search_type == "traversal":
|
||||
return list(self.traversal_search(query, **kwargs))
|
||||
elif search_type == "mmr_traversal":
|
||||
return list(self.mmr_traversal_search(query, **kwargs))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity', 'similarity_score_threshold', "
|
||||
"'mmr' or 'traversal'."
|
||||
)
|
||||
|
||||
async def asearch(
|
||||
self, query: str, search_type: str, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
if search_type == "similarity":
|
||||
return await self.asimilarity_search(query, **kwargs)
|
||||
elif search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = await self.asimilarity_search_with_relevance_scores(
|
||||
query, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_similarities]
|
||||
elif search_type == "mmr":
|
||||
return await self.amax_marginal_relevance_search(query, **kwargs)
|
||||
elif search_type == "traversal":
|
||||
return [doc async for doc in self.atraversal_search(query, **kwargs)]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity', 'similarity_score_threshold', "
|
||||
"'mmr' or 'traversal'."
|
||||
)
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever:
|
||||
"""Return GraphVectorStoreRetriever initialized from this GraphVectorStore.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to pass to the search function.
|
||||
Can include:
|
||||
|
||||
- search_type (Optional[str]): Defines the type of search that
|
||||
the Retriever should perform.
|
||||
Can be ``traversal`` (default), ``similarity``, ``mmr``, or
|
||||
``similarity_score_threshold``.
|
||||
- search_kwargs (Optional[Dict]): Keyword arguments to pass to the
|
||||
search function. Can include things like:
|
||||
|
||||
- k(int): Amount of documents to return (Default: 4).
|
||||
- depth(int): The maximum depth of edges to traverse (Default: 1).
|
||||
- score_threshold(float): Minimum relevance threshold
|
||||
for similarity_score_threshold.
|
||||
- fetch_k(int): Amount of documents to pass to MMR algorithm
|
||||
(Default: 20).
|
||||
- lambda_mult(float): Diversity of results returned by MMR;
|
||||
1 for minimum diversity and 0 for maximum. (Default: 0.5).
|
||||
Returns:
|
||||
Retriever for this GraphVectorStore.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Retrieve documents traversing edges
|
||||
docsearch.as_retriever(
|
||||
search_type="traversal",
|
||||
search_kwargs={'k': 6, 'depth': 3}
|
||||
)
|
||||
|
||||
# Retrieve more documents with higher diversity
|
||||
# Useful if your dataset has many similar documents
|
||||
docsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={'k': 6, 'lambda_mult': 0.25}
|
||||
)
|
||||
|
||||
# Fetch more documents for the MMR algorithm to consider
|
||||
# But only return the top 5
|
||||
docsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={'k': 5, 'fetch_k': 50}
|
||||
)
|
||||
|
||||
# Only retrieve documents that have a relevance score
|
||||
# Above a certain threshold
|
||||
docsearch.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={'score_threshold': 0.8}
|
||||
)
|
||||
|
||||
# Only get the single most similar document from the dataset
|
||||
docsearch.as_retriever(search_kwargs={'k': 1})
|
||||
|
||||
"""
|
||||
return GraphVectorStoreRetriever(vectorstore=self, **kwargs)
|
||||
|
||||
|
||||
class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
"""Retriever class for GraphVectorStore."""
|
||||
|
||||
vectorstore: GraphVectorStore
|
||||
"""GraphVectorStore to use for retrieval."""
|
||||
search_type: str = "traversal"
|
||||
"""Type of search to perform. Defaults to "traversal"."""
|
||||
allowed_search_types: ClassVar[Collection[str]] = (
|
||||
"similarity",
|
||||
"similarity_score_threshold",
|
||||
"mmr",
|
||||
"traversal",
|
||||
"mmr_traversal",
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
if self.search_type == "traversal":
|
||||
return list(self.vectorstore.traversal_search(query, **self.search_kwargs))
|
||||
elif self.search_type == "mmr_traversal":
|
||||
return list(
|
||||
self.vectorstore.mmr_traversal_search(query, **self.search_kwargs)
|
||||
)
|
||||
else:
|
||||
return super()._get_relevant_documents(query, run_manager=run_manager)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
if self.search_type == "traversal":
|
||||
return [
|
||||
doc
|
||||
async for doc in self.vectorstore.atraversal_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
]
|
||||
elif self.search_type == "mmr_traversal":
|
||||
return [
|
||||
doc
|
||||
async for doc in self.vectorstore.ammr_traversal_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
]
|
||||
else:
|
||||
return await super()._aget_relevant_documents(
|
||||
query, run_manager=run_manager
|
||||
)
|
||||
@@ -1,101 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Literal, Union
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
@beta()
|
||||
@dataclass(frozen=True)
|
||||
class Link:
|
||||
"""A link to/from a tag of a given tag.
|
||||
|
||||
Edges exist from nodes with an outgoing link to nodes with a matching incoming link.
|
||||
"""
|
||||
|
||||
kind: str
|
||||
"""The kind of link. Allows different extractors to use the same tag name without
|
||||
creating collisions between extractors. For example “keyword” vs “url”."""
|
||||
direction: Literal["in", "out", "bidir"]
|
||||
"""The direction of the link."""
|
||||
tag: str
|
||||
"""The tag of the link."""
|
||||
|
||||
@staticmethod
|
||||
def incoming(kind: str, tag: str) -> "Link":
|
||||
"""Create an incoming link."""
|
||||
return Link(kind=kind, direction="in", tag=tag)
|
||||
|
||||
@staticmethod
|
||||
def outgoing(kind: str, tag: str) -> "Link":
|
||||
"""Create an outgoing link."""
|
||||
return Link(kind=kind, direction="out", tag=tag)
|
||||
|
||||
@staticmethod
|
||||
def bidir(kind: str, tag: str) -> "Link":
|
||||
"""Create a bidirectional link."""
|
||||
return Link(kind=kind, direction="bidir", tag=tag)
|
||||
|
||||
|
||||
METADATA_LINKS_KEY = "links"
|
||||
|
||||
|
||||
@beta()
|
||||
def get_links(doc: Document) -> List[Link]:
|
||||
"""Get the links from a document.
|
||||
|
||||
Args:
|
||||
doc: The document to get the link tags from.
|
||||
Returns:
|
||||
The set of link tags from the document.
|
||||
"""
|
||||
|
||||
links = doc.metadata.setdefault(METADATA_LINKS_KEY, [])
|
||||
if not isinstance(links, list):
|
||||
# Convert to a list and remember that.
|
||||
links = list(links)
|
||||
doc.metadata[METADATA_LINKS_KEY] = links
|
||||
return links
|
||||
|
||||
|
||||
@beta()
|
||||
def add_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> None:
|
||||
"""Add links to the given metadata.
|
||||
|
||||
Args:
|
||||
doc: The document to add the links to.
|
||||
*links: The links to add to the document.
|
||||
"""
|
||||
links_in_metadata = get_links(doc)
|
||||
for link in links:
|
||||
if isinstance(link, Iterable):
|
||||
links_in_metadata.extend(link)
|
||||
else:
|
||||
links_in_metadata.append(link)
|
||||
|
||||
|
||||
@beta()
|
||||
def copy_with_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> Document:
|
||||
"""Return a document with the given links added.
|
||||
|
||||
Args:
|
||||
doc: The document to add the links to.
|
||||
*links: The links to add to the document.
|
||||
|
||||
Returns:
|
||||
A document with a shallow-copy of the metadata with the links added.
|
||||
"""
|
||||
new_links = set(get_links(doc))
|
||||
for link in links:
|
||||
if isinstance(link, Iterable):
|
||||
new_links.update(link)
|
||||
else:
|
||||
new_links.add(link)
|
||||
|
||||
return Document(
|
||||
page_content=doc.page_content,
|
||||
metadata={
|
||||
**doc.metadata,
|
||||
METADATA_LINKS_KEY: list(new_links),
|
||||
},
|
||||
)
|
||||
@@ -5,20 +5,13 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator, Sequence
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
@@ -71,7 +64,7 @@ class _HashedDocument(Document):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def calculate_hashes(cls, values: Dict[str, Any]) -> Any:
|
||||
def calculate_hashes(cls, values: dict[str, Any]) -> Any:
|
||||
"""Root validator to calculate content and metadata hash."""
|
||||
content = values.get("page_content", "")
|
||||
metadata = values.get("metadata", {})
|
||||
@@ -125,7 +118,7 @@ class _HashedDocument(Document):
|
||||
)
|
||||
|
||||
|
||||
def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
|
||||
def _batch(size: int, iterable: Iterable[T]) -> Iterator[list[T]]:
|
||||
"""Utility batching function."""
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
@@ -135,9 +128,9 @@ def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
|
||||
yield chunk
|
||||
|
||||
|
||||
async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[List[T]]:
|
||||
async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[list[T]]:
|
||||
"""Utility batching function."""
|
||||
batch: List[T] = []
|
||||
batch: list[T] = []
|
||||
async for element in iterable:
|
||||
if len(batch) < size:
|
||||
batch.append(element)
|
||||
@@ -171,7 +164,7 @@ def _deduplicate_in_order(
|
||||
hashed_documents: Iterable[_HashedDocument],
|
||||
) -> Iterator[_HashedDocument]:
|
||||
"""Deduplicate a list of hashed documents while preserving order."""
|
||||
seen: Set[str] = set()
|
||||
seen: set[str] = set()
|
||||
|
||||
for hashed_doc in hashed_documents:
|
||||
if hashed_doc.hash_ not in seen:
|
||||
@@ -349,7 +342,7 @@ def index(
|
||||
uids = []
|
||||
docs_to_index = []
|
||||
uids_to_refresh = []
|
||||
seen_docs: Set[str] = set()
|
||||
seen_docs: set[str] = set()
|
||||
for hashed_doc, doc_exists in zip(hashed_docs, exists_batch):
|
||||
if doc_exists:
|
||||
if force_update:
|
||||
@@ -589,7 +582,7 @@ async def aindex(
|
||||
uids: list[str] = []
|
||||
docs_to_index: list[Document] = []
|
||||
uids_to_refresh = []
|
||||
seen_docs: Set[str] = set()
|
||||
seen_docs: set[str] = set()
|
||||
for hashed_doc, doc_exists in zip(hashed_docs, exists_batch):
|
||||
if doc_exists:
|
||||
if force_update:
|
||||
|
||||
@@ -3,7 +3,8 @@ from __future__ import annotations
|
||||
import abc
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Sequence, TypedDict
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, TypedDict
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
@@ -144,7 +145,7 @@ class RecordManager(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, keys: Sequence[str]) -> List[bool]:
|
||||
def exists(self, keys: Sequence[str]) -> list[bool]:
|
||||
"""Check if the provided keys exist in the database.
|
||||
|
||||
Args:
|
||||
@@ -155,7 +156,7 @@ class RecordManager(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def aexists(self, keys: Sequence[str]) -> List[bool]:
|
||||
async def aexists(self, keys: Sequence[str]) -> list[bool]:
|
||||
"""Asynchronously check if the provided keys exist in the database.
|
||||
|
||||
Args:
|
||||
@@ -173,7 +174,7 @@ class RecordManager(ABC):
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""List records in the database based on the provided filters.
|
||||
|
||||
Args:
|
||||
@@ -194,7 +195,7 @@ class RecordManager(ABC):
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Asynchronously list records in the database based on the provided filters.
|
||||
|
||||
Args:
|
||||
@@ -241,7 +242,7 @@ class InMemoryRecordManager(RecordManager):
|
||||
super().__init__(namespace)
|
||||
# Each key points to a dictionary
|
||||
# of {'group_id': group_id, 'updated_at': timestamp}
|
||||
self.records: Dict[str, _Record] = {}
|
||||
self.records: dict[str, _Record] = {}
|
||||
self.namespace = namespace
|
||||
|
||||
def create_schema(self) -> None:
|
||||
@@ -325,7 +326,7 @@ class InMemoryRecordManager(RecordManager):
|
||||
"""
|
||||
self.update(keys, group_ids=group_ids, time_at_least=time_at_least)
|
||||
|
||||
def exists(self, keys: Sequence[str]) -> List[bool]:
|
||||
def exists(self, keys: Sequence[str]) -> list[bool]:
|
||||
"""Check if the provided keys exist in the database.
|
||||
|
||||
Args:
|
||||
@@ -336,7 +337,7 @@ class InMemoryRecordManager(RecordManager):
|
||||
"""
|
||||
return [key in self.records for key in keys]
|
||||
|
||||
async def aexists(self, keys: Sequence[str]) -> List[bool]:
|
||||
async def aexists(self, keys: Sequence[str]) -> list[bool]:
|
||||
"""Async check if the provided keys exist in the database.
|
||||
|
||||
Args:
|
||||
@@ -354,7 +355,7 @@ class InMemoryRecordManager(RecordManager):
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""List records in the database based on the provided filters.
|
||||
|
||||
Args:
|
||||
@@ -390,7 +391,7 @@ class InMemoryRecordManager(RecordManager):
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Async list records in the database based on the provided filters.
|
||||
|
||||
Args:
|
||||
@@ -449,9 +450,9 @@ class UpsertResponse(TypedDict):
|
||||
indexed to avoid this issue.
|
||||
"""
|
||||
|
||||
succeeded: List[str]
|
||||
succeeded: list[str]
|
||||
"""The IDs that were successfully indexed."""
|
||||
failed: List[str]
|
||||
failed: list[str]
|
||||
"""The IDs that failed to index."""
|
||||
|
||||
|
||||
@@ -562,7 +563,7 @@ class DocumentIndex(BaseRetriever):
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> DeleteResponse:
|
||||
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse:
|
||||
"""Delete by IDs or other criteria.
|
||||
|
||||
Calling delete without any input parameters should raise a ValueError!
|
||||
@@ -579,7 +580,7 @@ class DocumentIndex(BaseRetriever):
|
||||
"""
|
||||
|
||||
async def adelete(
|
||||
self, ids: Optional[List[str]] = None, **kwargs: Any
|
||||
self, ids: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> DeleteResponse:
|
||||
"""Delete by IDs or other criteria. Async variant.
|
||||
|
||||
@@ -607,7 +608,7 @@ class DocumentIndex(BaseRetriever):
|
||||
ids: Sequence[str],
|
||||
/,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get documents by id.
|
||||
|
||||
Fewer documents may be returned than requested if some IDs are not found or
|
||||
@@ -633,7 +634,7 @@ class DocumentIndex(BaseRetriever):
|
||||
ids: Sequence[str],
|
||||
/,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get documents by id.
|
||||
|
||||
Fewer documents may be returned than requested if some IDs are not found or
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Sequence, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -22,7 +23,7 @@ class InMemoryDocumentIndex(DocumentIndex):
|
||||
.. versionadded:: 0.2.29
|
||||
"""
|
||||
|
||||
store: Dict[str, Document] = Field(default_factory=dict)
|
||||
store: dict[str, Document] = Field(default_factory=dict)
|
||||
top_k: int = 4
|
||||
|
||||
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
|
||||
@@ -43,7 +44,7 @@ class InMemoryDocumentIndex(DocumentIndex):
|
||||
|
||||
return UpsertResponse(succeeded=ok_ids, failed=[])
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> DeleteResponse:
|
||||
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse:
|
||||
"""Delete by ID."""
|
||||
if ids is None:
|
||||
raise ValueError("IDs must be provided for deletion")
|
||||
@@ -59,7 +60,7 @@ class InMemoryDocumentIndex(DocumentIndex):
|
||||
succeeded=ok_ids, num_deleted=len(ok_ids), num_failed=0, failed=[]
|
||||
)
|
||||
|
||||
def get(self, ids: Sequence[str], /, **kwargs: Any) -> List[Document]:
|
||||
def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]:
|
||||
"""Get by ids."""
|
||||
found_documents = []
|
||||
|
||||
@@ -71,7 +72,7 @@ class InMemoryDocumentIndex(DocumentIndex):
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
counts_by_doc = []
|
||||
|
||||
for document in self.store.values():
|
||||
|
||||
@@ -1,25 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import cache
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from typing_extensions import TypeAlias, TypedDict
|
||||
from typing_extensions import TypeAlias, TypedDict, override
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import (
|
||||
@@ -51,11 +46,11 @@ class LangSmithParams(TypedDict, total=False):
|
||||
"""Temperature for generation."""
|
||||
ls_max_tokens: Optional[int]
|
||||
"""Max tokens for generation."""
|
||||
ls_stop: Optional[List[str]]
|
||||
ls_stop: Optional[list[str]]
|
||||
"""Stop words for generation."""
|
||||
|
||||
|
||||
@lru_cache(maxsize=None) # Cache the tokenizer
|
||||
@cache # Cache the tokenizer
|
||||
def get_tokenizer() -> Any:
|
||||
"""Get a GPT-2 tokenizer instance.
|
||||
|
||||
@@ -74,7 +69,7 @@ def get_tokenizer() -> Any:
|
||||
return GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
|
||||
def _get_token_ids_default_method(text: str) -> List[int]:
|
||||
def _get_token_ids_default_method(text: str) -> list[int]:
|
||||
"""Encode the text into token IDs."""
|
||||
# get the cached tokenizer
|
||||
tokenizer = get_tokenizer()
|
||||
@@ -117,11 +112,11 @@ class BaseLanguageModel(
|
||||
"""Whether to print out response text."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
"""Callbacks to add to the run trace."""
|
||||
tags: Optional[List[str]] = Field(default=None, exclude=True)
|
||||
tags: Optional[list[str]] = Field(default=None, exclude=True)
|
||||
"""Tags to add to the run trace."""
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
|
||||
metadata: Optional[dict[str, Any]] = Field(default=None, exclude=True)
|
||||
"""Metadata to add to the run trace."""
|
||||
custom_get_token_ids: Optional[Callable[[str], List[int]]] = Field(
|
||||
custom_get_token_ids: Optional[Callable[[str], list[int]]] = Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
"""Optional encoder to use for counting tokens."""
|
||||
@@ -148,6 +143,7 @@ class BaseLanguageModel(
|
||||
return verbose
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> TypeAlias:
|
||||
"""Get the input type for this runnable."""
|
||||
from langchain_core.prompt_values import (
|
||||
@@ -161,14 +157,14 @@ class BaseLanguageModel(
|
||||
return Union[
|
||||
str,
|
||||
Union[StringPromptValue, ChatPromptValueConcrete],
|
||||
List[AnyMessage],
|
||||
list[AnyMessage],
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def generate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[PromptValue],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@@ -202,8 +198,8 @@ class BaseLanguageModel(
|
||||
@abstractmethod
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[PromptValue],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@@ -235,8 +231,8 @@ class BaseLanguageModel(
|
||||
"""
|
||||
|
||||
def with_structured_output(
|
||||
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
self, schema: Union[dict, type[BaseModel]], **kwargs: Any
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
"""Not implemented on this class."""
|
||||
# Implement this on child class if there is a way of steering the model to
|
||||
# generate responses that match a given schema.
|
||||
@@ -267,7 +263,7 @@ class BaseLanguageModel(
|
||||
@abstractmethod
|
||||
def predict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: list[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
@@ -313,7 +309,7 @@ class BaseLanguageModel(
|
||||
@abstractmethod
|
||||
async def apredict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: list[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
@@ -339,7 +335,7 @@ class BaseLanguageModel(
|
||||
"""Get the identifying parameters."""
|
||||
return self.lc_attributes
|
||||
|
||||
def get_token_ids(self, text: str) -> List[int]:
|
||||
def get_token_ids(self, text: str) -> list[int]:
|
||||
"""Return the ordered ids of the tokens in a text.
|
||||
|
||||
Args:
|
||||
@@ -367,7 +363,7 @@ class BaseLanguageModel(
|
||||
"""
|
||||
return len(self.get_token_ids(text))
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
@@ -381,7 +377,7 @@ class BaseLanguageModel(
|
||||
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
|
||||
|
||||
@classmethod
|
||||
def _all_required_field_names(cls) -> Set:
|
||||
def _all_required_field_names(cls) -> set:
|
||||
"""DEPRECATED: Kept for backwards compatibility.
|
||||
|
||||
Use get_pydantic_field_names.
|
||||
|
||||
@@ -3,23 +3,19 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import typing
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from functools import cached_property
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@@ -30,6 +26,7 @@ from pydantic import (
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.caches import BaseCache
|
||||
@@ -223,7 +220,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: Dict) -> Any:
|
||||
def raise_deprecation(cls, values: dict) -> Any:
|
||||
"""Raise deprecation warning if callback_manager is used.
|
||||
|
||||
Args:
|
||||
@@ -255,6 +252,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
# --- Runnable methods ---
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> Any:
|
||||
"""Get the output type for this runnable."""
|
||||
return AnyMessage
|
||||
@@ -277,7 +275,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
config = ensure_config(config)
|
||||
@@ -300,7 +298,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
config = ensure_config(config)
|
||||
@@ -356,7 +354,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[BaseMessageChunk]:
|
||||
if not self._should_stream(async_api=False, **{**kwargs, **{"stream": True}}):
|
||||
@@ -426,7 +424,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[BaseMessageChunk]:
|
||||
if not self._should_stream(async_api=True, **{**kwargs, **{"stream": True}}):
|
||||
@@ -499,12 +497,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
# --- Custom methods ---
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
|
||||
return {}
|
||||
|
||||
def _get_invocation_params(
|
||||
self,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
params = self.dict()
|
||||
@@ -513,7 +511,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
def _get_ls_params(
|
||||
self,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
@@ -550,7 +548,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
return ls_params
|
||||
|
||||
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
|
||||
def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str:
|
||||
if self.is_lc_serializable():
|
||||
params = {**kwargs, **{"stop": stop}}
|
||||
param_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
@@ -567,12 +565,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[list[BaseMessage]],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
**kwargs: Any,
|
||||
@@ -658,12 +656,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[list[BaseMessage]],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
**kwargs: Any,
|
||||
@@ -777,8 +775,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
def generate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[PromptValue],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@@ -787,8 +785,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[PromptValue],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@@ -799,8 +797,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
def _generate_with_cache(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@@ -839,7 +837,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
):
|
||||
chunks: List[ChatGenerationChunk] = []
|
||||
chunks: list[ChatGenerationChunk] = []
|
||||
for chunk in self._stream(messages, stop=stop, **kwargs):
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
if run_manager:
|
||||
@@ -876,8 +874,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
async def _agenerate_with_cache(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@@ -916,7 +914,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
):
|
||||
chunks: List[ChatGenerationChunk] = []
|
||||
chunks: list[ChatGenerationChunk] = []
|
||||
async for chunk in self._astream(messages, stop=stop, **kwargs):
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
if run_manager:
|
||||
@@ -954,8 +952,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@@ -963,8 +961,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@@ -980,8 +978,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
@@ -989,8 +987,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
@@ -1017,8 +1015,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
@@ -1032,8 +1030,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
async def _call_async(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
@@ -1048,7 +1046,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
def call_as_llm(
|
||||
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
self, message: str, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
return self.predict(message, stop=stop, **kwargs)
|
||||
|
||||
@@ -1069,7 +1067,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
def predict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: list[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
@@ -1099,7 +1097,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
async def apredict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: list[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
@@ -1115,7 +1113,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
"""Return a dictionary of the LLM."""
|
||||
starter_dict = dict(self._identifying_params)
|
||||
starter_dict["_type"] = self._llm_type
|
||||
@@ -1123,18 +1121,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||
tools: Sequence[Union[typing.Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Union[Dict, Type],
|
||||
schema: Union[typing.Dict, type], # noqa: UP006
|
||||
*,
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
) -> Runnable[LanguageModelInput, Union[typing.Dict, BaseModel]]: # noqa: UP006
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
@@ -1281,8 +1279,8 @@ class SimpleChatModel(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@@ -1294,8 +1292,8 @@ class SimpleChatModel(BaseChatModel):
|
||||
@abstractmethod
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@@ -1303,8 +1301,8 @@ class SimpleChatModel(BaseChatModel):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@@ -14,7 +15,7 @@ from langchain_core.runnables import RunnableConfig
|
||||
class FakeListLLM(LLM):
|
||||
"""Fake LLM for testing purposes."""
|
||||
|
||||
responses: List[str]
|
||||
responses: list[str]
|
||||
"""List of responses to return in order."""
|
||||
# This parameter should be removed from FakeListLLM since
|
||||
# it's only used by sub-classes.
|
||||
@@ -37,7 +38,7 @@ class FakeListLLM(LLM):
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@@ -52,7 +53,7 @@ class FakeListLLM(LLM):
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@@ -90,7 +91,7 @@ class FakeStreamingListLLM(FakeListLLM):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
result = self.invoke(input, config)
|
||||
@@ -110,7 +111,7 @@ class FakeStreamingListLLM(FakeListLLM):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[str]:
|
||||
result = await self.ainvoke(input, config)
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@@ -17,7 +18,7 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResu
|
||||
class FakeMessagesListChatModel(BaseChatModel):
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
|
||||
responses: List[BaseMessage]
|
||||
responses: list[BaseMessage]
|
||||
"""List of responses to **cycle** through in order."""
|
||||
sleep: Optional[float] = None
|
||||
"""Sleep time in seconds between responses."""
|
||||
@@ -26,8 +27,8 @@ class FakeMessagesListChatModel(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@@ -51,7 +52,7 @@ class FakeListChatModelError(Exception):
|
||||
class FakeListChatModel(SimpleChatModel):
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
|
||||
responses: List[str]
|
||||
responses: list[str]
|
||||
"""List of responses to **cycle** through in order."""
|
||||
sleep: Optional[float] = None
|
||||
i: int = 0
|
||||
@@ -65,8 +66,8 @@ class FakeListChatModel(SimpleChatModel):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@@ -80,8 +81,8 @@ class FakeListChatModel(SimpleChatModel):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Union[List[str], None] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Union[list[str], None] = None,
|
||||
run_manager: Union[CallbackManagerForLLMRun, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
@@ -103,8 +104,8 @@ class FakeListChatModel(SimpleChatModel):
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Union[List[str], None] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Union[list[str], None] = None,
|
||||
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
@@ -124,7 +125,7 @@ class FakeListChatModel(SimpleChatModel):
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
return {"responses": self.responses}
|
||||
|
||||
|
||||
@@ -133,8 +134,8 @@ class FakeChatModel(SimpleChatModel):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@@ -142,8 +143,8 @@ class FakeChatModel(SimpleChatModel):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@@ -157,7 +158,7 @@ class FakeChatModel(SimpleChatModel):
|
||||
return "fake-chat-model"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
return {"key": "fake"}
|
||||
|
||||
|
||||
@@ -186,8 +187,8 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@@ -202,8 +203,8 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
@@ -231,7 +232,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
# Use a regular expression to split on whitespace with a capture group
|
||||
# so that we can preserve the whitespace in the output.
|
||||
assert isinstance(content, str)
|
||||
content_chunks = cast(List[str], re.split(r"(\s)", content))
|
||||
content_chunks = cast(list[str], re.split(r"(\s)", content))
|
||||
|
||||
for token in content_chunks:
|
||||
chunk = ChatGenerationChunk(
|
||||
@@ -249,7 +250,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
for fkey, fvalue in value.items():
|
||||
if isinstance(fvalue, str):
|
||||
# Break function call by `,`
|
||||
fvalue_chunks = cast(List[str], re.split(r"(,)", fvalue))
|
||||
fvalue_chunks = cast(list[str], re.split(r"(,)", fvalue))
|
||||
for fvalue_chunk in fvalue_chunks:
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
@@ -306,8 +307,8 @@ class ParrotFakeChatModel(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
|
||||
@@ -10,18 +10,12 @@ import logging
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@@ -37,6 +31,7 @@ from tenacity import (
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.caches import BaseCache
|
||||
@@ -76,7 +71,7 @@ def _log_error_once(msg: str) -> None:
|
||||
|
||||
|
||||
def create_base_retry_decorator(
|
||||
error_types: List[Type[BaseException]],
|
||||
error_types: list[type[BaseException]],
|
||||
max_retries: int = 1,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
@@ -153,10 +148,10 @@ def _resolve_cache(cache: Union[BaseCache, bool, None]) -> Optional[BaseCache]:
|
||||
|
||||
|
||||
def get_prompts(
|
||||
params: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
params: dict[str, Any],
|
||||
prompts: list[str],
|
||||
cache: Optional[Union[BaseCache, bool, None]] = None,
|
||||
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
|
||||
) -> tuple[dict[int, list], str, list[int], list[str]]:
|
||||
"""Get prompts that are already cached.
|
||||
|
||||
Args:
|
||||
@@ -189,10 +184,10 @@ def get_prompts(
|
||||
|
||||
|
||||
async def aget_prompts(
|
||||
params: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
params: dict[str, Any],
|
||||
prompts: list[str],
|
||||
cache: Optional[Union[BaseCache, bool, None]] = None,
|
||||
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
|
||||
) -> tuple[dict[int, list], str, list[int], list[str]]:
|
||||
"""Get prompts that are already cached. Async version.
|
||||
|
||||
Args:
|
||||
@@ -225,11 +220,11 @@ async def aget_prompts(
|
||||
|
||||
def update_cache(
|
||||
cache: Union[BaseCache, bool, None],
|
||||
existing_prompts: Dict[int, List],
|
||||
existing_prompts: dict[int, list],
|
||||
llm_string: str,
|
||||
missing_prompt_idxs: List[int],
|
||||
missing_prompt_idxs: list[int],
|
||||
new_results: LLMResult,
|
||||
prompts: List[str],
|
||||
prompts: list[str],
|
||||
) -> Optional[dict]:
|
||||
"""Update the cache and get the LLM output.
|
||||
|
||||
@@ -259,11 +254,11 @@ def update_cache(
|
||||
|
||||
async def aupdate_cache(
|
||||
cache: Union[BaseCache, bool, None],
|
||||
existing_prompts: Dict[int, List],
|
||||
existing_prompts: dict[int, list],
|
||||
llm_string: str,
|
||||
missing_prompt_idxs: List[int],
|
||||
missing_prompt_idxs: list[int],
|
||||
new_results: LLMResult,
|
||||
prompts: List[str],
|
||||
prompts: list[str],
|
||||
) -> Optional[dict]:
|
||||
"""Update the cache and get the LLM output. Async version.
|
||||
|
||||
@@ -306,7 +301,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: Dict) -> Any:
|
||||
def raise_deprecation(cls, values: dict) -> Any:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
@@ -324,7 +319,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
# --- Runnable methods ---
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[str]:
|
||||
@override
|
||||
def OutputType(self) -> type[str]:
|
||||
"""Get the input type for this runnable."""
|
||||
return str
|
||||
|
||||
@@ -343,7 +339,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
def _get_ls_params(
|
||||
self,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
@@ -383,7 +379,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
config = ensure_config(config)
|
||||
@@ -407,7 +403,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
config = ensure_config(config)
|
||||
@@ -425,12 +421,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
@@ -450,7 +446,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
return [g[0].text for g in llm_result.generations]
|
||||
except Exception as e:
|
||||
if return_exceptions:
|
||||
return cast(List[str], [e for _ in inputs])
|
||||
return cast(list[str], [e for _ in inputs])
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
@@ -472,12 +468,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
if not inputs:
|
||||
return []
|
||||
config = get_config_list(config, len(inputs))
|
||||
@@ -496,7 +492,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
return [g[0].text for g in llm_result.generations]
|
||||
except Exception as e:
|
||||
if return_exceptions:
|
||||
return cast(List[str], [e for _ in inputs])
|
||||
return cast(list[str], [e for _ in inputs])
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
@@ -521,7 +517,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
if type(self)._stream == BaseLLM._stream:
|
||||
@@ -583,7 +579,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[str]:
|
||||
if (
|
||||
@@ -649,8 +645,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@@ -658,8 +654,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@@ -676,7 +672,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
@@ -704,7 +700,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
@@ -747,9 +743,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
def generate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||
prompts: list[PromptValue],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
@@ -757,9 +753,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||
prompts: list[PromptValue],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
@@ -769,9 +765,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
def _generate_helper(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]],
|
||||
run_managers: List[CallbackManagerForLLMRun],
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]],
|
||||
run_managers: list[CallbackManagerForLLMRun],
|
||||
new_arg_supported: bool,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@@ -802,14 +798,14 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None,
|
||||
*,
|
||||
tags: Optional[Union[List[str], List[List[str]]]] = None,
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
run_name: Optional[Union[str, List[str]]] = None,
|
||||
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None,
|
||||
tags: Optional[Union[list[str], list[list[str]]]] = None,
|
||||
metadata: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None,
|
||||
run_name: Optional[Union[str, list[str]]] = None,
|
||||
run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Pass a sequence of prompts to a model and return generations.
|
||||
@@ -885,13 +881,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
assert run_name is None or (
|
||||
isinstance(run_name, list) and len(run_name) == len(prompts)
|
||||
)
|
||||
callbacks = cast(List[Callbacks], callbacks)
|
||||
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
|
||||
callbacks = cast(list[Callbacks], callbacks)
|
||||
tags_list = cast(list[Optional[list[str]]], tags or ([None] * len(prompts)))
|
||||
metadata_list = cast(
|
||||
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
|
||||
list[Optional[dict[str, Any]]], metadata or ([{}] * len(prompts))
|
||||
)
|
||||
run_name_list = run_name or cast(
|
||||
List[Optional[str]], ([None] * len(prompts))
|
||||
list[Optional[str]], ([None] * len(prompts))
|
||||
)
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
@@ -912,9 +908,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
cast(Callbacks, callbacks),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
cast(List[str], tags),
|
||||
cast(list[str], tags),
|
||||
self.tags,
|
||||
cast(Dict[str, Any], metadata),
|
||||
cast(dict[str, Any], metadata),
|
||||
self.metadata,
|
||||
)
|
||||
] * len(prompts)
|
||||
@@ -987,7 +983,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
@staticmethod
|
||||
def _get_run_ids_list(
|
||||
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]], prompts: list
|
||||
run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]], prompts: list
|
||||
) -> list:
|
||||
if run_id is None:
|
||||
return [None] * len(prompts)
|
||||
@@ -1002,9 +998,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
async def _agenerate_helper(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]],
|
||||
run_managers: List[AsyncCallbackManagerForLLMRun],
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]],
|
||||
run_managers: list[AsyncCallbackManagerForLLMRun],
|
||||
new_arg_supported: bool,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@@ -1044,14 +1040,14 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None,
|
||||
*,
|
||||
tags: Optional[Union[List[str], List[List[str]]]] = None,
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
run_name: Optional[Union[str, List[str]]] = None,
|
||||
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None,
|
||||
tags: Optional[Union[list[str], list[list[str]]]] = None,
|
||||
metadata: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None,
|
||||
run_name: Optional[Union[str, list[str]]] = None,
|
||||
run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Asynchronously pass a sequence of prompts to a model and return generations.
|
||||
@@ -1118,13 +1114,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
assert run_name is None or (
|
||||
isinstance(run_name, list) and len(run_name) == len(prompts)
|
||||
)
|
||||
callbacks = cast(List[Callbacks], callbacks)
|
||||
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
|
||||
callbacks = cast(list[Callbacks], callbacks)
|
||||
tags_list = cast(list[Optional[list[str]]], tags or ([None] * len(prompts)))
|
||||
metadata_list = cast(
|
||||
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
|
||||
list[Optional[dict[str, Any]]], metadata or ([{}] * len(prompts))
|
||||
)
|
||||
run_name_list = run_name or cast(
|
||||
List[Optional[str]], ([None] * len(prompts))
|
||||
list[Optional[str]], ([None] * len(prompts))
|
||||
)
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
@@ -1145,9 +1141,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
cast(Callbacks, callbacks),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
cast(List[str], tags),
|
||||
cast(list[str], tags),
|
||||
self.tags,
|
||||
cast(Dict[str, Any], metadata),
|
||||
cast(dict[str, Any], metadata),
|
||||
self.metadata,
|
||||
)
|
||||
] * len(prompts)
|
||||
@@ -1239,11 +1235,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Check Cache and run the LLM on the given prompt and input.
|
||||
@@ -1287,11 +1283,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
async def _call_async(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Check Cache and run the LLM on the given prompt and input."""
|
||||
@@ -1318,7 +1314,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
def predict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: list[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
@@ -1344,7 +1340,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
async def apredict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: list[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
@@ -1367,7 +1363,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
"""Return a dictionary of the LLM."""
|
||||
starter_dict = dict(self._identifying_params)
|
||||
starter_dict["_type"] = self._llm_type
|
||||
@@ -1443,7 +1439,7 @@ class LLM(BaseLLM):
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@@ -1467,7 +1463,7 @@ class LLM(BaseLLM):
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@@ -1500,8 +1496,8 @@ class LLM(BaseLLM):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@@ -1520,8 +1516,8 @@ class LLM(BaseLLM):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.load.mapping import (
|
||||
@@ -19,6 +19,17 @@ DEFAULT_NAMESPACES = [
|
||||
"langchain_anthropic",
|
||||
"langchain_groq",
|
||||
"langchain_google_genai",
|
||||
"langchain_aws",
|
||||
"langchain_openai",
|
||||
"langchain_google_vertexai",
|
||||
"langchain_mistralai",
|
||||
"langchain_fireworks",
|
||||
]
|
||||
# Namespaces for which only deserializing via the SERIALIZABLE_MAPPING is allowed.
|
||||
# Load by path is not allowed.
|
||||
DISALLOW_LOAD_FROM_PATH = [
|
||||
"langchain_community",
|
||||
"langchain",
|
||||
]
|
||||
|
||||
ALL_SERIALIZABLE_MAPPINGS = {
|
||||
@@ -34,11 +45,11 @@ class Reviver:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
secrets_map: Optional[Dict[str, str]] = None,
|
||||
valid_namespaces: Optional[List[str]] = None,
|
||||
secrets_map: Optional[dict[str, str]] = None,
|
||||
valid_namespaces: Optional[list[str]] = None,
|
||||
secrets_from_env: bool = True,
|
||||
additional_import_mappings: Optional[
|
||||
Dict[Tuple[str, ...], Tuple[str, ...]]
|
||||
dict[tuple[str, ...], tuple[str, ...]]
|
||||
] = None,
|
||||
) -> None:
|
||||
"""Initialize the reviver.
|
||||
@@ -73,7 +84,7 @@ class Reviver:
|
||||
else ALL_SERIALIZABLE_MAPPINGS
|
||||
)
|
||||
|
||||
def __call__(self, value: Dict[str, Any]) -> Any:
|
||||
def __call__(self, value: dict[str, Any]) -> Any:
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "secret"
|
||||
@@ -103,40 +114,31 @@ class Reviver:
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
[*namespace, name] = value["id"]
|
||||
mapping_key = tuple(value["id"])
|
||||
|
||||
if namespace[0] not in self.valid_namespaces:
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
# The root namespace "langchain" is not a valid identifier.
|
||||
if len(namespace) == 1 and namespace[0] == "langchain":
|
||||
# The root namespace ["langchain"] is not a valid identifier.
|
||||
elif namespace == ["langchain"]:
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
# If namespace is in known namespaces, try to use mapping
|
||||
key = tuple(namespace + [name])
|
||||
if namespace[0] in DEFAULT_NAMESPACES:
|
||||
# Get the importable path
|
||||
if key not in self.import_mappings:
|
||||
raise ValueError(
|
||||
"Trying to deserialize something that cannot "
|
||||
"be deserialized in current version of langchain-core: "
|
||||
f"{key}"
|
||||
)
|
||||
import_path = self.import_mappings[key]
|
||||
# Has explicit import path.
|
||||
elif mapping_key in self.import_mappings:
|
||||
import_path = self.import_mappings[mapping_key]
|
||||
# Split into module and name
|
||||
import_dir, import_obj = import_path[:-1], import_path[-1]
|
||||
import_dir, name = import_path[:-1], import_path[-1]
|
||||
# Import module
|
||||
mod = importlib.import_module(".".join(import_dir))
|
||||
# Import class
|
||||
cls = getattr(mod, import_obj)
|
||||
# Otherwise, load by path
|
||||
elif namespace[0] in DISALLOW_LOAD_FROM_PATH:
|
||||
raise ValueError(
|
||||
"Trying to deserialize something that cannot "
|
||||
"be deserialized in current version of langchain-core: "
|
||||
f"{mapping_key}."
|
||||
)
|
||||
# Otherwise, treat namespace as path.
|
||||
else:
|
||||
if key in self.additional_import_mappings:
|
||||
import_path = self.import_mappings[key]
|
||||
mod = importlib.import_module(".".join(import_path[:-1]))
|
||||
name = import_path[-1]
|
||||
else:
|
||||
mod = importlib.import_module(".".join(namespace))
|
||||
cls = getattr(mod, name)
|
||||
mod = importlib.import_module(".".join(namespace))
|
||||
|
||||
cls = getattr(mod, name)
|
||||
|
||||
# The class must be a subclass of Serializable.
|
||||
if not issubclass(cls, Serializable):
|
||||
@@ -154,10 +156,10 @@ class Reviver:
|
||||
def loads(
|
||||
text: str,
|
||||
*,
|
||||
secrets_map: Optional[Dict[str, str]] = None,
|
||||
valid_namespaces: Optional[List[str]] = None,
|
||||
secrets_map: Optional[dict[str, str]] = None,
|
||||
valid_namespaces: Optional[list[str]] = None,
|
||||
secrets_from_env: bool = True,
|
||||
additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None,
|
||||
additional_import_mappings: Optional[dict[tuple[str, ...], tuple[str, ...]]] = None,
|
||||
) -> Any:
|
||||
"""Revive a LangChain class from a JSON string.
|
||||
Equivalent to `load(json.loads(text))`.
|
||||
@@ -190,10 +192,10 @@ def loads(
|
||||
def load(
|
||||
obj: Any,
|
||||
*,
|
||||
secrets_map: Optional[Dict[str, str]] = None,
|
||||
valid_namespaces: Optional[List[str]] = None,
|
||||
secrets_map: Optional[dict[str, str]] = None,
|
||||
valid_namespaces: Optional[list[str]] = None,
|
||||
secrets_from_env: bool = True,
|
||||
additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None,
|
||||
additional_import_mappings: Optional[dict[tuple[str, ...], tuple[str, ...]]] = None,
|
||||
) -> Any:
|
||||
"""Revive a LangChain class from a JSON object. Use this if you already
|
||||
have a parsed JSON object, eg. from `json.load` or `orjson.loads`.
|
||||
|
||||
@@ -18,11 +18,9 @@ The mapping allows us to deserialize an AIMessage created with an older
|
||||
version of LangChain where the code was in a different location.
|
||||
"""
|
||||
|
||||
from typing import Dict, Tuple
|
||||
|
||||
# First value is the value that it is serialized as
|
||||
# Second value is the path to load it from
|
||||
SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
|
||||
("langchain", "schema", "messages", "AIMessage"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
@@ -535,7 +533,7 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
|
||||
# Needed for backwards compatibility for old versions of LangChain where things
|
||||
# Were in different place
|
||||
_OG_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
_OG_SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
|
||||
("langchain", "schema", "AIMessage"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
@@ -583,7 +581,7 @@ _OG_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
|
||||
# Needed for backwards compatibility for a few versions where we serialized
|
||||
# with langchain_core paths.
|
||||
OLD_CORE_NAMESPACES_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
OLD_CORE_NAMESPACES_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
|
||||
("langchain_core", "messages", "ai", "AIMessage"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
@@ -937,7 +935,7 @@ OLD_CORE_NAMESPACES_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
),
|
||||
}
|
||||
|
||||
_JS_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
_JS_SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
|
||||
("langchain_core", "messages", "AIMessage"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
@@ -1028,4 +1026,9 @@ _JS_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
"image",
|
||||
"ImagePromptTemplate",
|
||||
),
|
||||
("langchain", "chat_models", "bedrock", "ChatBedrock"): (
|
||||
"langchain_aws",
|
||||
"chat_models",
|
||||
"ChatBedrock",
|
||||
),
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from abc import ABC
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
TypedDict,
|
||||
@@ -25,9 +23,9 @@ class BaseSerialized(TypedDict):
|
||||
"""
|
||||
|
||||
lc: int
|
||||
id: List[str]
|
||||
id: list[str]
|
||||
name: NotRequired[str]
|
||||
graph: NotRequired[Dict[str, Any]]
|
||||
graph: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
class SerializedConstructor(BaseSerialized):
|
||||
@@ -39,7 +37,7 @@ class SerializedConstructor(BaseSerialized):
|
||||
"""
|
||||
|
||||
type: Literal["constructor"]
|
||||
kwargs: Dict[str, Any]
|
||||
kwargs: dict[str, Any]
|
||||
|
||||
|
||||
class SerializedSecret(BaseSerialized):
|
||||
@@ -125,7 +123,7 @@ class Serializable(BaseModel, ABC):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
|
||||
For example, if the class is `langchain.llms.openai.OpenAI`, then the
|
||||
@@ -134,7 +132,7 @@ class Serializable(BaseModel, ABC):
|
||||
return cls.__module__.split(".")
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
"""A map of constructor argument names to secret ids.
|
||||
|
||||
For example,
|
||||
@@ -143,7 +141,7 @@ class Serializable(BaseModel, ABC):
|
||||
return dict()
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict:
|
||||
def lc_attributes(self) -> dict:
|
||||
"""List of attribute names that should be included in the serialized kwargs.
|
||||
|
||||
These attributes must be accepted by the constructor.
|
||||
@@ -152,7 +150,7 @@ class Serializable(BaseModel, ABC):
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def lc_id(cls) -> List[str]:
|
||||
def lc_id(cls) -> list[str]:
|
||||
"""A unique identifier for this class for serialization purposes.
|
||||
|
||||
The unique identifier is a list of strings that describes the path
|
||||
@@ -315,8 +313,8 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
||||
|
||||
|
||||
def _replace_secrets(
|
||||
root: Dict[Any, Any], secrets_map: Dict[str, str]
|
||||
) -> Dict[Any, Any]:
|
||||
root: dict[Any, Any], secrets_map: dict[str, str]
|
||||
) -> dict[Any, Any]:
|
||||
result = root.copy()
|
||||
for path, secret_id in secrets_map.items():
|
||||
[*parts, last] = path.split(".")
|
||||
@@ -344,7 +342,7 @@ def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
|
||||
Returns:
|
||||
SerializedNotImplemented
|
||||
"""
|
||||
_id: List[str] = []
|
||||
_id: list[str] = []
|
||||
try:
|
||||
if hasattr(obj, "__name__"):
|
||||
_id = [*obj.__module__.split("."), obj.__name__]
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
@@ -55,11 +55,11 @@ class BaseMemory(Serializable, ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def memory_variables(self) -> List[str]:
|
||||
def memory_variables(self) -> list[str]:
|
||||
"""The string keys this memory class will add to chain inputs."""
|
||||
|
||||
@abstractmethod
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain.
|
||||
|
||||
Args:
|
||||
@@ -69,7 +69,7 @@ class BaseMemory(Serializable, ABC):
|
||||
A dictionary of key-value pairs.
|
||||
"""
|
||||
|
||||
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Async return key-value pairs given the text input to the chain.
|
||||
|
||||
Args:
|
||||
@@ -81,7 +81,7 @@ class BaseMemory(Serializable, ABC):
|
||||
return await run_in_executor(None, self.load_memory_variables, inputs)
|
||||
|
||||
@abstractmethod
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
|
||||
"""Save the context of this chain run to memory.
|
||||
|
||||
Args:
|
||||
@@ -90,7 +90,7 @@ class BaseMemory(Serializable, ABC):
|
||||
"""
|
||||
|
||||
async def asave_context(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
self, inputs: dict[str, Any], outputs: dict[str, str]
|
||||
) -> None:
|
||||
"""Async save the context of this chain run to memory.
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self, TypedDict
|
||||
@@ -69,9 +69,9 @@ class AIMessage(BaseMessage):
|
||||
At the moment, this is ignored by most models. Usage is discouraged.
|
||||
"""
|
||||
|
||||
tool_calls: List[ToolCall] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
"""If provided, tool calls associated with the message."""
|
||||
invalid_tool_calls: List[InvalidToolCall] = []
|
||||
invalid_tool_calls: list[InvalidToolCall] = []
|
||||
"""If provided, tool calls with parsing errors associated with the message."""
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
"""If provided, usage metadata for a message, such as token counts.
|
||||
@@ -83,7 +83,7 @@ class AIMessage(BaseMessage):
|
||||
"""The type of the message (used for deserialization). Defaults to "ai"."""
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
||||
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Pass in content as positional arg.
|
||||
|
||||
@@ -94,7 +94,7 @@ class AIMessage(BaseMessage):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
|
||||
Returns:
|
||||
@@ -104,7 +104,7 @@ class AIMessage(BaseMessage):
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict:
|
||||
def lc_attributes(self) -> dict:
|
||||
"""Attrs to be serialized even if they are derived from other init args."""
|
||||
return {
|
||||
"tool_calls": self.tool_calls,
|
||||
@@ -137,7 +137,7 @@ class AIMessage(BaseMessage):
|
||||
|
||||
# Ensure "type" is properly set on all tool call-like dicts.
|
||||
if tool_calls := values.get("tool_calls"):
|
||||
updated: List = []
|
||||
updated: list = []
|
||||
for tc in tool_calls:
|
||||
updated.append(
|
||||
create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
|
||||
@@ -178,7 +178,7 @@ class AIMessage(BaseMessage):
|
||||
base = super().pretty_repr(html=html)
|
||||
lines = []
|
||||
|
||||
def _format_tool_args(tc: Union[ToolCall, InvalidToolCall]) -> List[str]:
|
||||
def _format_tool_args(tc: Union[ToolCall, InvalidToolCall]) -> list[str]:
|
||||
lines = [
|
||||
f" {tc.get('name', 'Tool')} ({tc.get('id')})",
|
||||
f" Call ID: {tc.get('id')}",
|
||||
@@ -218,11 +218,11 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
"""The type of the message (used for deserialization).
|
||||
Defaults to "AIMessageChunk"."""
|
||||
|
||||
tool_call_chunks: List[ToolCallChunk] = []
|
||||
tool_call_chunks: list[ToolCallChunk] = []
|
||||
"""If provided, tool call chunks associated with the message."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
|
||||
Returns:
|
||||
@@ -232,7 +232,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict:
|
||||
def lc_attributes(self) -> dict:
|
||||
"""Attrs to be serialized even if they are derived from other init args."""
|
||||
return {
|
||||
"tool_calls": self.tool_calls,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user