mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-09 02:33:34 +00:00
Compare commits
74 Commits
isaac/tool
...
langchain-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c453b76579 | ||
|
|
f087ab43fd | ||
|
|
409f35363b | ||
|
|
e8236e58f2 | ||
|
|
eef18dec44 | ||
|
|
311f861547 | ||
|
|
c77c28e631 | ||
|
|
7d49ee9741 | ||
|
|
28dd6564db | ||
|
|
f91bdd12d2 | ||
|
|
4d3d62c249 | ||
|
|
60dc19da30 | ||
|
|
55b641b761 | ||
|
|
37b72023fe | ||
|
|
3fc0ea510e | ||
|
|
a8561bc303 | ||
|
|
4e0a6ebe7d | ||
|
|
fd21ffe293 | ||
|
|
7835c0651f | ||
|
|
85caaa773f | ||
|
|
8fb643a6e8 | ||
|
|
03b9aca55d | ||
|
|
acbb4e4701 | ||
|
|
e0c36afc3e | ||
|
|
9909354cd0 | ||
|
|
84b831356c | ||
|
|
a47b332841 | ||
|
|
0f07cf61da | ||
|
|
d158401e73 | ||
|
|
de58942618 | ||
|
|
df38d5250f | ||
|
|
b246052184 | ||
|
|
52729ac0be | ||
|
|
f62d454f36 | ||
|
|
6fe2536c5a | ||
|
|
418b170f94 | ||
|
|
c3b3f46cb8 | ||
|
|
e2245fac82 | ||
|
|
1a8e9023de | ||
|
|
1a62f9850f | ||
|
|
6ed50e78c9 | ||
|
|
5ced41bf50 | ||
|
|
c6bdd6f482 | ||
|
|
3a99467ccb | ||
|
|
2ef4c9466f | ||
|
|
194adc485c | ||
|
|
97b05d70e6 | ||
|
|
e1d113ea84 | ||
|
|
7c05f71e0f | ||
|
|
145a49cca2 | ||
|
|
5fc44989bf | ||
|
|
f4a65236ee | ||
|
|
06cde06a20 | ||
|
|
3e51fdc840 | ||
|
|
0a177ec2cc | ||
|
|
6758894af1 | ||
|
|
6ba3c715b7 | ||
|
|
d8952b8e8c | ||
|
|
31f61d4d7d | ||
|
|
99abd254fb | ||
|
|
3bcd641bc1 | ||
|
|
0bd98c99b3 | ||
|
|
8a2f2fc30b | ||
|
|
724a53711b | ||
|
|
c6a78132d6 | ||
|
|
a319a0ff1d | ||
|
|
63c3cc1f1f | ||
|
|
0154c586d3 | ||
|
|
c2588b334f | ||
|
|
8b985a42e9 | ||
|
|
5b4206acd8 | ||
|
|
0592c29e9b | ||
|
|
88891477eb | ||
|
|
88bc15d69b |
26
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
26
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -96,25 +96,21 @@ body:
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Please share your system info with us.
|
||||
Please share your system info with us. Do NOT skip this step and please don't trim
|
||||
the output. Most users don't include enough information here and it makes it harder
|
||||
for us to help you.
|
||||
|
||||
"pip freeze | grep langchain"
|
||||
platform (windows / linux / mac)
|
||||
python version
|
||||
|
||||
OR if you're on a recent version of langchain-core you can paste the output of:
|
||||
Run the following command in your terminal and paste the output here:
|
||||
|
||||
python -m langchain_core.sys_info
|
||||
|
||||
or if you have an existing python interpreter running:
|
||||
|
||||
from langchain_core import sys_info
|
||||
sys_info.print_sys_info()
|
||||
|
||||
alternatively, put the entire output of `pip freeze` here.
|
||||
placeholder: |
|
||||
"pip freeze | grep langchain"
|
||||
platform
|
||||
python version
|
||||
|
||||
Alternatively, if you're on a recent version of langchain-core you can paste the output of:
|
||||
|
||||
python -m langchain_core.sys_info
|
||||
|
||||
These will only surface LangChain packages, don't forget to include any other relevant
|
||||
packages you're using (if you're not sure what's relevant, you can paste the entire output of `pip freeze`).
|
||||
validations:
|
||||
required: true
|
||||
|
||||
12
.github/scripts/get_min_versions.py
vendored
12
.github/scripts/get_min_versions.py
vendored
@@ -21,7 +21,14 @@ MIN_VERSION_LIBS = [
|
||||
"SQLAlchemy",
|
||||
]
|
||||
|
||||
SKIP_IF_PULL_REQUEST = ["langchain-core"]
|
||||
# some libs only get checked on release because of simultaneous changes in
|
||||
# multiple libs
|
||||
SKIP_IF_PULL_REQUEST = [
|
||||
"langchain-core",
|
||||
"langchain-text-splitters",
|
||||
"langchain",
|
||||
"langchain-community",
|
||||
]
|
||||
|
||||
|
||||
def get_min_version(version: str) -> str:
|
||||
@@ -70,7 +77,7 @@ def get_min_version_from_toml(
|
||||
for lib in set(MIN_VERSION_LIBS + (include or [])):
|
||||
if versions_for == "pull_request" and lib in SKIP_IF_PULL_REQUEST:
|
||||
# some libs only get checked on release because of simultaneous
|
||||
# changes
|
||||
# changes in multiple libs
|
||||
continue
|
||||
# Check if the lib is present in the dependencies
|
||||
if lib in dependencies:
|
||||
@@ -88,7 +95,6 @@ def get_min_version_from_toml(
|
||||
if check_python_version(python_version, vs["python"])
|
||||
][0]["version"]
|
||||
|
||||
|
||||
# Use parse_version to get the minimum supported version from version_string
|
||||
min_version = get_min_version(version_string)
|
||||
|
||||
|
||||
2
.github/workflows/_release.yml
vendored
2
.github/workflows/_release.yml
vendored
@@ -85,7 +85,7 @@ jobs:
|
||||
path: langchain
|
||||
sparse-checkout: | # this only grabs files for relevant dir
|
||||
${{ inputs.working-directory }}
|
||||
ref: master # this scopes to just master branch
|
||||
ref: ${{ github.ref }} # this scopes to just ref'd branch
|
||||
fetch-depth: 0 # this fetches entire commit history
|
||||
- name: Check Tags
|
||||
id: check-tags
|
||||
|
||||
2
.github/workflows/_test.yml
vendored
2
.github/workflows/_test.yml
vendored
@@ -58,7 +58,7 @@ jobs:
|
||||
env:
|
||||
MIN_VERSIONS: ${{ steps.min-version.outputs.min-versions }}
|
||||
run: |
|
||||
poetry run pip install --force-reinstall $MIN_VERSIONS --editable .
|
||||
poetry run pip install $MIN_VERSIONS
|
||||
make tests
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
:member-order: groupwise
|
||||
:show-inheritance: True
|
||||
:special-members: __call__
|
||||
:exclude-members: construct, copy, dict, from_orm, parse_file, parse_obj, parse_raw, schema, schema_json, update_forward_refs, validate, json, is_lc_serializable, to_json_not_implemented, lc_secrets, lc_attributes, lc_id, get_lc_namespace, astream_log, transform, atransform, get_output_schema, get_prompts, config_schema, map, pick, pipe, with_listeners, with_alisteners, with_config, with_fallbacks, with_types, with_retry, InputType, OutputType, config_specs, output_schema, get_input_schema, get_graph, get_name, input_schema, name, bind, assign, as_tool, get_config_jsonschema, get_input_jsonschema, get_output_jsonschema, model_construct, model_copy, model_dump, model_dump_json, model_parametrized_name, model_post_init, model_rebuild, model_validate, model_validate_json, model_validate_strings, to_json, model_extra, model_fields_set, model_json_schema
|
||||
:exclude-members: construct, copy, dict, from_orm, parse_file, parse_obj, parse_raw, schema, schema_json, update_forward_refs, validate, json, is_lc_serializable, to_json_not_implemented, lc_secrets, lc_attributes, lc_id, get_lc_namespace, astream_log, transform, atransform, get_output_schema, get_prompts, config_schema, map, pick, pipe, InputType, OutputType, config_specs, output_schema, get_input_schema, get_graph, get_name, input_schema, name, assign, as_tool, get_config_jsonschema, get_input_jsonschema, get_output_jsonschema, model_construct, model_copy, model_dump, model_dump_json, model_parametrized_name, model_post_init, model_rebuild, model_validate, model_validate_json, model_validate_strings, to_json, model_extra, model_fields_set, model_json_schema, predict, apredict, predict_messages, apredict_messages, generate, generate_prompt, agenerate, agenerate_prompt, call_as_llm
|
||||
|
||||
.. NOTE:: {{objname}} implements the standard :py:class:`Runnable Interface <langchain_core.runnables.base.Runnable>`. 🏃
|
||||
|
||||
|
||||
@@ -206,7 +206,7 @@
|
||||
" ) -> List[Document]:\n",
|
||||
" \"\"\"Get docs, adding score information.\"\"\"\n",
|
||||
" docs, scores = zip(\n",
|
||||
" *vectorstore.similarity_search_with_score(query, **search_kwargs)\n",
|
||||
" *self.vectorstore.similarity_search_with_score(query, **search_kwargs)\n",
|
||||
" )\n",
|
||||
" for doc, score in zip(docs, scores):\n",
|
||||
" doc.metadata[\"score\"] = score\n",
|
||||
|
||||
@@ -15,43 +15,15 @@
|
||||
"\n",
|
||||
"Make sure you have the integration packages installed for any model providers you want to support. E.g. you should have `langchain-openai` installed to init an OpenAI model.\n",
|
||||
"\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
":::info Requires ``langchain >= 0.2.8``\n",
|
||||
"\n",
|
||||
"This functionality was added in ``langchain-core == 0.2.8``. Please make sure your package is up to date.\n",
|
||||
"\n",
|
||||
":::"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"id": "165b0de6-9ae3-4e3d-aa98-4fc8a97c4a06",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-09-10T20:22:32.858670Z",
|
||||
"iopub.status.busy": "2024-09-10T20:22:32.858278Z",
|
||||
"iopub.status.idle": "2024-09-10T20:22:33.009452Z",
|
||||
"shell.execute_reply": "2024-09-10T20:22:33.007022Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"zsh:1: 0.2.8 not found\r\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -qU langchain>=0.2.8 langchain-openai langchain-anthropic langchain-google-vertexai"
|
||||
]
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
"\n",
|
||||
"This sample demonstrates the use of `Amazon Textract` in combination with LangChain as a DocumentLoader.\n",
|
||||
"\n",
|
||||
"`Textract` supports`PDF`, `TIF`F, `PNG` and `JPEG` format.\n",
|
||||
"`Textract` supports`PDF`, `TIFF`, `PNG` and `JPEG` format.\n",
|
||||
"\n",
|
||||
"`Textract` supports these [document sizes, languages and characters](https://docs.aws.amazon.com/textract/latest/dg/limits-document.html)."
|
||||
]
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
"source": [
|
||||
"# Google Speech-to-Text Audio Transcripts\n",
|
||||
"\n",
|
||||
"The `GoogleSpeechToTextLoader` allows to transcribe audio files with the [Google Cloud Speech-to-Text API](https://cloud.google.com/speech-to-text) and loads the transcribed text into documents.\n",
|
||||
"The `SpeechToTextLoader` allows to transcribe audio files with the [Google Cloud Speech-to-Text API](https://cloud.google.com/speech-to-text) and loads the transcribed text into documents.\n",
|
||||
"\n",
|
||||
"To use it, you should have the `google-cloud-speech` python package installed, and a Google Cloud project with the [Speech-to-Text API enabled](https://cloud.google.com/speech-to-text/v2/docs/transcribe-client-libraries#before_you_begin).\n",
|
||||
"\n",
|
||||
@@ -41,7 +41,7 @@
|
||||
"source": [
|
||||
"## Example\n",
|
||||
"\n",
|
||||
"The `GoogleSpeechToTextLoader` must include the `project_id` and `file_path` arguments. Audio files can be specified as a Google Cloud Storage URI (`gs://...`) or a local file path.\n",
|
||||
"The `SpeechToTextLoader` must include the `project_id` and `file_path` arguments. Audio files can be specified as a Google Cloud Storage URI (`gs://...`) or a local file path.\n",
|
||||
"\n",
|
||||
"Only synchronous requests are supported by the loader, which has a [limit of 60 seconds or 10MB](https://cloud.google.com/speech-to-text/v2/docs/sync-recognize#:~:text=60%20seconds%20and/or%2010%20MB) per audio file."
|
||||
]
|
||||
@@ -52,13 +52,13 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_google_community import GoogleSpeechToTextLoader\n",
|
||||
"from langchain_google_community import SpeechToTextLoader\n",
|
||||
"\n",
|
||||
"project_id = \"<PROJECT_ID>\"\n",
|
||||
"file_path = \"gs://cloud-samples-data/speech/audio.flac\"\n",
|
||||
"# or a local file path: file_path = \"./audio.wav\"\n",
|
||||
"\n",
|
||||
"loader = GoogleSpeechToTextLoader(project_id=project_id, file_path=file_path)\n",
|
||||
"loader = SpeechToTextLoader(project_id=project_id, file_path=file_path)\n",
|
||||
"\n",
|
||||
"docs = loader.load()"
|
||||
]
|
||||
@@ -152,7 +152,7 @@
|
||||
" RecognitionConfig,\n",
|
||||
" RecognitionFeatures,\n",
|
||||
")\n",
|
||||
"from langchain_google_community import GoogleSpeechToTextLoader\n",
|
||||
"from langchain_google_community import SpeechToTextLoader\n",
|
||||
"\n",
|
||||
"project_id = \"<PROJECT_ID>\"\n",
|
||||
"location = \"global\"\n",
|
||||
@@ -171,7 +171,7 @@
|
||||
" ),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"loader = GoogleSpeechToTextLoader(\n",
|
||||
"loader = SpeechToTextLoader(\n",
|
||||
" project_id=project_id,\n",
|
||||
" location=location,\n",
|
||||
" recognizer_id=recognizer_id,\n",
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
"\n",
|
||||
"| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/document_loaders/file_loaders/unstructured/)|\n",
|
||||
"| :--- | :--- | :---: | :---: | :---: |\n",
|
||||
"| [UnstructuredLoader](https://python.langchain.com/api_reference/unstructured/document_loaders/langchain_unstructured.document_loaders.UnstructuredLoader.html) | [langchain_community](https://python.langchain.com/api_reference/unstructured/index.html) | ✅ | ❌ | ✅ | \n",
|
||||
"| [UnstructuredLoader](https://python.langchain.com/api_reference/unstructured/document_loaders/langchain_unstructured.document_loaders.UnstructuredLoader.html) | [langchain_unstructured](https://python.langchain.com/api_reference/unstructured/index.html) | ✅ | ❌ | ✅ | \n",
|
||||
"### Loader features\n",
|
||||
"| Source | Document Lazy Loading | Native Async Support\n",
|
||||
"| :---: | :---: | :---: | \n",
|
||||
@@ -519,6 +519,47 @@
|
||||
"print(\"Length of text in the document:\", len(docs[0].page_content))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3ec3c22d-02cd-498b-921f-b839d1404f32",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Loading web pages\n",
|
||||
"\n",
|
||||
"`UnstructuredLoader` accepts a `web_url` kwarg when run locally that populates the `url` parameter of the underlying Unstructured [partition](https://docs.unstructured.io/open-source/core-functionality/partitioning). This allows for the parsing of remotely hosted documents, such as HTML web pages.\n",
|
||||
"\n",
|
||||
"Example usage:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "bf9a8546-659d-4861-bff2-fdf1ad93ac65",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"page_content='Example Domain' metadata={'category_depth': 0, 'languages': ['eng'], 'filetype': 'text/html', 'url': 'https://www.example.com', 'category': 'Title', 'element_id': 'fdaa78d856f9d143aeeed85bf23f58f8'}\n",
|
||||
"\n",
|
||||
"page_content='This domain is for use in illustrative examples in documents. You may use this domain in literature without prior coordination or asking for permission.' metadata={'languages': ['eng'], 'parent_id': 'fdaa78d856f9d143aeeed85bf23f58f8', 'filetype': 'text/html', 'url': 'https://www.example.com', 'category': 'NarrativeText', 'element_id': '3652b8458b0688639f973fe36253c992'}\n",
|
||||
"\n",
|
||||
"page_content='More information...' metadata={'category_depth': 0, 'link_texts': ['More information...'], 'link_urls': ['https://www.iana.org/domains/example'], 'languages': ['eng'], 'filetype': 'text/html', 'url': 'https://www.example.com', 'category': 'Title', 'element_id': '793ab98565d6f6d6f3a6d614e3ace2a9'}\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_unstructured import UnstructuredLoader\n",
|
||||
"\n",
|
||||
"loader = UnstructuredLoader(web_url=\"https://www.example.com\")\n",
|
||||
"docs = loader.load()\n",
|
||||
"\n",
|
||||
"for doc in docs:\n",
|
||||
" print(f\"{doc}\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ce01aa40",
|
||||
@@ -546,7 +587,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
"version": "3.10.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -6,129 +6,11 @@
|
||||
"source": [
|
||||
"# SambaNova\n",
|
||||
"\n",
|
||||
"**[SambaNova](https://sambanova.ai/)'s** [Sambaverse](https://sambaverse.sambanova.ai/) and [Sambastudio](https://sambanova.ai/technology/full-stack-ai-platform) are platforms for running your own open-source models\n",
|
||||
"**[SambaNova](https://sambanova.ai/)'s** [Sambastudio](https://sambanova.ai/technology/full-stack-ai-platform) is a platform for running your own open-source models\n",
|
||||
"\n",
|
||||
"This example goes over how to use LangChain to interact with SambaNova models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Sambaverse"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Sambaverse** allows you to interact with multiple open-source models. You can view the list of available models and interact with them in the [playground](https://sambaverse.sambanova.ai/playground).\n",
|
||||
" **Please note that Sambaverse's free offering is performance-limited.** Companies that are ready to evaluate the production tokens-per-second performance, volume throughput, and 10x lower total cost of ownership (TCO) of SambaNova should [contact us](https://sambaverse.sambanova.ai/contact-us) for a non-limited evaluation instance."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"An API key is required to access Sambaverse models. To get a key, create an account at [sambaverse.sambanova.ai](https://sambaverse.sambanova.ai/)\n",
|
||||
"\n",
|
||||
"The [sseclient-py](https://pypi.org/project/sseclient-py/) package is required to run streaming predictions "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install --quiet sseclient-py==1.8.0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Register your API key as an environment variable:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"sambaverse_api_key = \"<Your sambaverse API key>\"\n",
|
||||
"\n",
|
||||
"# Set the environment variables\n",
|
||||
"os.environ[\"SAMBAVERSE_API_KEY\"] = sambaverse_api_key"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Call Sambaverse models directly from LangChain!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms.sambanova import Sambaverse\n",
|
||||
"\n",
|
||||
"llm = Sambaverse(\n",
|
||||
" sambaverse_model_name=\"Meta/llama-2-7b-chat-hf\",\n",
|
||||
" streaming=False,\n",
|
||||
" model_kwargs={\n",
|
||||
" \"do_sample\": True,\n",
|
||||
" \"max_tokens_to_generate\": 1000,\n",
|
||||
" \"temperature\": 0.01,\n",
|
||||
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
|
||||
" \"process_prompt\": False,\n",
|
||||
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
|
||||
" # \"repetition_penalty\": 1.0,\n",
|
||||
" # \"top_k\": 50,\n",
|
||||
" # \"top_p\": 1.0\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(llm.invoke(\"Why should I use open source models?\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Streaming response\n",
|
||||
"\n",
|
||||
"from langchain_community.llms.sambanova import Sambaverse\n",
|
||||
"\n",
|
||||
"llm = Sambaverse(\n",
|
||||
" sambaverse_model_name=\"Meta/llama-2-7b-chat-hf\",\n",
|
||||
" streaming=True,\n",
|
||||
" model_kwargs={\n",
|
||||
" \"do_sample\": True,\n",
|
||||
" \"max_tokens_to_generate\": 1000,\n",
|
||||
" \"temperature\": 0.01,\n",
|
||||
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
|
||||
" \"process_prompt\": False,\n",
|
||||
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
|
||||
" # \"repetition_penalty\": 1.0,\n",
|
||||
" # \"top_k\": 50,\n",
|
||||
" # \"top_p\": 1.0\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"for chunk in llm.stream(\"Why should I use open source models?\"):\n",
|
||||
" print(chunk, end=\"\", flush=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
# MLflow Deployments for LLMs
|
||||
# MLflow AI Gateway for LLMs
|
||||
|
||||
>[The MLflow Deployments for LLMs](https://www.mlflow.org/docs/latest/llms/deployments/index.html) is a powerful tool designed to streamline the usage and management of various large
|
||||
>[The MLflow AI Gateway for LLMs](https://www.mlflow.org/docs/latest/llms/deployments/index.html) is a powerful tool designed to streamline the usage and management of various large
|
||||
> language model (LLM) providers, such as OpenAI and Anthropic, within an organization. It offers a high-level interface
|
||||
> that simplifies the interaction with these services by providing a unified endpoint to handle specific LLM related requests.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
Install `mlflow` with MLflow Deployments dependencies:
|
||||
Install `mlflow` with MLflow GenAI dependencies:
|
||||
|
||||
```sh
|
||||
pip install 'mlflow[genai]'
|
||||
@@ -39,10 +39,10 @@ endpoints:
|
||||
openai_api_key: $OPENAI_API_KEY
|
||||
```
|
||||
|
||||
Start the deployments server:
|
||||
Start the gateway server:
|
||||
|
||||
```sh
|
||||
mlflow deployments start-server --config-path /path/to/config.yaml
|
||||
mlflow gateway start --config-path /path/to/config.yaml
|
||||
```
|
||||
|
||||
## Example provided by `MLflow`
|
||||
|
||||
@@ -1,160 +0,0 @@
|
||||
# MLflow AI Gateway
|
||||
|
||||
:::warning
|
||||
|
||||
MLflow AI Gateway has been deprecated. Please use [MLflow Deployments for LLMs](/docs/integrations/providers/mlflow/) instead.
|
||||
|
||||
:::
|
||||
|
||||
>[The MLflow AI Gateway](https://www.mlflow.org/docs/latest/index.html) service is a powerful tool designed to streamline the usage and management of various large
|
||||
> language model (LLM) providers, such as OpenAI and Anthropic, within an organization. It offers a high-level interface
|
||||
> that simplifies the interaction with these services by providing a unified endpoint to handle specific LLM related requests.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
Install `mlflow` with MLflow AI Gateway dependencies:
|
||||
|
||||
```sh
|
||||
pip install 'mlflow[gateway]'
|
||||
```
|
||||
|
||||
Set the OpenAI API key as an environment variable:
|
||||
|
||||
```sh
|
||||
export OPENAI_API_KEY=...
|
||||
```
|
||||
|
||||
Create a configuration file:
|
||||
|
||||
```yaml
|
||||
routes:
|
||||
- name: completions
|
||||
route_type: llm/v1/completions
|
||||
model:
|
||||
provider: openai
|
||||
name: text-davinci-003
|
||||
config:
|
||||
openai_api_key: $OPENAI_API_KEY
|
||||
|
||||
- name: embeddings
|
||||
route_type: llm/v1/embeddings
|
||||
model:
|
||||
provider: openai
|
||||
name: text-embedding-ada-002
|
||||
config:
|
||||
openai_api_key: $OPENAI_API_KEY
|
||||
```
|
||||
|
||||
Start the Gateway server:
|
||||
|
||||
```sh
|
||||
mlflow gateway start --config-path /path/to/config.yaml
|
||||
```
|
||||
|
||||
## Example provided by `MLflow`
|
||||
|
||||
>The `mlflow.langchain` module provides an API for logging and loading `LangChain` models.
|
||||
> This module exports multivariate LangChain models in the langchain flavor and univariate LangChain
|
||||
> models in the pyfunc flavor.
|
||||
|
||||
See the [API documentation and examples](https://www.mlflow.org/docs/latest/python_api/mlflow.langchain.html?highlight=langchain#module-mlflow.langchain).
|
||||
|
||||
|
||||
|
||||
## Completions Example
|
||||
|
||||
```python
|
||||
import mlflow
|
||||
from langchain.chains import LLMChain, PromptTemplate
|
||||
from langchain_community.llms import MlflowAIGateway
|
||||
|
||||
gateway = MlflowAIGateway(
|
||||
gateway_uri="http://127.0.0.1:5000",
|
||||
route="completions",
|
||||
params={
|
||||
"temperature": 0.0,
|
||||
"top_p": 0.1,
|
||||
},
|
||||
)
|
||||
|
||||
llm_chain = LLMChain(
|
||||
llm=gateway,
|
||||
prompt=PromptTemplate(
|
||||
input_variables=["adjective"],
|
||||
template="Tell me a {adjective} joke",
|
||||
),
|
||||
)
|
||||
result = llm_chain.run(adjective="funny")
|
||||
print(result)
|
||||
|
||||
with mlflow.start_run():
|
||||
model_info = mlflow.langchain.log_model(chain, "model")
|
||||
|
||||
model = mlflow.pyfunc.load_model(model_info.model_uri)
|
||||
print(model.predict([{"adjective": "funny"}]))
|
||||
```
|
||||
|
||||
## Embeddings Example
|
||||
|
||||
```python
|
||||
from langchain_community.embeddings import MlflowAIGatewayEmbeddings
|
||||
|
||||
embeddings = MlflowAIGatewayEmbeddings(
|
||||
gateway_uri="http://127.0.0.1:5000",
|
||||
route="embeddings",
|
||||
)
|
||||
|
||||
print(embeddings.embed_query("hello"))
|
||||
print(embeddings.embed_documents(["hello"]))
|
||||
```
|
||||
|
||||
## Chat Example
|
||||
|
||||
```python
|
||||
from langchain_community.chat_models import ChatMLflowAIGateway
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
chat = ChatMLflowAIGateway(
|
||||
gateway_uri="http://127.0.0.1:5000",
|
||||
route="chat",
|
||||
params={
|
||||
"temperature": 0.1
|
||||
}
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content="You are a helpful assistant that translates English to French."
|
||||
),
|
||||
HumanMessage(
|
||||
content="Translate this sentence from English to French: I love programming."
|
||||
),
|
||||
]
|
||||
print(chat(messages))
|
||||
```
|
||||
|
||||
## Databricks MLflow AI Gateway
|
||||
|
||||
Databricks MLflow AI Gateway is in private preview.
|
||||
Please contact a Databricks representative to enroll in the preview.
|
||||
|
||||
```python
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_community.llms import MlflowAIGateway
|
||||
|
||||
gateway = MlflowAIGateway(
|
||||
gateway_uri="databricks",
|
||||
route="completions",
|
||||
)
|
||||
|
||||
llm_chain = LLMChain(
|
||||
llm=gateway,
|
||||
prompt=PromptTemplate(
|
||||
input_variables=["adjective"],
|
||||
template="Tell me a {adjective} joke",
|
||||
),
|
||||
)
|
||||
result = llm_chain.run(adjective="funny")
|
||||
print(result)
|
||||
```
|
||||
File diff suppressed because one or more lines are too long
@@ -400,18 +400,29 @@
|
||||
"def hybrid_query(search_query: str) -> Dict:\n",
|
||||
" vector = embeddings.embed_query(search_query) # same embeddings as for indexing\n",
|
||||
" return {\n",
|
||||
" \"query\": {\n",
|
||||
" \"match\": {\n",
|
||||
" text_field: search_query,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"knn\": {\n",
|
||||
" \"field\": dense_vector_field,\n",
|
||||
" \"query_vector\": vector,\n",
|
||||
" \"k\": 5,\n",
|
||||
" \"num_candidates\": 10,\n",
|
||||
" },\n",
|
||||
" \"rank\": {\"rrf\": {}},\n",
|
||||
" \"retriever\": {\n",
|
||||
" \"rrf\": {\n",
|
||||
" \"retrievers\": [\n",
|
||||
" {\n",
|
||||
" \"standard\": {\n",
|
||||
" \"query\": {\n",
|
||||
" \"match\": {\n",
|
||||
" text_field: search_query,\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"knn\": {\n",
|
||||
" \"field\": dense_vector_field,\n",
|
||||
" \"query_vector\": vector,\n",
|
||||
" \"k\": 5,\n",
|
||||
" \"num_candidates\": 10,\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
||||
@@ -99,7 +99,7 @@
|
||||
"vector_store = Chroma(\n",
|
||||
" collection_name=\"example_collection\",\n",
|
||||
" embedding_function=embeddings,\n",
|
||||
" persist_directory=\"./chroma_langchain_db\", # Where to save data locally, remove if not neccesary\n",
|
||||
" persist_directory=\"./chroma_langchain_db\", # Where to save data locally, remove if not necessary\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@@ -179,7 +179,7 @@
|
||||
"from langchain_core.documents import Document\n",
|
||||
"\n",
|
||||
"document_1 = Document(\n",
|
||||
" page_content=\"I had chocalate chip pancakes and scrambled eggs for breakfast this morning.\",\n",
|
||||
" page_content=\"I had chocolate chip pancakes and scrambled eggs for breakfast this morning.\",\n",
|
||||
" metadata={\"source\": \"tweet\"},\n",
|
||||
" id=1,\n",
|
||||
")\n",
|
||||
@@ -273,7 +273,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"updated_document_1 = Document(\n",
|
||||
" page_content=\"I had chocalate chip pancakes and fried eggs for breakfast this morning.\",\n",
|
||||
" page_content=\"I had chocolate chip pancakes and fried eggs for breakfast this morning.\",\n",
|
||||
" metadata={\"source\": \"tweet\"},\n",
|
||||
" id=1,\n",
|
||||
")\n",
|
||||
@@ -287,7 +287,7 @@
|
||||
"vector_store.update_document(document_id=uuids[0], document=updated_document_1)\n",
|
||||
"# You can also update multiple documents at once\n",
|
||||
"vector_store.update_documents(\n",
|
||||
" ids=uuids[:2], documents=[updated_document_1, updated_document_1]\n",
|
||||
" ids=uuids[:2], documents=[updated_document_1, updated_document_2]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -380,7 +380,7 @@
|
||||
"source": [
|
||||
"## API reference\n",
|
||||
"\n",
|
||||
"For detailed documentation of all `AstraDBVectorStore` features and configurations head to the API reference:https://python.langchain.com/api_reference/community/vectorstores/langchain_community.vectorstores.clickhouse.Clickhouse.html"
|
||||
"For detailed documentation of all `Clickhouse` features and configurations head to the API reference:https://python.langchain.com/api_reference/community/vectorstores/langchain_community.vectorstores.clickhouse.Clickhouse.html"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@@ -90,7 +90,7 @@
|
||||
"source": [
|
||||
" </TabItem>\n",
|
||||
" <TabItem value=\"conda\" label=\"Conda\">\n",
|
||||
" <CodeBlock language=\"bash\">conda install langchain langchain_community langchain_chroma -c conda-forge</CodeBlock>\n",
|
||||
" <CodeBlock language=\"bash\">conda install langchain langchain-community langchain-chroma -c conda-forge</CodeBlock>\n",
|
||||
" </TabItem>\n",
|
||||
"</Tabs>\n",
|
||||
"\n",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# LangChain v0.3
|
||||
|
||||
*Last updated: 09.13.24*
|
||||
*Last updated: 09.16.24*
|
||||
|
||||
## What's changed
|
||||
|
||||
@@ -23,7 +23,7 @@ The following features have been added during the development of 0.2.x:
|
||||
|
||||
## How to update your code
|
||||
|
||||
If you're using `langchain` / `langchain-community` / `langchain-core` 0.0 or 0.1, we recommend that you first [upgrade to 0.2](https://python.langchain.com/v0.2/docs/versions/v0_2/). The `langchain-cli` will help you to migrate many imports automatically.
|
||||
If you're using `langchain` / `langchain-community` / `langchain-core` 0.0 or 0.1, we recommend that you first [upgrade to 0.2](https://python.langchain.com/v0.2/docs/versions/v0_2/).
|
||||
|
||||
If you're using `langgraph`, upgrade to `langgraph>=0.2.20,<0.3`. This will work with either 0.2 or 0.3 versions of all the base packages.
|
||||
|
||||
@@ -31,22 +31,27 @@ Here is a complete list of all packages that have been released and what we reco
|
||||
Any package that now requires `langchain-core` 0.3 had a minor version bump.
|
||||
Any package that is now compatible with both `langchain-core` 0.2 and 0.3 had a patch version bump.
|
||||
|
||||
You can use the `langchain-cli` to update deprecated imports automatically.
|
||||
The CLI will handle updating deprecated imports that were introduced in LangChain 0.0.x and LangChain 0.1, as
|
||||
well as updating the `langchain_core.pydantic_v1` and `langchain.pydantic_v1` imports.
|
||||
|
||||
|
||||
### Base packages
|
||||
|
||||
| Package | Latest | Recommended constraint |
|
||||
| -------------------------------------- | ------- | -------------------------- |
|
||||
| langchain | 0.3.0 | >=0.3,<0.4 |
|
||||
| langchain-community | 0.3.0 | >=0.3,<0.4 |
|
||||
| langchain-text-splitters | 0.3.0 | >=0.3,<0.4 |
|
||||
| langchain-core | 0.3.0 | >=0.3,<0.4 |
|
||||
| langchain-experimental | 0.3.0 | >=0.3,<0.4 |
|
||||
| Package | Latest | Recommended constraint |
|
||||
|--------------------------|--------|------------------------|
|
||||
| langchain | 0.3.0 | >=0.3,<0.4 |
|
||||
| langchain-community | 0.3.0 | >=0.3,<0.4 |
|
||||
| langchain-text-splitters | 0.3.0 | >=0.3,<0.4 |
|
||||
| langchain-core | 0.3.0 | >=0.3,<0.4 |
|
||||
| langchain-experimental | 0.3.0 | >=0.3,<0.4 |
|
||||
|
||||
### Downstream packages
|
||||
|
||||
| Package | Latest | Recommended constraint |
|
||||
| -------------------------------------- | ------- | -------------------------- |
|
||||
| langgraph | 0.2.20 | >=0.2.20,<0.3 |
|
||||
| langserve | 0.3.0 | >=0.3,<0.4 |
|
||||
| Package | Latest | Recommended constraint |
|
||||
|-----------|--------|------------------------|
|
||||
| langgraph | 0.2.20 | >=0.2.20,<0.3 |
|
||||
| langserve | 0.3.0 | >=0.3,<0.4 |
|
||||
|
||||
### Integration packages
|
||||
|
||||
@@ -59,7 +64,7 @@ Any package that is now compatible with both `langchain-core` 0.2 and 0.3 had a
|
||||
| langchain-azure-dynamic-sessions | 0.2.0 | >=0.2,<0.3 |
|
||||
| langchain-box | 0.2.0 | >=0.2,<0.3 |
|
||||
| langchain-chroma | 0.1.4 | >=0.1.4,<0.2 |
|
||||
| langchain-cohere | 0.2.0 | >=0.2,<0.3 |
|
||||
| langchain-cohere | 0.3.0 | >=0.3,<0.4 |
|
||||
| langchain-elasticsearch | 0.3.0 | >=0.3,<0.4 |
|
||||
| langchain-exa | 0.2.0 | >=0.2,<0.3 |
|
||||
| langchain-fireworks | 0.2.0 | >=0.2,<0.3 |
|
||||
@@ -68,6 +73,7 @@ Any package that is now compatible with both `langchain-core` 0.2 and 0.3 had a
|
||||
| langchain-google-genai | 2.0.0 | >=2,<3 |
|
||||
| langchain-google-vertexai | 2.0.0 | >=2,<3 |
|
||||
| langchain-huggingface | 0.1.0 | >=0.1,<0.2 |
|
||||
| langchain-ibm | 0.2.0 | >=0.2,<0.3 |
|
||||
| langchain-milvus | 0.1.6 | >=0.1.6,<0.2 |
|
||||
| langchain-mistralai | 0.2.0 | >=0.2,<0.3 |
|
||||
| langchain-mongodb | 0.2.0 | >=0.2,<0.3 |
|
||||
@@ -77,12 +83,14 @@ Any package that is now compatible with both `langchain-core` 0.2 and 0.3 had a
|
||||
| langchain-pinecone | 0.2.0 | >=0.2,<0.3 |
|
||||
| langchain-postgres | 0.0.13 | >=0.0.13,<0.1 |
|
||||
| langchain-prompty | 0.1.0 | >=0.1,<0.2 |
|
||||
| langchain-qdrant | 0.1.4 | >=0.1.4,<0.2 |
|
||||
| langchain-redis | 0.1.0 | >=0.1,<0.2 |
|
||||
| langchain-qdrant | 0.2.0 | >=0.2,<0.3 |
|
||||
| langchain-sema4 | 0.2.0 | >=0.2,<0.3 |
|
||||
| langchain-together | 0.2.0 | >=0.2,<0.3 |
|
||||
| langchain-unstructured | 0.1.4 | >=0.1.4,<0.2 |
|
||||
| langchain-upstage | 0.3.0 | >=0.3,<0.4 |
|
||||
| langchain-voyageai | 0.2.0 | >=0.2,<0.3 |
|
||||
| langchain-weaviate | 0.1.0 | >=0.1,<0.2 |
|
||||
| langchain-weaviate | 0.0.3 | >=0.0.3,<0.1 |
|
||||
|
||||
Once you've updated to recent versions of the packages, you may need to address the following issues stemming from the internal switch from Pydantic v1 to Pydantic v2:
|
||||
|
||||
@@ -185,6 +193,8 @@ CustomTool(
|
||||
When sub-classing from LangChain models, users may need to add relevant imports
|
||||
to the file and rebuild the model.
|
||||
|
||||
You can read more about `model_rebuild` [here](https://docs.pydantic.dev/latest/concepts/models/#rebuilding-model-schema).
|
||||
|
||||
```python
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
|
||||
@@ -205,3 +215,57 @@ class FooParser(BaseOutputParser):
|
||||
|
||||
FooParser.model_rebuild()
|
||||
```
|
||||
|
||||
## Migrate using langchain-cli
|
||||
|
||||
The `langchain-cli` can help update deprecated LangChain imports in your code automatically.
|
||||
|
||||
Please note that the `langchain-cli` only handles deprecated LangChain imports and cannot
|
||||
help to upgrade your code from pydantic 1 to pydantic 2.
|
||||
|
||||
For help with the Pydantic 1 to 2 migration itself please refer to the [Pydantic Migration Guidelines](https://docs.pydantic.dev/latest/migration/).
|
||||
|
||||
As of 0.0.31, the `langchain-cli` relies on [gritql](https://about.grit.io/) for applying code mods.
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip install -U langchain-cli
|
||||
langchain-cli --version # <-- Make sure the version is at least 0.0.31
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
Given that the migration script is not perfect, you should make sure you have a backup of your code first (e.g., using version control like `git`).
|
||||
|
||||
The `langchain-cli` will handle the `langchain_core.pydantic_v1` deprecation introduced in LangChain 0.3 as well
|
||||
as older deprecations (e.g.,`from langchain.chat_models import ChatOpenAI` which should be `from langchain_openai import ChatOpenAI`),
|
||||
|
||||
You will need to run the migration script **twice** as it only applies one import replacement per run.
|
||||
|
||||
For example, say that your code is still using the old import `from langchain.chat_models import ChatOpenAI`:
|
||||
|
||||
After the first run, you’ll get: `from langchain_community.chat_models import ChatOpenAI`
|
||||
After the second run, you’ll get: `from langchain_openai import ChatOpenAI`
|
||||
|
||||
```bash
|
||||
# Run a first time
|
||||
# Will replace from langchain.chat_models import ChatOpenAI
|
||||
langchain-cli migrate --help [path to code] # Help
|
||||
langchain-cli migrate [path to code] # Apply
|
||||
|
||||
# Run a second time to apply more import replacements
|
||||
langchain-cli migrate --diff [path to code] # Preview
|
||||
langchain-cli migrate [path to code] # Apply
|
||||
```
|
||||
|
||||
### Other options
|
||||
|
||||
```bash
|
||||
# See help menu
|
||||
langchain-cli migrate --help
|
||||
# Preview Changes without applying
|
||||
langchain-cli migrate --diff [path to code]
|
||||
# Approve changes interactively
|
||||
langchain-cli migrate --interactive [path to code]
|
||||
```
|
||||
|
||||
@@ -168,52 +168,43 @@ const config = {
|
||||
label: "Integrations",
|
||||
},
|
||||
{
|
||||
type: "dropdown",
|
||||
label: "API reference",
|
||||
position: "left",
|
||||
items: [
|
||||
{
|
||||
label: "Latest",
|
||||
to: "https://python.langchain.com/api_reference/reference.html",
|
||||
},
|
||||
{
|
||||
label: "Legacy",
|
||||
href: "https://api.python.langchain.com/"
|
||||
}
|
||||
]
|
||||
label: "API Reference",
|
||||
to: "https://python.langchain.com/api_reference/",
|
||||
},
|
||||
{
|
||||
type: "dropdown",
|
||||
label: "More",
|
||||
position: "left",
|
||||
items: [
|
||||
{
|
||||
type: "doc",
|
||||
docId: "people",
|
||||
label: "People",
|
||||
},
|
||||
{
|
||||
type: "doc",
|
||||
docId: "contributing/index",
|
||||
label: "Contributing",
|
||||
},
|
||||
{
|
||||
label: "Cookbooks",
|
||||
href: "https://github.com/langchain-ai/langchain/blob/master/cookbook/README.md"
|
||||
},
|
||||
{
|
||||
type: "doc",
|
||||
docId: "additional_resources/tutorials",
|
||||
label: "3rd party tutorials"
|
||||
docId: "people",
|
||||
label: "People",
|
||||
},
|
||||
{
|
||||
type: "doc",
|
||||
docId: "additional_resources/youtube",
|
||||
label: "YouTube"
|
||||
type: 'html',
|
||||
value: '<hr class="dropdown-separator" style="margin-top: 0.5rem; margin-bottom: 0.5rem">',
|
||||
},
|
||||
{
|
||||
to: "/docs/additional_resources/arxiv_references",
|
||||
label: "arXiv"
|
||||
href: "https://docs.smith.langchain.com",
|
||||
label: "LangSmith",
|
||||
},
|
||||
{
|
||||
href: "https://langchain-ai.github.io/langgraph/",
|
||||
label: "LangGraph",
|
||||
},
|
||||
{
|
||||
href: "https://smith.langchain.com/hub",
|
||||
label: "LangChain Hub",
|
||||
},
|
||||
{
|
||||
href: "https://js.langchain.com",
|
||||
label: "LangChain JS/TS",
|
||||
},
|
||||
]
|
||||
},
|
||||
@@ -237,30 +228,7 @@ const config = {
|
||||
]
|
||||
},
|
||||
{
|
||||
type: "dropdown",
|
||||
label: "🦜️🔗",
|
||||
position: "right",
|
||||
items: [
|
||||
{
|
||||
href: "https://smith.langchain.com",
|
||||
label: "LangSmith",
|
||||
},
|
||||
{
|
||||
href: "https://docs.smith.langchain.com/",
|
||||
label: "LangSmith Docs",
|
||||
},
|
||||
{
|
||||
href: "https://smith.langchain.com/hub",
|
||||
label: "LangChain Hub",
|
||||
},
|
||||
{
|
||||
href: "https://js.langchain.com",
|
||||
label: "JS/TS Docs",
|
||||
},
|
||||
]
|
||||
},
|
||||
{
|
||||
href: "https://chat.langchain.com",
|
||||
to: "https://chat.langchain.com",
|
||||
label: "💬",
|
||||
position: "right",
|
||||
},
|
||||
|
||||
@@ -38,6 +38,9 @@
|
||||
--ifm-menu-link-padding-horizontal: 0.5rem;
|
||||
--ifm-menu-link-padding-vertical: 0.5rem;
|
||||
--doc-sidebar-width: 275px !important;
|
||||
|
||||
/* Code block syntax highlighting */
|
||||
--docusaurus-highlighted-code-line-bg: rgb(176, 227, 199);
|
||||
}
|
||||
|
||||
/* For readability concerns, you should choose a lighter palette in dark mode. */
|
||||
@@ -49,6 +52,9 @@
|
||||
--ifm-color-primary-light: #29d5b0;
|
||||
--ifm-color-primary-lighter: #32d8b4;
|
||||
--ifm-color-primary-lightest: #4fddbf;
|
||||
|
||||
/* Code block syntax highlighting */
|
||||
--docusaurus-highlighted-code-line-bg: rgb(14, 73, 60);
|
||||
}
|
||||
|
||||
nav, h1, h2, h3, h4 {
|
||||
|
||||
@@ -354,7 +354,7 @@ const FEATURE_TABLES = {
|
||||
},
|
||||
{
|
||||
name: "Nomic",
|
||||
link: "cohere",
|
||||
link: "nomic",
|
||||
package: "langchain-nomic",
|
||||
apiLink: "https://python.langchain.com/api_reference/nomic/embeddings/langchain_nomic.embeddings.NomicEmbeddings.html"
|
||||
},
|
||||
@@ -886,7 +886,7 @@ const FEATURE_TABLES = {
|
||||
apiLink: "https://python.langchain.com/api_reference/community/document_loaders/langchain_community.document_loaders.html_bs.BSHTMLLoader.html"
|
||||
},
|
||||
{
|
||||
name: "UnstrucutredXMLLoader",
|
||||
name: "UnstructuredXMLLoader",
|
||||
link: "xml",
|
||||
source: "XML files",
|
||||
apiLink: "https://python.langchain.com/api_reference/community/document_loaders/langchain_community.document_loaders.xml.UnstructuredXMLLoader.html"
|
||||
|
||||
@@ -26,6 +26,22 @@
|
||||
}
|
||||
],
|
||||
"redirects": [
|
||||
{
|
||||
"source": "/v0.3/docs/:path(.*/?)*",
|
||||
"destination": "/docs/:path*"
|
||||
},
|
||||
{
|
||||
"source": "/docs/modules/agents/tools/custom_tools(/?)",
|
||||
"destination": "/docs/how_to/custom_tools/"
|
||||
},
|
||||
{
|
||||
"source": "/docs/expression_language(/?)",
|
||||
"destination": "/docs/concepts/#langchain-expression-language-lcel"
|
||||
},
|
||||
{
|
||||
"source": "/docs/expression_language/interface(/?)",
|
||||
"destination": "/docs/concepts/#runnable-interface"
|
||||
},
|
||||
{
|
||||
"source": "/docs/versions/overview(/?)",
|
||||
"destination": "/docs/versions/v0_2/overview/"
|
||||
@@ -61,6 +77,10 @@
|
||||
{
|
||||
"source": "/v0.2/docs/templates/:path(.*/?)*",
|
||||
"destination": "https://github.com/langchain-ai/langchain/tree/master/templates/:path*"
|
||||
},
|
||||
{
|
||||
"source": "/docs/integrations/providers/mlflow_ai_gateway(/?)",
|
||||
"destination": "/docs/integrations/providers/mlflow/"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
import rich
|
||||
import typer
|
||||
from gritql import run
|
||||
from typer import Option
|
||||
|
||||
|
||||
def get_gritdir_path() -> Path:
|
||||
@@ -15,6 +16,17 @@ def get_gritdir_path() -> Path:
|
||||
|
||||
def migrate(
|
||||
ctx: typer.Context,
|
||||
# Using diff instead of dry-run for backwards compatibility with the old CLI
|
||||
diff: bool = Option(
|
||||
False,
|
||||
"--diff",
|
||||
help="Show the changes that would be made without applying them.",
|
||||
),
|
||||
interactive: bool = Option(
|
||||
False,
|
||||
"--interactive",
|
||||
help="Prompt for confirmation before making each change",
|
||||
),
|
||||
) -> None:
|
||||
"""Migrate langchain to the most recent version.
|
||||
|
||||
@@ -47,9 +59,15 @@ def migrate(
|
||||
rich.print("-" * 10)
|
||||
rich.print()
|
||||
|
||||
args = list(ctx.args)
|
||||
if interactive:
|
||||
args.append("--interactive")
|
||||
if diff:
|
||||
args.append("--dry-run")
|
||||
|
||||
final_code = run.apply_pattern(
|
||||
"langchain_all_migrations()",
|
||||
ctx.args,
|
||||
args,
|
||||
grit_dir=get_gritdir_path(),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-cli"
|
||||
version = "0.0.30"
|
||||
version = "0.0.31"
|
||||
description = "CLI for interacting with LangChain"
|
||||
authors = ["Erick Friis <erick@langchain.dev>"]
|
||||
readme = "README.md"
|
||||
|
||||
@@ -15,7 +15,7 @@ LangChain Community contains third-party integrations that implement the base in
|
||||
|
||||
For full documentation see the [API reference](https://api.python.langchain.com/en/stable/community_api_reference.html).
|
||||
|
||||

|
||||

|
||||
|
||||
## 📕 Releases & Versioning
|
||||
|
||||
|
||||
@@ -301,7 +301,7 @@ class OpenAIAssistantV2Runnable(OpenAIAssistantRunnable):
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
dumpd(self), input, name=config.get("run_name") or self.get_name()
|
||||
)
|
||||
|
||||
files = _convert_file_ids_into_attachments(kwargs.get("file_ids", []))
|
||||
@@ -437,7 +437,7 @@ class OpenAIAssistantV2Runnable(OpenAIAssistantRunnable):
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
dumpd(self), input, name=config.get("run_name") or self.get_name()
|
||||
)
|
||||
|
||||
files = _convert_file_ids_into_attachments(kwargs.get("file_ids", []))
|
||||
|
||||
@@ -8,6 +8,18 @@ from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
|
||||
MODEL_COST_PER_1K_TOKENS = {
|
||||
# OpenAI o1-preview input
|
||||
"o1-preview": 0.015,
|
||||
"o1-preview-2024-09-12": 0.015,
|
||||
# OpenAI o1-preview output
|
||||
"o1-preview-completion": 0.06,
|
||||
"o1-preview-2024-09-12-completion": 0.06,
|
||||
# OpenAI o1-mini input
|
||||
"o1-mini": 0.003,
|
||||
"o1-mini-2024-09-12": 0.003,
|
||||
# OpenAI o1-mini output
|
||||
"o1-mini-completion": 0.012,
|
||||
"o1-mini-2024-09-12-completion": 0.012,
|
||||
# GPT-4o-mini input
|
||||
"gpt-4o-mini": 0.00015,
|
||||
"gpt-4o-mini-2024-07-18": 0.00015,
|
||||
@@ -153,6 +165,7 @@ def standardize_model_name(
|
||||
model_name.startswith("gpt-4")
|
||||
or model_name.startswith("gpt-3.5")
|
||||
or model_name.startswith("gpt-35")
|
||||
or model_name.startswith("o1-")
|
||||
or ("finetuned" in model_name and "legacy" not in model_name)
|
||||
):
|
||||
return model_name + "-completion"
|
||||
|
||||
@@ -53,13 +53,15 @@ class LLMThoughtLabeler:
|
||||
labeling logic.
|
||||
"""
|
||||
|
||||
def get_initial_label(self) -> str:
|
||||
@staticmethod
|
||||
def get_initial_label() -> str:
|
||||
"""Return the markdown label for a new LLMThought that doesn't have
|
||||
an associated tool yet.
|
||||
"""
|
||||
return f"{THINKING_EMOJI} **Thinking...**"
|
||||
|
||||
def get_tool_label(self, tool: ToolRecord, is_complete: bool) -> str:
|
||||
@staticmethod
|
||||
def get_tool_label(tool: ToolRecord, is_complete: bool) -> str:
|
||||
"""Return the label for an LLMThought that has an associated
|
||||
tool.
|
||||
|
||||
@@ -91,13 +93,15 @@ class LLMThoughtLabeler:
|
||||
label = f"{emoji} **{name}:** {input}"
|
||||
return label
|
||||
|
||||
def get_history_label(self) -> str:
|
||||
@staticmethod
|
||||
def get_history_label() -> str:
|
||||
"""Return a markdown label for the special 'history' container
|
||||
that contains overflow thoughts.
|
||||
"""
|
||||
return f"{HISTORY_EMOJI} **History**"
|
||||
|
||||
def get_final_agent_thought_label(self) -> str:
|
||||
@staticmethod
|
||||
def get_final_agent_thought_label() -> str:
|
||||
"""Return the markdown label for the agent's final thought -
|
||||
the "Now I have the answer" thought, that doesn't involve
|
||||
a tool.
|
||||
|
||||
@@ -359,6 +359,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from langchain_community.document_loaders.pebblo import (
|
||||
PebbloSafeLoader,
|
||||
PebbloTextLoader,
|
||||
)
|
||||
from langchain_community.document_loaders.polars_dataframe import (
|
||||
PolarsDataFrameLoader,
|
||||
@@ -650,6 +651,7 @@ _module_lookup = {
|
||||
"PDFPlumberLoader": "langchain_community.document_loaders.pdf",
|
||||
"PagedPDFSplitter": "langchain_community.document_loaders.pdf",
|
||||
"PebbloSafeLoader": "langchain_community.document_loaders.pebblo",
|
||||
"PebbloTextLoader": "langchain_community.document_loaders.pebblo",
|
||||
"PlaywrightURLLoader": "langchain_community.document_loaders.url_playwright",
|
||||
"PolarsDataFrameLoader": "langchain_community.document_loaders.polars_dataframe",
|
||||
"PsychicLoader": "langchain_community.document_loaders.psychic",
|
||||
@@ -855,6 +857,7 @@ __all__ = [
|
||||
"PDFPlumberLoader",
|
||||
"PagedPDFSplitter",
|
||||
"PebbloSafeLoader",
|
||||
"PebbloTextLoader",
|
||||
"PlaywrightURLLoader",
|
||||
"PolarsDataFrameLoader",
|
||||
"PsychicLoader",
|
||||
|
||||
@@ -20,13 +20,37 @@ class MongodbLoader(BaseLoader):
|
||||
*,
|
||||
filter_criteria: Optional[Dict] = None,
|
||||
field_names: Optional[Sequence[str]] = None,
|
||||
metadata_names: Optional[Sequence[str]] = None,
|
||||
include_db_collection_in_metadata: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the MongoDB loader with necessary database connection
|
||||
details and configurations.
|
||||
|
||||
Args:
|
||||
connection_string (str): MongoDB connection URI.
|
||||
db_name (str):Name of the database to connect to.
|
||||
collection_name (str): Name of the collection to fetch documents from.
|
||||
filter_criteria (Optional[Dict]): MongoDB filter criteria for querying
|
||||
documents.
|
||||
field_names (Optional[Sequence[str]]): List of field names to retrieve
|
||||
from documents.
|
||||
metadata_names (Optional[Sequence[str]]): Additional metadata fields to
|
||||
extract from documents.
|
||||
include_db_collection_in_metadata (bool): Flag to include database and
|
||||
collection names in metadata.
|
||||
|
||||
Raises:
|
||||
ImportError: If the motor library is not installed.
|
||||
ValueError: If any necessary argument is missing.
|
||||
"""
|
||||
try:
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Cannot import from motor, please install with `pip install motor`."
|
||||
) from e
|
||||
|
||||
if not connection_string:
|
||||
raise ValueError("connection_string must be provided.")
|
||||
|
||||
@@ -39,8 +63,10 @@ class MongodbLoader(BaseLoader):
|
||||
self.client = AsyncIOMotorClient(connection_string)
|
||||
self.db_name = db_name
|
||||
self.collection_name = collection_name
|
||||
self.field_names = field_names
|
||||
self.field_names = field_names or []
|
||||
self.filter_criteria = filter_criteria or {}
|
||||
self.metadata_names = metadata_names or []
|
||||
self.include_db_collection_in_metadata = include_db_collection_in_metadata
|
||||
|
||||
self.db = self.client.get_database(db_name)
|
||||
self.collection = self.db.get_collection(collection_name)
|
||||
@@ -60,36 +86,24 @@ class MongodbLoader(BaseLoader):
|
||||
return asyncio.run(self.aload())
|
||||
|
||||
async def aload(self) -> List[Document]:
|
||||
"""Load data into Document objects."""
|
||||
"""Asynchronously loads data into Document objects."""
|
||||
result = []
|
||||
total_docs = await self.collection.count_documents(self.filter_criteria)
|
||||
|
||||
# Construct the projection dictionary if field_names are specified
|
||||
projection = (
|
||||
{field: 1 for field in self.field_names} if self.field_names else None
|
||||
)
|
||||
projection = self._construct_projection()
|
||||
|
||||
async for doc in self.collection.find(self.filter_criteria, projection):
|
||||
metadata = {
|
||||
"database": self.db_name,
|
||||
"collection": self.collection_name,
|
||||
}
|
||||
metadata = self._extract_fields(doc, self.metadata_names, default="")
|
||||
|
||||
# Optionally add database and collection names to metadata
|
||||
if self.include_db_collection_in_metadata:
|
||||
metadata.update(
|
||||
{"database": self.db_name, "collection": self.collection_name}
|
||||
)
|
||||
|
||||
# Extract text content from filtered fields or use the entire document
|
||||
if self.field_names is not None:
|
||||
fields = {}
|
||||
for name in self.field_names:
|
||||
# Split the field names to handle nested fields
|
||||
keys = name.split(".")
|
||||
value = doc
|
||||
for key in keys:
|
||||
if key in value:
|
||||
value = value[key]
|
||||
else:
|
||||
value = ""
|
||||
break
|
||||
fields[name] = value
|
||||
|
||||
fields = self._extract_fields(doc, self.field_names, default="")
|
||||
texts = [str(value) for value in fields.values()]
|
||||
text = " ".join(texts)
|
||||
else:
|
||||
@@ -104,3 +118,29 @@ class MongodbLoader(BaseLoader):
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _construct_projection(self) -> Optional[Dict]:
|
||||
"""Constructs the projection dictionary for MongoDB query based
|
||||
on the specified field names and metadata names."""
|
||||
field_names = list(self.field_names) or []
|
||||
metadata_names = list(self.metadata_names) or []
|
||||
all_fields = field_names + metadata_names
|
||||
return {field: 1 for field in all_fields} if all_fields else None
|
||||
|
||||
def _extract_fields(
|
||||
self,
|
||||
document: Dict,
|
||||
fields: Sequence[str],
|
||||
default: str = "",
|
||||
) -> Dict:
|
||||
"""Extracts and returns values for specified fields from a document."""
|
||||
extracted = {}
|
||||
for field in fields or []:
|
||||
value = document
|
||||
for key in field.split("."):
|
||||
value = value.get(key, default)
|
||||
if value == default:
|
||||
break
|
||||
new_field_name = field.replace(".", "_")
|
||||
extracted[new_field_name] = value
|
||||
return extracted
|
||||
|
||||
@@ -267,6 +267,7 @@ class PyMuPDFParser(BaseBlobParser):
|
||||
|
||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]: # type: ignore[valid-type]
|
||||
"""Lazily parse the blob."""
|
||||
|
||||
import fitz
|
||||
|
||||
with blob.as_bytes_io() as file_path: # type: ignore[attr-defined]
|
||||
@@ -277,25 +278,49 @@ class PyMuPDFParser(BaseBlobParser):
|
||||
|
||||
yield from [
|
||||
Document(
|
||||
page_content=page.get_text(**self.text_kwargs)
|
||||
+ self._extract_images_from_page(doc, page),
|
||||
metadata=dict(
|
||||
{
|
||||
"source": blob.source, # type: ignore[attr-defined]
|
||||
"file_path": blob.source, # type: ignore[attr-defined]
|
||||
"page": page.number,
|
||||
"total_pages": len(doc),
|
||||
},
|
||||
**{
|
||||
k: doc.metadata[k]
|
||||
for k in doc.metadata
|
||||
if type(doc.metadata[k]) in [str, int]
|
||||
},
|
||||
),
|
||||
page_content=self._get_page_content(doc, page, blob),
|
||||
metadata=self._extract_metadata(doc, page, blob),
|
||||
)
|
||||
for page in doc
|
||||
]
|
||||
|
||||
def _get_page_content(
|
||||
self, doc: fitz.fitz.Document, page: fitz.fitz.Page, blob: Blob
|
||||
) -> str:
|
||||
"""
|
||||
Get the text of the page using PyMuPDF and RapidOCR and issue a warning
|
||||
if it is empty.
|
||||
"""
|
||||
content = page.get_text(**self.text_kwargs) + self._extract_images_from_page(
|
||||
doc, page
|
||||
)
|
||||
|
||||
if not content:
|
||||
warnings.warn(
|
||||
f"Warning: Empty content on page "
|
||||
f"{page.number} of document {blob.source}"
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
def _extract_metadata(
|
||||
self, doc: fitz.fitz.Document, page: fitz.fitz.Page, blob: Blob
|
||||
) -> dict:
|
||||
"""Extract metadata from the document and page."""
|
||||
return dict(
|
||||
{
|
||||
"source": blob.source, # type: ignore[attr-defined]
|
||||
"file_path": blob.source, # type: ignore[attr-defined]
|
||||
"page": page.number,
|
||||
"total_pages": len(doc),
|
||||
},
|
||||
**{
|
||||
k: doc.metadata[k]
|
||||
for k in doc.metadata
|
||||
if isinstance(doc.metadata[k], (str, int))
|
||||
},
|
||||
)
|
||||
|
||||
def _extract_images_from_page(
|
||||
self, doc: fitz.fitz.Document, page: fitz.fitz.Page
|
||||
) -> str:
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from importlib.metadata import version
|
||||
from typing import Dict, Iterator, List, Optional
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@@ -271,3 +271,67 @@ class PebbloSafeLoader(BaseLoader):
|
||||
doc_metadata["pb_checksum"] = classified_docs.get(doc.pb_id, {}).get(
|
||||
"pb_checksum", None
|
||||
)
|
||||
|
||||
|
||||
class PebbloTextLoader(BaseLoader):
|
||||
"""
|
||||
Loader for text data.
|
||||
|
||||
Since PebbloSafeLoader is a wrapper around document loaders, this loader is
|
||||
used to load text data directly into Documents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
*,
|
||||
source: Optional[str] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
texts: Iterable of text data.
|
||||
source: Source of the text data.
|
||||
Optional. Defaults to None.
|
||||
ids: List of unique identifiers for each text.
|
||||
Optional. Defaults to None.
|
||||
metadata: Metadata for all texts.
|
||||
Optional. Defaults to None.
|
||||
metadatas: List of metadata for each text.
|
||||
Optional. Defaults to None.
|
||||
"""
|
||||
self.texts = texts
|
||||
self.source = source
|
||||
self.ids = ids
|
||||
self.metadata = metadata
|
||||
self.metadatas = metadatas
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""
|
||||
Lazy load text data into Documents.
|
||||
|
||||
Returns:
|
||||
Iterator of Documents
|
||||
"""
|
||||
for i, text in enumerate(self.texts):
|
||||
_id = None
|
||||
metadata = self.metadata or {}
|
||||
if self.metadatas and i < len(self.metadatas) and self.metadatas[i]:
|
||||
metadata.update(self.metadatas[i])
|
||||
if self.ids and i < len(self.ids):
|
||||
_id = self.ids[i]
|
||||
yield Document(id=_id, page_content=text, metadata=metadata)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""
|
||||
Load text data into Documents.
|
||||
|
||||
Returns:
|
||||
List of Documents
|
||||
"""
|
||||
documents = []
|
||||
for doc in self.lazy_load():
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
@@ -227,7 +227,7 @@ class RecursiveUrlLoader(BaseLoader):
|
||||
"https://docs.python.org/3.9/",
|
||||
prevent_outside=True,
|
||||
base_url="https://docs.python.org",
|
||||
link_regex=r'<a\s+(?:[^>]*?\s+)?href="([^"]*(?=index)[^"]*)"',
|
||||
link_regex=r'<a\\s+(?:[^>]*?\\s+)?href="([^"]*(?=index)[^"]*)"',
|
||||
exclude_dirs=['https://docs.python.org/3.9/faq']
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
@@ -132,6 +132,7 @@ class BeautifulSoupTransformer(BaseDocumentTransformer):
|
||||
Args:
|
||||
html_content: The original HTML content string.
|
||||
tags: A list of tags to be extracted from the HTML.
|
||||
remove_comments: If set to True, the comments will be removed.
|
||||
|
||||
Returns:
|
||||
A string combining the content of the extracted tags.
|
||||
@@ -184,6 +185,7 @@ def get_navigable_strings(
|
||||
|
||||
Args:
|
||||
element: A BeautifulSoup element.
|
||||
remove_comments: If set to True, the comments will be removed.
|
||||
|
||||
Returns:
|
||||
A generator of strings.
|
||||
|
||||
@@ -213,7 +213,7 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
try:
|
||||
if params.get("select_expert"):
|
||||
embedding = response.json()["predictions"][0]
|
||||
embedding = response.json()["predictions"]
|
||||
else:
|
||||
embedding = response.json()["predictions"]
|
||||
embeddings.extend(embedding)
|
||||
@@ -299,7 +299,7 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
try:
|
||||
if params.get("select_expert"):
|
||||
embedding = response.json()["predictions"][0][0]
|
||||
embedding = response.json()["predictions"][0]
|
||||
else:
|
||||
embedding = response.json()["predictions"][0]
|
||||
except KeyError:
|
||||
|
||||
@@ -1,7 +1,708 @@
|
||||
from langchain_core.graph_vectorstores.base import (
|
||||
GraphVectorStore,
|
||||
GraphVectorStoreRetriever,
|
||||
Node,
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterable, Collection, Iterable, Iterator
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Optional,
|
||||
)
|
||||
|
||||
__all__ = ["GraphVectorStore", "GraphVectorStoreRetriever", "Node"]
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.runnables import run_in_executor
|
||||
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
|
||||
|
||||
|
||||
def _has_next(iterator: Iterator) -> bool:
|
||||
"""Checks if the iterator has more elements.
|
||||
Warning: consumes an element from the iterator"""
|
||||
sentinel = object()
|
||||
return next(iterator, sentinel) is not sentinel
|
||||
|
||||
|
||||
@beta()
|
||||
class Node(Serializable):
|
||||
"""Node in the GraphVectorStore.
|
||||
|
||||
Edges exist from nodes with an outgoing link to nodes with a matching incoming link.
|
||||
|
||||
For instance two nodes `a` and `b` connected over a hyperlink ``https://some-url``
|
||||
would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[
|
||||
Node(
|
||||
id="a",
|
||||
text="some text a",
|
||||
links= [
|
||||
Link(kind="hyperlink", tag="https://some-url", direction="incoming")
|
||||
],
|
||||
),
|
||||
Node(
|
||||
id="b",
|
||||
text="some text b",
|
||||
links= [
|
||||
Link(kind="hyperlink", tag="https://some-url", direction="outgoing")
|
||||
],
|
||||
)
|
||||
]
|
||||
"""
|
||||
|
||||
id: Optional[str] = None
|
||||
"""Unique ID for the node. Will be generated by the GraphVectorStore if not set."""
|
||||
text: str
|
||||
"""Text contained by the node."""
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
"""Metadata for the node."""
|
||||
links: list[Link] = Field(default_factory=list)
|
||||
"""Links associated with the node."""
|
||||
|
||||
|
||||
def _texts_to_nodes(
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]],
|
||||
ids: Optional[Iterable[str]],
|
||||
) -> Iterator[Node]:
|
||||
metadatas_it = iter(metadatas) if metadatas else None
|
||||
ids_it = iter(ids) if ids else None
|
||||
for text in texts:
|
||||
try:
|
||||
_metadata = next(metadatas_it).copy() if metadatas_it else {}
|
||||
except StopIteration as e:
|
||||
raise ValueError("texts iterable longer than metadatas") from e
|
||||
try:
|
||||
_id = next(ids_it) if ids_it else None
|
||||
except StopIteration as e:
|
||||
raise ValueError("texts iterable longer than ids") from e
|
||||
|
||||
links = _metadata.pop(METADATA_LINKS_KEY, [])
|
||||
if not isinstance(links, list):
|
||||
links = list(links)
|
||||
yield Node(
|
||||
id=_id,
|
||||
metadata=_metadata,
|
||||
text=text,
|
||||
links=links,
|
||||
)
|
||||
if ids_it and _has_next(ids_it):
|
||||
raise ValueError("ids iterable longer than texts")
|
||||
if metadatas_it and _has_next(metadatas_it):
|
||||
raise ValueError("metadatas iterable longer than texts")
|
||||
|
||||
|
||||
def _documents_to_nodes(documents: Iterable[Document]) -> Iterator[Node]:
|
||||
for doc in documents:
|
||||
metadata = doc.metadata.copy()
|
||||
links = metadata.pop(METADATA_LINKS_KEY, [])
|
||||
if not isinstance(links, list):
|
||||
links = list(links)
|
||||
yield Node(
|
||||
id=doc.id,
|
||||
metadata=metadata,
|
||||
text=doc.page_content,
|
||||
links=links,
|
||||
)
|
||||
|
||||
|
||||
@beta()
|
||||
def nodes_to_documents(nodes: Iterable[Node]) -> Iterator[Document]:
|
||||
"""Convert nodes to documents.
|
||||
|
||||
Args:
|
||||
nodes: The nodes to convert to documents.
|
||||
Returns:
|
||||
The documents generated from the nodes.
|
||||
"""
|
||||
for node in nodes:
|
||||
metadata = node.metadata.copy()
|
||||
metadata[METADATA_LINKS_KEY] = [
|
||||
# Convert the core `Link` (from the node) back to the local `Link`.
|
||||
Link(kind=link.kind, direction=link.direction, tag=link.tag)
|
||||
for link in node.links
|
||||
]
|
||||
|
||||
yield Document(
|
||||
id=node.id,
|
||||
page_content=node.text,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
@beta(message="Added in version 0.2.14 of langchain_core. API subject to change.")
|
||||
class GraphVectorStore(VectorStore):
|
||||
"""A hybrid vector-and-graph graph store.
|
||||
|
||||
Document chunks support vector-similarity search as well as edges linking
|
||||
chunks based on structural and semantic properties.
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_nodes(
|
||||
self,
|
||||
nodes: Iterable[Node],
|
||||
**kwargs: Any,
|
||||
) -> Iterable[str]:
|
||||
"""Add nodes to the graph store.
|
||||
|
||||
Args:
|
||||
nodes: the nodes to add.
|
||||
"""
|
||||
|
||||
async def aadd_nodes(
|
||||
self,
|
||||
nodes: Iterable[Node],
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[str]:
|
||||
"""Add nodes to the graph store.
|
||||
|
||||
Args:
|
||||
nodes: the nodes to add.
|
||||
"""
|
||||
iterator = iter(await run_in_executor(None, self.add_nodes, nodes, **kwargs))
|
||||
done = object()
|
||||
while True:
|
||||
doc = await run_in_executor(None, next, iterator, done)
|
||||
if doc is done:
|
||||
break
|
||||
yield doc # type: ignore[misc]
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
*,
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the metadata field `links` will be extracted to create
|
||||
the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
store.add_texts(
|
||||
ids=["a", "b"],
|
||||
texts=["some text a", "some text b"],
|
||||
metadatas=[
|
||||
{
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
{
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
The metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
|
||||
**kwargs: vectorstore specific parameters.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
nodes = _texts_to_nodes(texts, metadatas, ids)
|
||||
return list(self.add_nodes(nodes, **kwargs))
|
||||
|
||||
async def aadd_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
*,
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the metadata field `links` will be extracted to create
|
||||
the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
await store.aadd_texts(
|
||||
ids=["a", "b"],
|
||||
texts=["some text a", "some text b"],
|
||||
metadatas=[
|
||||
{
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
{
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
The metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
|
||||
**kwargs: vectorstore specific parameters.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
nodes = _texts_to_nodes(texts, metadatas, ids)
|
||||
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
documents: Iterable[Document],
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the document metadata field `links` will be extracted to
|
||||
create the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
store.add_documents(
|
||||
[
|
||||
Document(
|
||||
id="a",
|
||||
page_content="some text a",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
Document(
|
||||
id="b",
|
||||
page_content="some text b",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
)
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
The document's metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
|
||||
|
||||
Returns:
|
||||
List of IDs of the added texts.
|
||||
"""
|
||||
nodes = _documents_to_nodes(documents)
|
||||
return list(self.add_nodes(nodes, **kwargs))
|
||||
|
||||
async def aadd_documents(
|
||||
self,
|
||||
documents: Iterable[Document],
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the document metadata field `links` will be extracted to
|
||||
create the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
store.add_documents(
|
||||
[
|
||||
Document(
|
||||
id="a",
|
||||
page_content="some text a",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
Document(
|
||||
id="b",
|
||||
page_content="some text b",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
)
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
The document's metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
|
||||
|
||||
Returns:
|
||||
List of IDs of the added texts.
|
||||
"""
|
||||
nodes = _documents_to_nodes(documents)
|
||||
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
|
||||
|
||||
@abstractmethod
|
||||
def traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Document]:
|
||||
"""Retrieve documents from traversing this graph store.
|
||||
|
||||
First, `k` nodes are retrieved using a search for each `query` string.
|
||||
Then, additional nodes are discovered up to the given `depth` from those
|
||||
starting nodes.
|
||||
|
||||
Args:
|
||||
query: The query string.
|
||||
k: The number of Documents to return from the initial search.
|
||||
Defaults to 4. Applies to each of the query strings.
|
||||
depth: The maximum depth of edges to traverse. Defaults to 1.
|
||||
Returns:
|
||||
Retrieved documents.
|
||||
"""
|
||||
|
||||
async def atraversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[Document]:
|
||||
"""Retrieve documents from traversing this graph store.
|
||||
|
||||
First, `k` nodes are retrieved using a search for each `query` string.
|
||||
Then, additional nodes are discovered up to the given `depth` from those
|
||||
starting nodes.
|
||||
|
||||
Args:
|
||||
query: The query string.
|
||||
k: The number of Documents to return from the initial search.
|
||||
Defaults to 4. Applies to each of the query strings.
|
||||
depth: The maximum depth of edges to traverse. Defaults to 1.
|
||||
Returns:
|
||||
Retrieved documents.
|
||||
"""
|
||||
iterator = iter(
|
||||
await run_in_executor(
|
||||
None, self.traversal_search, query, k=k, depth=depth, **kwargs
|
||||
)
|
||||
)
|
||||
done = object()
|
||||
while True:
|
||||
doc = await run_in_executor(None, next, iterator, done)
|
||||
if doc is done:
|
||||
break
|
||||
yield doc # type: ignore[misc]
|
||||
|
||||
@abstractmethod
|
||||
def mmr_traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 2,
|
||||
fetch_k: int = 100,
|
||||
adjacent_k: int = 10,
|
||||
lambda_mult: float = 0.5,
|
||||
score_threshold: float = float("-inf"),
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Document]:
|
||||
"""Retrieve documents from this graph store using MMR-traversal.
|
||||
|
||||
This strategy first retrieves the top `fetch_k` results by similarity to
|
||||
the question. It then selects the top `k` results based on
|
||||
maximum-marginal relevance using the given `lambda_mult`.
|
||||
|
||||
At each step, it considers the (remaining) documents from `fetch_k` as
|
||||
well as any documents connected by edges to a selected document
|
||||
retrieved based on similarity (a "root").
|
||||
|
||||
Args:
|
||||
query: The query string to search for.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch via similarity.
|
||||
Defaults to 100.
|
||||
adjacent_k: Number of adjacent Documents to fetch.
|
||||
Defaults to 10.
|
||||
depth: Maximum depth of a node (number of edges) from a node
|
||||
retrieved via similarity. Defaults to 2.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding to maximum
|
||||
diversity and 1 to minimum diversity. Defaults to 0.5.
|
||||
score_threshold: Only documents with a score greater than or equal
|
||||
this threshold will be chosen. Defaults to negative infinity.
|
||||
"""
|
||||
|
||||
async def ammr_traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 2,
|
||||
fetch_k: int = 100,
|
||||
adjacent_k: int = 10,
|
||||
lambda_mult: float = 0.5,
|
||||
score_threshold: float = float("-inf"),
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[Document]:
|
||||
"""Retrieve documents from this graph store using MMR-traversal.
|
||||
|
||||
This strategy first retrieves the top `fetch_k` results by similarity to
|
||||
the question. It then selects the top `k` results based on
|
||||
maximum-marginal relevance using the given `lambda_mult`.
|
||||
|
||||
At each step, it considers the (remaining) documents from `fetch_k` as
|
||||
well as any documents connected by edges to a selected document
|
||||
retrieved based on similarity (a "root").
|
||||
|
||||
Args:
|
||||
query: The query string to search for.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch via similarity.
|
||||
Defaults to 100.
|
||||
adjacent_k: Number of adjacent Documents to fetch.
|
||||
Defaults to 10.
|
||||
depth: Maximum depth of a node (number of edges) from a node
|
||||
retrieved via similarity. Defaults to 2.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding to maximum
|
||||
diversity and 1 to minimum diversity. Defaults to 0.5.
|
||||
score_threshold: Only documents with a score greater than or equal
|
||||
this threshold will be chosen. Defaults to negative infinity.
|
||||
"""
|
||||
iterator = iter(
|
||||
await run_in_executor(
|
||||
None,
|
||||
self.mmr_traversal_search,
|
||||
query,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
adjacent_k=adjacent_k,
|
||||
depth=depth,
|
||||
lambda_mult=lambda_mult,
|
||||
score_threshold=score_threshold,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
done = object()
|
||||
while True:
|
||||
doc = await run_in_executor(None, next, iterator, done)
|
||||
if doc is done:
|
||||
break
|
||||
yield doc # type: ignore[misc]
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
return list(self.traversal_search(query, k=k, depth=0))
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
return list(
|
||||
self.mmr_traversal_search(
|
||||
query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0
|
||||
)
|
||||
)
|
||||
|
||||
async def asimilarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
return [doc async for doc in self.atraversal_search(query, k=k, depth=0)]
|
||||
|
||||
def search(self, query: str, search_type: str, **kwargs: Any) -> list[Document]:
|
||||
if search_type == "similarity":
|
||||
return self.similarity_search(query, **kwargs)
|
||||
elif search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = self.similarity_search_with_relevance_scores(
|
||||
query, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_similarities]
|
||||
elif search_type == "mmr":
|
||||
return self.max_marginal_relevance_search(query, **kwargs)
|
||||
elif search_type == "traversal":
|
||||
return list(self.traversal_search(query, **kwargs))
|
||||
elif search_type == "mmr_traversal":
|
||||
return list(self.mmr_traversal_search(query, **kwargs))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity', 'similarity_score_threshold', "
|
||||
"'mmr' or 'traversal'."
|
||||
)
|
||||
|
||||
async def asearch(
|
||||
self, query: str, search_type: str, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
if search_type == "similarity":
|
||||
return await self.asimilarity_search(query, **kwargs)
|
||||
elif search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = await self.asimilarity_search_with_relevance_scores(
|
||||
query, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_similarities]
|
||||
elif search_type == "mmr":
|
||||
return await self.amax_marginal_relevance_search(query, **kwargs)
|
||||
elif search_type == "traversal":
|
||||
return [doc async for doc in self.atraversal_search(query, **kwargs)]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity', 'similarity_score_threshold', "
|
||||
"'mmr' or 'traversal'."
|
||||
)
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever:
|
||||
"""Return GraphVectorStoreRetriever initialized from this GraphVectorStore.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to pass to the search function.
|
||||
Can include:
|
||||
|
||||
- search_type (Optional[str]): Defines the type of search that
|
||||
the Retriever should perform.
|
||||
Can be ``traversal`` (default), ``similarity``, ``mmr``, or
|
||||
``similarity_score_threshold``.
|
||||
- search_kwargs (Optional[Dict]): Keyword arguments to pass to the
|
||||
search function. Can include things like:
|
||||
|
||||
- k(int): Amount of documents to return (Default: 4).
|
||||
- depth(int): The maximum depth of edges to traverse (Default: 1).
|
||||
- score_threshold(float): Minimum relevance threshold
|
||||
for similarity_score_threshold.
|
||||
- fetch_k(int): Amount of documents to pass to MMR algorithm
|
||||
(Default: 20).
|
||||
- lambda_mult(float): Diversity of results returned by MMR;
|
||||
1 for minimum diversity and 0 for maximum. (Default: 0.5).
|
||||
Returns:
|
||||
Retriever for this GraphVectorStore.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Retrieve documents traversing edges
|
||||
docsearch.as_retriever(
|
||||
search_type="traversal",
|
||||
search_kwargs={'k': 6, 'depth': 3}
|
||||
)
|
||||
|
||||
# Retrieve more documents with higher diversity
|
||||
# Useful if your dataset has many similar documents
|
||||
docsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={'k': 6, 'lambda_mult': 0.25}
|
||||
)
|
||||
|
||||
# Fetch more documents for the MMR algorithm to consider
|
||||
# But only return the top 5
|
||||
docsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={'k': 5, 'fetch_k': 50}
|
||||
)
|
||||
|
||||
# Only retrieve documents that have a relevance score
|
||||
# Above a certain threshold
|
||||
docsearch.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={'score_threshold': 0.8}
|
||||
)
|
||||
|
||||
# Only get the single most similar document from the dataset
|
||||
docsearch.as_retriever(search_kwargs={'k': 1})
|
||||
|
||||
"""
|
||||
return GraphVectorStoreRetriever(vectorstore=self, **kwargs)
|
||||
|
||||
|
||||
class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
"""Retriever class for GraphVectorStore."""
|
||||
|
||||
vectorstore: GraphVectorStore
|
||||
"""GraphVectorStore to use for retrieval."""
|
||||
search_type: str = "traversal"
|
||||
"""Type of search to perform. Defaults to "traversal"."""
|
||||
allowed_search_types: ClassVar[Collection[str]] = (
|
||||
"similarity",
|
||||
"similarity_score_threshold",
|
||||
"mmr",
|
||||
"traversal",
|
||||
"mmr_traversal",
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> list[Document]:
|
||||
if self.search_type == "traversal":
|
||||
return list(self.vectorstore.traversal_search(query, **self.search_kwargs))
|
||||
elif self.search_type == "mmr_traversal":
|
||||
return list(
|
||||
self.vectorstore.mmr_traversal_search(query, **self.search_kwargs)
|
||||
)
|
||||
else:
|
||||
return super()._get_relevant_documents(query, run_manager=run_manager)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> list[Document]:
|
||||
if self.search_type == "traversal":
|
||||
return [
|
||||
doc
|
||||
async for doc in self.vectorstore.atraversal_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
]
|
||||
elif self.search_type == "mmr_traversal":
|
||||
return [
|
||||
doc
|
||||
async for doc in self.vectorstore.ammr_traversal_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
]
|
||||
else:
|
||||
return await super()._aget_relevant_documents(
|
||||
query, run_manager=run_manager
|
||||
)
|
||||
|
||||
@@ -12,12 +12,12 @@ from typing import (
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.graph_vectorstores.base import (
|
||||
|
||||
from langchain_community.graph_vectorstores.base import (
|
||||
GraphVectorStore,
|
||||
Node,
|
||||
nodes_to_documents,
|
||||
)
|
||||
|
||||
from langchain_community.utilities.cassandra import SetupMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -2,11 +2,11 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Union
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores.links import Link
|
||||
|
||||
from langchain_community.graph_vectorstores.extractors.link_extractor import (
|
||||
LinkExtractor,
|
||||
)
|
||||
from langchain_community.graph_vectorstores.links import Link
|
||||
|
||||
# TypeAlias is not available in Python 3.9, we can't use that or the newer `type`.
|
||||
GLiNERInput = Union[str, Document]
|
||||
@@ -34,7 +34,7 @@ class GLiNERLinkExtractor(LinkExtractor[GLiNERInput]):
|
||||
.. seealso::
|
||||
|
||||
- :mod:`How to use a graph vector store <langchain_community.graph_vectorstores>`
|
||||
- :class:`How to create links between documents <langchain_core.graph_vectorstores.links.Link>`
|
||||
- :class:`How to create links between documents <langchain_community.graph_vectorstores.links.Link>`
|
||||
|
||||
How to link Documents on common named entities
|
||||
==============================================
|
||||
@@ -59,12 +59,12 @@ class GLiNERLinkExtractor(LinkExtractor[GLiNERInput]):
|
||||
|
||||
We can use :meth:`extract_one` on a document to get the links and add the links
|
||||
to the document metadata with
|
||||
:meth:`~langchain_core.graph_vectorstores.links.add_links`::
|
||||
:meth:`~langchain_community.graph_vectorstores.links.add_links`::
|
||||
|
||||
from langchain_community.document_loaders import TextLoader
|
||||
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
|
||||
from langchain_community.graph_vectorstores.extractors import GLiNERLinkExtractor
|
||||
from langchain_core.graph_vectorstores.links import add_links
|
||||
from langchain_community.graph_vectorstores.links import add_links
|
||||
from langchain_text_splitters import CharacterTextSplitter
|
||||
|
||||
loader = TextLoader("state_of_the_union.txt")
|
||||
@@ -87,7 +87,7 @@ class GLiNERLinkExtractor(LinkExtractor[GLiNERInput]):
|
||||
Using LinkExtractorTransformer
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Using the :class:`~langchain_community.graph_vectorstores.extractors.keybert_link_extractor.LinkExtractorTransformer`,
|
||||
Using the :class:`~langchain_community.graph_vectorstores.extractors.link_extractor_transformer.LinkExtractorTransformer`,
|
||||
we can simplify the link extraction::
|
||||
|
||||
from langchain_community.document_loaders import TextLoader
|
||||
@@ -113,7 +113,7 @@ class GLiNERLinkExtractor(LinkExtractor[GLiNERInput]):
|
||||
|
||||
{'source': 'state_of_the_union.txt', 'links': [Link(kind='entity:Person', direction='bidir', tag='President Zelenskyy'), Link(kind='entity:Person', direction='bidir', tag='Vladimir Putin')]}
|
||||
|
||||
The documents with named entity links can then be added to a :class:`~langchain_core.graph_vectorstores.base.GraphVectorStore`::
|
||||
The documents with named entity links can then be added to a :class:`~langchain_community.graph_vectorstores.base.GraphVectorStore`::
|
||||
|
||||
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Callable, List, Set
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores.links import Link
|
||||
|
||||
from langchain_community.graph_vectorstores.extractors.link_extractor import (
|
||||
LinkExtractor,
|
||||
@@ -10,6 +9,7 @@ from langchain_community.graph_vectorstores.extractors.link_extractor import (
|
||||
from langchain_community.graph_vectorstores.extractors.link_extractor_adapter import (
|
||||
LinkExtractorAdapter,
|
||||
)
|
||||
from langchain_community.graph_vectorstores.links import Link
|
||||
|
||||
# TypeAlias is not available in Python 3.9, we can't use that or the newer `type`.
|
||||
HierarchyInput = List[str]
|
||||
|
||||
@@ -6,8 +6,8 @@ from urllib.parse import urldefrag, urljoin, urlparse
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores import Link
|
||||
|
||||
from langchain_community.graph_vectorstores import Link
|
||||
from langchain_community.graph_vectorstores.extractors.link_extractor import (
|
||||
LinkExtractor,
|
||||
)
|
||||
@@ -77,7 +77,7 @@ class HtmlLinkExtractor(LinkExtractor[HtmlInput]):
|
||||
.. seealso::
|
||||
|
||||
- :mod:`How to use a graph vector store <langchain_community.graph_vectorstores>`
|
||||
- :class:`How to create links between documents <langchain_core.graph_vectorstores.links.Link>`
|
||||
- :class:`How to create links between documents <langchain_community.graph_vectorstores.links.Link>`
|
||||
|
||||
How to link Documents on hyperlinks in HTML
|
||||
===========================================
|
||||
@@ -103,7 +103,7 @@ class HtmlLinkExtractor(LinkExtractor[HtmlInput]):
|
||||
|
||||
We can use :meth:`extract_one` on a document to get the links and add the links
|
||||
to the document metadata with
|
||||
:meth:`~langchain_core.graph_vectorstores.links.add_links`::
|
||||
:meth:`~langchain_community.graph_vectorstores.links.add_links`::
|
||||
|
||||
from langchain_community.document_loaders import AsyncHtmlLoader
|
||||
from langchain_community.graph_vectorstores.extractors import (
|
||||
@@ -148,7 +148,7 @@ class HtmlLinkExtractor(LinkExtractor[HtmlInput]):
|
||||
|
||||
from langchain_community.document_loaders import AsyncHtmlLoader
|
||||
from langchain_community.graph_vectorstores.extractors import HtmlLinkExtractor
|
||||
from langchain_core.graph_vectorstores.links import add_links
|
||||
from langchain_community.graph_vectorstores.links import add_links
|
||||
|
||||
loader = AsyncHtmlLoader(
|
||||
[
|
||||
@@ -176,7 +176,7 @@ class HtmlLinkExtractor(LinkExtractor[HtmlInput]):
|
||||
Using LinkExtractorTransformer
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Using the :class:`~langchain_community.graph_vectorstores.extractors.keybert_link_extractor.LinkExtractorTransformer`,
|
||||
Using the :class:`~langchain_community.graph_vectorstores.extractors.link_extractor_transformer.LinkExtractorTransformer`,
|
||||
we can simplify the link extraction::
|
||||
|
||||
from langchain_community.document_loaders import AsyncHtmlLoader
|
||||
@@ -227,7 +227,7 @@ class HtmlLinkExtractor(LinkExtractor[HtmlInput]):
|
||||
|
||||
Found link from https://python.langchain.com/v0.2/docs/integrations/providers/astradb/ to https://docs.datastax.com/en/astra/home/astra.html.
|
||||
|
||||
The documents with URL links can then be added to a :class:`~langchain_core.graph_vectorstores.base.GraphVectorStore`::
|
||||
The documents with URL links can then be added to a :class:`~langchain_community.graph_vectorstores.base.GraphVectorStore`::
|
||||
|
||||
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
|
||||
|
||||
|
||||
@@ -2,11 +2,11 @@ from typing import Any, Dict, Iterable, Optional, Set, Union
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores.links import Link
|
||||
|
||||
from langchain_community.graph_vectorstores.extractors.link_extractor import (
|
||||
LinkExtractor,
|
||||
)
|
||||
from langchain_community.graph_vectorstores.links import Link
|
||||
|
||||
KeybertInput = Union[str, Document]
|
||||
|
||||
@@ -37,7 +37,7 @@ class KeybertLinkExtractor(LinkExtractor[KeybertInput]):
|
||||
.. seealso::
|
||||
|
||||
- :mod:`How to use a graph vector store <langchain_community.graph_vectorstores>`
|
||||
- :class:`How to create links between documents <langchain_core.graph_vectorstores.links.Link>`
|
||||
- :class:`How to create links between documents <langchain_community.graph_vectorstores.links.Link>`
|
||||
|
||||
How to link Documents on common keywords using Keybert
|
||||
======================================================
|
||||
@@ -62,12 +62,12 @@ class KeybertLinkExtractor(LinkExtractor[KeybertInput]):
|
||||
|
||||
We can use :meth:`extract_one` on a document to get the links and add the links
|
||||
to the document metadata with
|
||||
:meth:`~langchain_core.graph_vectorstores.links.add_links`::
|
||||
:meth:`~langchain_community.graph_vectorstores.links.add_links`::
|
||||
|
||||
from langchain_community.document_loaders import TextLoader
|
||||
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
|
||||
from langchain_community.graph_vectorstores.extractors import KeybertLinkExtractor
|
||||
from langchain_core.graph_vectorstores.links import add_links
|
||||
from langchain_community.graph_vectorstores.links import add_links
|
||||
from langchain_text_splitters import CharacterTextSplitter
|
||||
|
||||
loader = TextLoader("state_of_the_union.txt")
|
||||
@@ -91,7 +91,7 @@ class KeybertLinkExtractor(LinkExtractor[KeybertInput]):
|
||||
Using LinkExtractorTransformer
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Using the :class:`~langchain_community.graph_vectorstores.extractors.keybert_link_extractor.LinkExtractorTransformer`,
|
||||
Using the :class:`~langchain_community.graph_vectorstores.extractors.link_extractor_transformer.LinkExtractorTransformer`,
|
||||
we can simplify the link extraction::
|
||||
|
||||
from langchain_community.document_loaders import TextLoader
|
||||
@@ -116,7 +116,7 @@ class KeybertLinkExtractor(LinkExtractor[KeybertInput]):
|
||||
|
||||
{'source': 'state_of_the_union.txt', 'links': [Link(kind='kw', direction='bidir', tag='ukraine'), Link(kind='kw', direction='bidir', tag='ukrainian'), Link(kind='kw', direction='bidir', tag='putin'), Link(kind='kw', direction='bidir', tag='vladimir'), Link(kind='kw', direction='bidir', tag='russia')]}
|
||||
|
||||
The documents with keyword links can then be added to a :class:`~langchain_core.graph_vectorstores.base.GraphVectorStore`::
|
||||
The documents with keyword links can then be added to a :class:`~langchain_community.graph_vectorstores.base.GraphVectorStore`::
|
||||
|
||||
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
|
||||
|
||||
|
||||
@@ -4,7 +4,8 @@ from abc import ABC, abstractmethod
|
||||
from typing import Generic, Iterable, Set, TypeVar
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.graph_vectorstores import Link
|
||||
|
||||
from langchain_community.graph_vectorstores import Link
|
||||
|
||||
InputT = TypeVar("InputT")
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Callable, Iterable, Set, TypeVar
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.graph_vectorstores import Link
|
||||
|
||||
from langchain_community.graph_vectorstores import Link
|
||||
from langchain_community.graph_vectorstores.extractors.link_extractor import (
|
||||
LinkExtractor,
|
||||
)
|
||||
|
||||
@@ -3,11 +3,11 @@ from typing import Any, Sequence
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.documents.transformers import BaseDocumentTransformer
|
||||
from langchain_core.graph_vectorstores.links import copy_with_links
|
||||
|
||||
from langchain_community.graph_vectorstores.extractors.link_extractor import (
|
||||
LinkExtractor,
|
||||
)
|
||||
from langchain_community.graph_vectorstores.links import copy_with_links
|
||||
|
||||
|
||||
@beta()
|
||||
|
||||
@@ -1,8 +1,102 @@
|
||||
from langchain_core.graph_vectorstores.links import (
|
||||
Link,
|
||||
add_links,
|
||||
copy_with_links,
|
||||
get_links,
|
||||
)
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Union
|
||||
|
||||
__all__ = ["Link", "add_links", "get_links", "copy_with_links"]
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
@beta()
|
||||
@dataclass(frozen=True)
|
||||
class Link:
|
||||
"""A link to/from a tag of a given tag.
|
||||
|
||||
Edges exist from nodes with an outgoing link to nodes with a matching incoming link.
|
||||
"""
|
||||
|
||||
kind: str
|
||||
"""The kind of link. Allows different extractors to use the same tag name without
|
||||
creating collisions between extractors. For example “keyword” vs “url”."""
|
||||
direction: Literal["in", "out", "bidir"]
|
||||
"""The direction of the link."""
|
||||
tag: str
|
||||
"""The tag of the link."""
|
||||
|
||||
@staticmethod
|
||||
def incoming(kind: str, tag: str) -> "Link":
|
||||
"""Create an incoming link."""
|
||||
return Link(kind=kind, direction="in", tag=tag)
|
||||
|
||||
@staticmethod
|
||||
def outgoing(kind: str, tag: str) -> "Link":
|
||||
"""Create an outgoing link."""
|
||||
return Link(kind=kind, direction="out", tag=tag)
|
||||
|
||||
@staticmethod
|
||||
def bidir(kind: str, tag: str) -> "Link":
|
||||
"""Create a bidirectional link."""
|
||||
return Link(kind=kind, direction="bidir", tag=tag)
|
||||
|
||||
|
||||
METADATA_LINKS_KEY = "links"
|
||||
|
||||
|
||||
@beta()
|
||||
def get_links(doc: Document) -> list[Link]:
|
||||
"""Get the links from a document.
|
||||
|
||||
Args:
|
||||
doc: The document to get the link tags from.
|
||||
Returns:
|
||||
The set of link tags from the document.
|
||||
"""
|
||||
|
||||
links = doc.metadata.setdefault(METADATA_LINKS_KEY, [])
|
||||
if not isinstance(links, list):
|
||||
# Convert to a list and remember that.
|
||||
links = list(links)
|
||||
doc.metadata[METADATA_LINKS_KEY] = links
|
||||
return links
|
||||
|
||||
|
||||
@beta()
|
||||
def add_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> None:
|
||||
"""Add links to the given metadata.
|
||||
|
||||
Args:
|
||||
doc: The document to add the links to.
|
||||
*links: The links to add to the document.
|
||||
"""
|
||||
links_in_metadata = get_links(doc)
|
||||
for link in links:
|
||||
if isinstance(link, Iterable):
|
||||
links_in_metadata.extend(link)
|
||||
else:
|
||||
links_in_metadata.append(link)
|
||||
|
||||
|
||||
@beta()
|
||||
def copy_with_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> Document:
|
||||
"""Return a document with the given links added.
|
||||
|
||||
Args:
|
||||
doc: The document to add the links to.
|
||||
*links: The links to add to the document.
|
||||
|
||||
Returns:
|
||||
A document with a shallow-copy of the metadata with the links added.
|
||||
"""
|
||||
new_links = set(get_links(doc))
|
||||
for link in links:
|
||||
if isinstance(link, Iterable):
|
||||
new_links.update(link)
|
||||
else:
|
||||
new_links.add(link)
|
||||
|
||||
return Document(
|
||||
page_content=doc.page_content,
|
||||
metadata={
|
||||
**doc.metadata,
|
||||
METADATA_LINKS_KEY: list(new_links),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -411,7 +411,9 @@ class Neo4jGraph(GraphStore):
|
||||
return self.structured_schema
|
||||
|
||||
def query(
|
||||
self, query: str, params: dict = {}, retry_on_session_expired: bool = True
|
||||
self,
|
||||
query: str,
|
||||
params: dict = {},
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query Neo4j database.
|
||||
|
||||
@@ -423,26 +425,44 @@ class Neo4jGraph(GraphStore):
|
||||
List[Dict[str, Any]]: The list of dictionaries containing the query results.
|
||||
"""
|
||||
from neo4j import Query
|
||||
from neo4j.exceptions import CypherSyntaxError, SessionExpired
|
||||
from neo4j.exceptions import Neo4jError
|
||||
|
||||
with self._driver.session(database=self._database) as session:
|
||||
try:
|
||||
data = session.run(Query(text=query, timeout=self.timeout), params)
|
||||
json_data = [r.data() for r in data]
|
||||
if self.sanitize:
|
||||
json_data = [value_sanitize(el) for el in json_data]
|
||||
return json_data
|
||||
except CypherSyntaxError as e:
|
||||
raise ValueError(f"Generated Cypher Statement is not valid\n{e}")
|
||||
except (
|
||||
SessionExpired
|
||||
) as e: # Session expired is a transient error that can be retried
|
||||
if retry_on_session_expired:
|
||||
return self.query(
|
||||
query, params=params, retry_on_session_expired=False
|
||||
try:
|
||||
data, _, _ = self._driver.execute_query(
|
||||
Query(text=query, timeout=self.timeout),
|
||||
database=self._database,
|
||||
parameters_=params,
|
||||
)
|
||||
json_data = [r.data() for r in data]
|
||||
if self.sanitize:
|
||||
json_data = [value_sanitize(el) for el in json_data]
|
||||
return json_data
|
||||
except Neo4jError as e:
|
||||
if not (
|
||||
(
|
||||
( # isCallInTransactionError
|
||||
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
|
||||
or e.code
|
||||
== "Neo.DatabaseError.Transaction.TransactionStartFailed"
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
and "in an implicit transaction" in e.message
|
||||
)
|
||||
or ( # isPeriodicCommitError
|
||||
e.code == "Neo.ClientError.Statement.SemanticError"
|
||||
and (
|
||||
"in an open transaction is not possible" in e.message
|
||||
or "tried to execute in an explicit transaction" in e.message
|
||||
)
|
||||
)
|
||||
):
|
||||
raise
|
||||
# fallback to allow implicit transactions
|
||||
with self._driver.session() as session:
|
||||
data = session.run(Query(text=query, timeout=self.timeout), params)
|
||||
json_data = [r.data() for r in data]
|
||||
if self.sanitize:
|
||||
json_data = [value_sanitize(el) for el in json_data]
|
||||
return json_data
|
||||
|
||||
def refresh_schema(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -510,12 +510,6 @@ def _import_sagemaker_endpoint() -> Type[BaseLLM]:
|
||||
return SagemakerEndpoint
|
||||
|
||||
|
||||
def _import_sambaverse() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.sambanova import Sambaverse
|
||||
|
||||
return Sambaverse
|
||||
|
||||
|
||||
def _import_sambastudio() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.sambanova import SambaStudio
|
||||
|
||||
@@ -817,8 +811,6 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_rwkv()
|
||||
elif name == "SagemakerEndpoint":
|
||||
return _import_sagemaker_endpoint()
|
||||
elif name == "Sambaverse":
|
||||
return _import_sambaverse()
|
||||
elif name == "SambaStudio":
|
||||
return _import_sambastudio()
|
||||
elif name == "SelfHostedPipeline":
|
||||
@@ -954,7 +946,6 @@ __all__ = [
|
||||
"RWKV",
|
||||
"Replicate",
|
||||
"SagemakerEndpoint",
|
||||
"Sambaverse",
|
||||
"SambaStudio",
|
||||
"SelfHostedHuggingFaceLLM",
|
||||
"SelfHostedPipeline",
|
||||
@@ -1051,7 +1042,6 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"replicate": _import_replicate,
|
||||
"rwkv": _import_rwkv,
|
||||
"sagemaker_endpoint": _import_sagemaker_endpoint,
|
||||
"sambaverse": _import_sambaverse,
|
||||
"sambastudio": _import_sambastudio,
|
||||
"self_hosted": _import_self_hosted,
|
||||
"self_hosted_hugging_face": _import_self_hosted_hugging_face,
|
||||
|
||||
@@ -9,464 +9,6 @@ from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class SVEndpointHandler:
|
||||
"""
|
||||
SambaNova Systems Interface for Sambaverse endpoint.
|
||||
|
||||
:param str host_url: Base URL of the DaaS API service
|
||||
"""
|
||||
|
||||
API_BASE_PATH: str = "/api/predict"
|
||||
|
||||
def __init__(self, host_url: str):
|
||||
"""
|
||||
Initialize the SVEndpointHandler.
|
||||
|
||||
:param str host_url: Base URL of the DaaS API service
|
||||
"""
|
||||
self.host_url = host_url
|
||||
self.http_session = requests.Session()
|
||||
|
||||
@staticmethod
|
||||
def _process_response(response: requests.Response) -> Dict:
|
||||
"""
|
||||
Processes the API response and returns the resulting dict.
|
||||
|
||||
All resulting dicts, regardless of success or failure, will contain the
|
||||
`status_code` key with the API response status code.
|
||||
|
||||
If the API returned an error, the resulting dict will contain the key
|
||||
`detail` with the error message.
|
||||
|
||||
If the API call was successful, the resulting dict will contain the key
|
||||
`data` with the response data.
|
||||
|
||||
:param requests.Response response: the response object to process
|
||||
:return: the response dict
|
||||
:type: dict
|
||||
"""
|
||||
result: Dict[str, Any] = {}
|
||||
try:
|
||||
lines_result = response.text.strip().split("\n")
|
||||
text_result = lines_result[-1]
|
||||
if response.status_code == 200 and json.loads(text_result).get("error"):
|
||||
completion = ""
|
||||
for line in lines_result[:-1]:
|
||||
completion += json.loads(line)["result"]["responses"][0][
|
||||
"stream_token"
|
||||
]
|
||||
text_result = lines_result[-2]
|
||||
result = json.loads(text_result)
|
||||
result["result"]["responses"][0]["completion"] = completion
|
||||
else:
|
||||
result = json.loads(text_result)
|
||||
except Exception as e:
|
||||
result["detail"] = str(e)
|
||||
if "status_code" not in result:
|
||||
result["status_code"] = response.status_code
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _process_streaming_response(
|
||||
response: requests.Response,
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Process the streaming response"""
|
||||
try:
|
||||
for line in response.iter_lines():
|
||||
chunk = json.loads(line)
|
||||
if "status_code" not in chunk:
|
||||
chunk["status_code"] = response.status_code
|
||||
if chunk["status_code"] == 200 and chunk.get("error"):
|
||||
chunk["result"] = {"responses": [{"stream_token": ""}]}
|
||||
return chunk
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error processing streaming response: {e}")
|
||||
|
||||
def _get_full_url(self) -> str:
|
||||
"""
|
||||
Return the full API URL for a given path.
|
||||
:returns: the full API URL for the sub-path
|
||||
:type: str
|
||||
"""
|
||||
return f"{self.host_url}{self.API_BASE_PATH}"
|
||||
|
||||
def nlp_predict(
|
||||
self,
|
||||
key: str,
|
||||
sambaverse_model_name: Optional[str],
|
||||
input: Union[List[str], str],
|
||||
params: Optional[str] = "",
|
||||
stream: bool = False,
|
||||
) -> Dict:
|
||||
"""
|
||||
NLP predict using inline input string.
|
||||
|
||||
:param str project: Project ID in which the endpoint exists
|
||||
:param str endpoint: Endpoint ID
|
||||
:param str key: API Key
|
||||
:param str input_str: Input string
|
||||
:param str params: Input params string
|
||||
:returns: Prediction results
|
||||
:type: dict
|
||||
"""
|
||||
if params:
|
||||
data = {"instance": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"instance": input}
|
||||
response = self.http_session.post(
|
||||
self._get_full_url(),
|
||||
headers={
|
||||
"key": key,
|
||||
"Content-Type": "application/json",
|
||||
"modelName": sambaverse_model_name,
|
||||
},
|
||||
json=data,
|
||||
)
|
||||
return SVEndpointHandler._process_response(response)
|
||||
|
||||
def nlp_predict_stream(
|
||||
self,
|
||||
key: str,
|
||||
sambaverse_model_name: Optional[str],
|
||||
input: Union[List[str], str],
|
||||
params: Optional[str] = "",
|
||||
) -> Iterator[Dict]:
|
||||
"""
|
||||
NLP predict using inline input string.
|
||||
|
||||
:param str project: Project ID in which the endpoint exists
|
||||
:param str endpoint: Endpoint ID
|
||||
:param str key: API Key
|
||||
:param str input_str: Input string
|
||||
:param str params: Input params string
|
||||
:returns: Prediction results
|
||||
:type: dict
|
||||
"""
|
||||
if params:
|
||||
data = {"instance": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"instance": input}
|
||||
# Streaming output
|
||||
response = self.http_session.post(
|
||||
self._get_full_url(),
|
||||
headers={
|
||||
"key": key,
|
||||
"Content-Type": "application/json",
|
||||
"modelName": sambaverse_model_name,
|
||||
},
|
||||
json=data,
|
||||
stream=True,
|
||||
)
|
||||
for chunk in SVEndpointHandler._process_streaming_response(response):
|
||||
yield chunk
|
||||
|
||||
|
||||
class Sambaverse(LLM):
|
||||
"""
|
||||
Sambaverse large language models.
|
||||
|
||||
To use, you should have the environment variable ``SAMBAVERSE_API_KEY``
|
||||
set with your API key.
|
||||
|
||||
get one in https://sambaverse.sambanova.ai
|
||||
read extra documentation in https://docs.sambanova.ai/sambaverse/latest/index.html
|
||||
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms.sambanova import Sambaverse
|
||||
Sambaverse(
|
||||
sambaverse_url="https://sambaverse.sambanova.ai",
|
||||
sambaverse_api_key="your-sambaverse-api-key",
|
||||
sambaverse_model_name="Meta/llama-2-7b-chat-hf",
|
||||
streaming: = False
|
||||
model_kwargs={
|
||||
"select_expert": "llama-2-7b-chat-hf",
|
||||
"do_sample": False,
|
||||
"max_tokens_to_generate": 100,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"repetition_penalty": 1.0,
|
||||
"top_k": 50,
|
||||
"process_prompt": False
|
||||
},
|
||||
)
|
||||
"""
|
||||
|
||||
sambaverse_url: str = ""
|
||||
"""Sambaverse url to use"""
|
||||
|
||||
sambaverse_api_key: str = ""
|
||||
"""sambaverse api key"""
|
||||
|
||||
sambaverse_model_name: Optional[str] = None
|
||||
"""sambaverse expert model to use"""
|
||||
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
streaming: Optional[bool] = False
|
||||
"""Streaming flag to get streamed response."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["sambaverse_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"sambaverse_url",
|
||||
"SAMBAVERSE_URL",
|
||||
default="https://sambaverse.sambanova.ai",
|
||||
)
|
||||
values["sambaverse_api_key"] = get_from_dict_or_env(
|
||||
values, "sambaverse_api_key", "SAMBAVERSE_API_KEY"
|
||||
)
|
||||
values["sambaverse_model_name"] = get_from_dict_or_env(
|
||||
values, "sambaverse_model_name", "SAMBAVERSE_MODEL_NAME"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model_kwargs": self.model_kwargs}}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "Sambaverse LLM"
|
||||
|
||||
def _get_tuning_params(self, stop: Optional[List[str]]) -> str:
|
||||
"""
|
||||
Get the tuning parameters to use when calling the LLM.
|
||||
|
||||
Args:
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of the stop substrings.
|
||||
|
||||
Returns:
|
||||
The tuning parameters as a JSON string.
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_kwarg_stop_sequences = _model_kwargs.get("stop_sequences", [])
|
||||
_stop_sequences = stop or _kwarg_stop_sequences
|
||||
if not _kwarg_stop_sequences:
|
||||
_model_kwargs["stop_sequences"] = ",".join(
|
||||
f'"{x}"' for x in _stop_sequences
|
||||
)
|
||||
tuning_params_dict = {
|
||||
k: {"type": type(v).__name__, "value": str(v)}
|
||||
for k, v in (_model_kwargs.items())
|
||||
}
|
||||
_model_kwargs["stop_sequences"] = _kwarg_stop_sequences
|
||||
tuning_params = json.dumps(tuning_params_dict)
|
||||
return tuning_params
|
||||
|
||||
def _handle_nlp_predict(
|
||||
self,
|
||||
sdk: SVEndpointHandler,
|
||||
prompt: Union[List[str], str],
|
||||
tuning_params: str,
|
||||
) -> str:
|
||||
"""
|
||||
Perform an NLP prediction using the Sambaverse endpoint handler.
|
||||
|
||||
Args:
|
||||
sdk: The SVEndpointHandler to use for the prediction.
|
||||
prompt: The prompt to use for the prediction.
|
||||
tuning_params: The tuning parameters to use for the prediction.
|
||||
|
||||
Returns:
|
||||
The prediction result.
|
||||
|
||||
Raises:
|
||||
ValueError: If the prediction fails.
|
||||
"""
|
||||
response = sdk.nlp_predict(
|
||||
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
|
||||
)
|
||||
if response["status_code"] != 200:
|
||||
error = response.get("error")
|
||||
if error:
|
||||
optional_code = error.get("code")
|
||||
optional_details = error.get("details")
|
||||
optional_message = error.get("message")
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response['status_code']}.\n"
|
||||
f"Message: {optional_message}\n"
|
||||
f"Details: {optional_details}\n"
|
||||
f"Code: {optional_code}\n"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response['status_code']}."
|
||||
f"{response}."
|
||||
)
|
||||
return response["result"]["responses"][0]["completion"]
|
||||
|
||||
def _handle_completion_requests(
|
||||
self, prompt: Union[List[str], str], stop: Optional[List[str]]
|
||||
) -> str:
|
||||
"""
|
||||
Perform a prediction using the Sambaverse endpoint handler.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to use for the prediction.
|
||||
stop: stop sequences.
|
||||
|
||||
Returns:
|
||||
The prediction result.
|
||||
|
||||
Raises:
|
||||
ValueError: If the prediction fails.
|
||||
"""
|
||||
ss_endpoint = SVEndpointHandler(self.sambaverse_url)
|
||||
tuning_params = self._get_tuning_params(stop)
|
||||
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
|
||||
|
||||
def _handle_nlp_predict_stream(
|
||||
self, sdk: SVEndpointHandler, prompt: Union[List[str], str], tuning_params: str
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""
|
||||
Perform a streaming request to the LLM.
|
||||
|
||||
Args:
|
||||
sdk: The SVEndpointHandler to use for the prediction.
|
||||
prompt: The prompt to use for the prediction.
|
||||
tuning_params: The tuning parameters to use for the prediction.
|
||||
|
||||
Returns:
|
||||
An iterator of GenerationChunks.
|
||||
"""
|
||||
for chunk in sdk.nlp_predict_stream(
|
||||
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
|
||||
):
|
||||
if chunk["status_code"] != 200:
|
||||
error = chunk.get("error")
|
||||
if error:
|
||||
optional_code = error.get("code")
|
||||
optional_details = error.get("details")
|
||||
optional_message = error.get("message")
|
||||
raise ValueError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{chunk['status_code']}.\n"
|
||||
f"Message: {optional_message}\n"
|
||||
f"Details: {optional_details}\n"
|
||||
f"Code: {optional_code}\n"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{chunk['status_code']}."
|
||||
f"{chunk}."
|
||||
)
|
||||
text = chunk["result"]["responses"][0]["stream_token"]
|
||||
generated_chunk = GenerationChunk(text=text)
|
||||
yield generated_chunk
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: Union[List[str], str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""Stream the Sambaverse's LLM on the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
run_manager: Callback manager for the run.
|
||||
kwargs: Additional keyword arguments. directly passed
|
||||
to the sambaverse model in API call.
|
||||
|
||||
Returns:
|
||||
An iterator of GenerationChunks.
|
||||
"""
|
||||
ss_endpoint = SVEndpointHandler(self.sambaverse_url)
|
||||
tuning_params = self._get_tuning_params(stop)
|
||||
try:
|
||||
if self.streaming:
|
||||
for chunk in self._handle_nlp_predict_stream(
|
||||
ss_endpoint, prompt, tuning_params
|
||||
):
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text)
|
||||
yield chunk
|
||||
else:
|
||||
return
|
||||
except Exception as e:
|
||||
# Handle any errors raised by the inference endpoint
|
||||
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
|
||||
|
||||
def _handle_stream_request(
|
||||
self,
|
||||
prompt: Union[List[str], str],
|
||||
stop: Optional[List[str]],
|
||||
run_manager: Optional[CallbackManagerForLLMRun],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""
|
||||
Perform a streaming request to the LLM.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to generate from.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of the stop substrings.
|
||||
run_manager: Callback manager for the run.
|
||||
kwargs: Additional keyword arguments. directly passed
|
||||
to the sambaverse model in API call.
|
||||
|
||||
Returns:
|
||||
The model output as a string.
|
||||
"""
|
||||
completion = ""
|
||||
for chunk in self._stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: Union[List[str], str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the LLM on the given input.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to generate from.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of the stop substrings.
|
||||
run_manager: Callback manager for the run.
|
||||
kwargs: Additional keyword arguments. directly passed
|
||||
to the sambaverse model in API call.
|
||||
|
||||
Returns:
|
||||
The model output as a string.
|
||||
"""
|
||||
try:
|
||||
if self.streaming:
|
||||
return self._handle_stream_request(prompt, stop, run_manager, kwargs)
|
||||
return self._handle_completion_requests(prompt, stop)
|
||||
except Exception as e:
|
||||
# Handle any errors raised by the inference endpoint
|
||||
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
|
||||
|
||||
|
||||
class SSEndpointHandler:
|
||||
"""
|
||||
SambaNova Systems Interface for SambaStudio model endpoints.
|
||||
@@ -975,7 +517,7 @@ class SambaStudio(LLM):
|
||||
first occurrence of any of the stop substrings.
|
||||
run_manager: Callback manager for the run.
|
||||
kwargs: Additional keyword arguments. directly passed
|
||||
to the sambaverse model in API call.
|
||||
to the sambastudio model in API call.
|
||||
|
||||
Returns:
|
||||
The model output as a string.
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import BaseModel, Field, create_model
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from databricks.sdk import WorkspaceClient
|
||||
from databricks.sdk.service.catalog import FunctionInfo
|
||||
|
||||
from pydantic import ConfigDict
|
||||
@@ -121,7 +120,7 @@ def _get_tool_name(function: "FunctionInfo") -> str:
|
||||
return tool_name
|
||||
|
||||
|
||||
def _get_default_workspace_client() -> "WorkspaceClient":
|
||||
def _get_default_workspace_client() -> Any:
|
||||
try:
|
||||
from databricks.sdk import WorkspaceClient
|
||||
except ImportError as e:
|
||||
@@ -137,7 +136,7 @@ class UCFunctionToolkit(BaseToolkit):
|
||||
description="The ID of a Databricks SQL Warehouse to execute functions."
|
||||
)
|
||||
|
||||
workspace_client: "WorkspaceClient" = Field(
|
||||
workspace_client: Any = Field(
|
||||
default_factory=_get_default_workspace_client,
|
||||
description="Databricks workspace client.",
|
||||
)
|
||||
|
||||
@@ -443,6 +443,12 @@ class AzureSearch(VectorStore):
|
||||
logger.debug("Nothing to insert, skipping.")
|
||||
return []
|
||||
|
||||
# when `keys` are not passed in and there is `ids` in kwargs, use those instead
|
||||
# base class expects `ids` passed in rather than `keys`
|
||||
# https://github.com/langchain-ai/langchain/blob/4cdaca67dc51dba887289f56c6fead3c1a52f97d/libs/core/langchain_core/vectorstores/base.py#L65
|
||||
if (not keys) and ("ids" in kwargs) and (len(kwargs["ids"]) == len(embeddings)):
|
||||
keys = kwargs["ids"]
|
||||
|
||||
return self.add_embeddings(zip(texts, embeddings), metadatas, keys=keys)
|
||||
|
||||
async def aadd_texts(
|
||||
@@ -467,6 +473,12 @@ class AzureSearch(VectorStore):
|
||||
logger.debug("Nothing to insert, skipping.")
|
||||
return []
|
||||
|
||||
# when `keys` are not passed in and there is `ids` in kwargs, use those instead
|
||||
# base class expects `ids` passed in rather than `keys`
|
||||
# https://github.com/langchain-ai/langchain/blob/4cdaca67dc51dba887289f56c6fead3c1a52f97d/libs/core/langchain_core/vectorstores/base.py#L65
|
||||
if (not keys) and ("ids" in kwargs) and (len(kwargs["ids"]) == len(embeddings)):
|
||||
keys = kwargs["ids"]
|
||||
|
||||
return await self.aadd_embeddings(zip(texts, embeddings), metadatas, keys=keys)
|
||||
|
||||
def add_embeddings(
|
||||
@@ -483,9 +495,13 @@ class AzureSearch(VectorStore):
|
||||
data = []
|
||||
for i, (text, embedding) in enumerate(text_embeddings):
|
||||
# Use provided key otherwise use default key
|
||||
key = keys[i] if keys else str(uuid.uuid4())
|
||||
# Encoding key for Azure Search valid characters
|
||||
key = base64.urlsafe_b64encode(bytes(key, "utf-8")).decode("ascii")
|
||||
if keys:
|
||||
key = keys[i]
|
||||
else:
|
||||
key = str(uuid.uuid4())
|
||||
# Encoding key for Azure Search valid characters
|
||||
key = base64.urlsafe_b64encode(bytes(key, "utf-8")).decode("ascii")
|
||||
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
# Add data to index
|
||||
# Additional metadata to fields mapping
|
||||
|
||||
@@ -595,11 +595,8 @@ class Neo4jVector(VectorStore):
|
||||
query: str,
|
||||
*,
|
||||
params: Optional[dict] = None,
|
||||
retry_on_session_expired: bool = True,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
This method sends a Cypher query to the connected Neo4j database
|
||||
and returns the results as a list of dictionaries.
|
||||
"""Query Neo4j database with retries and exponential backoff.
|
||||
|
||||
Args:
|
||||
query (str): The Cypher query to execute.
|
||||
@@ -608,24 +605,38 @@ class Neo4jVector(VectorStore):
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of dictionaries containing the query results.
|
||||
"""
|
||||
from neo4j.exceptions import CypherSyntaxError, SessionExpired
|
||||
from neo4j import Query
|
||||
from neo4j.exceptions import Neo4jError
|
||||
|
||||
params = params or {}
|
||||
with self._driver.session(database=self._database) as session:
|
||||
try:
|
||||
data = session.run(query, params)
|
||||
return [r.data() for r in data]
|
||||
except CypherSyntaxError as e:
|
||||
raise ValueError(f"Cypher Statement is not valid\n{e}")
|
||||
except (
|
||||
SessionExpired
|
||||
) as e: # Session expired is a transient error that can be retried
|
||||
if retry_on_session_expired:
|
||||
return self.query(
|
||||
query, params=params, retry_on_session_expired=False
|
||||
try:
|
||||
data, _, _ = self._driver.execute_query(
|
||||
query, database=self._database, parameters_=params
|
||||
)
|
||||
return [r.data() for r in data]
|
||||
except Neo4jError as e:
|
||||
if not (
|
||||
(
|
||||
( # isCallInTransactionError
|
||||
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
|
||||
or e.code
|
||||
== "Neo.DatabaseError.Transaction.TransactionStartFailed"
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
and "in an implicit transaction" in e.message
|
||||
)
|
||||
or ( # isPeriodicCommitError
|
||||
e.code == "Neo.ClientError.Statement.SemanticError"
|
||||
and (
|
||||
"in an open transaction is not possible" in e.message
|
||||
or "tried to execute in an explicit transaction" in e.message
|
||||
)
|
||||
)
|
||||
):
|
||||
raise
|
||||
# Fallback to allow implicit transactions
|
||||
with self._driver.session() as session:
|
||||
data = session.run(Query(text=query), params)
|
||||
return [r.data() for r in data]
|
||||
|
||||
def verify_version(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -144,7 +144,7 @@ class TencentVectorDB(VectorStore):
|
||||
|
||||
In order to use this you need to have a database instance.
|
||||
See the following documentation for details:
|
||||
https://cloud.tencent.com/document/product/1709/94951
|
||||
https://cloud.tencent.com/document/product/1709/104489
|
||||
"""
|
||||
|
||||
field_id: str = "id"
|
||||
|
||||
@@ -20,7 +20,7 @@ count=$(git grep -E '(@root_validator)|(@validator)|(@field_validator)|(@pre_ini
|
||||
# PRs that increase the current count will not be accepted.
|
||||
# PRs that decrease update the code in the repository
|
||||
# and allow decreasing the count of are welcome!
|
||||
current_count=129
|
||||
current_count=128
|
||||
|
||||
if [ "$count" -gt "$current_count" ]; then
|
||||
echo "The PR seems to be introducing new usage of @root_validator and/or @field_validator."
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Sequence, Union
|
||||
|
||||
@@ -11,7 +10,6 @@ from langchain_community.document_loaders import (
|
||||
PDFMinerPDFasHTMLLoader,
|
||||
PyMuPDFLoader,
|
||||
PyPDFium2Loader,
|
||||
PyPDFLoader,
|
||||
UnstructuredPDFLoader,
|
||||
)
|
||||
|
||||
@@ -86,37 +84,6 @@ def test_pdfminer_pdf_as_html_loader() -> None:
|
||||
assert len(docs) == 1
|
||||
|
||||
|
||||
def test_pypdf_loader() -> None:
|
||||
"""Test PyPDFLoader."""
|
||||
file_path = Path(__file__).parent.parent / "examples/hello.pdf"
|
||||
loader = PyPDFLoader(str(file_path))
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 1
|
||||
|
||||
file_path = Path(__file__).parent.parent / "examples/layout-parser-paper.pdf"
|
||||
loader = PyPDFLoader(str(file_path))
|
||||
|
||||
docs = loader.load()
|
||||
assert len(docs) == 16
|
||||
|
||||
|
||||
def test_pypdf_loader_with_layout() -> None:
|
||||
"""Test PyPDFLoader with layout mode."""
|
||||
file_path = Path(__file__).parent.parent / "examples/layout-parser-paper.pdf"
|
||||
loader = PyPDFLoader(str(file_path), extraction_mode="layout")
|
||||
|
||||
docs = loader.load()
|
||||
first_page = docs[0].page_content
|
||||
|
||||
expected = (
|
||||
Path(__file__).parent.parent / "examples/layout-parser-paper-page-1.txt"
|
||||
).read_text(encoding="utf-8")
|
||||
cleaned_first_page = re.sub(r"\x00", "", first_page)
|
||||
cleaned_expected = re.sub(r"\x00", "", expected)
|
||||
assert cleaned_first_page == cleaned_expected
|
||||
|
||||
|
||||
def test_pypdfium2_loader() -> None:
|
||||
"""Test PyPDFium2Loader."""
|
||||
file_path = Path(__file__).parent.parent / "examples/hello.pdf"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from langchain_core.graph_vectorstores.links import Link
|
||||
|
||||
from langchain_community.graph_vectorstores.extractors import GLiNERLinkExtractor
|
||||
from langchain_community.graph_vectorstores.links import Link
|
||||
|
||||
PAGE_1 = """
|
||||
Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃ'tjɐnu
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from langchain_core.graph_vectorstores.links import Link
|
||||
|
||||
from langchain_community.graph_vectorstores.extractors import KeybertLinkExtractor
|
||||
from langchain_community.graph_vectorstores.links import Link
|
||||
|
||||
PAGE_1 = """
|
||||
Supervised learning is the machine learning task of learning a function that
|
||||
|
||||
@@ -4,9 +4,9 @@ from typing import Iterable, List, Optional, Type
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.graph_vectorstores.links import METADATA_LINKS_KEY, Link
|
||||
|
||||
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
|
||||
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
|
||||
|
||||
CASSANDRA_DEFAULT_KEYSPACE = "graph_test_keyspace"
|
||||
|
||||
|
||||
@@ -1,28 +1,17 @@
|
||||
"""Test sambanova API wrapper.
|
||||
|
||||
In order to run this test, you need to have an sambaverse api key,
|
||||
and a sambaverse base url, project id, endpoint id, and api key.
|
||||
You'll then need to set SAMBAVERSE_API_KEY, SAMBASTUDIO_BASE_URL,
|
||||
In order to run this test, you need to have a sambastudio base url,
|
||||
project id, endpoint id, and api key.
|
||||
You'll then need to set SAMBASTUDIO_BASE_URL, SAMBASTUDIO_BASE_URI
|
||||
SAMBASTUDIO_PROJECT_ID, SAMBASTUDIO_ENDPOINT_ID, and SAMBASTUDIO_API_KEY
|
||||
environment variables.
|
||||
"""
|
||||
|
||||
from langchain_community.llms.sambanova import SambaStudio, Sambaverse
|
||||
|
||||
|
||||
def test_sambaverse_call() -> None:
|
||||
"""Test simple non-streaming call to sambaverse."""
|
||||
llm = Sambaverse(
|
||||
sambaverse_model_name="Meta/llama-2-7b-chat-hf",
|
||||
model_kwargs={"select_expert": "llama-2-7b-chat-hf"},
|
||||
)
|
||||
output = llm.invoke("What is LangChain")
|
||||
assert output
|
||||
assert isinstance(output, str)
|
||||
from langchain_community.llms.sambanova import SambaStudio
|
||||
|
||||
|
||||
def test_sambastudio_call() -> None:
|
||||
"""Test simple non-streaming call to sambaverse."""
|
||||
"""Test simple non-streaming call to sambastudio."""
|
||||
llm = SambaStudio()
|
||||
output = llm.invoke("What is LangChain")
|
||||
assert output
|
||||
|
||||
@@ -121,4 +121,4 @@ def test_callback_manager_configure_context_vars(
|
||||
assert cb.completion_tokens == 1
|
||||
assert cb.total_cost > 0
|
||||
wait_for_all_tracers()
|
||||
assert LangChainTracer._persist_run_single.call_count == 1 # type: ignore
|
||||
assert LangChainTracer._persist_run_single.call_count == 4 # type: ignore
|
||||
|
||||
@@ -55,6 +55,7 @@ EXPECTED_ALL = [
|
||||
"DedocFileLoader",
|
||||
"DedocPDFLoader",
|
||||
"PebbloSafeLoader",
|
||||
"PebbloTextLoader",
|
||||
"DiffbotLoader",
|
||||
"DirectoryLoader",
|
||||
"DiscordChatLoader",
|
||||
|
||||
@@ -12,6 +12,7 @@ def raw_docs() -> List[Dict]:
|
||||
return [
|
||||
{"_id": "1", "address": {"building": "1", "room": "1"}},
|
||||
{"_id": "2", "address": {"building": "2", "room": "2"}},
|
||||
{"_id": "3", "address": {"building": "3", "room": "2"}},
|
||||
]
|
||||
|
||||
|
||||
@@ -19,18 +20,23 @@ def raw_docs() -> List[Dict]:
|
||||
def expected_documents() -> List[Document]:
|
||||
return [
|
||||
Document(
|
||||
page_content="{'_id': '1', 'address': {'building': '1', 'room': '1'}}",
|
||||
page_content="{'_id': '2', 'address': {'building': '2', 'room': '2'}}",
|
||||
metadata={"database": "sample_restaurants", "collection": "restaurants"},
|
||||
),
|
||||
Document(
|
||||
page_content="{'_id': '2', 'address': {'building': '2', 'room': '2'}}",
|
||||
page_content="{'_id': '3', 'address': {'building': '3', 'room': '2'}}",
|
||||
metadata={"database": "sample_restaurants", "collection": "restaurants"},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.requires("motor")
|
||||
async def test_load_mocked(expected_documents: List[Document]) -> None:
|
||||
async def test_load_mocked_with_filters(expected_documents: List[Document]) -> None:
|
||||
filter_criteria = {"address.room": {"$eq": "2"}}
|
||||
field_names = ["address.building", "address.room"]
|
||||
metadata_names = ["_id"]
|
||||
include_db_collection_in_metadata = True
|
||||
|
||||
mock_async_load = AsyncMock()
|
||||
mock_async_load.return_value = expected_documents
|
||||
|
||||
@@ -51,7 +57,13 @@ async def test_load_mocked(expected_documents: List[Document]) -> None:
|
||||
new=mock_async_load,
|
||||
):
|
||||
loader = MongodbLoader(
|
||||
"mongodb://localhost:27017", "test_db", "test_collection"
|
||||
"mongodb://localhost:27017",
|
||||
"test_db",
|
||||
"test_collection",
|
||||
filter_criteria=filter_criteria,
|
||||
field_names=field_names,
|
||||
metadata_names=metadata_names,
|
||||
include_db_collection_in_metadata=include_db_collection_in_metadata,
|
||||
)
|
||||
loader.collection = mock_collection
|
||||
documents = await loader.aload()
|
||||
|
||||
62
libs/community/tests/unit_tests/document_loaders/test_pdf.py
Normal file
62
libs/community/tests/unit_tests/document_loaders/test_pdf.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
|
||||
path_to_simple_pdf = (
|
||||
Path(__file__).parent.parent.parent / "integration_tests/examples/hello.pdf"
|
||||
)
|
||||
path_to_layout_pdf = (
|
||||
Path(__file__).parent.parent
|
||||
/ "document_loaders/sample_documents/layout-parser-paper.pdf"
|
||||
)
|
||||
path_to_layout_pdf_txt = (
|
||||
Path(__file__).parent.parent.parent
|
||||
/ "integration_tests/examples/layout-parser-paper-page-1.txt"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("pypdf")
|
||||
def test_pypdf_loader() -> None:
|
||||
"""Test PyPDFLoader."""
|
||||
loader = PyPDFLoader(str(path_to_simple_pdf))
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 1
|
||||
|
||||
loader = PyPDFLoader(str(path_to_layout_pdf))
|
||||
|
||||
docs = loader.load()
|
||||
assert len(docs) == 16
|
||||
for page, doc in enumerate(docs):
|
||||
assert doc.metadata["page"] == page
|
||||
assert doc.metadata["source"].endswith("layout-parser-paper.pdf")
|
||||
assert len(doc.page_content) > 10
|
||||
|
||||
first_page = docs[0].page_content
|
||||
for expected in ["LayoutParser", "A Unified Toolkit"]:
|
||||
assert expected in first_page
|
||||
|
||||
|
||||
@pytest.mark.requires("pypdf")
|
||||
def test_pypdf_loader_with_layout() -> None:
|
||||
"""Test PyPDFLoader with layout mode."""
|
||||
loader = PyPDFLoader(str(path_to_layout_pdf), extraction_mode="layout")
|
||||
|
||||
docs = loader.load()
|
||||
assert len(docs) == 16
|
||||
for page, doc in enumerate(docs):
|
||||
assert doc.metadata["page"] == page
|
||||
assert doc.metadata["source"].endswith("layout-parser-paper.pdf")
|
||||
assert len(doc.page_content) > 10
|
||||
|
||||
first_page = docs[0].page_content
|
||||
for expected in ["LayoutParser", "A Unified Toolkit"]:
|
||||
assert expected in first_page
|
||||
|
||||
expected = path_to_layout_pdf_txt.read_text(encoding="utf-8")
|
||||
cleaned_first_page = re.sub(r"\x00", "", first_page)
|
||||
cleaned_expected = re.sub(r"\x00", "", expected)
|
||||
assert cleaned_first_page == cleaned_expected
|
||||
@@ -25,6 +25,11 @@ def test_pebblo_import() -> None:
|
||||
from langchain_community.document_loaders import PebbloSafeLoader # noqa: F401
|
||||
|
||||
|
||||
def test_pebblo_text_loader_import() -> None:
|
||||
"""Test that the Pebblo text loader can be imported."""
|
||||
from langchain_community.document_loaders import PebbloTextLoader # noqa: F401
|
||||
|
||||
|
||||
def test_empty_filebased_loader(mocker: MockerFixture) -> None:
|
||||
"""Test basic file based csv loader."""
|
||||
# Setup
|
||||
@@ -146,3 +151,42 @@ def test_pebblo_safe_loader_api_key() -> None:
|
||||
# Assert
|
||||
assert loader.pb_client.api_key == api_key
|
||||
assert loader.pb_client.classifier_location == "local"
|
||||
|
||||
|
||||
def test_pebblo_text_loader(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test loading in-memory text with PebbloTextLoader and PebbloSafeLoader.
|
||||
"""
|
||||
# Setup
|
||||
from langchain_community.document_loaders import PebbloSafeLoader, PebbloTextLoader
|
||||
|
||||
mocker.patch.multiple(
|
||||
"requests",
|
||||
get=MockResponse(json_data={"data": ""}, status_code=200),
|
||||
post=MockResponse(json_data={"data": ""}, status_code=200),
|
||||
)
|
||||
|
||||
text = "This is a test text."
|
||||
source = "fake_source"
|
||||
expected_docs = [
|
||||
Document(
|
||||
metadata={
|
||||
"full_path": source,
|
||||
"pb_checksum": None,
|
||||
},
|
||||
page_content=text,
|
||||
),
|
||||
]
|
||||
|
||||
# Exercise
|
||||
texts = [text]
|
||||
loader = PebbloSafeLoader(
|
||||
PebbloTextLoader(texts, source=source),
|
||||
"dummy_app_name",
|
||||
"dummy_owner",
|
||||
"dummy_description",
|
||||
)
|
||||
result = loader.load()
|
||||
|
||||
# Assert
|
||||
assert result == expected_docs
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from langchain_core.graph_vectorstores.links import Link
|
||||
|
||||
from langchain_community.graph_vectorstores.extractors import HierarchyLinkExtractor
|
||||
from langchain_community.graph_vectorstores.links import Link
|
||||
|
||||
PATH_1 = ["Root", "H1", "h2"]
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from langchain_core.graph_vectorstores import Link
|
||||
|
||||
from langchain_community.graph_vectorstores import Link
|
||||
from langchain_community.graph_vectorstores.extractors import (
|
||||
HtmlInput,
|
||||
HtmlLinkExtractor,
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import Set
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores.links import Link, get_links
|
||||
|
||||
from langchain_community.graph_vectorstores.extractors import (
|
||||
LinkExtractor,
|
||||
LinkExtractorTransformer,
|
||||
)
|
||||
from langchain_community.graph_vectorstores.links import Link, get_links
|
||||
|
||||
TEXT1 = "Text1"
|
||||
TEXT2 = "Text2"
|
||||
|
||||
@@ -77,7 +77,6 @@ EXPECT_ALL = [
|
||||
"RWKV",
|
||||
"Replicate",
|
||||
"SagemakerEndpoint",
|
||||
"Sambaverse",
|
||||
"SambaStudio",
|
||||
"SelfHostedHuggingFaceLLM",
|
||||
"SelfHostedPipeline",
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import pytest
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores.base import (
|
||||
|
||||
from langchain_community.graph_vectorstores.base import (
|
||||
Node,
|
||||
_documents_to_nodes,
|
||||
_texts_to_nodes,
|
||||
)
|
||||
from langchain_core.graph_vectorstores.links import Link
|
||||
from langchain_community.graph_vectorstores.links import Link
|
||||
|
||||
|
||||
def test_texts_to_nodes() -> None:
|
||||
@@ -190,3 +190,40 @@ def test_additional_search_options() -> None:
|
||||
)
|
||||
assert vector_store.client is not None
|
||||
assert vector_store.client._api_version == "test"
|
||||
|
||||
|
||||
@pytest.mark.requires("azure.search.documents")
|
||||
def test_ids_used_correctly() -> None:
|
||||
"""Check whether vector store uses the document ids when provided with them."""
|
||||
from azure.search.documents import SearchClient
|
||||
from azure.search.documents.indexes import SearchIndexClient
|
||||
from langchain_core.documents import Document
|
||||
|
||||
class Response:
|
||||
def __init__(self) -> None:
|
||||
self.succeeded: bool = True
|
||||
|
||||
def mock_upload_documents(self, documents: List[object]) -> List[Response]: # type: ignore[no-untyped-def]
|
||||
# assume all documents uploaded successfuly
|
||||
response = [Response() for _ in documents]
|
||||
return response
|
||||
|
||||
documents = [
|
||||
Document(
|
||||
page_content="page zero Lorem Ipsum",
|
||||
metadata={"source": "document.pdf", "page": 0, "id": "ID-document-1"},
|
||||
),
|
||||
Document(
|
||||
page_content="page one Lorem Ipsum",
|
||||
metadata={"source": "document.pdf", "page": 1, "id": "ID-document-2"},
|
||||
),
|
||||
]
|
||||
ids_provided = [i.metadata.get("id") for i in documents]
|
||||
|
||||
with patch.object(
|
||||
SearchClient, "upload_documents", mock_upload_documents
|
||||
), patch.object(SearchIndexClient, "get_index", mock_default_index):
|
||||
vector_store = create_vector_store()
|
||||
ids_used_at_upload = vector_store.add_documents(documents, ids=ids_provided)
|
||||
assert len(ids_provided) == len(ids_used_at_upload)
|
||||
assert ids_provided == ids_used_at_upload
|
||||
|
||||
@@ -53,7 +53,7 @@ LangChain Core compiles LCEL sequences to an _optimized execution plan_, with au
|
||||
|
||||
For more check out the [LCEL docs](https://python.langchain.com/docs/expression_language/).
|
||||
|
||||

|
||||

|
||||
|
||||
For more advanced use cases, also check out [LangGraph](https://github.com/langchain-ai/langgraph), which is a graph-based runner for cyclic and recursive LLM workflows.
|
||||
|
||||
|
||||
@@ -14,7 +14,8 @@ import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Generator, Type, TypeVar, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Callable, TypeVar, Union, cast
|
||||
|
||||
from langchain_core._api.internal import is_caller_internal
|
||||
|
||||
@@ -26,7 +27,7 @@ class LangChainBetaWarning(DeprecationWarning):
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Union[Callable[..., Any], Type])
|
||||
T = TypeVar("T", bound=Union[Callable[..., Any], type])
|
||||
|
||||
|
||||
def beta(
|
||||
@@ -154,7 +155,7 @@ def beta(
|
||||
_name = _name or obj.fget.__qualname__
|
||||
old_doc = obj.__doc__
|
||||
|
||||
class _beta_property(property):
|
||||
class _BetaProperty(property):
|
||||
"""A beta property."""
|
||||
|
||||
def __init__(self, fget=None, fset=None, fdel=None, doc=None):
|
||||
@@ -185,7 +186,7 @@ def beta(
|
||||
|
||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any:
|
||||
"""Finalize the property."""
|
||||
return _beta_property(
|
||||
return _BetaProperty(
|
||||
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
|
||||
)
|
||||
|
||||
|
||||
@@ -14,11 +14,10 @@ import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from collections.abc import Generator
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generator,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
@@ -41,7 +40,7 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
|
||||
|
||||
|
||||
# Last Any should be FieldInfoV1 but this leads to circular imports
|
||||
T = TypeVar("T", bound=Union[Type, Callable[..., Any], Any])
|
||||
T = TypeVar("T", bound=Union[type, Callable[..., Any], Any])
|
||||
|
||||
|
||||
def _validate_deprecation_params(
|
||||
@@ -262,10 +261,10 @@ def deprecated(
|
||||
if not _obj_type:
|
||||
_obj_type = "attribute"
|
||||
wrapped = None
|
||||
_name = _name or cast(Union[Type, Callable], obj.fget).__qualname__
|
||||
_name = _name or cast(Union[type, Callable], obj.fget).__qualname__
|
||||
old_doc = obj.__doc__
|
||||
|
||||
class _deprecated_property(property):
|
||||
class _DeprecatedProperty(property):
|
||||
"""A deprecated property."""
|
||||
|
||||
def __init__(self, fget=None, fset=None, fdel=None, doc=None): # type: ignore[no-untyped-def]
|
||||
@@ -298,13 +297,13 @@ def deprecated(
|
||||
"""Finalize the property."""
|
||||
return cast(
|
||||
T,
|
||||
_deprecated_property(
|
||||
_DeprecatedProperty(
|
||||
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
_name = _name or cast(Union[Type, Callable], obj).__qualname__
|
||||
_name = _name or cast(Union[type, Callable], obj).__qualname__
|
||||
if not _obj_type:
|
||||
# edge case: when a function is within another function
|
||||
# within a test, this will call it a "method" not a "function"
|
||||
@@ -333,9 +332,26 @@ def deprecated(
|
||||
old_doc = ""
|
||||
|
||||
# Modify the docstring to include a deprecation notice.
|
||||
if (
|
||||
_alternative
|
||||
and _alternative.split(".")[-1].lower() == _alternative.split(".")[-1]
|
||||
):
|
||||
_alternative = f":meth:`~{_alternative}`"
|
||||
elif _alternative:
|
||||
_alternative = f":class:`~{_alternative}`"
|
||||
|
||||
if (
|
||||
_alternative_import
|
||||
and _alternative_import.split(".")[-1].lower()
|
||||
== _alternative_import.split(".")[-1]
|
||||
):
|
||||
_alternative_import = f":meth:`~{_alternative_import}`"
|
||||
elif _alternative_import:
|
||||
_alternative_import = f":class:`~{_alternative_import}`"
|
||||
|
||||
components = [
|
||||
_message,
|
||||
f"Use ``{_alternative}`` instead." if _alternative else "",
|
||||
f"Use {_alternative} instead." if _alternative else "",
|
||||
f"Use ``{_alternative_import}`` instead." if _alternative_import else "",
|
||||
_addendum,
|
||||
]
|
||||
|
||||
@@ -25,7 +25,8 @@ The schemas for the agents themselves are defined in langchain.agents.agent.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, List, Literal, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import (
|
||||
@@ -71,7 +72,7 @@ class AgentAction(Serializable):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
Default is ["langchain", "schema", "agent"]."""
|
||||
return ["langchain", "schema", "agent"]
|
||||
@@ -145,7 +146,7 @@ class AgentFinish(Serializable):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "agent"]
|
||||
|
||||
|
||||
@@ -1,19 +1,13 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from collections.abc import Awaitable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from itertools import groupby
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
@@ -30,7 +24,7 @@ from langchain_core.runnables.config import RunnableConfig, ensure_config, patch
|
||||
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
|
||||
|
||||
T = TypeVar("T")
|
||||
Values = Dict[Union[asyncio.Event, threading.Event], Any]
|
||||
Values = dict[Union[asyncio.Event, threading.Event], Any]
|
||||
CONTEXT_CONFIG_PREFIX = "__context__/"
|
||||
CONTEXT_CONFIG_SUFFIX_GET = "/get"
|
||||
CONTEXT_CONFIG_SUFFIX_SET = "/set"
|
||||
@@ -70,10 +64,10 @@ def _key_from_id(id_: str) -> str:
|
||||
|
||||
def _config_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: List[Runnable],
|
||||
steps: list[Runnable],
|
||||
setter: Callable,
|
||||
getter: Callable,
|
||||
event_cls: Union[Type[threading.Event], Type[asyncio.Event]],
|
||||
event_cls: Union[type[threading.Event], type[asyncio.Event]],
|
||||
) -> RunnableConfig:
|
||||
if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})):
|
||||
return config
|
||||
@@ -99,10 +93,10 @@ def _config_with_context(
|
||||
}
|
||||
|
||||
values: Values = {}
|
||||
events: DefaultDict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
|
||||
events: defaultdict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
|
||||
event_cls
|
||||
)
|
||||
context_funcs: Dict[str, Callable[[], Any]] = {}
|
||||
context_funcs: dict[str, Callable[[], Any]] = {}
|
||||
for key, group in grouped_by_key.items():
|
||||
getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)]
|
||||
setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)]
|
||||
@@ -129,7 +123,7 @@ def _config_with_context(
|
||||
|
||||
def aconfig_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: List[Runnable],
|
||||
steps: list[Runnable],
|
||||
) -> RunnableConfig:
|
||||
"""Asynchronously patch a runnable config with context getters and setters.
|
||||
|
||||
@@ -145,7 +139,7 @@ def aconfig_with_context(
|
||||
|
||||
def config_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: List[Runnable],
|
||||
steps: list[Runnable],
|
||||
) -> RunnableConfig:
|
||||
"""Patch a runnable config with context getters and setters.
|
||||
|
||||
@@ -165,13 +159,13 @@ class ContextGet(RunnableSerializable):
|
||||
|
||||
prefix: str = ""
|
||||
|
||||
key: Union[str, List[str]]
|
||||
key: Union[str, list[str]]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ContextGet({_print_keys(self.key)})"
|
||||
|
||||
@property
|
||||
def ids(self) -> List[str]:
|
||||
def ids(self) -> list[str]:
|
||||
prefix = self.prefix + "/" if self.prefix else ""
|
||||
keys = self.key if isinstance(self.key, list) else [self.key]
|
||||
return [
|
||||
@@ -180,7 +174,7 @@ class ContextGet(RunnableSerializable):
|
||||
]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
return super().config_specs + [
|
||||
ConfigurableFieldSpec(
|
||||
id=id_,
|
||||
@@ -256,7 +250,7 @@ class ContextSet(RunnableSerializable):
|
||||
return f"ContextSet({_print_keys(list(self.keys.keys()))})"
|
||||
|
||||
@property
|
||||
def ids(self) -> List[str]:
|
||||
def ids(self) -> list[str]:
|
||||
prefix = self.prefix + "/" if self.prefix else ""
|
||||
return [
|
||||
f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}"
|
||||
@@ -264,7 +258,7 @@ class ContextSet(RunnableSerializable):
|
||||
]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
mapper_config_specs = [
|
||||
s
|
||||
for mapper in self.keys.values()
|
||||
@@ -364,7 +358,7 @@ class Context:
|
||||
return PrefixContext(prefix=scope)
|
||||
|
||||
@staticmethod
|
||||
def getter(key: Union[str, List[str]], /) -> ContextGet:
|
||||
def getter(key: Union[str, list[str]], /) -> ContextGet:
|
||||
return ContextGet(key=key)
|
||||
|
||||
@staticmethod
|
||||
@@ -385,7 +379,7 @@ class PrefixContext:
|
||||
def __init__(self, prefix: str = ""):
|
||||
self.prefix = prefix
|
||||
|
||||
def getter(self, key: Union[str, List[str]], /) -> ContextGet:
|
||||
def getter(self, key: Union[str, list[str]], /) -> ContextGet:
|
||||
return ContextGet(key=key, prefix=self.prefix)
|
||||
|
||||
def setter(
|
||||
|
||||
@@ -23,7 +23,8 @@ Cache directly competes with Memory. See documentation for Pros and Cons.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.outputs import Generation
|
||||
from langchain_core.runnables import run_in_executor
|
||||
@@ -157,7 +158,7 @@ class InMemoryCache(BaseCache):
|
||||
Raises:
|
||||
ValueError: If maxsize is less than or equal to 0.
|
||||
"""
|
||||
self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {}
|
||||
self._cache: dict[tuple[str, str], RETURN_VAL_TYPE] = {}
|
||||
if maxsize is not None and maxsize <= 0:
|
||||
raise ValueError("maxsize must be greater than 0")
|
||||
self._maxsize = maxsize
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
@@ -118,7 +119,7 @@ class ChainManagerMixin:
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
outputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -222,13 +223,13 @@ class CallbackManagerMixin:
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM starts running.
|
||||
@@ -249,13 +250,13 @@ class CallbackManagerMixin:
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chat model starts running.
|
||||
@@ -280,13 +281,13 @@ class CallbackManagerMixin:
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when the Retriever starts running.
|
||||
@@ -303,13 +304,13 @@ class CallbackManagerMixin:
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chain starts running.
|
||||
@@ -326,14 +327,14 @@ class CallbackManagerMixin:
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when the tool starts running.
|
||||
@@ -393,8 +394,8 @@ class RunManagerMixin:
|
||||
data: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Override to define a handler for a custom event.
|
||||
@@ -470,13 +471,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running.
|
||||
@@ -497,13 +498,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chat model starts running.
|
||||
@@ -533,7 +534,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled.
|
||||
@@ -554,7 +555,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM ends running.
|
||||
@@ -573,7 +574,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors.
|
||||
@@ -590,13 +591,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when a chain starts running.
|
||||
@@ -613,11 +614,11 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
outputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when a chain ends running.
|
||||
@@ -636,7 +637,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors.
|
||||
@@ -651,14 +652,14 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when the tool starts running.
|
||||
@@ -680,7 +681,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when the tool ends running.
|
||||
@@ -699,7 +700,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors.
|
||||
@@ -718,7 +719,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on an arbitrary text.
|
||||
@@ -754,7 +755,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on agent action.
|
||||
@@ -773,7 +774,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on the agent end.
|
||||
@@ -788,13 +789,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on the retriever start.
|
||||
@@ -815,7 +816,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on the retriever end.
|
||||
@@ -833,7 +834,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever error.
|
||||
@@ -852,8 +853,8 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
data: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Override to define a handler for a custom event.
|
||||
@@ -880,14 +881,14 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
||||
handlers: list[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[list[BaseCallbackHandler]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
inheritable_tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize callback manager.
|
||||
|
||||
@@ -901,8 +902,8 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
Default is None.
|
||||
metadata (Optional[Dict[str, Any]]): The metadata. Default is None.
|
||||
"""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
self.inheritable_handlers: List[BaseCallbackHandler] = (
|
||||
self.handlers: list[BaseCallbackHandler] = handlers
|
||||
self.inheritable_handlers: list[BaseCallbackHandler] = (
|
||||
inheritable_handlers or []
|
||||
)
|
||||
self.parent_run_id: Optional[UUID] = parent_run_id
|
||||
@@ -1002,7 +1003,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
self.inheritable_handlers.remove(handler)
|
||||
|
||||
def set_handlers(
|
||||
self, handlers: List[BaseCallbackHandler], inherit: bool = True
|
||||
self, handlers: list[BaseCallbackHandler], inherit: bool = True
|
||||
) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager.
|
||||
|
||||
@@ -1024,7 +1025,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
"""
|
||||
self.set_handlers([handler], inherit=inherit)
|
||||
|
||||
def add_tags(self, tags: List[str], inherit: bool = True) -> None:
|
||||
def add_tags(self, tags: list[str], inherit: bool = True) -> None:
|
||||
"""Add tags to the callback manager.
|
||||
|
||||
Args:
|
||||
@@ -1038,7 +1039,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
if inherit:
|
||||
self.inheritable_tags.extend(tags)
|
||||
|
||||
def remove_tags(self, tags: List[str]) -> None:
|
||||
def remove_tags(self, tags: list[str]) -> None:
|
||||
"""Remove tags from the callback manager.
|
||||
|
||||
Args:
|
||||
@@ -1048,7 +1049,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
self.tags.remove(tag)
|
||||
self.inheritable_tags.remove(tag)
|
||||
|
||||
def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None:
|
||||
def add_metadata(self, metadata: dict[str, Any], inherit: bool = True) -> None:
|
||||
"""Add metadata to the callback manager.
|
||||
|
||||
Args:
|
||||
@@ -1059,7 +1060,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
if inherit:
|
||||
self.inheritable_metadata.update(metadata)
|
||||
|
||||
def remove_metadata(self, keys: List[str]) -> None:
|
||||
def remove_metadata(self, keys: list[str]) -> None:
|
||||
"""Remove metadata from the callback manager.
|
||||
|
||||
Args:
|
||||
@@ -1070,4 +1071,4 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
self.inheritable_metadata.pop(key)
|
||||
|
||||
|
||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||
Callbacks = Optional[Union[list[BaseCallbackHandler], BaseCallbackManager]]
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional, TextIO, cast
|
||||
from typing import Any, Optional, TextIO, cast
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
@@ -35,7 +35,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
self.file.close()
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain.
|
||||
|
||||
@@ -51,7 +51,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
file=self.file,
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -5,21 +5,15 @@ import functools
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Coroutine, Generator, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextvars import copy_context
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
@@ -64,12 +58,12 @@ def trace_as_chain_group(
|
||||
group_name: str,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
project_name: Optional[str] = None,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> Generator[CallbackManagerForChainGroup, None, None]:
|
||||
"""Get a callback manager for a chain group in a context manager.
|
||||
Useful for grouping different calls together as a single run even if
|
||||
@@ -144,12 +138,12 @@ async def atrace_as_chain_group(
|
||||
group_name: str,
|
||||
callback_manager: Optional[AsyncCallbackManager] = None,
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
project_name: Optional[str] = None,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]:
|
||||
"""Get an async callback manager for a chain group in a context manager.
|
||||
Useful for grouping different async calls together as a single run even if
|
||||
@@ -240,7 +234,7 @@ def shielded(func: Func) -> Func:
|
||||
|
||||
|
||||
def handle_event(
|
||||
handlers: List[BaseCallbackHandler],
|
||||
handlers: list[BaseCallbackHandler],
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
*args: Any,
|
||||
@@ -258,10 +252,10 @@ def handle_event(
|
||||
*args: The arguments to pass to the event handler.
|
||||
**kwargs: The keyword arguments to pass to the event handler
|
||||
"""
|
||||
coros: List[Coroutine[Any, Any, Any]] = []
|
||||
coros: list[Coroutine[Any, Any, Any]] = []
|
||||
|
||||
try:
|
||||
message_strings: Optional[List[str]] = None
|
||||
message_strings: Optional[list[str]] = None
|
||||
for handler in handlers:
|
||||
try:
|
||||
if ignore_condition_name is None or not getattr(
|
||||
@@ -318,7 +312,7 @@ def handle_event(
|
||||
_run_coros(coros)
|
||||
|
||||
|
||||
def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
|
||||
def _run_coros(coros: list[Coroutine[Any, Any, Any]]) -> None:
|
||||
if hasattr(asyncio, "Runner"):
|
||||
# Python 3.11+
|
||||
# Run the coroutines in a new event loop, taking care to
|
||||
@@ -399,7 +393,7 @@ async def _ahandle_event_for_handler(
|
||||
|
||||
|
||||
async def ahandle_event(
|
||||
handlers: List[BaseCallbackHandler],
|
||||
handlers: list[BaseCallbackHandler],
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
*args: Any,
|
||||
@@ -446,13 +440,13 @@ class BaseRunManager(RunManagerMixin):
|
||||
self,
|
||||
*,
|
||||
run_id: UUID,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: List[BaseCallbackHandler],
|
||||
handlers: list[BaseCallbackHandler],
|
||||
inheritable_handlers: list[BaseCallbackHandler],
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
inheritable_tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize the run manager.
|
||||
|
||||
@@ -481,7 +475,7 @@ class BaseRunManager(RunManagerMixin):
|
||||
self.inheritable_metadata = inheritable_metadata or {}
|
||||
|
||||
@classmethod
|
||||
def get_noop_manager(cls: Type[BRM]) -> BRM:
|
||||
def get_noop_manager(cls: type[BRM]) -> BRM:
|
||||
"""Return a manager that doesn't perform any operations.
|
||||
|
||||
Returns:
|
||||
@@ -824,7 +818,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||
"""Callback manager for chain run."""
|
||||
|
||||
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: Union[dict[str, Any], Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running.
|
||||
|
||||
Args:
|
||||
@@ -929,7 +923,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
|
||||
@shielded
|
||||
async def on_chain_end(
|
||||
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
||||
self, outputs: Union[dict[str, Any], Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when a chain ends running.
|
||||
|
||||
@@ -1248,11 +1242,11 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[CallbackManagerForLLMRun]:
|
||||
) -> list[CallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running.
|
||||
|
||||
Args:
|
||||
@@ -1299,11 +1293,11 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[CallbackManagerForLLMRun]:
|
||||
) -> list[CallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running.
|
||||
|
||||
Args:
|
||||
@@ -1354,8 +1348,8 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
inputs: Union[dict[str, Any], Any],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForChainRun:
|
||||
@@ -1398,11 +1392,11 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
input_str: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForToolRun:
|
||||
"""Run when tool starts running.
|
||||
@@ -1453,7 +1447,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
query: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -1541,10 +1535,10 @@ class CallbackManager(BaseCallbackManager):
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
local_callbacks: Callbacks = None,
|
||||
verbose: bool = False,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
local_tags: Optional[List[str]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
local_metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_tags: Optional[list[str]] = None,
|
||||
local_tags: Optional[list[str]] = None,
|
||||
inheritable_metadata: Optional[dict[str, Any]] = None,
|
||||
local_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> CallbackManager:
|
||||
"""Configure the callback manager.
|
||||
|
||||
@@ -1583,8 +1577,8 @@ class CallbackManagerForChainGroup(CallbackManager):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
||||
handlers: list[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[list[BaseCallbackHandler]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
*,
|
||||
parent_run_manager: CallbackManagerForChainRun,
|
||||
@@ -1681,7 +1675,7 @@ class CallbackManagerForChainGroup(CallbackManager):
|
||||
manager.add_handler(handler, inherit=True)
|
||||
return manager
|
||||
|
||||
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: Union[dict[str, Any], Any], **kwargs: Any) -> None:
|
||||
"""Run when traced chain group ends.
|
||||
|
||||
Args:
|
||||
@@ -1716,11 +1710,11 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[AsyncCallbackManagerForLLMRun]:
|
||||
) -> list[AsyncCallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running.
|
||||
|
||||
Args:
|
||||
@@ -1779,11 +1773,11 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[AsyncCallbackManagerForLLMRun]:
|
||||
) -> list[AsyncCallbackManagerForLLMRun]:
|
||||
"""Async run when LLM starts running.
|
||||
|
||||
Args:
|
||||
@@ -1840,8 +1834,8 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
inputs: Union[dict[str, Any], Any],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCallbackManagerForChainRun:
|
||||
@@ -1886,7 +1880,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
input_str: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -1975,7 +1969,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
query: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -2027,10 +2021,10 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
local_callbacks: Callbacks = None,
|
||||
verbose: bool = False,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
local_tags: Optional[List[str]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
local_metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_tags: Optional[list[str]] = None,
|
||||
local_tags: Optional[list[str]] = None,
|
||||
inheritable_metadata: Optional[dict[str, Any]] = None,
|
||||
local_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncCallbackManager:
|
||||
"""Configure the async callback manager.
|
||||
|
||||
@@ -2069,8 +2063,8 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
||||
handlers: list[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[list[BaseCallbackHandler]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
*,
|
||||
parent_run_manager: AsyncCallbackManagerForChainRun,
|
||||
@@ -2169,7 +2163,7 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
|
||||
return manager
|
||||
|
||||
async def on_chain_end(
|
||||
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
||||
self, outputs: Union[dict[str, Any], Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when traced chain group ends.
|
||||
|
||||
@@ -2202,14 +2196,14 @@ H = TypeVar("H", bound=BaseCallbackHandler, covariant=True)
|
||||
|
||||
|
||||
def _configure(
|
||||
callback_manager_cls: Type[T],
|
||||
callback_manager_cls: type[T],
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
local_callbacks: Callbacks = None,
|
||||
verbose: bool = False,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
local_tags: Optional[List[str]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
local_metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_tags: Optional[list[str]] = None,
|
||||
local_tags: Optional[list[str]] = None,
|
||||
inheritable_metadata: Optional[dict[str, Any]] = None,
|
||||
local_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> T:
|
||||
"""Configure the callback manager.
|
||||
|
||||
@@ -2354,7 +2348,7 @@ def _configure(
|
||||
and handler_class is not None
|
||||
)
|
||||
if var.get() is not None or create_one:
|
||||
var_handler = var.get() or cast(Type[BaseCallbackHandler], handler_class)()
|
||||
var_handler = var.get() or cast(type[BaseCallbackHandler], handler_class)()
|
||||
if handler_class is None:
|
||||
if not any(
|
||||
handler is var_handler # direct pointer comparison
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
from langchain_core.utils import print_text
|
||||
@@ -23,7 +23,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
self.color = color
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain.
|
||||
|
||||
@@ -35,7 +35,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
|
||||
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") # noqa: T201
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
|
||||
@@ -17,7 +17,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler for streaming. Only works with LLMs that support streaming."""
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running.
|
||||
|
||||
@@ -29,8 +29,8 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running.
|
||||
@@ -68,7 +68,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when a chain starts running.
|
||||
|
||||
@@ -78,7 +78,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when a chain ends running.
|
||||
|
||||
Args:
|
||||
@@ -95,7 +95,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
self, serialized: dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when the tool starts running.
|
||||
|
||||
|
||||
@@ -18,7 +18,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -87,7 +88,7 @@ class BaseChatMessageHistory(ABC):
|
||||
f.write("[]")
|
||||
"""
|
||||
|
||||
messages: List[BaseMessage]
|
||||
messages: list[BaseMessage]
|
||||
"""A property or attribute that returns a list of messages.
|
||||
|
||||
In general, getting the messages may involve IO to the underlying
|
||||
@@ -95,7 +96,7 @@ class BaseChatMessageHistory(ABC):
|
||||
latency.
|
||||
"""
|
||||
|
||||
async def aget_messages(self) -> List[BaseMessage]:
|
||||
async def aget_messages(self) -> list[BaseMessage]:
|
||||
"""Async version of getting messages.
|
||||
|
||||
Can over-ride this method to provide an efficient async implementation.
|
||||
@@ -204,10 +205,10 @@ class InMemoryChatMessageHistory(BaseChatMessageHistory, BaseModel):
|
||||
Stores messages in a memory list.
|
||||
"""
|
||||
|
||||
messages: List[BaseMessage] = Field(default_factory=list)
|
||||
messages: list[BaseMessage] = Field(default_factory=list)
|
||||
"""A list of messages stored in memory."""
|
||||
|
||||
async def aget_messages(self) -> List[BaseMessage]:
|
||||
async def aget_messages(self) -> list[BaseMessage]:
|
||||
"""Async version of getting messages.
|
||||
|
||||
Can over-ride this method to provide an efficient async implementation.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterator, List
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langchain_core.chat_sessions import ChatSession
|
||||
|
||||
@@ -15,7 +15,7 @@ class BaseChatLoader(ABC):
|
||||
An iterator of chat sessions.
|
||||
"""
|
||||
|
||||
def load(self) -> List[ChatSession]:
|
||||
def load(self) -> list[ChatSession]:
|
||||
"""Eagerly load the chat sessions into memory.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""**Chat Sessions** are a collection of messages and function calls."""
|
||||
|
||||
from typing import Sequence, TypedDict
|
||||
from collections.abc import Sequence
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Optional
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.runnables import run_in_executor
|
||||
@@ -25,17 +26,17 @@ class BaseLoader(ABC): # noqa: B024
|
||||
|
||||
# Sub-classes should not implement this method directly. Instead, they
|
||||
# should implement the lazy load method.
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
"""Load data into Document objects."""
|
||||
return list(self.lazy_load())
|
||||
|
||||
async def aload(self) -> List[Document]:
|
||||
async def aload(self) -> list[Document]:
|
||||
"""Load data into Document objects."""
|
||||
return [document async for document in self.alazy_load()]
|
||||
|
||||
def load_and_split(
|
||||
self, text_splitter: Optional[TextSplitter] = None
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Load Documents and split into chunks. Chunks are returned as Documents.
|
||||
|
||||
Do not override this method. It should be considered to be deprecated!
|
||||
@@ -108,7 +109,7 @@ class BaseBlobParser(ABC):
|
||||
Generator of documents
|
||||
"""
|
||||
|
||||
def parse(self, blob: Blob) -> List[Document]:
|
||||
def parse(self, blob: Blob) -> list[Document]:
|
||||
"""Eagerly parse the blob into a document or documents.
|
||||
|
||||
This is a convenience method for interactive development environment.
|
||||
|
||||
@@ -8,7 +8,7 @@ In addition, content loading code should provide a lazy loading interface by def
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable
|
||||
from collections.abc import Iterable
|
||||
|
||||
# Re-export Blob and PathLike for backwards compatibility
|
||||
from langchain_core.documents.base import Blob as Blob
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Callable, Iterator, Optional, Sequence, Union
|
||||
from collections.abc import Iterator, Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from langsmith import Client as LangSmithClient
|
||||
|
||||
|
||||
@@ -2,9 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from io import BufferedReader, BytesIO
|
||||
from pathlib import PurePath
|
||||
from typing import Any, Dict, Generator, List, Literal, Optional, Union, cast
|
||||
from typing import Any, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
@@ -138,7 +139,7 @@ class Blob(BaseMedia):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_blob_is_valid(cls, values: Dict[str, Any]) -> Any:
|
||||
def check_blob_is_valid(cls, values: dict[str, Any]) -> Any:
|
||||
"""Verify that either data or path is provided."""
|
||||
if "data" not in values and "path" not in values:
|
||||
raise ValueError("Either data or path must be provided")
|
||||
@@ -285,7 +286,7 @@ class Document(BaseMedia):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "document"]
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""**Embeddings** interface."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
@@ -35,7 +34,7 @@ class Embeddings(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs.
|
||||
|
||||
Args:
|
||||
@@ -46,7 +45,7 @@ class Embeddings(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text.
|
||||
|
||||
Args:
|
||||
@@ -56,7 +55,7 @@ class Embeddings(ABC):
|
||||
Embedding.
|
||||
"""
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Asynchronous Embed search docs.
|
||||
|
||||
Args:
|
||||
@@ -67,7 +66,7 @@ class Embeddings(ABC):
|
||||
"""
|
||||
return await run_in_executor(None, self.embed_documents, texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Asynchronous Embed query text.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
# Please do not add additional fake embedding model implementations here.
|
||||
import hashlib
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -51,15 +50,15 @@ class FakeEmbeddings(Embeddings, BaseModel):
|
||||
size: int
|
||||
"""The size of the embedding vector."""
|
||||
|
||||
def _get_embedding(self) -> List[float]:
|
||||
def _get_embedding(self) -> list[float]:
|
||||
import numpy as np # type: ignore[import-not-found, import-untyped]
|
||||
|
||||
return list(np.random.normal(size=self.size))
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return [self._get_embedding() for _ in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
return self._get_embedding()
|
||||
|
||||
|
||||
@@ -106,7 +105,7 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
|
||||
size: int
|
||||
"""The size of the embedding vector."""
|
||||
|
||||
def _get_embedding(self, seed: int) -> List[float]:
|
||||
def _get_embedding(self, seed: int) -> list[float]:
|
||||
import numpy as np # type: ignore[import-not-found, import-untyped]
|
||||
|
||||
# set the seed for the random generator
|
||||
@@ -117,8 +116,8 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
|
||||
"""Get a seed for the random generator, using the hash of the text."""
|
||||
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return [self._get_embedding(seed=self._get_seed(_)) for _ in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
return self._get_embedding(seed=self._get_seed(text))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Interface for selecting examples to include in prompts."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
@@ -10,14 +10,14 @@ class BaseExampleSelector(ABC):
|
||||
"""Interface for selecting examples to include in prompts."""
|
||||
|
||||
@abstractmethod
|
||||
def add_example(self, example: Dict[str, str]) -> Any:
|
||||
def add_example(self, example: dict[str, str]) -> Any:
|
||||
"""Add new example to store.
|
||||
|
||||
Args:
|
||||
example: A dictionary with keys as input variables
|
||||
and values as their values."""
|
||||
|
||||
async def aadd_example(self, example: Dict[str, str]) -> Any:
|
||||
async def aadd_example(self, example: dict[str, str]) -> Any:
|
||||
"""Async add new example to store.
|
||||
|
||||
Args:
|
||||
@@ -27,14 +27,14 @@ class BaseExampleSelector(ABC):
|
||||
return await run_in_executor(None, self.add_example, example)
|
||||
|
||||
@abstractmethod
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Select which examples to use based on the inputs.
|
||||
|
||||
Args:
|
||||
input_variables: A dictionary with keys as input variables
|
||||
and values as their values."""
|
||||
|
||||
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Async select which examples to use based on the inputs.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Select examples based on length."""
|
||||
|
||||
import re
|
||||
from typing import Callable, Dict, List
|
||||
from typing import Callable
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
@@ -17,7 +17,7 @@ def _get_length_based(text: str) -> int:
|
||||
class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
"""Select examples based on length."""
|
||||
|
||||
examples: List[dict]
|
||||
examples: list[dict]
|
||||
"""A list of the examples that the prompt template expects."""
|
||||
|
||||
example_prompt: PromptTemplate
|
||||
@@ -29,10 +29,10 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
max_length: int = 2048
|
||||
"""Max length for the prompt, beyond which examples are cut."""
|
||||
|
||||
example_text_lengths: List[int] = Field(default_factory=list) # :meta private:
|
||||
example_text_lengths: list[int] = Field(default_factory=list) # :meta private:
|
||||
"""Length of each example."""
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> None:
|
||||
def add_example(self, example: dict[str, str]) -> None:
|
||||
"""Add new example to list.
|
||||
|
||||
Args:
|
||||
@@ -43,7 +43,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
string_example = self.example_prompt.format(**example)
|
||||
self.example_text_lengths.append(self.get_text_length(string_example))
|
||||
|
||||
async def aadd_example(self, example: Dict[str, str]) -> None:
|
||||
async def aadd_example(self, example: dict[str, str]) -> None:
|
||||
"""Async add new example to list.
|
||||
|
||||
Args:
|
||||
@@ -62,7 +62,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
self.example_text_lengths = [self.get_text_length(eg) for eg in string_examples]
|
||||
return self
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Select which examples to use based on the input lengths.
|
||||
|
||||
Args:
|
||||
@@ -86,7 +86,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
i += 1
|
||||
return examples
|
||||
|
||||
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Async select which examples to use based on the input lengths.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
|
||||
def sorted_values(values: Dict[str, str]) -> List[Any]:
|
||||
def sorted_values(values: dict[str, str]) -> list[Any]:
|
||||
"""Return a list of values in dict sorted by key.
|
||||
|
||||
Args:
|
||||
@@ -35,12 +35,12 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
"""VectorStore that contains information about examples."""
|
||||
k: int = 4
|
||||
"""Number of examples to select."""
|
||||
example_keys: Optional[List[str]] = None
|
||||
example_keys: Optional[list[str]] = None
|
||||
"""Optional keys to filter examples to."""
|
||||
input_keys: Optional[List[str]] = None
|
||||
input_keys: Optional[list[str]] = None
|
||||
"""Optional keys to filter input to. If provided, the search is based on
|
||||
the input variables instead of all variables."""
|
||||
vectorstore_kwargs: Optional[Dict[str, Any]] = None
|
||||
vectorstore_kwargs: Optional[dict[str, Any]] = None
|
||||
"""Extra arguments passed to similarity_search function of the vectorstore."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -50,14 +50,14 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
|
||||
@staticmethod
|
||||
def _example_to_text(
|
||||
example: Dict[str, str], input_keys: Optional[List[str]]
|
||||
example: dict[str, str], input_keys: Optional[list[str]]
|
||||
) -> str:
|
||||
if input_keys:
|
||||
return " ".join(sorted_values({key: example[key] for key in input_keys}))
|
||||
else:
|
||||
return " ".join(sorted_values(example))
|
||||
|
||||
def _documents_to_examples(self, documents: List[Document]) -> List[dict]:
|
||||
def _documents_to_examples(self, documents: list[Document]) -> list[dict]:
|
||||
# Get the examples from the metadata.
|
||||
# This assumes that examples are stored in metadata.
|
||||
examples = [dict(e.metadata) for e in documents]
|
||||
@@ -66,7 +66,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
examples = [{k: eg[k] for k in self.example_keys} for eg in examples]
|
||||
return examples
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> str:
|
||||
def add_example(self, example: dict[str, str]) -> str:
|
||||
"""Add a new example to vectorstore.
|
||||
|
||||
Args:
|
||||
@@ -81,7 +81,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
)
|
||||
return ids[0]
|
||||
|
||||
async def aadd_example(self, example: Dict[str, str]) -> str:
|
||||
async def aadd_example(self, example: dict[str, str]) -> str:
|
||||
"""Async add new example to vectorstore.
|
||||
|
||||
Args:
|
||||
@@ -100,7 +100,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
|
||||
"""Select examples based on semantic similarity."""
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Select examples based on semantic similarity.
|
||||
|
||||
Args:
|
||||
@@ -118,7 +118,7 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
|
||||
)
|
||||
return self._documents_to_examples(example_docs)
|
||||
|
||||
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Asynchronously select examples based on semantic similarity.
|
||||
|
||||
Args:
|
||||
@@ -139,13 +139,13 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
|
||||
@classmethod
|
||||
def from_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
examples: list[dict],
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
vectorstore_cls: type[VectorStore],
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
input_keys: Optional[list[str]] = None,
|
||||
*,
|
||||
example_keys: Optional[List[str]] = None,
|
||||
example_keys: Optional[list[str]] = None,
|
||||
vectorstore_kwargs: Optional[dict] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> SemanticSimilarityExampleSelector:
|
||||
@@ -183,13 +183,13 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
|
||||
@classmethod
|
||||
async def afrom_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
examples: list[dict],
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
vectorstore_cls: type[VectorStore],
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
input_keys: Optional[list[str]] = None,
|
||||
*,
|
||||
example_keys: Optional[List[str]] = None,
|
||||
example_keys: Optional[list[str]] = None,
|
||||
vectorstore_kwargs: Optional[dict] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> SemanticSimilarityExampleSelector:
|
||||
@@ -235,7 +235,7 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
|
||||
fetch_k: int = 20
|
||||
"""Number of examples to fetch to rerank."""
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Select examples based on Max Marginal Relevance.
|
||||
|
||||
Args:
|
||||
@@ -251,7 +251,7 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
|
||||
)
|
||||
return self._documents_to_examples(example_docs)
|
||||
|
||||
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Asynchronously select examples based on Max Marginal Relevance.
|
||||
|
||||
Args:
|
||||
@@ -270,13 +270,13 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
|
||||
@classmethod
|
||||
def from_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
examples: list[dict],
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
vectorstore_cls: type[VectorStore],
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
input_keys: Optional[list[str]] = None,
|
||||
fetch_k: int = 20,
|
||||
example_keys: Optional[List[str]] = None,
|
||||
example_keys: Optional[list[str]] = None,
|
||||
vectorstore_kwargs: Optional[dict] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> MaxMarginalRelevanceExampleSelector:
|
||||
@@ -317,14 +317,14 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
|
||||
@classmethod
|
||||
async def afrom_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
examples: list[dict],
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
vectorstore_cls: type[VectorStore],
|
||||
*,
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
input_keys: Optional[list[str]] = None,
|
||||
fetch_k: int = 20,
|
||||
example_keys: Optional[List[str]] = None,
|
||||
example_keys: Optional[list[str]] = None,
|
||||
vectorstore_kwargs: Optional[dict] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> MaxMarginalRelevanceExampleSelector:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class LangChainException(Exception):
|
||||
class LangChainException(Exception): # noqa: N818
|
||||
"""General LangChain exception."""
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ class TracerException(LangChainException):
|
||||
"""Base class for exceptions in tracers module."""
|
||||
|
||||
|
||||
class OutputParserException(ValueError, LangChainException):
|
||||
class OutputParserException(ValueError, LangChainException): # noqa: N818
|
||||
"""Exception that output parsers should raise to signify a parsing error.
|
||||
|
||||
This exists to differentiate parsing errors from other code or execution errors
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
from langchain_core.graph_vectorstores.base import (
|
||||
GraphVectorStore,
|
||||
GraphVectorStoreRetriever,
|
||||
Node,
|
||||
)
|
||||
from langchain_core.graph_vectorstores.links import (
|
||||
Link,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GraphVectorStore",
|
||||
"GraphVectorStoreRetriever",
|
||||
"Node",
|
||||
"Link",
|
||||
]
|
||||
@@ -1,712 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
ClassVar,
|
||||
Collection,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores.links import METADATA_LINKS_KEY, Link
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.runnables import run_in_executor
|
||||
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
|
||||
|
||||
|
||||
def _has_next(iterator: Iterator) -> bool:
|
||||
"""Checks if the iterator has more elements.
|
||||
Warning: consumes an element from the iterator"""
|
||||
sentinel = object()
|
||||
return next(iterator, sentinel) is not sentinel
|
||||
|
||||
|
||||
@beta()
|
||||
class Node(Serializable):
|
||||
"""Node in the GraphVectorStore.
|
||||
|
||||
Edges exist from nodes with an outgoing link to nodes with a matching incoming link.
|
||||
|
||||
For instance two nodes `a` and `b` connected over a hyperlink ``https://some-url``
|
||||
would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[
|
||||
Node(
|
||||
id="a",
|
||||
text="some text a",
|
||||
links= [
|
||||
Link(kind="hyperlink", tag="https://some-url", direction="incoming")
|
||||
],
|
||||
),
|
||||
Node(
|
||||
id="b",
|
||||
text="some text b",
|
||||
links= [
|
||||
Link(kind="hyperlink", tag="https://some-url", direction="outgoing")
|
||||
],
|
||||
)
|
||||
]
|
||||
"""
|
||||
|
||||
id: Optional[str] = None
|
||||
"""Unique ID for the node. Will be generated by the GraphVectorStore if not set."""
|
||||
text: str
|
||||
"""Text contained by the node."""
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
"""Metadata for the node."""
|
||||
links: List[Link] = Field(default_factory=list)
|
||||
"""Links associated with the node."""
|
||||
|
||||
|
||||
def _texts_to_nodes(
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]],
|
||||
ids: Optional[Iterable[str]],
|
||||
) -> Iterator[Node]:
|
||||
metadatas_it = iter(metadatas) if metadatas else None
|
||||
ids_it = iter(ids) if ids else None
|
||||
for text in texts:
|
||||
try:
|
||||
_metadata = next(metadatas_it).copy() if metadatas_it else {}
|
||||
except StopIteration as e:
|
||||
raise ValueError("texts iterable longer than metadatas") from e
|
||||
try:
|
||||
_id = next(ids_it) if ids_it else None
|
||||
except StopIteration as e:
|
||||
raise ValueError("texts iterable longer than ids") from e
|
||||
|
||||
links = _metadata.pop(METADATA_LINKS_KEY, [])
|
||||
if not isinstance(links, list):
|
||||
links = list(links)
|
||||
yield Node(
|
||||
id=_id,
|
||||
metadata=_metadata,
|
||||
text=text,
|
||||
links=links,
|
||||
)
|
||||
if ids_it and _has_next(ids_it):
|
||||
raise ValueError("ids iterable longer than texts")
|
||||
if metadatas_it and _has_next(metadatas_it):
|
||||
raise ValueError("metadatas iterable longer than texts")
|
||||
|
||||
|
||||
def _documents_to_nodes(documents: Iterable[Document]) -> Iterator[Node]:
|
||||
for doc in documents:
|
||||
metadata = doc.metadata.copy()
|
||||
links = metadata.pop(METADATA_LINKS_KEY, [])
|
||||
if not isinstance(links, list):
|
||||
links = list(links)
|
||||
yield Node(
|
||||
id=doc.id,
|
||||
metadata=metadata,
|
||||
text=doc.page_content,
|
||||
links=links,
|
||||
)
|
||||
|
||||
|
||||
@beta()
|
||||
def nodes_to_documents(nodes: Iterable[Node]) -> Iterator[Document]:
|
||||
"""Convert nodes to documents.
|
||||
|
||||
Args:
|
||||
nodes: The nodes to convert to documents.
|
||||
Returns:
|
||||
The documents generated from the nodes.
|
||||
"""
|
||||
for node in nodes:
|
||||
metadata = node.metadata.copy()
|
||||
metadata[METADATA_LINKS_KEY] = [
|
||||
# Convert the core `Link` (from the node) back to the local `Link`.
|
||||
Link(kind=link.kind, direction=link.direction, tag=link.tag)
|
||||
for link in node.links
|
||||
]
|
||||
|
||||
yield Document(
|
||||
id=node.id,
|
||||
page_content=node.text,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
@beta(message="Added in version 0.2.14 of langchain_core. API subject to change.")
|
||||
class GraphVectorStore(VectorStore):
|
||||
"""A hybrid vector-and-graph graph store.
|
||||
|
||||
Document chunks support vector-similarity search as well as edges linking
|
||||
chunks based on structural and semantic properties.
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_nodes(
|
||||
self,
|
||||
nodes: Iterable[Node],
|
||||
**kwargs: Any,
|
||||
) -> Iterable[str]:
|
||||
"""Add nodes to the graph store.
|
||||
|
||||
Args:
|
||||
nodes: the nodes to add.
|
||||
"""
|
||||
|
||||
async def aadd_nodes(
|
||||
self,
|
||||
nodes: Iterable[Node],
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[str]:
|
||||
"""Add nodes to the graph store.
|
||||
|
||||
Args:
|
||||
nodes: the nodes to add.
|
||||
"""
|
||||
iterator = iter(await run_in_executor(None, self.add_nodes, nodes, **kwargs))
|
||||
done = object()
|
||||
while True:
|
||||
doc = await run_in_executor(None, next, iterator, done)
|
||||
if doc is done:
|
||||
break
|
||||
yield doc # type: ignore[misc]
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
*,
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the metadata field `links` will be extracted to create
|
||||
the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
store.add_texts(
|
||||
ids=["a", "b"],
|
||||
texts=["some text a", "some text b"],
|
||||
metadatas=[
|
||||
{
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
{
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
The metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_core.graph_vectorstores.links.Link`.
|
||||
**kwargs: vectorstore specific parameters.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
nodes = _texts_to_nodes(texts, metadatas, ids)
|
||||
return list(self.add_nodes(nodes, **kwargs))
|
||||
|
||||
async def aadd_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
*,
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the metadata field `links` will be extracted to create
|
||||
the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
await store.aadd_texts(
|
||||
ids=["a", "b"],
|
||||
texts=["some text a", "some text b"],
|
||||
metadatas=[
|
||||
{
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
{
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
The metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_core.graph_vectorstores.links.Link`.
|
||||
**kwargs: vectorstore specific parameters.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
nodes = _texts_to_nodes(texts, metadatas, ids)
|
||||
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
documents: Iterable[Document],
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the document metadata field `links` will be extracted to
|
||||
create the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
store.add_documents(
|
||||
[
|
||||
Document(
|
||||
id="a",
|
||||
page_content="some text a",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
Document(
|
||||
id="b",
|
||||
page_content="some text b",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
)
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
The document's metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_core.graph_vectorstores.links.Link`.
|
||||
|
||||
Returns:
|
||||
List of IDs of the added texts.
|
||||
"""
|
||||
nodes = _documents_to_nodes(documents)
|
||||
return list(self.add_nodes(nodes, **kwargs))
|
||||
|
||||
async def aadd_documents(
|
||||
self,
|
||||
documents: Iterable[Document],
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the document metadata field `links` will be extracted to
|
||||
create the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
store.add_documents(
|
||||
[
|
||||
Document(
|
||||
id="a",
|
||||
page_content="some text a",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
Document(
|
||||
id="b",
|
||||
page_content="some text b",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
)
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
The document's metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_core.graph_vectorstores.links.Link`.
|
||||
|
||||
Returns:
|
||||
List of IDs of the added texts.
|
||||
"""
|
||||
nodes = _documents_to_nodes(documents)
|
||||
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
|
||||
|
||||
@abstractmethod
|
||||
def traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Document]:
|
||||
"""Retrieve documents from traversing this graph store.
|
||||
|
||||
First, `k` nodes are retrieved using a search for each `query` string.
|
||||
Then, additional nodes are discovered up to the given `depth` from those
|
||||
starting nodes.
|
||||
|
||||
Args:
|
||||
query: The query string.
|
||||
k: The number of Documents to return from the initial search.
|
||||
Defaults to 4. Applies to each of the query strings.
|
||||
depth: The maximum depth of edges to traverse. Defaults to 1.
|
||||
Returns:
|
||||
Retrieved documents.
|
||||
"""
|
||||
|
||||
async def atraversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[Document]:
|
||||
"""Retrieve documents from traversing this graph store.
|
||||
|
||||
First, `k` nodes are retrieved using a search for each `query` string.
|
||||
Then, additional nodes are discovered up to the given `depth` from those
|
||||
starting nodes.
|
||||
|
||||
Args:
|
||||
query: The query string.
|
||||
k: The number of Documents to return from the initial search.
|
||||
Defaults to 4. Applies to each of the query strings.
|
||||
depth: The maximum depth of edges to traverse. Defaults to 1.
|
||||
Returns:
|
||||
Retrieved documents.
|
||||
"""
|
||||
iterator = iter(
|
||||
await run_in_executor(
|
||||
None, self.traversal_search, query, k=k, depth=depth, **kwargs
|
||||
)
|
||||
)
|
||||
done = object()
|
||||
while True:
|
||||
doc = await run_in_executor(None, next, iterator, done)
|
||||
if doc is done:
|
||||
break
|
||||
yield doc # type: ignore[misc]
|
||||
|
||||
@abstractmethod
|
||||
def mmr_traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 2,
|
||||
fetch_k: int = 100,
|
||||
adjacent_k: int = 10,
|
||||
lambda_mult: float = 0.5,
|
||||
score_threshold: float = float("-inf"),
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Document]:
|
||||
"""Retrieve documents from this graph store using MMR-traversal.
|
||||
|
||||
This strategy first retrieves the top `fetch_k` results by similarity to
|
||||
the question. It then selects the top `k` results based on
|
||||
maximum-marginal relevance using the given `lambda_mult`.
|
||||
|
||||
At each step, it considers the (remaining) documents from `fetch_k` as
|
||||
well as any documents connected by edges to a selected document
|
||||
retrieved based on similarity (a "root").
|
||||
|
||||
Args:
|
||||
query: The query string to search for.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch via similarity.
|
||||
Defaults to 100.
|
||||
adjacent_k: Number of adjacent Documents to fetch.
|
||||
Defaults to 10.
|
||||
depth: Maximum depth of a node (number of edges) from a node
|
||||
retrieved via similarity. Defaults to 2.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding to maximum
|
||||
diversity and 1 to minimum diversity. Defaults to 0.5.
|
||||
score_threshold: Only documents with a score greater than or equal
|
||||
this threshold will be chosen. Defaults to negative infinity.
|
||||
"""
|
||||
|
||||
async def ammr_traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 2,
|
||||
fetch_k: int = 100,
|
||||
adjacent_k: int = 10,
|
||||
lambda_mult: float = 0.5,
|
||||
score_threshold: float = float("-inf"),
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[Document]:
|
||||
"""Retrieve documents from this graph store using MMR-traversal.
|
||||
|
||||
This strategy first retrieves the top `fetch_k` results by similarity to
|
||||
the question. It then selects the top `k` results based on
|
||||
maximum-marginal relevance using the given `lambda_mult`.
|
||||
|
||||
At each step, it considers the (remaining) documents from `fetch_k` as
|
||||
well as any documents connected by edges to a selected document
|
||||
retrieved based on similarity (a "root").
|
||||
|
||||
Args:
|
||||
query: The query string to search for.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch via similarity.
|
||||
Defaults to 100.
|
||||
adjacent_k: Number of adjacent Documents to fetch.
|
||||
Defaults to 10.
|
||||
depth: Maximum depth of a node (number of edges) from a node
|
||||
retrieved via similarity. Defaults to 2.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding to maximum
|
||||
diversity and 1 to minimum diversity. Defaults to 0.5.
|
||||
score_threshold: Only documents with a score greater than or equal
|
||||
this threshold will be chosen. Defaults to negative infinity.
|
||||
"""
|
||||
iterator = iter(
|
||||
await run_in_executor(
|
||||
None,
|
||||
self.mmr_traversal_search,
|
||||
query,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
adjacent_k=adjacent_k,
|
||||
depth=depth,
|
||||
lambda_mult=lambda_mult,
|
||||
score_threshold=score_threshold,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
done = object()
|
||||
while True:
|
||||
doc = await run_in_executor(None, next, iterator, done)
|
||||
if doc is done:
|
||||
break
|
||||
yield doc # type: ignore[misc]
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
return list(self.traversal_search(query, k=k, depth=0))
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return list(
|
||||
self.mmr_traversal_search(
|
||||
query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0
|
||||
)
|
||||
)
|
||||
|
||||
async def asimilarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
return [doc async for doc in self.atraversal_search(query, k=k, depth=0)]
|
||||
|
||||
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
|
||||
if search_type == "similarity":
|
||||
return self.similarity_search(query, **kwargs)
|
||||
elif search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = self.similarity_search_with_relevance_scores(
|
||||
query, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_similarities]
|
||||
elif search_type == "mmr":
|
||||
return self.max_marginal_relevance_search(query, **kwargs)
|
||||
elif search_type == "traversal":
|
||||
return list(self.traversal_search(query, **kwargs))
|
||||
elif search_type == "mmr_traversal":
|
||||
return list(self.mmr_traversal_search(query, **kwargs))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity', 'similarity_score_threshold', "
|
||||
"'mmr' or 'traversal'."
|
||||
)
|
||||
|
||||
async def asearch(
|
||||
self, query: str, search_type: str, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
if search_type == "similarity":
|
||||
return await self.asimilarity_search(query, **kwargs)
|
||||
elif search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = await self.asimilarity_search_with_relevance_scores(
|
||||
query, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_similarities]
|
||||
elif search_type == "mmr":
|
||||
return await self.amax_marginal_relevance_search(query, **kwargs)
|
||||
elif search_type == "traversal":
|
||||
return [doc async for doc in self.atraversal_search(query, **kwargs)]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity', 'similarity_score_threshold', "
|
||||
"'mmr' or 'traversal'."
|
||||
)
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever:
|
||||
"""Return GraphVectorStoreRetriever initialized from this GraphVectorStore.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to pass to the search function.
|
||||
Can include:
|
||||
|
||||
- search_type (Optional[str]): Defines the type of search that
|
||||
the Retriever should perform.
|
||||
Can be ``traversal`` (default), ``similarity``, ``mmr``, or
|
||||
``similarity_score_threshold``.
|
||||
- search_kwargs (Optional[Dict]): Keyword arguments to pass to the
|
||||
search function. Can include things like:
|
||||
|
||||
- k(int): Amount of documents to return (Default: 4).
|
||||
- depth(int): The maximum depth of edges to traverse (Default: 1).
|
||||
- score_threshold(float): Minimum relevance threshold
|
||||
for similarity_score_threshold.
|
||||
- fetch_k(int): Amount of documents to pass to MMR algorithm
|
||||
(Default: 20).
|
||||
- lambda_mult(float): Diversity of results returned by MMR;
|
||||
1 for minimum diversity and 0 for maximum. (Default: 0.5).
|
||||
Returns:
|
||||
Retriever for this GraphVectorStore.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Retrieve documents traversing edges
|
||||
docsearch.as_retriever(
|
||||
search_type="traversal",
|
||||
search_kwargs={'k': 6, 'depth': 3}
|
||||
)
|
||||
|
||||
# Retrieve more documents with higher diversity
|
||||
# Useful if your dataset has many similar documents
|
||||
docsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={'k': 6, 'lambda_mult': 0.25}
|
||||
)
|
||||
|
||||
# Fetch more documents for the MMR algorithm to consider
|
||||
# But only return the top 5
|
||||
docsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={'k': 5, 'fetch_k': 50}
|
||||
)
|
||||
|
||||
# Only retrieve documents that have a relevance score
|
||||
# Above a certain threshold
|
||||
docsearch.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={'score_threshold': 0.8}
|
||||
)
|
||||
|
||||
# Only get the single most similar document from the dataset
|
||||
docsearch.as_retriever(search_kwargs={'k': 1})
|
||||
|
||||
"""
|
||||
return GraphVectorStoreRetriever(vectorstore=self, **kwargs)
|
||||
|
||||
|
||||
class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
"""Retriever class for GraphVectorStore."""
|
||||
|
||||
vectorstore: GraphVectorStore
|
||||
"""GraphVectorStore to use for retrieval."""
|
||||
search_type: str = "traversal"
|
||||
"""Type of search to perform. Defaults to "traversal"."""
|
||||
allowed_search_types: ClassVar[Collection[str]] = (
|
||||
"similarity",
|
||||
"similarity_score_threshold",
|
||||
"mmr",
|
||||
"traversal",
|
||||
"mmr_traversal",
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
if self.search_type == "traversal":
|
||||
return list(self.vectorstore.traversal_search(query, **self.search_kwargs))
|
||||
elif self.search_type == "mmr_traversal":
|
||||
return list(
|
||||
self.vectorstore.mmr_traversal_search(query, **self.search_kwargs)
|
||||
)
|
||||
else:
|
||||
return super()._get_relevant_documents(query, run_manager=run_manager)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
if self.search_type == "traversal":
|
||||
return [
|
||||
doc
|
||||
async for doc in self.vectorstore.atraversal_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
]
|
||||
elif self.search_type == "mmr_traversal":
|
||||
return [
|
||||
doc
|
||||
async for doc in self.vectorstore.ammr_traversal_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
]
|
||||
else:
|
||||
return await super()._aget_relevant_documents(
|
||||
query, run_manager=run_manager
|
||||
)
|
||||
@@ -1,101 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Literal, Union
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
@beta()
|
||||
@dataclass(frozen=True)
|
||||
class Link:
|
||||
"""A link to/from a tag of a given tag.
|
||||
|
||||
Edges exist from nodes with an outgoing link to nodes with a matching incoming link.
|
||||
"""
|
||||
|
||||
kind: str
|
||||
"""The kind of link. Allows different extractors to use the same tag name without
|
||||
creating collisions between extractors. For example “keyword” vs “url”."""
|
||||
direction: Literal["in", "out", "bidir"]
|
||||
"""The direction of the link."""
|
||||
tag: str
|
||||
"""The tag of the link."""
|
||||
|
||||
@staticmethod
|
||||
def incoming(kind: str, tag: str) -> "Link":
|
||||
"""Create an incoming link."""
|
||||
return Link(kind=kind, direction="in", tag=tag)
|
||||
|
||||
@staticmethod
|
||||
def outgoing(kind: str, tag: str) -> "Link":
|
||||
"""Create an outgoing link."""
|
||||
return Link(kind=kind, direction="out", tag=tag)
|
||||
|
||||
@staticmethod
|
||||
def bidir(kind: str, tag: str) -> "Link":
|
||||
"""Create a bidirectional link."""
|
||||
return Link(kind=kind, direction="bidir", tag=tag)
|
||||
|
||||
|
||||
METADATA_LINKS_KEY = "links"
|
||||
|
||||
|
||||
@beta()
|
||||
def get_links(doc: Document) -> List[Link]:
|
||||
"""Get the links from a document.
|
||||
|
||||
Args:
|
||||
doc: The document to get the link tags from.
|
||||
Returns:
|
||||
The set of link tags from the document.
|
||||
"""
|
||||
|
||||
links = doc.metadata.setdefault(METADATA_LINKS_KEY, [])
|
||||
if not isinstance(links, list):
|
||||
# Convert to a list and remember that.
|
||||
links = list(links)
|
||||
doc.metadata[METADATA_LINKS_KEY] = links
|
||||
return links
|
||||
|
||||
|
||||
@beta()
|
||||
def add_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> None:
|
||||
"""Add links to the given metadata.
|
||||
|
||||
Args:
|
||||
doc: The document to add the links to.
|
||||
*links: The links to add to the document.
|
||||
"""
|
||||
links_in_metadata = get_links(doc)
|
||||
for link in links:
|
||||
if isinstance(link, Iterable):
|
||||
links_in_metadata.extend(link)
|
||||
else:
|
||||
links_in_metadata.append(link)
|
||||
|
||||
|
||||
@beta()
|
||||
def copy_with_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> Document:
|
||||
"""Return a document with the given links added.
|
||||
|
||||
Args:
|
||||
doc: The document to add the links to.
|
||||
*links: The links to add to the document.
|
||||
|
||||
Returns:
|
||||
A document with a shallow-copy of the metadata with the links added.
|
||||
"""
|
||||
new_links = set(get_links(doc))
|
||||
for link in links:
|
||||
if isinstance(link, Iterable):
|
||||
new_links.update(link)
|
||||
else:
|
||||
new_links.add(link)
|
||||
|
||||
return Document(
|
||||
page_content=doc.page_content,
|
||||
metadata={
|
||||
**doc.metadata,
|
||||
METADATA_LINKS_KEY: list(new_links),
|
||||
},
|
||||
)
|
||||
@@ -5,20 +5,13 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator, Sequence
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
@@ -71,7 +64,7 @@ class _HashedDocument(Document):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def calculate_hashes(cls, values: Dict[str, Any]) -> Any:
|
||||
def calculate_hashes(cls, values: dict[str, Any]) -> Any:
|
||||
"""Root validator to calculate content and metadata hash."""
|
||||
content = values.get("page_content", "")
|
||||
metadata = values.get("metadata", {})
|
||||
@@ -125,7 +118,7 @@ class _HashedDocument(Document):
|
||||
)
|
||||
|
||||
|
||||
def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
|
||||
def _batch(size: int, iterable: Iterable[T]) -> Iterator[list[T]]:
|
||||
"""Utility batching function."""
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
@@ -135,9 +128,9 @@ def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
|
||||
yield chunk
|
||||
|
||||
|
||||
async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[List[T]]:
|
||||
async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[list[T]]:
|
||||
"""Utility batching function."""
|
||||
batch: List[T] = []
|
||||
batch: list[T] = []
|
||||
async for element in iterable:
|
||||
if len(batch) < size:
|
||||
batch.append(element)
|
||||
@@ -171,7 +164,7 @@ def _deduplicate_in_order(
|
||||
hashed_documents: Iterable[_HashedDocument],
|
||||
) -> Iterator[_HashedDocument]:
|
||||
"""Deduplicate a list of hashed documents while preserving order."""
|
||||
seen: Set[str] = set()
|
||||
seen: set[str] = set()
|
||||
|
||||
for hashed_doc in hashed_documents:
|
||||
if hashed_doc.hash_ not in seen:
|
||||
@@ -349,7 +342,7 @@ def index(
|
||||
uids = []
|
||||
docs_to_index = []
|
||||
uids_to_refresh = []
|
||||
seen_docs: Set[str] = set()
|
||||
seen_docs: set[str] = set()
|
||||
for hashed_doc, doc_exists in zip(hashed_docs, exists_batch):
|
||||
if doc_exists:
|
||||
if force_update:
|
||||
@@ -589,7 +582,7 @@ async def aindex(
|
||||
uids: list[str] = []
|
||||
docs_to_index: list[Document] = []
|
||||
uids_to_refresh = []
|
||||
seen_docs: Set[str] = set()
|
||||
seen_docs: set[str] = set()
|
||||
for hashed_doc, doc_exists in zip(hashed_docs, exists_batch):
|
||||
if doc_exists:
|
||||
if force_update:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user