mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-23 21:31:02 +00:00
Compare commits
26 Commits
wfh/bind_t
...
erick/nbco
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7bc4a5964c | ||
|
|
65091ebe50 | ||
|
|
4855964332 | ||
|
|
e3132a7efc | ||
|
|
93c7eb4e6b | ||
|
|
7f42811e14 | ||
|
|
6bbf0797f7 | ||
|
|
c7b5dbe8ec | ||
|
|
480821da59 | ||
|
|
b802dd96f2 | ||
|
|
9d4100f915 | ||
|
|
b9975fac89 | ||
|
|
9fb26a2a71 | ||
|
|
1cec0afc62 | ||
|
|
ba897fc04c | ||
|
|
74211aa02e | ||
|
|
c5c64aa863 | ||
|
|
a86065c536 | ||
|
|
ff206ae30d | ||
|
|
852b9ca494 | ||
|
|
79ae6c2a9e | ||
|
|
bc3ec78a38 | ||
|
|
451c5d1d8c | ||
|
|
1e21a3f7ed | ||
|
|
3449fce273 | ||
|
|
7234335a9a |
57
.github/workflows/_integration_test.yml
vendored
Normal file
57
.github/workflows/_integration_test.yml
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
name: Integration tests
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
working-directory:
|
||||
required: true
|
||||
type: string
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.6.1"
|
||||
|
||||
jobs:
|
||||
build:
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.8"
|
||||
- "3.11"
|
||||
name: Python ${{ matrix.python-version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
|
||||
uses: "./.github/actions/poetry_setup"
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
poetry-version: ${{ env.POETRY_VERSION }}
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
cache-key: core
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: poetry install --with test,test_integration
|
||||
|
||||
- name: Run integration tests
|
||||
shell: bash
|
||||
env:
|
||||
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
|
||||
run: |
|
||||
make integration_tests
|
||||
|
||||
- name: Ensure the tests did not create any additional files
|
||||
shell: bash
|
||||
run: |
|
||||
set -eu
|
||||
|
||||
STATUS="$(git status)"
|
||||
echo "$STATUS"
|
||||
|
||||
# grep will exit non-zero if the target message isn't found,
|
||||
# and `set -e` above will cause the step to fail.
|
||||
echo "$STATUS" | grep 'nothing to commit, working tree clean'
|
||||
45
.github/workflows/_release.yml
vendored
45
.github/workflows/_release.yml
vendored
@@ -19,6 +19,7 @@ on:
|
||||
- libs/experimental
|
||||
- libs/community
|
||||
- libs/partners/google-genai
|
||||
- libs/partners/nvidia-aiplay
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
@@ -88,6 +89,8 @@ jobs:
|
||||
- test-pypi-publish
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
# We explicitly *don't* set up caching here. This ensures our tests are
|
||||
# maximally sensitive to catching breakage.
|
||||
#
|
||||
@@ -100,12 +103,17 @@ jobs:
|
||||
# - Tests pass, because the dependency is present even though it wasn't specified.
|
||||
# - The package is published, and it breaks on the missing dependency when
|
||||
# used in the real world.
|
||||
- uses: actions/setup-python@v4
|
||||
|
||||
- name: Set up Python + Poetry ${{ env.POETRY_VERSION }}
|
||||
uses: "./.github/actions/poetry_setup"
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
poetry-version: ${{ env.POETRY_VERSION }}
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
|
||||
- name: Test published package
|
||||
- name: Import published package
|
||||
shell: bash
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
env:
|
||||
PKG_NAME: ${{ needs.build.outputs.pkg-name }}
|
||||
VERSION: ${{ needs.build.outputs.version }}
|
||||
@@ -117,9 +125,8 @@ jobs:
|
||||
# (https://test.pypi.org/simple). This will include the PKG_NAME==VERSION
|
||||
# package because VERSION will not have been uploaded to regular PyPI yet.
|
||||
#
|
||||
# TODO: add more in-depth pre-publish tests after testing that importing works
|
||||
run: |
|
||||
pip install \
|
||||
poetry run pip install \
|
||||
--extra-index-url https://test.pypi.org/simple/ \
|
||||
"$PKG_NAME==$VERSION"
|
||||
|
||||
@@ -127,7 +134,35 @@ jobs:
|
||||
# since that's how Python imports packages with dashes in the name.
|
||||
IMPORT_NAME="$(echo "$PKG_NAME" | sed s/-/_/g)"
|
||||
|
||||
python -c "import $IMPORT_NAME; print(dir($IMPORT_NAME))"
|
||||
poetry run python -c "import $IMPORT_NAME; print(dir($IMPORT_NAME))"
|
||||
|
||||
- name: Import test dependencies
|
||||
run: poetry install --with test,test_integration
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
|
||||
# Overwrite the local version of the package with the test PyPI version.
|
||||
- name: Import published package (again)
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
shell: bash
|
||||
env:
|
||||
PKG_NAME: ${{ needs.build.outputs.pkg-name }}
|
||||
VERSION: ${{ needs.build.outputs.version }}
|
||||
run: |
|
||||
poetry run pip install \
|
||||
--extra-index-url https://test.pypi.org/simple/ \
|
||||
"$PKG_NAME==$VERSION"
|
||||
|
||||
- name: Run unit tests
|
||||
run: make tests
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
|
||||
- name: Run integration tests
|
||||
if: ${{ startsWith(inputs.working-directory, 'libs/partners/') }}
|
||||
env:
|
||||
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
|
||||
run: make integration_tests
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
|
||||
|
||||
publish:
|
||||
needs:
|
||||
|
||||
@@ -277,7 +277,7 @@
|
||||
"source": [
|
||||
"%env CMAKE_ARGS=\"-DLLAMA_METAL=on\"\n",
|
||||
"%env FORCE_CMAKE=1\n",
|
||||
"%pip install -U llama-cpp-python --no-cache-dirclear`"
|
||||
"%pip install -U llama-cpp-python --no-cache-dirclear"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"sidebar_label: Google Generative AI\n",
|
||||
"sidebar_label: Google AI\n",
|
||||
"---"
|
||||
]
|
||||
},
|
||||
@@ -15,9 +15,9 @@
|
||||
"id": "bb9e152f-a1dc-45df-a50c-60a8d7ecdf69",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ChatGoogleGenerativeAI\n",
|
||||
"# Google AI chat models\n",
|
||||
"\n",
|
||||
"Access Google's `gemini` and `gemini-vision` models, as well as other generative models through `ChatGoogleGenerativeAI` class in the [langchain-google-genai](https://pypi.org/project/langchain-google-genai/) integration package."
|
||||
"Access Google AI's `gemini` and `gemini-vision` models, as well as other generative models through `ChatGoogleGenerativeAI` class in the [langchain-google-genai](https://pypi.org/project/langchain-google-genai/) integration package."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -41,7 +41,7 @@
|
||||
"import os\n",
|
||||
"\n",
|
||||
"if \"GOOGLE_API_KEY\" not in os.environ:\n",
|
||||
" os.environ[\"GOOGLE_API_KEY\"] = getpass(\"Provide your Google API Key\")"
|
||||
" os.environ[\"GOOGLE_API_KEY\"] = getpass.getpass(\"Provide your Google API Key\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -285,7 +285,7 @@
|
||||
"source": [
|
||||
"## Gemini Prompting FAQs\n",
|
||||
"\n",
|
||||
"As of the time this doc was written (2024/12/12), Gemini has some restrictions on the types and structure of prompts it accepts. Specifically:\n",
|
||||
"As of the time this doc was written (2023/12/12), Gemini has some restrictions on the types and structure of prompts it accepts. Specifically:\n",
|
||||
"\n",
|
||||
"1. When providing multimodal (image) inputs, you are restricted to at most 1 message of \"human\" (user) type. You cannot pass multiple messages (though the single human message may have multiple content entries)\n",
|
||||
"2. System messages are not accepted.\n",
|
||||
@@ -295,6 +295,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "92b5aca5",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
}
|
||||
@@ -315,7 +316,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
921
docs/docs/integrations/chat/nv_aiplay.ipynb
Normal file
921
docs/docs/integrations/chat/nv_aiplay.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -3,15 +3,7 @@
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "5147e458-3b83-449e-9c2f-e7e1972e43fc",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Databricks\n",
|
||||
"\n",
|
||||
@@ -145,15 +137,7 @@
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "94f6540e-40cd-4d9b-95d3-33d36f061dcc",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Wrapping a serving endpoint: Custom model\n",
|
||||
"\n",
|
||||
@@ -173,18 +157,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "7496dc7a-8a1a-4ce6-9648-4f69ed25275b",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@@ -211,18 +184,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "0c86d952-4236-4a5e-bdac-cf4e3ccf3a16",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@@ -242,18 +204,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "5f2507a2-addd-431d-9da5-dc2ae33783f6",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@@ -288,18 +239,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "9b54f8ce-ffe5-4c47-a3f0-b4ebde524a6a",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@@ -323,18 +263,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "50f172f5-ea1f-4ceb-8cf1-20289848de7b",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@@ -370,13 +299,6 @@
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "8ea49319-a041-494d-afcd-87bcf00d5efb",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Wrapping a cluster driver proxy app\n",
|
||||
@@ -448,18 +370,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "e3330a01-e738-4170-a176-9954aff56442",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@@ -483,18 +394,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "39c121cf-0e44-4e31-91db-37fcac459677",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@@ -519,18 +419,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "3d3de599-82fd-45e4-8d8b-bacfc49dc9ce",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@@ -554,18 +443,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "853fae8e-8df4-41e6-9d45-7769f883fe80",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@@ -607,15 +485,6 @@
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+notebook": {
|
||||
"dashboards": [],
|
||||
"language": "python",
|
||||
"notebookMetadata": {
|
||||
"pythonIndentUnit": 2
|
||||
},
|
||||
"notebookName": "databricks",
|
||||
"widgets": {}
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
|
||||
287
docs/docs/integrations/llms/google_ai.ipynb
Normal file
287
docs/docs/integrations/llms/google_ai.ipynb
Normal file
@@ -0,0 +1,287 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7aZWXpbf0Eph",
|
||||
"metadata": {
|
||||
"id": "7aZWXpbf0Eph"
|
||||
},
|
||||
"source": [
|
||||
"# Google AI\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bead5ede-d9cc-44b9-b062-99c90a10cf40",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"A guide on using [Google Generative AI](https://developers.generativeai.google/) models with Langchain. Note: It's separate from Google Cloud Vertex AI [integration](https://python.langchain.com/docs/integrations/llms/google_vertex_ai_palm)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "H4AjsqTswBCE",
|
||||
"metadata": {
|
||||
"id": "H4AjsqTswBCE"
|
||||
},
|
||||
"source": [
|
||||
"## Setting up\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "EFHNUieMwJrl",
|
||||
"metadata": {
|
||||
"id": "EFHNUieMwJrl"
|
||||
},
|
||||
"source": [
|
||||
"To use Google Generative AI you must install the `langchain-google-genai` Python package and generate an API key. [Read more details](https://developers.generativeai.google/)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8Qzm6SqKwgak",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# !pip install langchain-google-genai"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7ONb7ZtOwjbo",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_google_genai import GoogleGenerativeAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "X3pjCW0i22gm",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"api_key = getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "GT50LgFP0j-w",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"**Pros of Python:**\n",
|
||||
"\n",
|
||||
"* **Easy to learn:** Python is a very easy-to-learn programming language, even for beginners. Its syntax is simple and straightforward, and there are a lot of resources available to help you get started.\n",
|
||||
"* **Versatile:** Python can be used for a wide variety of tasks, including web development, data science, and machine learning. It's also a good choice for beginners because it can be used for a variety of projects, so you can learn the basics and then move on to more complex tasks.\n",
|
||||
"* **High-level:** Python is a high-level programming language, which means that it's closer to human language than other programming languages. This makes it easier to read and understand, which can be a big advantage for beginners.\n",
|
||||
"* **Open-source:** Python is an open-source programming language, which means that it's free to use and there are a lot of resources available to help you learn it.\n",
|
||||
"* **Community:** Python has a large and active community of developers, which means that there are a lot of people who can help you if you get stuck.\n",
|
||||
"\n",
|
||||
"**Cons of Python:**\n",
|
||||
"\n",
|
||||
"* **Slow:** Python is a relatively slow programming language compared to some other languages, such as C++. This can be a disadvantage if you're working on computationally intensive tasks.\n",
|
||||
"* **Not as performant:** Python is not as performant as some other programming languages, such as C++ or Java. This can be a disadvantage if you're working on projects that require high performance.\n",
|
||||
"* **Dynamic typing:** Python is a dynamically typed programming language, which means that the type of a variable can change during runtime. This can be a disadvantage if you need to ensure that your code is type-safe.\n",
|
||||
"* **Unmanaged memory:** Python uses a garbage collection system to manage memory. This can be a disadvantage if you need to have more control over memory management.\n",
|
||||
"\n",
|
||||
"Overall, Python is a very good programming language for beginners. It's easy to learn, versatile, and has a large community of developers. However, it's important to be aware of its limitations, such as its slow performance and lack of performance.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm = GoogleGenerativeAI(model=\"models/text-bison-001\", google_api_key=api_key)\n",
|
||||
"print(\n",
|
||||
" llm.invoke(\n",
|
||||
" \"What are some of the pros and cons of Python as a programming language?\"\n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "TSGdxkJtwl8-",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"**Pros:**\n",
|
||||
"\n",
|
||||
"* **Simplicity and Readability:** Python is known for its simple and easy-to-read syntax, which makes it accessible to beginners and reduces the chance of errors. It uses indentation to define blocks of code, making the code structure clear and visually appealing.\n",
|
||||
"\n",
|
||||
"* **Versatility:** Python is a general-purpose language, meaning it can be used for a wide range of tasks, including web development, data science, machine learning, and desktop applications. This versatility makes it a popular choice for various projects and industries.\n",
|
||||
"\n",
|
||||
"* **Large Community:** Python has a vast and active community of developers, which contributes to its growth and popularity. This community provides extensive documentation, tutorials, and open-source libraries, making it easy for Python developers to find support and resources.\n",
|
||||
"\n",
|
||||
"* **Extensive Libraries:** Python offers a rich collection of libraries and frameworks for various tasks, such as data analysis (NumPy, Pandas), web development (Django, Flask), machine learning (Scikit-learn, TensorFlow), and many more. These libraries provide pre-built functions and modules, allowing developers to quickly and efficiently solve common problems.\n",
|
||||
"\n",
|
||||
"* **Cross-Platform Support:** Python is cross-platform, meaning it can run on various operating systems, including Windows, macOS, and Linux. This allows developers to write code that can be easily shared and used across different platforms.\n",
|
||||
"\n",
|
||||
"**Cons:**\n",
|
||||
"\n",
|
||||
"* **Speed and Performance:** Python is generally slower than compiled languages like C++ or Java due to its interpreted nature. This can be a disadvantage for performance-intensive tasks, such as real-time systems or heavy numerical computations.\n",
|
||||
"\n",
|
||||
"* **Memory Usage:** Python programs tend to consume more memory compared to compiled languages. This is because Python uses a dynamic memory allocation system, which can lead to memory fragmentation and higher memory usage.\n",
|
||||
"\n",
|
||||
"* **Lack of Static Typing:** Python is a dynamically typed language, which means that data types are not explicitly defined for variables. This can make it challenging to detect type errors during development, which can lead to unexpected behavior or errors at runtime.\n",
|
||||
"\n",
|
||||
"* **GIL (Global Interpreter Lock):** Python uses a global interpreter lock (GIL) to ensure that only one thread can execute Python bytecode at a time. This can limit the scalability and parallelism of Python programs, especially in multi-threaded or multiprocessing scenarios.\n",
|
||||
"\n",
|
||||
"* **Package Management:** While Python has a vast ecosystem of libraries and packages, managing dependencies and package versions can be challenging. The Python Package Index (PyPI) is the official repository for Python packages, but it can be difficult to ensure compatibility and avoid conflicts between different versions of packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm = GoogleGenerativeAI(model=\"gemini-pro\", google_api_key=api_key)\n",
|
||||
"print(\n",
|
||||
" llm.invoke(\n",
|
||||
" \"What are some of the pros and cons of Python as a programming language?\"\n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "OQ_SlL0K1Cw6",
|
||||
"metadata": {
|
||||
"id": "OQ_SlL0K1Cw6"
|
||||
},
|
||||
"source": [
|
||||
"## Using in a chain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "Nwc9P5_ry79W",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts import PromptTemplate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "35856bf2-aa5e-436b-977a-9e5725b1a595",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"4\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"template = \"\"\"Question: {question}\n",
|
||||
"\n",
|
||||
"Answer: Let's think step by step.\"\"\"\n",
|
||||
"prompt = PromptTemplate.from_template(template)\n",
|
||||
"\n",
|
||||
"chain = prompt | llm\n",
|
||||
"\n",
|
||||
"question = \"How much is 2+2?\"\n",
|
||||
"print(chain.invoke({\"question\": question}))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ueAin0xQzCqq",
|
||||
"metadata": {
|
||||
"id": "ueAin0xQzCqq"
|
||||
},
|
||||
"source": [
|
||||
"## Streaming calls"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "WftL7x0A0hlF",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"In winter's embrace, a silent ballet,\n",
|
||||
"Snowflakes descend, a celestial display.\n",
|
||||
"Whispering secrets, they softly fall,\n",
|
||||
"A blanket of white, covering all.\n",
|
||||
"\n",
|
||||
"With gentle grace, they paint the land,\n",
|
||||
"Transforming the world into a winter wonderland.\n",
|
||||
"Trees stand adorned in icy splendor,\n",
|
||||
"A glistening spectacle, a sight to render.\n",
|
||||
"\n",
|
||||
"Snowflakes twirl, like dancers on a stage,\n",
|
||||
"Creating a symphony, a winter montage.\n",
|
||||
"Their silent whispers, a sweet serenade,\n",
|
||||
"As they dance and twirl, a snowy cascade.\n",
|
||||
"\n",
|
||||
"In the hush of dawn, a frosty morn,\n",
|
||||
"Snow sparkles bright, like diamonds reborn.\n",
|
||||
"Each flake unique, in its own design,\n",
|
||||
"A masterpiece crafted by the divine.\n",
|
||||
"\n",
|
||||
"So let us revel in this wintry bliss,\n",
|
||||
"As snowflakes fall, with a gentle kiss.\n",
|
||||
"For in their embrace, we find a peace profound,\n",
|
||||
"A frozen world, with magic all around."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"for chunk in llm.stream(\"Tell me a short poem about snow\"):\n",
|
||||
" sys.stdout.write(chunk)\n",
|
||||
" sys.stdout.flush()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aefe6df7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -4,9 +4,9 @@ All functionality related to [Google Cloud Platform](https://cloud.google.com/)
|
||||
|
||||
## Chat models
|
||||
|
||||
### ChatGoogleGenerativeAI
|
||||
### Google AI
|
||||
|
||||
Access `Gemini` models such as `gemini-pro` and `gemini-pro-vision` through the `ChatGoogleGenerativeAI` class.
|
||||
Access GoogleAI `Gemini` models such as `gemini-pro` and `gemini-pro-vision` through the `ChatGoogleGenerativeAI` class.
|
||||
|
||||
```bash
|
||||
pip install -U langchain-google-genai
|
||||
|
||||
39
docs/docs/integrations/providers/nv_aiplay.mdx
Normal file
39
docs/docs/integrations/providers/nv_aiplay.mdx
Normal file
@@ -0,0 +1,39 @@
|
||||
# NVIDIA AI Playground
|
||||
|
||||
> [NVIDIA AI Playground](https://www.nvidia.com/en-us/research/ai-playground/) gives users easy access to hosted endpoints for generative AI models like Llama-2, Mistral, etc. This example demonstrates how to use LangChain to interact with supported AI Playground models.
|
||||
|
||||
These models are provided via the `langchain-nvidia-aiplay` package.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install -U langchain-nvidia-aiplay
|
||||
```
|
||||
|
||||
## Setup and Authentication
|
||||
|
||||
- Create a free account at [NVIDIA GPU Cloud](https://catalog.ngc.nvidia.com/).
|
||||
- Navigate to `Catalog > AI Foundation Models > (Model with API endpoint)`.
|
||||
- Select `API` and generate the key `NVIDIA_API_KEY`.
|
||||
|
||||
```bash
|
||||
export NVIDIA_API_KEY=nvapi-XXXXXXXXXXXXXXXXXXXXXXXXXX
|
||||
```
|
||||
|
||||
```python
|
||||
from langchain_nvidia_aiplay import ChatNVAIPlay
|
||||
|
||||
llm = ChatNVAIPlay(model="mixtral_8x7b")
|
||||
result = llm.invoke("Write a ballad about LangChain.")
|
||||
print(result.content)
|
||||
```
|
||||
|
||||
## Using NVIDIA AI Playground Models
|
||||
|
||||
A selection of NVIDIA AI Playground models are supported directly in LangChain with familiar APIs.
|
||||
|
||||
The active models which are supported can be found [in NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/ai-foundation/). In addition, a selection of models can be retrieved from `langchain.<llms/chat_models>.nv_aiplay` which pull in default model options based on their use cases.
|
||||
|
||||
**The following may be useful examples to help you get started:**
|
||||
- **[`ChatNVAIPlay` Model](/docs/integrations/chat/nv_aiplay).**
|
||||
- **[`NVAIPlayEmbedding` Model for RAG Workflows](/docs/integrations/text_embeddings/nv_aiplay).**
|
||||
220
docs/docs/integrations/text_embedding/google_generative_ai.ipynb
Normal file
220
docs/docs/integrations/text_embedding/google_generative_ai.ipynb
Normal file
@@ -0,0 +1,220 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "afab8b36-10bb-4795-bc98-75ab2d2081bb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Google Generative AI Embeddings\n",
|
||||
"\n",
|
||||
"Connect to Google's generative AI embeddings service using the `GoogleGenerativeAIEmbeddings` class, found in the [langchain-google-genai](https://pypi.org/project/langchain-google-genai/) package."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "63545b38-9d56-4312-8f61-8d4f1e7a3b1b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Installation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d2f6a3cd-379f-4dff-a449-d3a9f3196f2a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -U langchain-google-genai"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "25f3f88e-164e-400d-b371-9fa488baba19",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Credentials"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec89153f-8999-4aab-a21b-0bfba1cc3893",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"if \"GOOGLE_API_KEY\" not in os.environ:\n",
|
||||
" os.environ[\"GOOGLE_API_KEY\"] = getpass(\"Provide your Google API key here\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f2437b22-e364-418a-8c13-490a026cb7b5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Usage"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "eedc551e-a1f3-4fd8-8d65-4e0784c4441b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[0.05636945, 0.0048285457, -0.0762591, -0.023642512, 0.05329321]"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_google_genai import GoogleGenerativeAIEmbeddings\n",
|
||||
"\n",
|
||||
"embeddings = GoogleGenerativeAIEmbeddings(model=\"models/embedding-001\")\n",
|
||||
"vector = embeddings.embed_query(\"hello, world!\")\n",
|
||||
"vector[:5]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2b2bed60-e7bd-4e48-83d6-1c87001f98bd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Batch\n",
|
||||
"\n",
|
||||
"You can also embed multiple strings at once for a processing speedup:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "6ec53aba-404f-4778-acd9-5d6664e79ed2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(3, 768)"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"vectors = embeddings.embed_documents(\n",
|
||||
" [\n",
|
||||
" \"Today is Monday\",\n",
|
||||
" \"Today is Tuesday\",\n",
|
||||
" \"Today is April Fools day\",\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"len(vectors), len(vectors[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1482486f-5617-498a-8a44-1974d3212dda",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Task type\n",
|
||||
"`GoogleGenerativeAIEmbeddings` optionally support a `task_type`, which currently must be one of:\n",
|
||||
"\n",
|
||||
"- task_type_unspecified\n",
|
||||
"- retrieval_query\n",
|
||||
"- retrieval_document\n",
|
||||
"- semantic_similarity\n",
|
||||
"- classification\n",
|
||||
"- clustering\n",
|
||||
"\n",
|
||||
"By default, we use `retrieval_document` in the `embed_documents` method and `retrieval_query` in the `embed_query` method. If you provide a task type, we will use that for all methods."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "a223bb25-2b1b-418e-a570-2f543083132e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%pip install --quiet matplotlib scikit-learn"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"id": "f1f077db-8eb4-49f7-8866-471a8528dcdb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query_embeddings = GoogleGenerativeAIEmbeddings(\n",
|
||||
" model=\"models/embedding-001\", task_type=\"retrieval_query\"\n",
|
||||
")\n",
|
||||
"doc_embeddings = GoogleGenerativeAIEmbeddings(\n",
|
||||
" model=\"models/embedding-001\", task_type=\"retrieval_document\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "79bd4a5e-75ba-413c-befa-86167c938caf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"All of these will be embedded with the 'retrieval_query' task set\n",
|
||||
"```python\n",
|
||||
"query_vecs = [query_embeddings.embed_query(q) for q in [query, query_2, answer_1]]\n",
|
||||
"```\n",
|
||||
"All of these will be embedded with the 'retrieval_document' task set\n",
|
||||
"```python\n",
|
||||
"doc_vecs = [doc_embeddings.embed_query(q) for q in [query, query_2, answer_1]]\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9e1fae5e-0f84-4812-89f5-7d4d71affbc1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In retrieval, relative distance matters. In the image above, you can see the difference in similarity scores between the \"relevant doc\" and \"simil stronger delta between the similar query and relevant doc on the latter case."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
556
docs/docs/integrations/text_embedding/nv_aiplay.ipynb
Normal file
556
docs/docs/integrations/text_embedding/nv_aiplay.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -2,15 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "5a8c5767-adfe-4b9d-a665-a898756d7a6c",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Databricks Vector Search\n",
|
||||
"\n",
|
||||
@@ -21,15 +13,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "746cfacd-fb30-48fd-96a5-bbcc0d15ae49",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Install `databricks-vectorsearch` and related Python packages used in this notebook."
|
||||
]
|
||||
@@ -37,15 +21,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "9258a3e7-e050-4390-9d3f-9adff1460dab",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install langchain-core databricks-vectorsearch openai tiktoken"
|
||||
@@ -53,15 +29,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "f4f09d6d-002d-4cb0-a664-0a83bd2a13da",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Use `OpenAIEmbeddings` for the embeddings."
|
||||
]
|
||||
@@ -69,15 +37,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "f11b902d-a772-45e0-bbd9-526218b717cc",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
@@ -88,15 +48,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "59b568f3-8db2-427e-9a4a-1df6fa7a1739",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Split documents and get embeddings."
|
||||
]
|
||||
@@ -104,15 +56,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "b28e1c7b-eae4-4be8-abbd-8433c7557dc2",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
@@ -130,15 +74,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "e8fcdda1-208a-45c9-816e-ff0d2c8f59d6",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setup Databricks Vector Search client"
|
||||
]
|
||||
@@ -146,15 +82,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "9b87fff1-99e5-4d9f-aba3-d21a7ccc498e",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from databricks.vector_search.client import VectorSearchClient\n",
|
||||
@@ -181,15 +109,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "81090f87-3efd-4c1e-9f58-8d6adba7553d",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create Direct Vector Access Index\n",
|
||||
"Direct Vector Access Index supports direct read and write of embedding vectors and metadata through a REST API or an SDK. For this index, you manage embedding vectors and index updates yourself."
|
||||
@@ -198,15 +118,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "9389ec6b-5885-411f-a26e-1a4b03651f5c",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vector_search_endpoint_name = \"vector_search_demo_endpoint\"\n",
|
||||
@@ -232,15 +144,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "047a14c9-2f06-4f74-883d-815b2c69786c",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.vectorstores import DatabricksVectorSearch\n",
|
||||
@@ -252,15 +156,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "951bd581-2ced-497f-9c70-4fda902fd3a1",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Add docs to the index"
|
||||
]
|
||||
@@ -268,15 +164,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "1e85f235-901f-4cf5-845f-5dbf4ce42078",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dvs.add_documents(docs)"
|
||||
@@ -284,15 +172,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "8bea6f0a-b305-455a-acba-99cc8c9350b5",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Similarity search"
|
||||
]
|
||||
@@ -300,15 +180,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "25c5a044-a61a-4929-9e65-a0f0462925df",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
@@ -318,15 +190,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "46e3f41b-dac2-4bed-91cb-a3914c25d275",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Work with Delta Sync Index\n",
|
||||
"\n",
|
||||
@@ -336,15 +200,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {},
|
||||
"inputWidgets": {},
|
||||
"nuid": "0c1f448e-77ca-41ce-887c-15948e866a0e",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dvs_delta_sync = DatabricksVectorSearch(\"catalog_name.schema_name.delta_sync_index\")\n",
|
||||
@@ -353,15 +209,6 @@
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+notebook": {
|
||||
"dashboards": [],
|
||||
"language": "python",
|
||||
"notebookMetadata": {
|
||||
"pythonIndentUnit": 2
|
||||
},
|
||||
"notebookName": "databricks_vector_search",
|
||||
"widgets": {}
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
|
||||
69
docs/scripts/ipynb_to_md.py
Normal file
69
docs/scripts/ipynb_to_md.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# based on outerbounds/nbdoc
|
||||
|
||||
from nbdev.export import nbglob
|
||||
from nbconvert import MarkdownExporter
|
||||
from nbconvert.preprocessors import Preprocessor
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
|
||||
class WriteTitle(Preprocessor):
|
||||
"""Modify the code-fence with the filename upon %%writefile cell magic."""
|
||||
|
||||
pattern = r"(^[\S\s]*%%writefile\s)(\S+)\n"
|
||||
|
||||
def preprocess_cell(self, cell, resources, index):
|
||||
print("here")
|
||||
m = re.match(self.pattern, cell.source)
|
||||
if m:
|
||||
filename = m.group(2)
|
||||
ext = filename.split(".")[-1]
|
||||
cell.metadata.magics_language = f'{ext} title="{filename}"'
|
||||
cell.metadata.script = True
|
||||
cell.metadata.file_ext = ext
|
||||
cell.metadata.filename = filename
|
||||
cell.outputs = []
|
||||
return cell, resources
|
||||
|
||||
|
||||
def get_exporter():
|
||||
# c = Config()
|
||||
# c.MarkdownExporter.preprocessors = [WriteTitle]
|
||||
exporter = MarkdownExporter(
|
||||
config={"MarkdownExporter": {"preprocessors": [WriteTitle]}}
|
||||
)
|
||||
return exporter
|
||||
|
||||
|
||||
def process_file(fname: Path, force: bool = False) -> None:
|
||||
fname_rel = fname.relative_to(basedir)
|
||||
fname_out_ipynb = outdir / fname_rel
|
||||
fname_out = fname_out_ipynb.with_suffix(".md")
|
||||
|
||||
if (
|
||||
force
|
||||
or not fname_out.exists()
|
||||
or fname.stat().st_mtime > fname_out.stat().st_mtime
|
||||
):
|
||||
print(f"Converting {fname_rel} to markdown")
|
||||
exporter = get_exporter()
|
||||
output, _ = exporter.from_filename(fname)
|
||||
fname_out.write_text(output)
|
||||
print(fname_out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# parallel process
|
||||
basedir = Path(__file__).parent.parent / "docs"
|
||||
outdir = Path(__file__).parent.parent.parent.parent / "_dist" / "docs"
|
||||
files = nbglob(basedir, recursive=True)
|
||||
|
||||
fname = files[0]
|
||||
process_file(fname, True)
|
||||
|
||||
# for fname in files:
|
||||
# process_file(fname)
|
||||
|
||||
# print(fname_out)
|
||||
# for fname in files:
|
||||
# fname_out = fname.with_suffix('.md')
|
||||
@@ -1,3 +1,4 @@
|
||||
-e ../libs/langchain
|
||||
-e ../libs/community
|
||||
-e ../libs/core
|
||||
urllib3==1.26.18
|
||||
|
||||
@@ -89,7 +89,14 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
elif role == "function":
|
||||
return FunctionMessage(content=_dict["content"], name=_dict["name"])
|
||||
elif role == "tool":
|
||||
return ToolMessage(content=_dict["content"], tool_call_id=_dict["tool_call_id"])
|
||||
additional_kwargs = {}
|
||||
if "name" in _dict:
|
||||
additional_kwargs["name"] = _dict["name"]
|
||||
return ToolMessage(
|
||||
content=_dict["content"],
|
||||
tool_call_id=_dict["tool_call_id"],
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_community.llms import BaseLLM
|
||||
@@ -13,7 +15,9 @@ from langchain_community.utilities.vertexai import create_retry_decorator
|
||||
|
||||
def completion_with_retry(
|
||||
llm: GooglePalm,
|
||||
*args: Any,
|
||||
prompt: LanguageModelInput,
|
||||
is_gemini: bool = False,
|
||||
stream: bool = False,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@@ -23,10 +27,23 @@ def completion_with_retry(
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
return llm.client.generate_text(*args, **kwargs)
|
||||
def _completion_with_retry(
|
||||
prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
|
||||
) -> Any:
|
||||
generation_config = kwargs.get("generation_config", {})
|
||||
if is_gemini:
|
||||
return llm.client.generate_content(
|
||||
contents=prompt, stream=stream, generation_config=generation_config
|
||||
)
|
||||
return llm.client.generate_text(prompt=prompt, **kwargs)
|
||||
|
||||
return _completion_with_retry(*args, **kwargs)
|
||||
return _completion_with_retry(
|
||||
prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def _is_gemini_model(model_name: str) -> bool:
|
||||
return "gemini" in model_name
|
||||
|
||||
|
||||
def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||
@@ -42,11 +59,16 @@ def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
@deprecated("0.0.351", alternative="langchain_google_genai.GoogleGenerativeAI")
|
||||
class GooglePalm(BaseLLM, BaseModel):
|
||||
"""Google PaLM models."""
|
||||
"""
|
||||
DEPRECATED: Use `langchain_google_genai.GoogleGenerativeAI` instead.
|
||||
|
||||
Google PaLM models.
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
google_api_key: Optional[str]
|
||||
google_api_key: Optional[SecretStr]
|
||||
model_name: str = "models/text-bison-001"
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
@@ -67,6 +89,11 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
max_retries: int = 6
|
||||
"""The maximum number of retries to make when generating."""
|
||||
|
||||
@property
|
||||
def is_gemini(self) -> bool:
|
||||
"""Returns whether a model is belongs to a Gemini family or not."""
|
||||
return _is_gemini_model(self.model_name)
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"google_api_key": "GOOGLE_API_KEY"}
|
||||
@@ -86,18 +113,25 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
google_api_key = get_from_dict_or_env(
|
||||
values, "google_api_key", "GOOGLE_API_KEY"
|
||||
)
|
||||
model_name = values["model_name"]
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
|
||||
if isinstance(google_api_key, SecretStr):
|
||||
google_api_key = google_api_key.get_secret_value()
|
||||
|
||||
genai.configure(api_key=google_api_key)
|
||||
|
||||
if _is_gemini_model(model_name):
|
||||
values["client"] = genai.GenerativeModel(model_name=model_name)
|
||||
else:
|
||||
values["client"] = genai
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import google-generativeai python package. "
|
||||
"Please install it with `pip install google-generativeai`."
|
||||
)
|
||||
|
||||
values["client"] = genai
|
||||
|
||||
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||
|
||||
@@ -119,30 +153,76 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
generations = []
|
||||
generations: List[List[Generation]] = []
|
||||
generation_config = {
|
||||
"stop_sequences": stop,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"max_output_tokens": self.max_output_tokens,
|
||||
"candidate_count": self.n,
|
||||
}
|
||||
for prompt in prompts:
|
||||
completion = completion_with_retry(
|
||||
self,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
stop_sequences=stop,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
max_output_tokens=self.max_output_tokens,
|
||||
candidate_count=self.n,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
prompt_generations = []
|
||||
for candidate in completion.candidates:
|
||||
raw_text = candidate["output"]
|
||||
stripped_text = _strip_erroneous_leading_spaces(raw_text)
|
||||
prompt_generations.append(Generation(text=stripped_text))
|
||||
generations.append(prompt_generations)
|
||||
if self.is_gemini:
|
||||
res = completion_with_retry(
|
||||
self,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
is_gemini=True,
|
||||
run_manager=run_manager,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
candidates = [
|
||||
"".join([p.text for p in c.content.parts]) for c in res.candidates
|
||||
]
|
||||
generations.append([Generation(text=c) for c in candidates])
|
||||
else:
|
||||
res = completion_with_retry(
|
||||
self,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
is_gemini=False,
|
||||
run_manager=run_manager,
|
||||
**generation_config,
|
||||
)
|
||||
prompt_generations = []
|
||||
for candidate in res.candidates:
|
||||
raw_text = candidate["output"]
|
||||
stripped_text = _strip_erroneous_leading_spaces(raw_text)
|
||||
prompt_generations.append(Generation(text=stripped_text))
|
||||
generations.append(prompt_generations)
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
generation_config = kwargs.get("generation_config", {})
|
||||
if stop:
|
||||
generation_config["stop_sequences"] = stop
|
||||
for stream_resp in completion_with_retry(
|
||||
self,
|
||||
prompt,
|
||||
stream=True,
|
||||
is_gemini=True,
|
||||
run_manager=run_manager,
|
||||
generation_config=generation_config,
|
||||
**kwargs,
|
||||
):
|
||||
chunk = GenerationChunk(text=stream_resp.text)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
stream_resp.text,
|
||||
chunk=chunk,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
@@ -159,5 +239,7 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
Returns:
|
||||
The integer number of tokens in the text.
|
||||
"""
|
||||
if self.is_gemini:
|
||||
raise ValueError("Counting tokens is not yet supported!")
|
||||
result = self.client.count_text_tokens(model=self.model_name, prompt=text)
|
||||
return result["token_count"]
|
||||
|
||||
@@ -60,7 +60,14 @@ class BaseModel(Base):
|
||||
uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
|
||||
_classes: Any = None
|
||||
|
||||
|
||||
def _get_embedding_collection_store() -> Any:
|
||||
global _classes
|
||||
if _classes is not None:
|
||||
return _classes
|
||||
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
class CollectionStore(BaseModel):
|
||||
@@ -126,7 +133,9 @@ def _get_embedding_collection_store() -> Any:
|
||||
# custom_id : any user defined id
|
||||
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||
|
||||
return EmbeddingStore, CollectionStore
|
||||
_classes = (EmbeddingStore, CollectionStore)
|
||||
|
||||
return _classes
|
||||
|
||||
|
||||
def _results_to_docs(docs_and_scores: Any) -> List[Document]:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Test Google PaLM Text API wrapper.
|
||||
"""Test Google GenerativeAI API wrapper.
|
||||
|
||||
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
|
||||
valid API key.
|
||||
@@ -6,35 +6,68 @@ Note: This test must be run with the GOOGLE_API_KEY environment variable set to
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_community.llms.google_palm import GooglePalm
|
||||
from langchain_community.llms.loading import load_llm
|
||||
|
||||
model_names = [None, "models/text-bison-001", "gemini-pro"]
|
||||
|
||||
def test_google_palm_call() -> None:
|
||||
"""Test valid call to Google PaLM text API."""
|
||||
llm = GooglePalm(max_output_tokens=10)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names,
|
||||
)
|
||||
def test_google_generativeai_call(model_name: str) -> None:
|
||||
"""Test valid call to Google GenerativeAI text API."""
|
||||
if model_name:
|
||||
llm = GooglePalm(max_output_tokens=10, model_name=model_name)
|
||||
else:
|
||||
llm = GooglePalm(max_output_tokens=10)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
assert llm._llm_type == "google_palm"
|
||||
assert llm.model_name == "models/text-bison-001"
|
||||
if model_name and "gemini" in model_name:
|
||||
assert llm.client.model_name == "models/gemini-pro"
|
||||
else:
|
||||
assert llm.model_name == "models/text-bison-001"
|
||||
|
||||
|
||||
def test_google_palm_generate() -> None:
|
||||
llm = GooglePalm(temperature=0.3, n=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names,
|
||||
)
|
||||
def test_google_generativeai_generate(model_name: str) -> None:
|
||||
n = 1 if model_name == "gemini-pro" else 2
|
||||
if model_name:
|
||||
llm = GooglePalm(temperature=0.3, n=n, model_name=model_name)
|
||||
else:
|
||||
llm = GooglePalm(temperature=0.3, n=n)
|
||||
output = llm.generate(["Say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
assert len(output.generations[0]) == 2
|
||||
assert len(output.generations[0]) == n
|
||||
|
||||
|
||||
def test_google_palm_get_num_tokens() -> None:
|
||||
def test_google_generativeai_get_num_tokens() -> None:
|
||||
llm = GooglePalm()
|
||||
output = llm.get_num_tokens("How are you?")
|
||||
assert output == 4
|
||||
|
||||
|
||||
async def test_google_generativeai_agenerate() -> None:
|
||||
llm = GooglePalm(temperature=0, model_name="gemini-pro")
|
||||
output = await llm.agenerate(["Please say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
def test_generativeai_stream() -> None:
|
||||
llm = GooglePalm(temperature=0, model_name="gemini-pro")
|
||||
outputs = list(llm.stream("Please say foo:"))
|
||||
assert isinstance(outputs[0], str)
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading a Google PaLM LLM."""
|
||||
llm = GooglePalm(max_output_tokens=10)
|
||||
|
||||
@@ -30,6 +30,24 @@ class FakeTracer(BaseTracer):
|
||||
self.runs.append(run)
|
||||
|
||||
|
||||
def _compare_run_with_error(run: Run, expected_run: Run) -> None:
|
||||
if run.child_runs:
|
||||
assert len(expected_run.child_runs) == len(run.child_runs)
|
||||
for received, expected in zip(run.child_runs, expected_run.child_runs):
|
||||
_compare_run_with_error(received, expected)
|
||||
received = run.dict(exclude={"child_runs"})
|
||||
received_err = received.pop("error")
|
||||
expected = expected_run.dict(exclude={"child_runs"})
|
||||
expected_err = expected.pop("error")
|
||||
|
||||
assert received == expected
|
||||
if expected_err is not None:
|
||||
assert received_err is not None
|
||||
assert expected_err in received_err
|
||||
else:
|
||||
assert received_err is None
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run() -> None:
|
||||
"""Test tracer on an LLM run."""
|
||||
@@ -328,7 +346,8 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
assert len(tracer.runs) == 1
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@@ -364,7 +383,7 @@ def test_tracer_llm_run_on_error_callback() -> None:
|
||||
tracer = FakeTracerWithLlmErrorCallback()
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_error(exception, run_id=uuid)
|
||||
assert tracer.error_run == compare_run
|
||||
_compare_run_with_error(tracer.error_run, compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@@ -394,7 +413,7 @@ def test_tracer_chain_run_on_error() -> None:
|
||||
|
||||
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
||||
tracer.on_chain_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@@ -425,7 +444,7 @@ def test_tracer_tool_run_on_error() -> None:
|
||||
|
||||
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
|
||||
tracer.on_tool_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@@ -568,4 +587,6 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
),
|
||||
],
|
||||
)
|
||||
assert tracer.runs == [compare_run] * 3
|
||||
assert len(tracer.runs) == 3
|
||||
for run in tracer.runs:
|
||||
_compare_run_with_error(run, compare_run)
|
||||
|
||||
@@ -98,6 +98,15 @@ class FakeTracer(BaseTracer):
|
||||
return load_default_session()
|
||||
|
||||
|
||||
def _compare_run_with_error(run: Run, expected_run: Run) -> None:
|
||||
received = run.dict()
|
||||
received_err = received.pop("error")
|
||||
expected = expected_run.dict()
|
||||
expected_err = expected.pop("error")
|
||||
assert received == expected
|
||||
assert expected_err in received_err
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run() -> None:
|
||||
"""Test tracer on an LLM run."""
|
||||
@@ -376,7 +385,7 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@@ -404,7 +413,7 @@ def test_tracer_chain_run_on_error() -> None:
|
||||
tracer.new_session()
|
||||
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
||||
tracer.on_chain_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@@ -433,136 +442,7 @@ def test_tracer_tool_run_on_error() -> None:
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
|
||||
tracer.on_tool_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_nested_runs_on_error() -> None:
|
||||
"""Test tracer on a nested run with an error."""
|
||||
exception = Exception("test")
|
||||
|
||||
tracer = FakeTracer()
|
||||
tracer.new_session()
|
||||
chain_uuid = uuid4()
|
||||
tool_uuid = uuid4()
|
||||
llm_uuid1 = uuid4()
|
||||
llm_uuid2 = uuid4()
|
||||
llm_uuid3 = uuid4()
|
||||
|
||||
for _ in range(3):
|
||||
tracer.on_chain_start(
|
||||
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||
tracer.on_tool_start(
|
||||
serialized={"name": "tool"},
|
||||
input_str="test",
|
||||
run_id=tool_uuid,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid3,
|
||||
parent_run_id=tool_uuid,
|
||||
)
|
||||
tracer.on_llm_error(exception, run_id=llm_uuid3)
|
||||
tracer.on_tool_error(exception, run_id=tool_uuid)
|
||||
tracer.on_chain_error(exception, run_id=chain_uuid)
|
||||
|
||||
compare_run = ChainRun(
|
||||
uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=5,
|
||||
serialized={"name": "chain"},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
inputs={},
|
||||
outputs=None,
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid1),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
child_execution_order=2,
|
||||
serialized=SERIALIZED,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]], llm_output=None),
|
||||
),
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid2),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized=SERIALIZED,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]], llm_output=None),
|
||||
),
|
||||
],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[
|
||||
ToolRun(
|
||||
uuid=str(tool_uuid),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
child_execution_order=5,
|
||||
serialized={"name": "tool"},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
tool_input="test",
|
||||
output=None,
|
||||
action="{'name': 'tool'}",
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid3),
|
||||
parent_uuid=str(tool_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=5,
|
||||
child_execution_order=5,
|
||||
serialized=SERIALIZED,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
prompts=[],
|
||||
response=None,
|
||||
)
|
||||
],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
),
|
||||
],
|
||||
)
|
||||
assert tracer.runs == [compare_run] * 3
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -2,9 +2,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
@@ -45,6 +56,21 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
|
||||
@staticmethod
|
||||
def _get_stacktrace(error: BaseException) -> str:
|
||||
"""Get the stacktrace of the parent error."""
|
||||
msg = repr(error)
|
||||
try:
|
||||
if sys.version_info < (3, 10):
|
||||
tb = traceback.format_exception(
|
||||
error.__class__, error, error.__traceback__
|
||||
)
|
||||
else:
|
||||
tb = traceback.format_exception(error)
|
||||
return (msg + "\n\n".join(tb)).strip()
|
||||
except: # noqa: E722
|
||||
return msg
|
||||
|
||||
def _start_trace(self, run: Run) -> None:
|
||||
"""Start a trace for a run."""
|
||||
if run.parent_run_id:
|
||||
@@ -220,7 +246,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
) -> Run:
|
||||
"""Handle an error for an LLM run."""
|
||||
llm_run = self._get_run(run_id, run_type="llm")
|
||||
llm_run.error = repr(error)
|
||||
llm_run.error = self._get_stacktrace(error)
|
||||
llm_run.end_time = datetime.utcnow()
|
||||
llm_run.events.append({"name": "error", "time": llm_run.end_time})
|
||||
self._end_trace(llm_run)
|
||||
@@ -296,7 +322,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
) -> Run:
|
||||
"""Handle an error for a chain run."""
|
||||
chain_run = self._get_run(run_id)
|
||||
chain_run.error = repr(error)
|
||||
chain_run.error = self._get_stacktrace(error)
|
||||
chain_run.end_time = datetime.utcnow()
|
||||
chain_run.events.append({"name": "error", "time": chain_run.end_time})
|
||||
if inputs is not None:
|
||||
@@ -361,7 +387,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
) -> Run:
|
||||
"""Handle an error for a tool run."""
|
||||
tool_run = self._get_run(run_id, run_type="tool")
|
||||
tool_run.error = repr(error)
|
||||
tool_run.error = self._get_stacktrace(error)
|
||||
tool_run.end_time = datetime.utcnow()
|
||||
tool_run.events.append({"name": "error", "time": tool_run.end_time})
|
||||
self._end_trace(tool_run)
|
||||
@@ -414,7 +440,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
) -> Run:
|
||||
"""Run when Retriever errors."""
|
||||
retrieval_run = self._get_run(run_id, run_type="retriever")
|
||||
retrieval_run.error = repr(error)
|
||||
retrieval_run.error = self._get_stacktrace(error)
|
||||
retrieval_run.end_time = datetime.utcnow()
|
||||
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
|
||||
self._end_trace(retrieval_run)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-core"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
description = "Building applications with LLMs through composability"
|
||||
authors = []
|
||||
license = "MIT"
|
||||
|
||||
@@ -630,13 +630,14 @@ def test_lambda_schemas() -> None:
|
||||
}
|
||||
|
||||
second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"]) # noqa: E731
|
||||
assert RunnableLambda(
|
||||
second_lambda, # type: ignore[arg-type]
|
||||
).input_schema.schema() == {
|
||||
"title": "RunnableLambdaInput",
|
||||
"type": "object",
|
||||
"properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},
|
||||
}
|
||||
assert (
|
||||
RunnableLambda(second_lambda).input_schema.schema() # type: ignore[arg-type]
|
||||
== {
|
||||
"title": "RunnableLambdaInput",
|
||||
"type": "object",
|
||||
"properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},
|
||||
}
|
||||
)
|
||||
|
||||
def get_value(input): # type: ignore[no-untyped-def]
|
||||
return input["variable_name"]
|
||||
@@ -3624,33 +3625,32 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
|
||||
parent_run_foo = parent_runs[0]
|
||||
assert parent_run_foo.inputs["input"] == "foo"
|
||||
assert parent_run_foo.error == repr(ValueError())
|
||||
assert repr(ValueError()) in str(parent_run_foo.error)
|
||||
assert len(parent_run_foo.child_runs) == 4
|
||||
assert [r.error for r in parent_run_foo.child_runs] == [
|
||||
assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
assert repr(ValueError()) in str(parent_run_foo.child_runs[-1].error)
|
||||
|
||||
parent_run_bar = parent_runs[1]
|
||||
assert parent_run_bar.inputs["input"] == "bar"
|
||||
assert parent_run_bar.error == repr(ValueError())
|
||||
assert repr(ValueError()) in str(parent_run_bar.error)
|
||||
assert len(parent_run_bar.child_runs) == 2
|
||||
assert [r.error for r in parent_run_bar.child_runs] == [
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
assert parent_run_bar.child_runs[0].error is None
|
||||
assert repr(ValueError()) in str(parent_run_bar.child_runs[1].error)
|
||||
|
||||
parent_run_baz = parent_runs[2]
|
||||
assert parent_run_baz.inputs["input"] == "baz"
|
||||
assert parent_run_baz.error == repr(ValueError())
|
||||
assert repr(ValueError()) in str(parent_run_baz.error)
|
||||
assert len(parent_run_baz.child_runs) == 3
|
||||
assert [r.error for r in parent_run_baz.child_runs] == [
|
||||
|
||||
assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
|
||||
None,
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
assert repr(ValueError()) in str(parent_run_baz.child_runs[-1].error)
|
||||
|
||||
parent_run_qux = parent_runs[3]
|
||||
assert parent_run_qux.inputs["input"] == "qux"
|
||||
@@ -3746,33 +3746,31 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
|
||||
parent_run_foo = parent_runs[0]
|
||||
assert parent_run_foo.inputs["input"] == "foo"
|
||||
assert parent_run_foo.error == repr(ValueError())
|
||||
assert repr(ValueError()) in str(parent_run_foo.error)
|
||||
assert len(parent_run_foo.child_runs) == 4
|
||||
assert [r.error for r in parent_run_foo.child_runs] == [
|
||||
assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
assert repr(ValueError()) in str(parent_run_foo.child_runs[-1].error)
|
||||
|
||||
parent_run_bar = parent_runs[1]
|
||||
assert parent_run_bar.inputs["input"] == "bar"
|
||||
assert parent_run_bar.error == repr(ValueError())
|
||||
assert repr(ValueError()) in str(parent_run_bar.error)
|
||||
assert len(parent_run_bar.child_runs) == 2
|
||||
assert [r.error for r in parent_run_bar.child_runs] == [
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
assert parent_run_bar.child_runs[0].error is None
|
||||
assert repr(ValueError()) in str(parent_run_bar.child_runs[1].error)
|
||||
|
||||
parent_run_baz = parent_runs[2]
|
||||
assert parent_run_baz.inputs["input"] == "baz"
|
||||
assert parent_run_baz.error == repr(ValueError())
|
||||
assert repr(ValueError()) in str(parent_run_baz.error)
|
||||
assert len(parent_run_baz.child_runs) == 3
|
||||
assert [r.error for r in parent_run_baz.child_runs] == [
|
||||
assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
|
||||
None,
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
assert repr(ValueError()) in str(parent_run_baz.child_runs[-1].error)
|
||||
|
||||
parent_run_qux = parent_runs[3]
|
||||
assert parent_run_qux.inputs["input"] == "qux"
|
||||
@@ -3941,7 +3939,7 @@ def test_runnable_branch_invoke_callbacks() -> None:
|
||||
branch.invoke(1000, config={"callbacks": [tracer]})
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
assert tracer.runs[1].error == "ValueError('x is too large')"
|
||||
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
|
||||
assert tracer.runs[1].outputs is None
|
||||
|
||||
|
||||
@@ -3968,7 +3966,7 @@ async def test_runnable_branch_ainvoke_callbacks() -> None:
|
||||
await branch.ainvoke(1000, config={"callbacks": [tracer]})
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
assert tracer.runs[1].error == "ValueError('x is too large')"
|
||||
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
|
||||
assert tracer.runs[1].outputs is None
|
||||
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ def _create_tool_message(
|
||||
return ToolMessage(
|
||||
tool_call_id=agent_action.tool_call_id,
|
||||
content=content,
|
||||
additional_kwargs={"name": agent_action.tool},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -32,11 +34,12 @@ from langchain_core.tracers.evaluation import (
|
||||
)
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
from langsmith.client import Client
|
||||
from langsmith.evaluation import RunEvaluator
|
||||
from langsmith.evaluation import EvaluationResult, RunEvaluator
|
||||
from langsmith.run_helpers import as_runnable, is_traceable_function
|
||||
from langsmith.schemas import Dataset, DataType, Example
|
||||
from langsmith.schemas import Dataset, DataType, Example, TracerSession
|
||||
from langsmith.utils import LangSmithError
|
||||
from requests import HTTPError
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.base import Chain
|
||||
@@ -919,9 +922,12 @@ def _prepare_eval_run(
|
||||
project_name: str,
|
||||
project_metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Tuple[MCF, str, Dataset, List[Example]]:
|
||||
) -> Tuple[MCF, TracerSession, Dataset, List[Example]]:
|
||||
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
|
||||
dataset = client.read_dataset(dataset_name=dataset_name)
|
||||
examples = list(client.list_examples(dataset_id=dataset.id))
|
||||
if not examples:
|
||||
raise ValueError(f"Dataset {dataset_name} has no example rows.")
|
||||
|
||||
try:
|
||||
project_extra: dict = {"metadata": project_metadata} if project_metadata else {}
|
||||
@@ -953,111 +959,159 @@ run_on_dataset(
|
||||
f"View all tests for Dataset {dataset_name} at:\n{dataset.url}",
|
||||
flush=True,
|
||||
)
|
||||
examples = list(client.list_examples(dataset_id=dataset.id))
|
||||
if not examples:
|
||||
raise ValueError(f"Dataset {dataset_name} has no example rows.")
|
||||
return wrapped_model, project_name, dataset, examples
|
||||
return wrapped_model, project, dataset, examples
|
||||
|
||||
|
||||
def _prepare_run_on_dataset(
|
||||
client: Client,
|
||||
dataset_name: str,
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
project_name: Optional[str],
|
||||
evaluation: Optional[smith_eval.RunEvalConfig] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
concurrency_level: int = 5,
|
||||
project_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[MCF, str, List[Example], List[RunnableConfig]]:
|
||||
project_name = project_name or name_generation.random_name()
|
||||
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
|
||||
client,
|
||||
dataset_name,
|
||||
llm_or_chain_factory,
|
||||
project_name,
|
||||
project_metadata=project_metadata,
|
||||
tags=tags,
|
||||
)
|
||||
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
|
||||
run_evaluators = _setup_evaluation(
|
||||
wrapped_model, examples, evaluation, dataset.data_type or DataType.kv
|
||||
)
|
||||
_validate_example_inputs(examples[0], wrapped_model, input_mapper)
|
||||
progress_bar = progress.ProgressBarCallback(len(examples))
|
||||
configs = [
|
||||
RunnableConfig(
|
||||
callbacks=[
|
||||
LangChainTracer(
|
||||
project_name=project_name,
|
||||
client=client,
|
||||
use_threading=False,
|
||||
example_id=example.id,
|
||||
),
|
||||
EvaluatorCallbackHandler(
|
||||
evaluators=run_evaluators or [],
|
||||
client=client,
|
||||
example_id=example.id,
|
||||
max_concurrency=0,
|
||||
),
|
||||
progress_bar,
|
||||
],
|
||||
tags=tags or [],
|
||||
max_concurrency=concurrency_level,
|
||||
class _RowResult(TypedDict, total=False):
|
||||
"""A dictionary of the results for a single example row."""
|
||||
|
||||
feedback: Optional[List[EvaluationResult]]
|
||||
execution_time: Optional[float]
|
||||
run_id: Optional[str]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DatasetRunContainer:
|
||||
"""A container to help manage the state of a eval run."""
|
||||
|
||||
client: Client
|
||||
project: TracerSession
|
||||
wrapped_model: MCF
|
||||
examples: List[Example]
|
||||
configs: List[RunnableConfig]
|
||||
|
||||
def _merge_test_outputs(
|
||||
self,
|
||||
batch_results: list,
|
||||
all_eval_results: Dict[str, _RowResult],
|
||||
) -> dict:
|
||||
results: dict = {}
|
||||
for example, output in zip(self.examples, batch_results):
|
||||
row_result = cast(_RowResult, all_eval_results.get(str(example.id), {}))
|
||||
results[str(example.id)] = {
|
||||
"input": example.inputs,
|
||||
"feedback": row_result.get("feedback", []),
|
||||
"execution_time": row_result.get("execution_time"),
|
||||
"run_id": row_result.get("run_id"),
|
||||
}
|
||||
if isinstance(output, EvalError):
|
||||
results[str(example.id)]["Error"] = output.Error
|
||||
else:
|
||||
results[str(example.id)]["output"] = output
|
||||
if example.outputs:
|
||||
results[str(example.id)]["reference"] = example.outputs
|
||||
return results
|
||||
|
||||
def _collect_metrics(self) -> Dict[str, _RowResult]:
|
||||
all_eval_results: dict = {}
|
||||
for c in self.configs:
|
||||
for callback in cast(list, c["callbacks"]):
|
||||
if isinstance(callback, EvaluatorCallbackHandler):
|
||||
eval_results = callback.logged_eval_results
|
||||
for (_, example_id), v in eval_results.items():
|
||||
all_eval_results.setdefault(str(example_id), {}).update(
|
||||
{"feedback": v}
|
||||
)
|
||||
elif isinstance(callback, LangChainTracer):
|
||||
run = callback.latest_run
|
||||
execution_time = (
|
||||
(run.end_time - run.start_time).total_seconds()
|
||||
if run and run.end_time
|
||||
else None
|
||||
)
|
||||
run_id = str(run.id) if run else None
|
||||
all_eval_results.setdefault(str(callback.example_id), {}).update(
|
||||
{
|
||||
"execution_time": execution_time,
|
||||
"run_id": run_id,
|
||||
}
|
||||
)
|
||||
return cast(Dict[str, _RowResult], all_eval_results)
|
||||
|
||||
def _collect_test_results(
|
||||
self,
|
||||
batch_results: List[Union[dict, str, LLMResult, ChatResult]],
|
||||
) -> TestResult:
|
||||
wait_for_all_evaluators()
|
||||
all_eval_results = self._collect_metrics()
|
||||
results = self._merge_test_outputs(batch_results, all_eval_results)
|
||||
return TestResult(
|
||||
project_name=self.project.name,
|
||||
results=results,
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
return wrapped_model, project_name, examples, configs
|
||||
|
||||
def finish(self, batch_results: list, verbose: bool = False) -> TestResult:
|
||||
results = self._collect_test_results(batch_results)
|
||||
if verbose:
|
||||
try:
|
||||
agg_feedback = results.get_aggregate_feedback()
|
||||
_display_aggregate_results(agg_feedback)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to print aggregate feedback: {repr(e)}")
|
||||
try:
|
||||
# Closing the project permits name changing and metric optimizations
|
||||
self.client.update_project(self.project.id, end_time=datetime.utcnow())
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to close project: {repr(e)}")
|
||||
return results
|
||||
|
||||
def _collect_test_results(
|
||||
examples: List[Example],
|
||||
batch_results: List[Union[dict, str, LLMResult, ChatResult]],
|
||||
configs: List[RunnableConfig],
|
||||
project_name: str,
|
||||
) -> TestResult:
|
||||
wait_for_all_evaluators()
|
||||
all_eval_results = {}
|
||||
all_execution_time = {}
|
||||
all_run_ids = {}
|
||||
for c in configs:
|
||||
for callback in cast(list, c["callbacks"]):
|
||||
if isinstance(callback, EvaluatorCallbackHandler):
|
||||
eval_results = callback.logged_eval_results
|
||||
all_eval_results.update(
|
||||
{example_id: v for (_, example_id), v in eval_results.items()}
|
||||
)
|
||||
elif isinstance(callback, LangChainTracer):
|
||||
run = callback.latest_run
|
||||
example_id = callback.example_id
|
||||
run_id = str(run.id) if run else None
|
||||
execution_time = (
|
||||
(run.end_time - run.start_time).total_seconds()
|
||||
if run and run.end_time
|
||||
else None
|
||||
)
|
||||
all_execution_time[str(example_id)] = execution_time
|
||||
all_run_ids[str(example_id)] = run_id
|
||||
|
||||
results: dict = {}
|
||||
for example, output in zip(examples, batch_results):
|
||||
feedback = all_eval_results.get(str(example.id), [])
|
||||
results[str(example.id)] = {
|
||||
"input": example.inputs,
|
||||
"feedback": feedback,
|
||||
"execution_time": all_execution_time.get(str(example.id)),
|
||||
"run_id": all_run_ids.get(str(example.id)),
|
||||
}
|
||||
if isinstance(output, EvalError):
|
||||
results[str(example.id)]["Error"] = output.Error
|
||||
else:
|
||||
results[str(example.id)]["output"] = output
|
||||
if example.outputs:
|
||||
results[str(example.id)]["reference"] = example.outputs
|
||||
return TestResult(
|
||||
project_name=project_name,
|
||||
results=results,
|
||||
)
|
||||
@classmethod
|
||||
def prepare(
|
||||
cls,
|
||||
client: Client,
|
||||
dataset_name: str,
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
project_name: Optional[str],
|
||||
evaluation: Optional[smith_eval.RunEvalConfig] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
concurrency_level: int = 5,
|
||||
project_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> _DatasetRunContainer:
|
||||
project_name = project_name or name_generation.random_name()
|
||||
wrapped_model, project, dataset, examples = _prepare_eval_run(
|
||||
client,
|
||||
dataset_name,
|
||||
llm_or_chain_factory,
|
||||
project_name,
|
||||
project_metadata=project_metadata,
|
||||
tags=tags,
|
||||
)
|
||||
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
|
||||
run_evaluators = _setup_evaluation(
|
||||
wrapped_model, examples, evaluation, dataset.data_type or DataType.kv
|
||||
)
|
||||
_validate_example_inputs(examples[0], wrapped_model, input_mapper)
|
||||
progress_bar = progress.ProgressBarCallback(len(examples))
|
||||
configs = [
|
||||
RunnableConfig(
|
||||
callbacks=[
|
||||
LangChainTracer(
|
||||
project_name=project.name,
|
||||
client=client,
|
||||
use_threading=False,
|
||||
example_id=example.id,
|
||||
),
|
||||
EvaluatorCallbackHandler(
|
||||
evaluators=run_evaluators or [],
|
||||
client=client,
|
||||
example_id=example.id,
|
||||
max_concurrency=0,
|
||||
),
|
||||
progress_bar,
|
||||
],
|
||||
tags=tags or [],
|
||||
max_concurrency=concurrency_level,
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
return cls(
|
||||
client=client,
|
||||
project=project,
|
||||
wrapped_model=wrapped_model,
|
||||
examples=examples,
|
||||
configs=configs,
|
||||
)
|
||||
|
||||
|
||||
def _is_jupyter_environment() -> bool:
|
||||
@@ -1125,7 +1179,7 @@ async def arun_on_dataset(
|
||||
removal="0.0.305",
|
||||
)
|
||||
client = client or Client()
|
||||
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
|
||||
container = _DatasetRunContainer.prepare(
|
||||
client,
|
||||
dataset_name,
|
||||
llm_or_chain_factory,
|
||||
@@ -1137,26 +1191,18 @@ async def arun_on_dataset(
|
||||
project_metadata=project_metadata,
|
||||
)
|
||||
batch_results = await runnable_utils.gather_with_concurrency(
|
||||
configs[0].get("max_concurrency"),
|
||||
container.configs[0].get("max_concurrency"),
|
||||
*map(
|
||||
functools.partial(
|
||||
_arun_llm_or_chain,
|
||||
llm_or_chain_factory=wrapped_model,
|
||||
llm_or_chain_factory=container.wrapped_model,
|
||||
input_mapper=input_mapper,
|
||||
),
|
||||
examples,
|
||||
configs,
|
||||
container.examples,
|
||||
container.configs,
|
||||
),
|
||||
)
|
||||
results = _collect_test_results(examples, batch_results, configs, project_name)
|
||||
if verbose:
|
||||
try:
|
||||
agg_feedback = results.get_aggregate_feedback()
|
||||
print("\n Eval quantiles:")
|
||||
print(agg_feedback)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to print aggregate feedback: {repr(e)}")
|
||||
return results
|
||||
return container.finish(batch_results, verbose=verbose)
|
||||
|
||||
|
||||
def run_on_dataset(
|
||||
@@ -1185,7 +1231,7 @@ def run_on_dataset(
|
||||
removal="0.0.305",
|
||||
)
|
||||
client = client or Client()
|
||||
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
|
||||
container = _DatasetRunContainer.prepare(
|
||||
client,
|
||||
dataset_name,
|
||||
llm_or_chain_factory,
|
||||
@@ -1201,33 +1247,26 @@ def run_on_dataset(
|
||||
_run_llm_or_chain(
|
||||
example,
|
||||
config,
|
||||
llm_or_chain_factory=wrapped_model,
|
||||
llm_or_chain_factory=container.wrapped_model,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
for example, config in zip(examples, configs)
|
||||
for example, config in zip(container.examples, container.configs)
|
||||
]
|
||||
else:
|
||||
with runnable_config.get_executor_for_config(configs[0]) as executor:
|
||||
with runnable_config.get_executor_for_config(container.configs[0]) as executor:
|
||||
batch_results = list(
|
||||
executor.map(
|
||||
functools.partial(
|
||||
_run_llm_or_chain,
|
||||
llm_or_chain_factory=wrapped_model,
|
||||
llm_or_chain_factory=container.wrapped_model,
|
||||
input_mapper=input_mapper,
|
||||
),
|
||||
examples,
|
||||
configs,
|
||||
container.examples,
|
||||
container.configs,
|
||||
)
|
||||
)
|
||||
|
||||
results = _collect_test_results(examples, batch_results, configs, project_name)
|
||||
if verbose:
|
||||
try:
|
||||
agg_feedback = results.get_aggregate_feedback()
|
||||
_display_aggregate_results(agg_feedback)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to print aggregate feedback: {repr(e)}")
|
||||
return results
|
||||
return container.finish(batch_results, verbose=verbose)
|
||||
|
||||
|
||||
_RUN_ON_DATASET_DOCSTRING = """
|
||||
|
||||
40
libs/langchain/poetry.lock
generated
40
libs/langchain/poetry.lock
generated
@@ -3133,7 +3133,6 @@ optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
|
||||
files = [
|
||||
{file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"},
|
||||
{file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3447,7 +3446,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-community"
|
||||
version = "0.0.2"
|
||||
version = "0.0.3"
|
||||
description = "Community contributed LangChain integrations."
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@@ -3475,7 +3474,7 @@ url = "../community"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@@ -3485,7 +3484,7 @@ develop = true
|
||||
[package.dependencies]
|
||||
anyio = ">=3,<5"
|
||||
jsonpatch = "^1.33"
|
||||
langsmith = "~0.0.63"
|
||||
langsmith = "~0.0.70"
|
||||
packaging = "^23.2"
|
||||
pydantic = ">=1,<3"
|
||||
PyYAML = ">=5.3"
|
||||
@@ -3501,13 +3500,13 @@ url = "../core"
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.0.63"
|
||||
version = "0.0.70"
|
||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
files = [
|
||||
{file = "langsmith-0.0.63-py3-none-any.whl", hash = "sha256:43a521dd10d8405ac21a0b959e3de33e2270e4abe6c73cc4036232a6990a0793"},
|
||||
{file = "langsmith-0.0.63.tar.gz", hash = "sha256:ddb2dfadfad3e05151ed8ba1643d1c516024b80fbd0c6263024400ced06a3768"},
|
||||
{file = "langsmith-0.0.70-py3-none-any.whl", hash = "sha256:a0d4cac3af94fe44c2ef3814c32b6740f92aebe267e395d62e62040bc5bad343"},
|
||||
{file = "langsmith-0.0.70.tar.gz", hash = "sha256:3a546c45e67f6600d6669ef63f1f58b772e505703126338ad4f22fe0e2bbf677"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3727,16 +3726,6 @@ files = [
|
||||
{file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"},
|
||||
{file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"},
|
||||
{file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"},
|
||||
{file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"},
|
||||
{file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"},
|
||||
{file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"},
|
||||
@@ -5262,8 +5251,6 @@ files = [
|
||||
{file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"},
|
||||
{file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"},
|
||||
{file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"},
|
||||
{file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"},
|
||||
{file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"},
|
||||
{file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"},
|
||||
{file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"},
|
||||
{file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"},
|
||||
@@ -5306,7 +5293,6 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"},
|
||||
{file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"},
|
||||
@@ -5315,8 +5301,6 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"},
|
||||
@@ -6305,7 +6289,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
|
||||
@@ -6313,15 +6296,8 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
|
||||
@@ -6338,7 +6314,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
|
||||
@@ -6346,7 +6321,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
|
||||
{file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
|
||||
@@ -9103,4 +9077,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "0b232a037505cefcdf2203edc9d750e70e2e52a297475490022402994c3036a3"
|
||||
content-hash = "e93141191088db7b4aec1a976ebd8cb20075e26d4a987bf97c0495ad865b7460"
|
||||
|
||||
@@ -80,7 +80,7 @@ cassio = {version = "^0.1.0", optional = true}
|
||||
sympy = {version = "^1.12", optional = true}
|
||||
rapidfuzz = {version = "^3.1.1", optional = true}
|
||||
jsonschema = {version = ">1", optional = true}
|
||||
langsmith = "~0.0.63"
|
||||
langsmith = "~0.0.70"
|
||||
rank-bm25 = {version = "^0.2.2", optional = true}
|
||||
geopandas = {version = "^0.13.1", optional = true}
|
||||
gitpython = {version = "^3.1.32", optional = true}
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from langchain.agents.format_scratchpad.openai_tools import (
|
||||
format_to_openai_tool_messages,
|
||||
)
|
||||
from langchain.agents.output_parsers.openai_tools import (
|
||||
parse_ai_message_to_openai_tool_action,
|
||||
)
|
||||
|
||||
|
||||
def test_calls_convert_agent_action_to_messages() -> None:
|
||||
additional_kwargs1 = {
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abcd12345",
|
||||
"function": {"arguments": '{"a": 3, "b": 5}', "name": "add"},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
}
|
||||
message1 = AIMessage(content="", additional_kwargs=additional_kwargs1)
|
||||
|
||||
actions1 = parse_ai_message_to_openai_tool_action(message1)
|
||||
additional_kwargs2 = {
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abcd54321",
|
||||
"function": {"arguments": '{"a": 3, "b": 5}', "name": "subtract"},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
}
|
||||
message2 = AIMessage(content="", additional_kwargs=additional_kwargs2)
|
||||
actions2 = parse_ai_message_to_openai_tool_action(message2)
|
||||
|
||||
additional_kwargs3 = {
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abcd67890",
|
||||
"function": {"arguments": '{"a": 3, "b": 5}', "name": "multiply"},
|
||||
"type": "function",
|
||||
},
|
||||
{
|
||||
"id": "call_abcd09876",
|
||||
"function": {"arguments": '{"a": 3, "b": 5}', "name": "divide"},
|
||||
"type": "function",
|
||||
},
|
||||
],
|
||||
}
|
||||
message3 = AIMessage(content="", additional_kwargs=additional_kwargs3)
|
||||
actions3 = parse_ai_message_to_openai_tool_action(message3)
|
||||
# for mypy
|
||||
assert isinstance(actions1, list)
|
||||
assert isinstance(actions2, list)
|
||||
assert isinstance(actions3, list)
|
||||
|
||||
intermediate_steps = [
|
||||
(actions1[0], "observation1"),
|
||||
(actions2[0], "observation2"),
|
||||
(actions3[0], "observation3"),
|
||||
(actions3[1], "observation4"),
|
||||
]
|
||||
expected_messages = [
|
||||
message1,
|
||||
ToolMessage(
|
||||
tool_call_id="call_abcd12345",
|
||||
content="observation1",
|
||||
additional_kwargs={"name": "add"},
|
||||
),
|
||||
message2,
|
||||
ToolMessage(
|
||||
tool_call_id="call_abcd54321",
|
||||
content="observation2",
|
||||
additional_kwargs={"name": "subtract"},
|
||||
),
|
||||
message3,
|
||||
ToolMessage(
|
||||
tool_call_id="call_abcd67890",
|
||||
content="observation3",
|
||||
additional_kwargs={"name": "multiply"},
|
||||
),
|
||||
ToolMessage(
|
||||
tool_call_id="call_abcd09876",
|
||||
content="observation4",
|
||||
additional_kwargs={"name": "divide"},
|
||||
),
|
||||
]
|
||||
output = format_to_openai_tool_messages(intermediate_steps)
|
||||
assert output == expected_messages
|
||||
|
||||
|
||||
def test_handles_empty_input_list() -> None:
|
||||
output = format_to_openai_tool_messages([])
|
||||
assert output == []
|
||||
@@ -4,10 +4,17 @@ This package contains the LangChain integrations for Gemini through their genera
|
||||
|
||||
## Installation
|
||||
|
||||
```python
|
||||
```bash
|
||||
pip install -U langchain-google-genai
|
||||
```
|
||||
|
||||
### Image utilities
|
||||
To use image utility methods, like loading images from GCS urls, install with extras group 'images':
|
||||
|
||||
```bash
|
||||
pip install -e "langchain-google-genai[images]"
|
||||
```
|
||||
|
||||
## Chat Models
|
||||
|
||||
This package contains the `ChatGoogleGenerativeAI` class, which is the recommended way to interface with the Google Gemini series of models.
|
||||
@@ -56,3 +63,16 @@ The value of `image_url` can be any of the following:
|
||||
- A local file path
|
||||
- A base64 encoded image (e.g., `data:image/png;base64,abcd124`)
|
||||
- A PIL image
|
||||
|
||||
|
||||
|
||||
## Embeddings
|
||||
|
||||
This package also adds support for google's embeddings models.
|
||||
|
||||
```
|
||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||
|
||||
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
||||
embeddings.embed_query("hello, world!")
|
||||
```
|
||||
@@ -1,3 +1,65 @@
|
||||
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
|
||||
"""**LangChain Google Generative AI Integration**
|
||||
|
||||
__all__ = ["ChatGoogleGenerativeAI"]
|
||||
This module integrates Google's Generative AI models, specifically the Gemini series, with the LangChain framework. It provides classes for interacting with chat models and generating embeddings, leveraging Google's advanced AI capabilities.
|
||||
|
||||
**Chat Models**
|
||||
|
||||
The `ChatGoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini chat models. It allows users to send and receive messages using a specified Gemini model, suitable for various conversational AI applications.
|
||||
|
||||
**LLMs**
|
||||
|
||||
The `GoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini LLMs. It allows users to generate text using a specified Gemini model.
|
||||
|
||||
**Embeddings**
|
||||
|
||||
The `GoogleGenerativeAIEmbeddings` class provides functionalities to generate embeddings using Google's models.
|
||||
These embeddings can be used for a range of NLP tasks, including semantic analysis, similarity comparisons, and more.
|
||||
**Installation**
|
||||
|
||||
To install the package, use pip:
|
||||
|
||||
```python
|
||||
pip install -U langchain-google-genai
|
||||
```
|
||||
## Using Chat Models
|
||||
|
||||
After setting up your environment with the required API key, you can interact with the Google Gemini models.
|
||||
|
||||
```python
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
llm = ChatGoogleGenerativeAI(model="gemini-pro")
|
||||
llm.invoke("Sing a ballad of LangChain.")
|
||||
```
|
||||
|
||||
## Using LLMs
|
||||
|
||||
The package also supports generating text with Google's models.
|
||||
|
||||
```python
|
||||
from langchain_google_genai import GoogleGenerativeAI
|
||||
|
||||
llm = GoogleGenerativeAI(model="gemini-pro")
|
||||
llm.invoke("Once upon a time, a library called LangChain")
|
||||
```
|
||||
|
||||
## Embedding Generation
|
||||
|
||||
The package also supports creating embeddings with Google's models, useful for textual similarity and other NLP applications.
|
||||
|
||||
```python
|
||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||
|
||||
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
||||
embeddings.embed_query("hello, world!")
|
||||
```
|
||||
""" # noqa: E501
|
||||
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
|
||||
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
||||
from langchain_google_genai.llms import GoogleGenerativeAI
|
||||
|
||||
__all__ = [
|
||||
"ChatGoogleGenerativeAI",
|
||||
"GoogleGenerativeAIEmbeddings",
|
||||
"GoogleGenerativeAI",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
class GoogleGenerativeAIError(Exception):
|
||||
"""
|
||||
Custom exception class for errors associated with the `Google GenAI` API.
|
||||
"""
|
||||
@@ -5,7 +5,6 @@ import logging
|
||||
import os
|
||||
from io import BytesIO
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
@@ -22,6 +21,8 @@ from typing import (
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
# TODO: remove ignore once the google package is published with types
|
||||
import google.generativeai as genai # type: ignore[import]
|
||||
import requests
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@@ -38,7 +39,7 @@ from langchain_core.messages import (
|
||||
HumanMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
@@ -48,11 +49,8 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from langchain_google_genai._common import GoogleGenerativeAIError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# TODO: remove ignore once the google package is published with types
|
||||
import google.generativeai as genai # type: ignore[import]
|
||||
IMAGE_TYPES: Tuple = ()
|
||||
try:
|
||||
import PIL
|
||||
@@ -63,8 +61,10 @@ except ImportError:
|
||||
PIL = None # type: ignore
|
||||
Image = None # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ChatGoogleGenerativeAIError(Exception):
|
||||
|
||||
class ChatGoogleGenerativeAIError(GoogleGenerativeAIError):
|
||||
"""
|
||||
Custom exception class for errors associated with the `Google GenAI` API.
|
||||
|
||||
@@ -106,7 +106,7 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||
)
|
||||
|
||||
|
||||
def chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
|
||||
def _chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Executes a chat generation method with retry logic using tenacity.
|
||||
|
||||
@@ -139,7 +139,7 @@ def chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
|
||||
return _chat_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
|
||||
async def _achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Executes a chat generation method with retry logic using tenacity.
|
||||
|
||||
@@ -269,8 +269,6 @@ def _convert_to_parts(
|
||||
content: Sequence[Union[str, dict]],
|
||||
) -> List[genai.types.PartType]:
|
||||
"""Converts a list of LangChain messages into a google parts."""
|
||||
import google.generativeai as genai
|
||||
|
||||
parts = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
@@ -410,8 +408,7 @@ def _response_to_result(
|
||||
class ChatGoogleGenerativeAI(BaseChatModel):
|
||||
"""`Google Generative AI` Chat models API.
|
||||
|
||||
To use you must have the google.generativeai Python package installed and
|
||||
either:
|
||||
To use, you must have either:
|
||||
|
||||
1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
|
||||
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
|
||||
@@ -435,7 +432,7 @@ Supported examples:
|
||||
max_output_tokens: int = Field(default=None, description="Max output tokens")
|
||||
|
||||
client: Any #: :meta private:
|
||||
google_api_key: Optional[str] = None
|
||||
google_api_key: Optional[SecretStr] = None
|
||||
temperature: Optional[float] = None
|
||||
"""Run inference with this temperature. Must by in the closed
|
||||
interval [0.0, 1.0]."""
|
||||
@@ -487,17 +484,9 @@ Supported examples:
|
||||
google_api_key = get_from_dict_or_env(
|
||||
values, "google_api_key", "GOOGLE_API_KEY"
|
||||
)
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
|
||||
genai.configure(api_key=google_api_key)
|
||||
except ImportError:
|
||||
raise ChatGoogleGenerativeAIError(
|
||||
"Could not import google.generativeai python package. "
|
||||
"Please install it with `pip install google-generativeai`"
|
||||
)
|
||||
|
||||
values["client"] = genai
|
||||
if isinstance(google_api_key, SecretStr):
|
||||
google_api_key = google_api_key.get_secret_value()
|
||||
genai.configure(api_key=google_api_key)
|
||||
if (
|
||||
values.get("temperature") is not None
|
||||
and not 0 <= values["temperature"] <= 1
|
||||
@@ -560,7 +549,7 @@ Supported examples:
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
params = self._prepare_params(messages, stop, **kwargs)
|
||||
response: genai.types.GenerateContentResponse = chat_with_retry(
|
||||
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
||||
**params,
|
||||
generation_method=self._generation_method,
|
||||
)
|
||||
@@ -574,7 +563,7 @@ Supported examples:
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
params = self._prepare_params(messages, stop, **kwargs)
|
||||
response: genai.types.GenerateContentResponse = await achat_with_retry(
|
||||
response: genai.types.GenerateContentResponse = await _achat_with_retry(
|
||||
**params,
|
||||
generation_method=self._async_generation_method,
|
||||
)
|
||||
@@ -588,7 +577,7 @@ Supported examples:
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._prepare_params(messages, stop, **kwargs)
|
||||
response: genai.types.GenerateContentResponse = chat_with_retry(
|
||||
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
||||
**params,
|
||||
generation_method=self._generation_method,
|
||||
stream=True,
|
||||
@@ -614,7 +603,7 @@ Supported examples:
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
params = self._prepare_params(messages, stop, **kwargs)
|
||||
async for chunk in await achat_with_retry(
|
||||
async for chunk in await _achat_with_retry(
|
||||
**params,
|
||||
generation_method=self._async_generation_method,
|
||||
stream=True,
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
# TODO: remove ignore once the google package is published with types
|
||||
import google.generativeai as genai # type: ignore[import]
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_google_genai._common import GoogleGenerativeAIError
|
||||
|
||||
|
||||
class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
|
||||
"""`Google Generative AI Embeddings`.
|
||||
|
||||
To use, you must have either:
|
||||
|
||||
1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
|
||||
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
|
||||
constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||
|
||||
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
||||
embeddings.embed_query("What's our Q1 revenue?")
|
||||
"""
|
||||
|
||||
model: str = Field(
|
||||
...,
|
||||
description="The name of the embedding model to use. "
|
||||
"Example: models/embedding-001",
|
||||
)
|
||||
task_type: Optional[str] = Field(
|
||||
None,
|
||||
description="The task type. Valid options include: "
|
||||
"task_type_unspecified, retrieval_query, retrieval_document, "
|
||||
"semantic_similarity, classification, and clustering",
|
||||
)
|
||||
google_api_key: Optional[SecretStr] = Field(
|
||||
None,
|
||||
description="The Google API key to use. If not provided, "
|
||||
"the GOOGLE_API_KEY environment variable will be used.",
|
||||
)
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validates that the python package exists in environment."""
|
||||
google_api_key = get_from_dict_or_env(
|
||||
values, "google_api_key", "GOOGLE_API_KEY"
|
||||
)
|
||||
if isinstance(google_api_key, SecretStr):
|
||||
google_api_key = google_api_key.get_secret_value()
|
||||
genai.configure(api_key=google_api_key)
|
||||
return values
|
||||
|
||||
def _embed(
|
||||
self, texts: List[str], task_type: str, title: Optional[str] = None
|
||||
) -> List[List[float]]:
|
||||
task_type = self.task_type or "retrieval_document"
|
||||
try:
|
||||
result = genai.embed_content(
|
||||
model=self.model,
|
||||
content=texts,
|
||||
task_type=task_type,
|
||||
title=title,
|
||||
)
|
||||
except Exception as e:
|
||||
raise GoogleGenerativeAIError(f"Error embedding content: {e}") from e
|
||||
return result["embedding"]
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], batch_size: int = 5
|
||||
) -> List[List[float]]:
|
||||
"""Embed a list of strings. Vertex AI currently
|
||||
sets a max batch size of 5 strings.
|
||||
|
||||
Args:
|
||||
texts: List[str] The list of strings to embed.
|
||||
batch_size: [int] The batch size of embeddings to send to the model
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
task_type = self.task_type or "retrieval_document"
|
||||
return self._embed(texts, task_type=task_type)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
task_type = self.task_type or "retrieval_query"
|
||||
return self._embed([text], task_type=task_type)[0]
|
||||
262
libs/partners/google-genai/langchain_google_genai/llms.py
Normal file
262
libs/partners/google-genai/langchain_google_genai/llms.py
Normal file
@@ -0,0 +1,262 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import google.api_core
|
||||
import google.generativeai as genai # type: ignore[import]
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: BaseLLM,
|
||||
*,
|
||||
max_retries: int = 1,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Creates a retry decorator for Vertex / Palm LLMs."""
|
||||
|
||||
errors = [
|
||||
google.api_core.exceptions.ResourceExhausted,
|
||||
google.api_core.exceptions.ServiceUnavailable,
|
||||
google.api_core.exceptions.Aborted,
|
||||
google.api_core.exceptions.DeadlineExceeded,
|
||||
google.api_core.exceptions.GoogleAPIError,
|
||||
]
|
||||
decorator = create_base_retry_decorator(
|
||||
error_types=errors, max_retries=max_retries, run_manager=run_manager
|
||||
)
|
||||
return decorator
|
||||
|
||||
|
||||
def _completion_with_retry(
|
||||
llm: GoogleGenerativeAI,
|
||||
prompt: LanguageModelInput,
|
||||
is_gemini: bool = False,
|
||||
stream: bool = False,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(
|
||||
llm, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(
|
||||
prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
|
||||
) -> Any:
|
||||
generation_config = kwargs.get("generation_config", {})
|
||||
if is_gemini:
|
||||
return llm.client.generate_content(
|
||||
contents=prompt, stream=stream, generation_config=generation_config
|
||||
)
|
||||
return llm.client.generate_text(prompt=prompt, **kwargs)
|
||||
|
||||
return _completion_with_retry(
|
||||
prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def _is_gemini_model(model_name: str) -> bool:
|
||||
return "gemini" in model_name
|
||||
|
||||
|
||||
def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||
"""Strip erroneous leading spaces from text.
|
||||
|
||||
The PaLM API will sometimes erroneously return a single leading space in all
|
||||
lines > 1. This function strips that space.
|
||||
"""
|
||||
has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:])
|
||||
if has_leading_space:
|
||||
return text.replace("\n ", "\n")
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
class GoogleGenerativeAI(BaseLLM, BaseModel):
|
||||
"""Google GenerativeAI models.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_google_genai import GoogleGenerativeAI
|
||||
llm = GoogleGenerativeAI(model="gemini-pro")
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model: str = Field(
|
||||
...,
|
||||
description="""The name of the model to use.
|
||||
Supported examples:
|
||||
- gemini-pro
|
||||
- models/text-bison-001""",
|
||||
)
|
||||
"""Model name to use."""
|
||||
google_api_key: Optional[SecretStr] = None
|
||||
temperature: float = 0.7
|
||||
"""Run inference with this temperature. Must by in the closed interval
|
||||
[0.0, 1.0]."""
|
||||
top_p: Optional[float] = None
|
||||
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
||||
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
||||
top_k: Optional[int] = None
|
||||
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
||||
Must be positive."""
|
||||
max_output_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to include in a candidate. Must be greater than zero.
|
||||
If unset, will default to 64."""
|
||||
n: int = 1
|
||||
"""Number of chat completions to generate for each prompt. Note that the API may
|
||||
not return the full n completions if duplicates are generated."""
|
||||
max_retries: int = 6
|
||||
"""The maximum number of retries to make when generating."""
|
||||
|
||||
@property
|
||||
def is_gemini(self) -> bool:
|
||||
"""Returns whether a model is belongs to a Gemini family or not."""
|
||||
return _is_gemini_model(self.model)
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"google_api_key": "GOOGLE_API_KEY"}
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate api key, python package exists."""
|
||||
google_api_key = get_from_dict_or_env(
|
||||
values, "google_api_key", "GOOGLE_API_KEY"
|
||||
)
|
||||
model_name = values["model"]
|
||||
|
||||
if isinstance(google_api_key, SecretStr):
|
||||
google_api_key = google_api_key.get_secret_value()
|
||||
|
||||
genai.configure(api_key=google_api_key)
|
||||
|
||||
if _is_gemini_model(model_name):
|
||||
values["client"] = genai.GenerativeModel(model_name=model_name)
|
||||
else:
|
||||
values["client"] = genai
|
||||
|
||||
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_k"] is not None and values["top_k"] <= 0:
|
||||
raise ValueError("top_k must be positive")
|
||||
|
||||
if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0:
|
||||
raise ValueError("max_output_tokens must be greater than zero")
|
||||
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
generations: List[List[Generation]] = []
|
||||
generation_config = {
|
||||
"stop_sequences": stop,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"max_output_tokens": self.max_output_tokens,
|
||||
"candidate_count": self.n,
|
||||
}
|
||||
for prompt in prompts:
|
||||
if self.is_gemini:
|
||||
res = _completion_with_retry(
|
||||
self,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
is_gemini=True,
|
||||
run_manager=run_manager,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
candidates = [
|
||||
"".join([p.text for p in c.content.parts]) for c in res.candidates
|
||||
]
|
||||
generations.append([Generation(text=c) for c in candidates])
|
||||
else:
|
||||
res = _completion_with_retry(
|
||||
self,
|
||||
model=self.model,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
is_gemini=False,
|
||||
run_manager=run_manager,
|
||||
**generation_config,
|
||||
)
|
||||
prompt_generations = []
|
||||
for candidate in res.candidates:
|
||||
raw_text = candidate["output"]
|
||||
stripped_text = _strip_erroneous_leading_spaces(raw_text)
|
||||
prompt_generations.append(Generation(text=stripped_text))
|
||||
generations.append(prompt_generations)
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
generation_config = kwargs.get("generation_config", {})
|
||||
if stop:
|
||||
generation_config["stop_sequences"] = stop
|
||||
for stream_resp in _completion_with_retry(
|
||||
self,
|
||||
prompt,
|
||||
stream=True,
|
||||
is_gemini=True,
|
||||
run_manager=run_manager,
|
||||
generation_config=generation_config,
|
||||
**kwargs,
|
||||
):
|
||||
chunk = GenerationChunk(text=stream_resp.text)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
stream_resp.text,
|
||||
chunk=chunk,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "google_palm"
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
Returns:
|
||||
The integer number of tokens in the text.
|
||||
"""
|
||||
if self.is_gemini:
|
||||
raise ValueError("Counting tokens is not yet supported!")
|
||||
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
||||
return result["token_count"]
|
||||
50
libs/partners/google-genai/poetry.lock
generated
50
libs/partners/google-genai/poetry.lock
generated
@@ -546,6 +546,51 @@ files = [
|
||||
{file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "1.26.2"
|
||||
description = "Fundamental package for array computing in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "numpy-1.26.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3703fc9258a4a122d17043e57b35e5ef1c5a5837c3db8be396c82e04c1cf9b0f"},
|
||||
{file = "numpy-1.26.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cc392fdcbd21d4be6ae1bb4475a03ce3b025cd49a9be5345d76d7585aea69440"},
|
||||
{file = "numpy-1.26.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36340109af8da8805d8851ef1d74761b3b88e81a9bd80b290bbfed61bd2b4f75"},
|
||||
{file = "numpy-1.26.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcc008217145b3d77abd3e4d5ef586e3bdfba8fe17940769f8aa09b99e856c00"},
|
||||
{file = "numpy-1.26.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3ced40d4e9e18242f70dd02d739e44698df3dcb010d31f495ff00a31ef6014fe"},
|
||||
{file = "numpy-1.26.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b272d4cecc32c9e19911891446b72e986157e6a1809b7b56518b4f3755267523"},
|
||||
{file = "numpy-1.26.2-cp310-cp310-win32.whl", hash = "sha256:22f8fc02fdbc829e7a8c578dd8d2e15a9074b630d4da29cda483337e300e3ee9"},
|
||||
{file = "numpy-1.26.2-cp310-cp310-win_amd64.whl", hash = "sha256:26c9d33f8e8b846d5a65dd068c14e04018d05533b348d9eaeef6c1bd787f9919"},
|
||||
{file = "numpy-1.26.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b96e7b9c624ef3ae2ae0e04fa9b460f6b9f17ad8b4bec6d7756510f1f6c0c841"},
|
||||
{file = "numpy-1.26.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:aa18428111fb9a591d7a9cc1b48150097ba6a7e8299fb56bdf574df650e7d1f1"},
|
||||
{file = "numpy-1.26.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06fa1ed84aa60ea6ef9f91ba57b5ed963c3729534e6e54055fc151fad0423f0a"},
|
||||
{file = "numpy-1.26.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96ca5482c3dbdd051bcd1fce8034603d6ebfc125a7bd59f55b40d8f5d246832b"},
|
||||
{file = "numpy-1.26.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:854ab91a2906ef29dc3925a064fcd365c7b4da743f84b123002f6139bcb3f8a7"},
|
||||
{file = "numpy-1.26.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f43740ab089277d403aa07567be138fc2a89d4d9892d113b76153e0e412409f8"},
|
||||
{file = "numpy-1.26.2-cp311-cp311-win32.whl", hash = "sha256:a2bbc29fcb1771cd7b7425f98b05307776a6baf43035d3b80c4b0f29e9545186"},
|
||||
{file = "numpy-1.26.2-cp311-cp311-win_amd64.whl", hash = "sha256:2b3fca8a5b00184828d12b073af4d0fc5fdd94b1632c2477526f6bd7842d700d"},
|
||||
{file = "numpy-1.26.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a4cd6ed4a339c21f1d1b0fdf13426cb3b284555c27ac2f156dfdaaa7e16bfab0"},
|
||||
{file = "numpy-1.26.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5d5244aabd6ed7f312268b9247be47343a654ebea52a60f002dc70c769048e75"},
|
||||
{file = "numpy-1.26.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a3cdb4d9c70e6b8c0814239ead47da00934666f668426fc6e94cce869e13fd7"},
|
||||
{file = "numpy-1.26.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa317b2325f7aa0a9471663e6093c210cb2ae9c0ad824732b307d2c51983d5b6"},
|
||||
{file = "numpy-1.26.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:174a8880739c16c925799c018f3f55b8130c1f7c8e75ab0a6fa9d41cab092fd6"},
|
||||
{file = "numpy-1.26.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f79b231bf5c16b1f39c7f4875e1ded36abee1591e98742b05d8a0fb55d8a3eec"},
|
||||
{file = "numpy-1.26.2-cp312-cp312-win32.whl", hash = "sha256:4a06263321dfd3598cacb252f51e521a8cb4b6df471bb12a7ee5cbab20ea9167"},
|
||||
{file = "numpy-1.26.2-cp312-cp312-win_amd64.whl", hash = "sha256:b04f5dc6b3efdaab541f7857351aac359e6ae3c126e2edb376929bd3b7f92d7e"},
|
||||
{file = "numpy-1.26.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4eb8df4bf8d3d90d091e0146f6c28492b0be84da3e409ebef54349f71ed271ef"},
|
||||
{file = "numpy-1.26.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1a13860fdcd95de7cf58bd6f8bc5a5ef81c0b0625eb2c9a783948847abbef2c2"},
|
||||
{file = "numpy-1.26.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:64308ebc366a8ed63fd0bf426b6a9468060962f1a4339ab1074c228fa6ade8e3"},
|
||||
{file = "numpy-1.26.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baf8aab04a2c0e859da118f0b38617e5ee65d75b83795055fb66c0d5e9e9b818"},
|
||||
{file = "numpy-1.26.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d73a3abcac238250091b11caef9ad12413dab01669511779bc9b29261dd50210"},
|
||||
{file = "numpy-1.26.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:b361d369fc7e5e1714cf827b731ca32bff8d411212fccd29ad98ad622449cc36"},
|
||||
{file = "numpy-1.26.2-cp39-cp39-win32.whl", hash = "sha256:bd3f0091e845164a20bd5a326860c840fe2af79fa12e0469a12768a3ec578d80"},
|
||||
{file = "numpy-1.26.2-cp39-cp39-win_amd64.whl", hash = "sha256:2beef57fb031dcc0dc8fa4fe297a742027b954949cabb52a2a376c144e5e6060"},
|
||||
{file = "numpy-1.26.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:1cc3d5029a30fb5f06704ad6b23b35e11309491c999838c31f124fee32107c79"},
|
||||
{file = "numpy-1.26.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94cc3c222bb9fb5a12e334d0479b97bb2df446fbe622b470928f5284ffca3f8d"},
|
||||
{file = "numpy-1.26.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fe6b44fb8fcdf7eda4ef4461b97b3f63c466b27ab151bec2366db8b197387841"},
|
||||
{file = "numpy-1.26.2.tar.gz", hash = "sha256:f65738447676ab5777f11e6bbbdb8ce11b785e105f690bc45966574816b6d3ea"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "23.2"
|
||||
@@ -1226,7 +1271,10 @@ files = [
|
||||
[package.extras]
|
||||
watchmedo = ["PyYAML (>=3.10)"]
|
||||
|
||||
[extras]
|
||||
images = ["pillow"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<4.0"
|
||||
content-hash = "7753b9e2cb62c5b4dac124f0ff43027232c45138dbf07fdacc3c320b82367dad"
|
||||
content-hash = "f3b43f02c7300c3003347dbdfa9c07ddba988aab1387eda3efa02b2351c868d9"
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-google-genai"
|
||||
version = "0.0.2"
|
||||
version = "0.0.4"
|
||||
description = "An integration package connecting Google's genai package and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain/blob/master/libs/partners/google-genai"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9,<4.0"
|
||||
langchain-core = "^0.1"
|
||||
google-generativeai = "^0.3.1"
|
||||
pillow = { version = "^10.1.0", optional = true }
|
||||
|
||||
[tool.poetry.extras]
|
||||
images = ["pillow"]
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
@@ -16,11 +21,12 @@ optional = true
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
numpy = "^1.26.2"
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
@@ -32,6 +38,8 @@ codespell = "^2.2.0"
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
pillow = "^10.1.0"
|
||||
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
@@ -41,7 +49,7 @@ ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
types-requests = "^2.28.11.5"
|
||||
types-google-cloud-ndb = "^2.2.0.1"
|
||||
types-pillow = "^10.1.0.2"
|
||||
@@ -50,7 +58,7 @@ types-pillow = "^10.1.0.2"
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
pillow = "^10.1.0"
|
||||
types-requests = "^2.31.0.10"
|
||||
types-pillow = "^10.1.0.2"
|
||||
@@ -58,19 +66,16 @@ types-google-cloud-ndb = "^2.2.0.1"
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
exclude = ["notebooks", "examples", "example_data", "langchain_core/pydantic"]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"tests/*",
|
||||
]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from langchain_google_genai._common import GoogleGenerativeAIError
|
||||
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
||||
|
||||
_MODEL = "models/embedding-001"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query",
|
||||
[
|
||||
"Hi",
|
||||
"This is a longer query string to test the embedding functionality of the"
|
||||
" model against the pickle rick?",
|
||||
],
|
||||
)
|
||||
def test_embed_query_different_lengths(query: str) -> None:
|
||||
"""Test embedding queries of different lengths."""
|
||||
model = GoogleGenerativeAIEmbeddings(model=_MODEL)
|
||||
result = model.embed_query(query)
|
||||
assert len(result) == 768
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query",
|
||||
[
|
||||
"Hi",
|
||||
"This is a longer query string to test the embedding functionality of the"
|
||||
" model against the pickle rick?",
|
||||
],
|
||||
)
|
||||
async def test_aembed_query_different_lengths(query: str) -> None:
|
||||
"""Test embedding queries of different lengths."""
|
||||
model = GoogleGenerativeAIEmbeddings(model=_MODEL)
|
||||
result = await model.aembed_query(query)
|
||||
assert len(result) == 768
|
||||
|
||||
|
||||
def test_embed_documents() -> None:
|
||||
"""Test embedding a query."""
|
||||
model = GoogleGenerativeAIEmbeddings(
|
||||
model=_MODEL,
|
||||
)
|
||||
result = model.embed_documents(["Hello world", "Good day, world"])
|
||||
assert len(result) == 2
|
||||
assert len(result[0]) == 768
|
||||
assert len(result[1]) == 768
|
||||
|
||||
|
||||
async def test_aembed_documents() -> None:
|
||||
"""Test embedding a query."""
|
||||
model = GoogleGenerativeAIEmbeddings(
|
||||
model=_MODEL,
|
||||
)
|
||||
result = await model.aembed_documents(["Hello world", "Good day, world"])
|
||||
assert len(result) == 2
|
||||
assert len(result[0]) == 768
|
||||
assert len(result[1]) == 768
|
||||
|
||||
|
||||
def test_invalid_model_error_handling() -> None:
|
||||
"""Test error handling with an invalid model name."""
|
||||
with pytest.raises(GoogleGenerativeAIError):
|
||||
GoogleGenerativeAIEmbeddings(model="invalid_model").embed_query("Hello world")
|
||||
|
||||
|
||||
def test_invalid_api_key_error_handling() -> None:
|
||||
"""Test error handling with an invalid API key."""
|
||||
with pytest.raises(GoogleGenerativeAIError):
|
||||
GoogleGenerativeAIEmbeddings(
|
||||
model=_MODEL, google_api_key="invalid_key"
|
||||
).embed_query("Hello world")
|
||||
|
||||
|
||||
def test_embed_documents_consistency() -> None:
|
||||
"""Test embedding consistency for the same document."""
|
||||
model = GoogleGenerativeAIEmbeddings(model=_MODEL)
|
||||
doc = "Consistent document for testing"
|
||||
result1 = model.embed_documents([doc])
|
||||
result2 = model.embed_documents([doc])
|
||||
assert result1 == result2
|
||||
|
||||
|
||||
def test_embed_documents_quality() -> None:
|
||||
"""Smoke test embedding quality by comparing similar and dissimilar documents."""
|
||||
model = GoogleGenerativeAIEmbeddings(model=_MODEL)
|
||||
similar_docs = ["Document A", "Similar Document A"]
|
||||
dissimilar_docs = ["Document A", "Completely Different Zebra"]
|
||||
similar_embeddings = model.embed_documents(similar_docs)
|
||||
dissimilar_embeddings = model.embed_documents(dissimilar_docs)
|
||||
similar_distance = np.linalg.norm(
|
||||
np.array(similar_embeddings[0]) - np.array(similar_embeddings[1])
|
||||
)
|
||||
dissimilar_distance = np.linalg.norm(
|
||||
np.array(dissimilar_embeddings[0]) - np.array(dissimilar_embeddings[1])
|
||||
)
|
||||
assert similar_distance < dissimilar_distance
|
||||
@@ -0,0 +1,65 @@
|
||||
"""Test Google GenerativeAI API wrapper.
|
||||
|
||||
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
|
||||
valid API key.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_google_genai.llms import GoogleGenerativeAI
|
||||
|
||||
model_names = [None, "models/text-bison-001", "gemini-pro"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names,
|
||||
)
|
||||
def test_google_generativeai_call(model_name: str) -> None:
|
||||
"""Test valid call to Google GenerativeAI text API."""
|
||||
if model_name:
|
||||
llm = GoogleGenerativeAI(max_output_tokens=10, model=model_name)
|
||||
else:
|
||||
llm = GoogleGenerativeAI(max_output_tokens=10)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
assert llm._llm_type == "google_palm"
|
||||
if model_name and "gemini" in model_name:
|
||||
assert llm.client.model_name == "models/gemini-pro"
|
||||
else:
|
||||
assert llm.model == "models/text-bison-001"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names,
|
||||
)
|
||||
def test_google_generativeai_generate(model_name: str) -> None:
|
||||
n = 1 if model_name == "gemini-pro" else 2
|
||||
if model_name:
|
||||
llm = GoogleGenerativeAI(temperature=0.3, n=n, model=model_name)
|
||||
else:
|
||||
llm = GoogleGenerativeAI(temperature=0.3, n=n)
|
||||
output = llm.generate(["Say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
assert len(output.generations[0]) == n
|
||||
|
||||
|
||||
def test_google_generativeai_get_num_tokens() -> None:
|
||||
llm = GoogleGenerativeAI()
|
||||
output = llm.get_num_tokens("How are you?")
|
||||
assert output == 4
|
||||
|
||||
|
||||
async def test_google_generativeai_agenerate() -> None:
|
||||
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
|
||||
output = await llm.agenerate(["Please say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
def test_generativeai_stream() -> None:
|
||||
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
|
||||
outputs = list(llm.stream("Please say foo:"))
|
||||
assert isinstance(outputs[0], str)
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Test chat model integration."""
|
||||
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture
|
||||
|
||||
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
|
||||
|
||||
@@ -22,3 +23,16 @@ def test_integration_initialization() -> None:
|
||||
temperature=0.7,
|
||||
candidate_count=2,
|
||||
)
|
||||
|
||||
|
||||
def test_api_key_is_string() -> None:
|
||||
chat = ChatGoogleGenerativeAI(model="gemini-nano", google_api_key="secret-api-key")
|
||||
assert isinstance(chat.google_api_key, SecretStr)
|
||||
|
||||
|
||||
def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None:
|
||||
chat = ChatGoogleGenerativeAI(model="gemini-nano", google_api_key="secret-api-key")
|
||||
print(chat.google_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Test embeddings model integration."""
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture
|
||||
|
||||
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
||||
|
||||
|
||||
def test_integration_initialization() -> None:
|
||||
"""Test chat model initialization."""
|
||||
GoogleGenerativeAIEmbeddings(
|
||||
model="models/embedding-001",
|
||||
google_api_key="...",
|
||||
)
|
||||
GoogleGenerativeAIEmbeddings(
|
||||
model="models/embedding-001",
|
||||
google_api_key="...",
|
||||
task_type="retrieval_document",
|
||||
)
|
||||
|
||||
|
||||
def test_api_key_is_string() -> None:
|
||||
embeddings = GoogleGenerativeAIEmbeddings(
|
||||
model="models/embedding-001",
|
||||
google_api_key="secret-api-key",
|
||||
)
|
||||
assert isinstance(embeddings.google_api_key, SecretStr)
|
||||
|
||||
|
||||
def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None:
|
||||
embeddings = GoogleGenerativeAIEmbeddings(
|
||||
model="models/embedding-001",
|
||||
google_api_key="secret-api-key",
|
||||
)
|
||||
print(embeddings.google_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
@@ -2,6 +2,8 @@ from langchain_google_genai import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"ChatGoogleGenerativeAI",
|
||||
"GoogleGenerativeAIEmbeddings",
|
||||
"GoogleGenerativeAI",
|
||||
]
|
||||
|
||||
|
||||
|
||||
1
libs/partners/nvidia-aiplay/.gitignore
vendored
Normal file
1
libs/partners/nvidia-aiplay/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
__pycache__
|
||||
21
libs/partners/nvidia-aiplay/LICENSE
Normal file
21
libs/partners/nvidia-aiplay/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
62
libs/partners/nvidia-aiplay/Makefile
Normal file
62
libs/partners/nvidia-aiplay/Makefile
Normal file
@@ -0,0 +1,62 @@
|
||||
.PHONY: all format lint test tests integration_tests help
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
test:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
check_imports: $(shell find langchain_nvidia_aiplay -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
integration_tests:
|
||||
poetry run pytest tests/integration_tests
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_nvidia_aiplay
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
./scripts/check_pydantic.sh .
|
||||
./scripts/lint_imports.sh
|
||||
poetry run ruff .
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
||||
358
libs/partners/nvidia-aiplay/README.md
Normal file
358
libs/partners/nvidia-aiplay/README.md
Normal file
@@ -0,0 +1,358 @@
|
||||
# langchain-nvidia-aiplay
|
||||
|
||||
The `langchain-nvidia-aiplay` package contains LangChain integrations for chat models and embeddings powered by the NVIDIA AI Playground.
|
||||
|
||||
>[NVIDIA AI Playground](https://www.nvidia.com/en-us/research/ai-playground/) gives users easy access to hosted endpoints for generative AI models like Llama-2, SteerLM, Mistral, etc. Using the API, you can query NVCR (NVIDIA Container Registry) function endpoints and get quick results from a DGX-hosted cloud compute environment. All models are source-accessible and can be deployed on your own compute cluster.
|
||||
|
||||
Below is an example on how to use some common chat model functionality.
|
||||
|
||||
## Installation
|
||||
|
||||
|
||||
```python
|
||||
%pip install -U --quiet langchain-nvidia-aiplay
|
||||
```
|
||||
|
||||
## Setup
|
||||
|
||||
**To get started:**
|
||||
1. Create a free account with the [NVIDIA GPU Cloud](https://catalog.ngc.nvidia.com/) service, which hosts AI solution catalogs, containers, models, etc.
|
||||
2. Navigate to `Catalog > AI Foundation Models > (Model with API endpoint)`.
|
||||
3. Select the `API` option and click `Generate Key`.
|
||||
4. Save the generated key as `NVIDIA_API_KEY`. From there, you should have access to the endpoints.
|
||||
|
||||
|
||||
```python
|
||||
import getpass
|
||||
import os
|
||||
|
||||
if not os.environ.get("NVIDIA_API_KEY", "").startswith("nvapi-"):
|
||||
nvidia_api_key = getpass.getpass("Enter your NVIDIA AIPLAY API key: ")
|
||||
assert nvidia_api_key.startswith("nvapi-"), f"{nvidia_api_key[:5]}... is not a valid key"
|
||||
os.environ["NVIDIA_API_KEY"] = nvidia_api_key
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
## Core LC Chat Interface
|
||||
from langchain_nvidia_aiplay import ChatNVAIPlay
|
||||
|
||||
llm = ChatNVAIPlay(model="mixtral_8x7b")
|
||||
result = llm.invoke("Write a ballad about LangChain.")
|
||||
print(result.content)
|
||||
```
|
||||
|
||||
|
||||
## Stream, Batch, and Async
|
||||
|
||||
These models natively support streaming, and as is the case with all LangChain LLMs they expose a batch method to handle concurrent requests, as well as async methods for invoke, stream, and batch. Below are a few examples.
|
||||
|
||||
|
||||
```python
|
||||
print(llm.batch(["What's 2*3?", "What's 2*6?"]))
|
||||
# Or via the async API
|
||||
# await llm.abatch(["What's 2*3?", "What's 2*6?"])
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
for chunk in llm.stream("How far can a seagull fly in one day?"):
|
||||
# Show the token separations
|
||||
print(chunk.content, end="|")
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
async for chunk in llm.astream("How long does it take for monarch butterflies to migrate?"):
|
||||
print(chunk.content, end="|")
|
||||
```
|
||||
|
||||
## Supported models
|
||||
|
||||
Querying `available_models` will still give you all of the other models offered by your API credentials.
|
||||
|
||||
The `playground_` prefix is optional.
|
||||
|
||||
|
||||
```python
|
||||
list(llm.available_models)
|
||||
|
||||
|
||||
# ['playground_llama2_13b',
|
||||
# 'playground_llama2_code_13b',
|
||||
# 'playground_clip',
|
||||
# 'playground_fuyu_8b',
|
||||
# 'playground_mistral_7b',
|
||||
# 'playground_nvolveqa_40k',
|
||||
# 'playground_yi_34b',
|
||||
# 'playground_nemotron_steerlm_8b',
|
||||
# 'playground_nv_llama2_rlhf_70b',
|
||||
# 'playground_llama2_code_34b',
|
||||
# 'playground_mixtral_8x7b',
|
||||
# 'playground_neva_22b',
|
||||
# 'playground_steerlm_llama_70b',
|
||||
# 'playground_nemotron_qa_8b',
|
||||
# 'playground_sdxl']
|
||||
```
|
||||
|
||||
|
||||
## Model types
|
||||
|
||||
All of these models above are supported and can be accessed via `ChatNVAIPlay`.
|
||||
|
||||
Some model types support unique prompting techniques and chat messages. We will review a few important ones below.
|
||||
|
||||
|
||||
**To find out more about a specific model, please navigate to the API section of an AI Playground model [as linked here](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/ai-foundation/models/codellama-13b/api).**
|
||||
|
||||
### General Chat
|
||||
|
||||
Models such as `llama2_13b` and `mixtral_8x7b` are good all-around models that you can use for with any LangChain chat messages. Example below.
|
||||
|
||||
|
||||
```python
|
||||
from langchain_nvidia_aiplay import ChatNVAIPlay
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are a helpful AI assistant named Fred."),
|
||||
("user", "{input}")
|
||||
]
|
||||
)
|
||||
chain = (
|
||||
prompt
|
||||
| ChatNVAIPlay(model="llama2_13b")
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
for txt in chain.stream({"input": "What's your name?"}):
|
||||
print(txt, end="")
|
||||
```
|
||||
|
||||
|
||||
### Code Generation
|
||||
|
||||
These models accept the same arguments and input structure as regular chat models, but they tend to perform better on code-genreation and structured code tasks. An example of this is `llama2_code_13b`.
|
||||
|
||||
|
||||
```python
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are an expert coding AI. Respond only in valid python; no narration whatsoever."),
|
||||
("user", "{input}")
|
||||
]
|
||||
)
|
||||
chain = (
|
||||
prompt
|
||||
| ChatNVAIPlay(model="llama2_code_13b")
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
for txt in chain.stream({"input": "How do I solve this fizz buzz problem?"}):
|
||||
print(txt, end="")
|
||||
```
|
||||
|
||||
## Steering LLMs
|
||||
|
||||
> [SteerLM-optimized models](https://developer.nvidia.com/blog/announcing-steerlm-a-simple-and-practical-technique-to-customize-llms-during-inference/) supports "dynamic steering" of model outputs at inference time.
|
||||
|
||||
This lets you "control" the complexity, verbosity, and creativity of the model via integer labels on a scale from 0 to 9. Under the hood, these are passed as a special type of assistant message to the model.
|
||||
|
||||
The "steer" models support this type of input, such as `steerlm_llama_70b`
|
||||
|
||||
|
||||
```python
|
||||
from langchain_nvidia_aiplay import ChatNVAIPlay
|
||||
|
||||
llm = ChatNVAIPlay(model="steerlm_llama_70b")
|
||||
# Try making it uncreative and not verbose
|
||||
complex_result = llm.invoke(
|
||||
"What's a PB&J?",
|
||||
labels={"creativity": 0, "complexity": 3, "verbosity": 0}
|
||||
)
|
||||
print("Un-creative\n")
|
||||
print(complex_result.content)
|
||||
|
||||
# Try making it very creative and verbose
|
||||
print("\n\nCreative\n")
|
||||
creative_result = llm.invoke(
|
||||
"What's a PB&J?",
|
||||
labels={"creativity": 9, "complexity": 3, "verbosity": 9}
|
||||
)
|
||||
print(creative_result.content)
|
||||
```
|
||||
|
||||
|
||||
#### Use within LCEL
|
||||
|
||||
The labels are passed as invocation params. You can `bind` these to the LLM using the `bind` method on the LLM to include it within a declarative, functional chain. Below is an example.
|
||||
|
||||
|
||||
```python
|
||||
from langchain_nvidia_aiplay import ChatNVAIPlay
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are a helpful AI assistant named Fred."),
|
||||
("user", "{input}")
|
||||
]
|
||||
)
|
||||
chain = (
|
||||
prompt
|
||||
| ChatNVAIPlay(model="steerlm_llama_70b").bind(labels={"creativity": 9, "complexity": 0, "verbosity": 9})
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
for txt in chain.stream({"input": "Why is a PB&J?"}):
|
||||
print(txt, end="")
|
||||
```
|
||||
|
||||
## Multimodal
|
||||
|
||||
NVidia also supports multimodal inputs, meaning you can provide both images and text for the model to reason over.
|
||||
|
||||
These models also accept `labels`, similar to the Steering LLMs above. In addition to `creativity`, `complexity`, and `verbosity`, these models support a `quality` toggle.
|
||||
|
||||
An example model supporting multimodal inputs is `playground_neva_22b`.
|
||||
|
||||
These models accept LangChain's standard image formats. Below are examples.
|
||||
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
image_url = "https://picsum.photos/seed/kitten/300/200"
|
||||
image_content = requests.get(image_url).content
|
||||
```
|
||||
|
||||
Initialize the model like so:
|
||||
|
||||
```python
|
||||
from langchain_nvidia_aiplay import ChatNVAIPlay
|
||||
|
||||
llm = ChatNVAIPlay(model="playground_neva_22b")
|
||||
```
|
||||
|
||||
#### Passing an image as a URL
|
||||
|
||||
|
||||
```python
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
llm.invoke(
|
||||
[
|
||||
HumanMessage(content=[
|
||||
{"type": "text", "text": "Describe this image:"},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
])
|
||||
])
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
### You can specify the labels for steering here as well. You can try setting a low verbosity, for instance
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
llm.invoke(
|
||||
[
|
||||
HumanMessage(content=[
|
||||
{"type": "text", "text": "Describe this image:"},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
])
|
||||
],
|
||||
labels={
|
||||
"creativity": 0,
|
||||
"quality": 9,
|
||||
"complexity": 0,
|
||||
"verbosity": 0
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
|
||||
#### Passing an image as a base64 encoded string
|
||||
|
||||
|
||||
```python
|
||||
import base64
|
||||
b64_string = base64.b64encode(image_content).decode('utf-8')
|
||||
llm.invoke(
|
||||
[
|
||||
HumanMessage(content=[
|
||||
{"type": "text", "text": "Describe this image:"},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_string}"}},
|
||||
])
|
||||
])
|
||||
```
|
||||
|
||||
#### Directly within the string
|
||||
|
||||
The NVIDIA API uniquely accepts images as base64 images inlined within <img> HTML tags. While this isn't interoperable with other LLMs, you can directly prompt the model accordingly.
|
||||
|
||||
|
||||
```python
|
||||
base64_with_mime_type = f"data:image/png;base64,{b64_string}"
|
||||
llm.invoke(
|
||||
f'What\'s in this image?\n<img src="{base64_with_mime_type}" />'
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
|
||||
## RAG: Context models
|
||||
|
||||
NVIDIA also has Q&A models that support a special "context" chat message containing retrieved context (such as documents within a RAG chain). This is useful to avoid prompt-injecting the model.
|
||||
|
||||
**Note:** Only "user" (human) and "context" chat messages are supported for these models, not system or AI messages useful in conversational flows.
|
||||
|
||||
The `_qa_` models like `nemotron_qa_8b` support this.
|
||||
|
||||
|
||||
```python
|
||||
from langchain_nvidia_aiplay import ChatNVAIPlay
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.messages import ChatMessage
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
ChatMessage(role="context", content="Parrots and Cats have signed the peace accord."),
|
||||
("user", "{input}")
|
||||
]
|
||||
)
|
||||
llm = ChatNVAIPlay(model="nemotron_qa_8b")
|
||||
chain = (
|
||||
prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
)
|
||||
chain.invoke({"input": "What was signed?"})
|
||||
```
|
||||
|
||||
## Embeddings
|
||||
|
||||
You can also connect to embeddings models through this package. Below is an example:
|
||||
|
||||
```
|
||||
from langchain_nvidia_aiplay import NVAIPlayEmbeddings
|
||||
|
||||
embedder = NVAIPlayEmbeddings(model="nvolveqa_40k")
|
||||
embedder.embed_query("What's the temperature today?")
|
||||
embedder.embed_documents([
|
||||
"The temperature is 42 degrees.",
|
||||
"Class is dismissed at 9 PM."
|
||||
])
|
||||
```
|
||||
|
||||
By default the embedding model will use the "passage" type for documents and "query" type for queries, but you can fix this on the instance.
|
||||
|
||||
```python
|
||||
query_embedder = NVAIPlayEmbeddings(model="nvolveqa_40k", model_type="query")
|
||||
doc_embeddder = NVAIPlayEmbeddings(model="nvolveqa_40k", model_type="passage")
|
||||
```
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
**LangChain NVIDIA AI Playground Integration**
|
||||
|
||||
This comprehensive module integrates NVIDIA's state-of-the-art AI Playground, featuring advanced models for conversational AI and semantic embeddings, into the LangChain framework. It provides robust classes for seamless interaction with NVIDIA's AI models, particularly tailored for enriching conversational experiences and enhancing semantic understanding in various applications.
|
||||
|
||||
**Features:**
|
||||
|
||||
1. **Chat Models (`ChatNVAIPlay`):** This class serves as the primary interface for interacting with NVIDIA AI Playground's chat models. Users can effortlessly utilize NVIDIA's advanced models like 'Mistral' to engage in rich, context-aware conversations, applicable across diverse domains from customer support to interactive storytelling.
|
||||
|
||||
2. **Semantic Embeddings (`NVAIPlayEmbeddings`):** The module offers capabilities to generate sophisticated embeddings using NVIDIA's AI models. These embeddings are instrumental for tasks like semantic analysis, text similarity assessments, and contextual understanding, significantly enhancing the depth of NLP applications.
|
||||
|
||||
**Installation:**
|
||||
|
||||
Install this module easily using pip:
|
||||
|
||||
```python
|
||||
pip install langchain-nvidia-aiplay
|
||||
```
|
||||
|
||||
## Utilizing Chat Models:
|
||||
|
||||
After setting up the environment, interact with NVIDIA AI Playground models:
|
||||
```python
|
||||
from langchain_nvidia_aiplay import ChatNVAIPlay
|
||||
|
||||
ai_chat_model = ChatNVAIPlay(model="llama2_13b")
|
||||
response = ai_chat_model.invoke("Tell me about the LangChain integration.")
|
||||
```
|
||||
|
||||
# Generating Semantic Embeddings:
|
||||
|
||||
Use NVIDIA's models for creating embeddings, useful in various NLP tasks:
|
||||
|
||||
```python
|
||||
from langchain_nvidia_aiplay import NVAIPlayEmbeddings
|
||||
|
||||
embed_model = NVAIPlayEmbeddings(model="nvolveqa_40k")
|
||||
embedding_output = embed_model.embed_query("Exploring AI capabilities.")
|
||||
```
|
||||
""" # noqa: E501
|
||||
|
||||
from langchain_nvidia_aiplay.chat_models import ChatNVAIPlay
|
||||
from langchain_nvidia_aiplay.embeddings import NVAIPlayEmbeddings
|
||||
|
||||
__all__ = ["ChatNVAIPlay", "NVAIPlayEmbeddings"]
|
||||
525
libs/partners/nvidia-aiplay/langchain_nvidia_aiplay/_common.py
Normal file
525
libs/partners/nvidia-aiplay/langchain_nvidia_aiplay/_common.py
Normal file
@@ -0,0 +1,525 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from requests.models import Response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NVCRModel(BaseModel):
|
||||
|
||||
"""
|
||||
Underlying Client for interacting with the AI Playground API.
|
||||
Leveraged by the NVAIPlayBaseModel to provide a simple requests-oriented interface.
|
||||
Direct abstraction over NGC-recommended streaming/non-streaming Python solutions.
|
||||
|
||||
NOTE: AI Playground does not currently support raw text continuation.
|
||||
"""
|
||||
|
||||
## Core defaults. These probably should not be changed
|
||||
fetch_url_format: str = Field("https://api.nvcf.nvidia.com/v2/nvcf/pexec/status/")
|
||||
call_invoke_base: str = Field("https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions")
|
||||
get_session_fn: Callable = Field(requests.Session)
|
||||
get_asession_fn: Callable = Field(aiohttp.ClientSession)
|
||||
|
||||
nvidia_api_key: SecretStr = Field(
|
||||
...,
|
||||
description="API key for NVIDIA AI Playground. Should start with `nvapi-`",
|
||||
)
|
||||
is_staging: bool = Field(False, description="Whether to use staging API")
|
||||
|
||||
## Generation arguments
|
||||
max_tries: int = Field(5, ge=1)
|
||||
headers_tmpl: dict = Field(
|
||||
...,
|
||||
description="Headers template for API calls."
|
||||
" Should contain `call` and `stream` keys.",
|
||||
)
|
||||
_available_functions: Optional[List[dict]] = PrivateAttr(default=None)
|
||||
_available_models: Optional[dict] = PrivateAttr(default=None)
|
||||
|
||||
@property
|
||||
def headers(self) -> dict:
|
||||
"""Return headers with API key injected"""
|
||||
headers_ = self.headers_tmpl.copy()
|
||||
for header in headers_.values():
|
||||
if "{nvidia_api_key}" in header["Authorization"]:
|
||||
header["Authorization"] = header["Authorization"].format(
|
||||
nvidia_api_key=self.nvidia_api_key.get_secret_value(),
|
||||
)
|
||||
return headers_
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_model(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate and update model arguments, including API key and formatting"""
|
||||
values["nvidia_api_key"] = get_from_dict_or_env(
|
||||
values,
|
||||
"nvidia_api_key",
|
||||
"NVIDIA_API_KEY",
|
||||
)
|
||||
if "nvapi-" not in values.get("nvidia_api_key", ""):
|
||||
raise ValueError("Invalid NVAPI key detected. Should start with `nvapi-`")
|
||||
is_staging = "nvapi-stg-" in values["nvidia_api_key"]
|
||||
values["is_staging"] = is_staging
|
||||
if "headers_tmpl" not in values:
|
||||
values["headers_tmpl"] = {
|
||||
"call": {
|
||||
"Authorization": "Bearer {nvidia_api_key}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
"stream": {
|
||||
"Authorization": "Bearer {nvidia_api_key}",
|
||||
"Accept": "text/event-stream",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
}
|
||||
|
||||
values["fetch_url_format"] = cls._stagify(
|
||||
is_staging,
|
||||
values.get(
|
||||
"fetch_url_format", "https://api.nvcf.nvidia.com/v2/nvcf/pexec/status/"
|
||||
),
|
||||
)
|
||||
values["call_invoke_base"] = cls._stagify(
|
||||
is_staging,
|
||||
values.get(
|
||||
"call_invoke_base",
|
||||
"https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions",
|
||||
),
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def available_models(self) -> dict:
|
||||
"""List the available models that can be invoked."""
|
||||
if self._available_models is not None:
|
||||
return self._available_models
|
||||
live_fns = [v for v in self.available_functions if v.get("status") == "ACTIVE"]
|
||||
self._available_models = {v["name"]: v["id"] for v in live_fns}
|
||||
return self._available_models
|
||||
|
||||
@property
|
||||
def available_functions(self) -> List[dict]:
|
||||
"""List the available functions that can be invoked."""
|
||||
if self._available_functions is not None:
|
||||
return self._available_functions
|
||||
invoke_url = self._stagify(
|
||||
self.is_staging, "https://api.nvcf.nvidia.com/v2/nvcf/functions"
|
||||
)
|
||||
query_res = self.query(invoke_url)
|
||||
if "functions" not in query_res:
|
||||
raise ValueError(
|
||||
f"Unexpected response when querying {invoke_url}\n{query_res}"
|
||||
)
|
||||
self._available_functions = query_res["functions"]
|
||||
return self._available_functions
|
||||
|
||||
@classmethod
|
||||
def _stagify(cls, is_staging: bool, path: str) -> str:
|
||||
"""Helper method to switch between staging and production endpoints"""
|
||||
if is_staging and "stg.api" not in path:
|
||||
return path.replace("api.", "stg.api.")
|
||||
if not is_staging and "stg.api" in path:
|
||||
return path.replace("stg.api.", "api.")
|
||||
return path
|
||||
|
||||
####################################################################################
|
||||
## Core utilities for posting and getting from NVCR
|
||||
|
||||
def _post(self, invoke_url: str, payload: dict = {}) -> Tuple[Response, Any]:
|
||||
"""Method for posting to the AI Playground API."""
|
||||
call_inputs = {
|
||||
"url": invoke_url,
|
||||
"headers": self.headers["call"],
|
||||
"json": payload,
|
||||
"stream": False,
|
||||
}
|
||||
session = self.get_session_fn()
|
||||
response = session.post(**call_inputs)
|
||||
self._try_raise(response)
|
||||
return response, session
|
||||
|
||||
def _get(self, invoke_url: str, payload: dict = {}) -> Tuple[Response, Any]:
|
||||
"""Method for getting from the AI Playground API."""
|
||||
last_inputs = {
|
||||
"url": invoke_url,
|
||||
"headers": self.headers["call"],
|
||||
"json": payload,
|
||||
"stream": False,
|
||||
}
|
||||
session = self.get_session_fn()
|
||||
last_response = session.get(**last_inputs)
|
||||
self._try_raise(last_response)
|
||||
return last_response, session
|
||||
|
||||
def _wait(self, response: Response, session: Any) -> Response:
|
||||
"""Wait for a response from API after an initial response is made."""
|
||||
i = 1
|
||||
while response.status_code == 202:
|
||||
request_id = response.headers.get("NVCF-REQID", "")
|
||||
response = session.get(
|
||||
self.fetch_url_format + request_id,
|
||||
headers=self.headers["call"],
|
||||
)
|
||||
if response.status_code == 202:
|
||||
try:
|
||||
body = response.json()
|
||||
except ValueError:
|
||||
body = str(response)
|
||||
if i > self.max_tries:
|
||||
raise ValueError(f"Failed to get response with {i} tries: {body}")
|
||||
self._try_raise(response)
|
||||
return response
|
||||
|
||||
def _try_raise(self, response: Response) -> None:
|
||||
"""Try to raise an error from a response"""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
try:
|
||||
rd = response.json()
|
||||
except json.JSONDecodeError:
|
||||
rd = response.__dict__
|
||||
rd = rd.get("_content", rd)
|
||||
if isinstance(rd, bytes):
|
||||
rd = rd.decode("utf-8")[5:] ## lop of data: prefix ??
|
||||
try:
|
||||
rd = json.loads(rd)
|
||||
except Exception:
|
||||
rd = {"detail": rd}
|
||||
title = f"[{rd.get('status', '###')}] {rd.get('title', 'Unknown Error')}"
|
||||
body = f"{rd.get('detail', rd.get('type', rd))}"
|
||||
raise Exception(f"{title}\n{body}") from e
|
||||
|
||||
####################################################################################
|
||||
## Simple query interface to show the set of model options
|
||||
|
||||
def query(self, invoke_url: str, payload: dict = {}) -> dict:
|
||||
"""Simple method for an end-to-end get query. Returns result dictionary"""
|
||||
response, session = self._get(invoke_url, payload)
|
||||
response = self._wait(response, session)
|
||||
output = self._process_response(response)[0]
|
||||
return output
|
||||
|
||||
def _process_response(self, response: Union[str, Response]) -> List[dict]:
|
||||
"""General-purpose response processing for single responses and streams"""
|
||||
if hasattr(response, "json"): ## For single response (i.e. non-streaming)
|
||||
try:
|
||||
return [response.json()]
|
||||
except json.JSONDecodeError:
|
||||
response = str(response.__dict__)
|
||||
if isinstance(response, str): ## For set of responses (i.e. streaming)
|
||||
msg_list = []
|
||||
for msg in response.split("\n\n"):
|
||||
if "{" not in msg:
|
||||
continue
|
||||
msg_list += [json.loads(msg[msg.find("{") :])]
|
||||
return msg_list
|
||||
raise ValueError(f"Received ill-formed response: {response}")
|
||||
|
||||
def _get_invoke_url(
|
||||
self, model_name: Optional[str] = None, invoke_url: Optional[str] = None
|
||||
) -> str:
|
||||
"""Helper method to get invoke URL from a model name, URL, or endpoint stub"""
|
||||
if not invoke_url:
|
||||
if not model_name:
|
||||
raise ValueError("URL or model name must be specified to invoke")
|
||||
if model_name in self.available_models:
|
||||
invoke_url = self.available_models[model_name]
|
||||
elif f"playground_{model_name}" in self.available_models:
|
||||
invoke_url = self.available_models[f"playground_{model_name}"]
|
||||
else:
|
||||
available_models_str = "\n".join(
|
||||
[f"{k} - {v}" for k, v in self.available_models.items()]
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unknown model name {model_name} specified."
|
||||
"\nAvailable models are:\n"
|
||||
f"{available_models_str}"
|
||||
)
|
||||
if not invoke_url:
|
||||
# For mypy
|
||||
raise ValueError("URL or model name must be specified to invoke")
|
||||
# Why is this even needed?
|
||||
if "http" not in invoke_url:
|
||||
invoke_url = f"{self.call_invoke_base}/{invoke_url}"
|
||||
return invoke_url
|
||||
|
||||
####################################################################################
|
||||
## Generation interface to allow users to generate new values from endpoints
|
||||
|
||||
def get_req(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
payload: dict = {},
|
||||
invoke_url: Optional[str] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
) -> Response:
|
||||
"""Post to the API."""
|
||||
invoke_url = self._get_invoke_url(model_name, invoke_url)
|
||||
if payload.get("stream", False) is True:
|
||||
payload = {**payload, "stream": False}
|
||||
response, session = self._post(invoke_url, payload)
|
||||
return self._wait(response, session)
|
||||
|
||||
def get_req_generation(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
payload: dict = {},
|
||||
invoke_url: Optional[str] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
) -> dict:
|
||||
"""Method for an end-to-end post query with NVCR post-processing."""
|
||||
response = self.get_req(model_name, payload, invoke_url)
|
||||
output, _ = self.postprocess(response, stop=stop)
|
||||
return output
|
||||
|
||||
def postprocess(
|
||||
self, response: Union[str, Response], stop: Optional[Sequence[str]] = None
|
||||
) -> Tuple[dict, bool]:
|
||||
"""Parses a response from the AI Playground API.
|
||||
Strongly assumes that the API will return a single response.
|
||||
"""
|
||||
msg_list = self._process_response(response)
|
||||
msg, is_stopped = self._aggregate_msgs(msg_list)
|
||||
msg, is_stopped = self._early_stop_msg(msg, is_stopped, stop=stop)
|
||||
return msg, is_stopped
|
||||
|
||||
def _aggregate_msgs(self, msg_list: Sequence[dict]) -> Tuple[dict, bool]:
|
||||
"""Dig out relevant details of aggregated message"""
|
||||
content_buffer: Dict[str, Any] = dict()
|
||||
content_holder: Dict[Any, Any] = dict()
|
||||
is_stopped = False
|
||||
for msg in msg_list:
|
||||
if "choices" in msg:
|
||||
## Tease out ['choices'][0]...['delta'/'message']
|
||||
msg = msg.get("choices", [{}])[0]
|
||||
is_stopped = msg.get("finish_reason", "") == "stop"
|
||||
msg = msg.get("delta", msg.get("message", {"content": ""}))
|
||||
elif "data" in msg:
|
||||
## Tease out ['data'][0]...['embedding']
|
||||
msg = msg.get("data", [{}])[0]
|
||||
content_holder = msg
|
||||
for k, v in msg.items():
|
||||
if k in ("content",) and k in content_buffer:
|
||||
content_buffer[k] += v
|
||||
else:
|
||||
content_buffer[k] = v
|
||||
if is_stopped:
|
||||
break
|
||||
content_holder = {**content_holder, **content_buffer}
|
||||
return content_holder, is_stopped
|
||||
|
||||
def _early_stop_msg(
|
||||
self, msg: dict, is_stopped: bool, stop: Optional[Sequence[str]] = None
|
||||
) -> Tuple[dict, bool]:
|
||||
"""Try to early-terminate streaming or generation by iterating over stop list"""
|
||||
content = msg.get("content", "")
|
||||
if content and stop:
|
||||
for stop_str in stop:
|
||||
if stop_str and stop_str in content:
|
||||
msg["content"] = content[: content.find(stop_str) + 1]
|
||||
is_stopped = True
|
||||
return msg, is_stopped
|
||||
|
||||
####################################################################################
|
||||
## Streaming interface to allow you to iterate through progressive generations
|
||||
|
||||
def get_req_stream(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
payload: dict = {},
|
||||
invoke_url: Optional[str] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
) -> Iterator:
|
||||
invoke_url = self._get_invoke_url(model, invoke_url)
|
||||
if payload.get("stream", True) is False:
|
||||
payload = {**payload, "stream": True}
|
||||
last_inputs = {
|
||||
"url": invoke_url,
|
||||
"headers": self.headers["stream"],
|
||||
"json": payload,
|
||||
"stream": True,
|
||||
}
|
||||
response = self.get_session_fn().post(**last_inputs)
|
||||
self._try_raise(response)
|
||||
call = self.copy()
|
||||
|
||||
def out_gen() -> Generator[dict, Any, Any]:
|
||||
## Good for client, since it allows self.last_input
|
||||
for line in response.iter_lines():
|
||||
if line and line.strip() != b"data: [DONE]":
|
||||
line = line.decode("utf-8")
|
||||
msg, final_line = call.postprocess(line, stop=stop)
|
||||
yield msg
|
||||
if final_line:
|
||||
break
|
||||
self._try_raise(response)
|
||||
|
||||
return (r for r in out_gen())
|
||||
|
||||
####################################################################################
|
||||
## Asynchronous streaming interface to allow multiple generations to happen at once.
|
||||
|
||||
async def get_req_astream(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
payload: dict = {},
|
||||
invoke_url: Optional[str] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
) -> AsyncIterator:
|
||||
invoke_url = self._get_invoke_url(model, invoke_url)
|
||||
if payload.get("stream", True) is False:
|
||||
payload = {**payload, "stream": True}
|
||||
last_inputs = {
|
||||
"url": invoke_url,
|
||||
"headers": self.headers["stream"],
|
||||
"json": payload,
|
||||
}
|
||||
async with self.get_asession_fn() as session:
|
||||
async with session.post(**last_inputs) as response:
|
||||
self._try_raise(response)
|
||||
async for line in response.content.iter_any():
|
||||
if line and line.strip() != b"data: [DONE]":
|
||||
line = line.decode("utf-8")
|
||||
msg, final_line = self.postprocess(line, stop=stop)
|
||||
yield msg
|
||||
if final_line:
|
||||
break
|
||||
|
||||
|
||||
class _NVAIPlayClient(BaseModel):
|
||||
"""
|
||||
Higher-Level Client for interacting with AI Playground API with argument defaults.
|
||||
Is subclassed by NVAIPlayLLM/ChatNVAIPlay to provide a simple LangChain interface.
|
||||
"""
|
||||
|
||||
client: NVCRModel = Field(NVCRModel)
|
||||
|
||||
model: str = Field(..., description="Name of the model to invoke")
|
||||
|
||||
temperature: float = Field(0.2, le=1.0, gt=0.0)
|
||||
top_p: float = Field(0.7, le=1.0, ge=0.0)
|
||||
max_tokens: int = Field(1024, le=1024, ge=32)
|
||||
|
||||
####################################################################################
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_client(cls, values: Any) -> Any:
|
||||
"""Validate and update client arguments, including API key and formatting"""
|
||||
if not values.get("client"):
|
||||
values["client"] = NVCRModel(**values)
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def available_functions(self) -> List[dict]:
|
||||
"""Map the available functions that can be invoked."""
|
||||
return self.client.available_functions
|
||||
|
||||
@property
|
||||
def available_models(self) -> dict:
|
||||
"""Map the available models that can be invoked."""
|
||||
return self.client.available_models
|
||||
|
||||
def get_model_details(self, model: Optional[str] = None) -> dict:
|
||||
"""Get more meta-details about a model retrieved by a given name"""
|
||||
if model is None:
|
||||
model = self.model
|
||||
model_key = self.client._get_invoke_url(model).split("/")[-1]
|
||||
known_fns = self.client.available_functions
|
||||
fn_spec = [f for f in known_fns if f.get("id") == model_key][0]
|
||||
return fn_spec
|
||||
|
||||
def get_generation(
|
||||
self,
|
||||
inputs: Sequence[Dict],
|
||||
labels: Optional[dict] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Call to client generate method with call scope"""
|
||||
payload = self.get_payload(inputs=inputs, stream=False, labels=labels, **kwargs)
|
||||
out = self.client.get_req_generation(self.model, stop=stop, payload=payload)
|
||||
return out
|
||||
|
||||
def get_stream(
|
||||
self,
|
||||
inputs: Sequence[Dict],
|
||||
labels: Optional[dict] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator:
|
||||
"""Call to client stream method with call scope"""
|
||||
payload = self.get_payload(inputs=inputs, stream=True, labels=labels, **kwargs)
|
||||
return self.client.get_req_stream(self.model, stop=stop, payload=payload)
|
||||
|
||||
def get_astream(
|
||||
self,
|
||||
inputs: Sequence[Dict],
|
||||
labels: Optional[dict] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator:
|
||||
"""Call to client astream methods with call scope"""
|
||||
payload = self.get_payload(inputs=inputs, stream=True, labels=labels, **kwargs)
|
||||
return self.client.get_req_astream(self.model, stop=stop, payload=payload)
|
||||
|
||||
def get_payload(
|
||||
self, inputs: Sequence[Dict], labels: Optional[dict] = None, **kwargs: Any
|
||||
) -> dict:
|
||||
"""Generates payload for the _NVAIPlayClient API to send to service."""
|
||||
return {
|
||||
**self.preprocess(inputs=inputs, labels=labels),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def preprocess(self, inputs: Sequence[Dict], labels: Optional[dict] = None) -> dict:
|
||||
"""Prepares a message or list of messages for the payload"""
|
||||
messages = [self.prep_msg(m) for m in inputs]
|
||||
if labels:
|
||||
# (WFH) Labels are currently (?) always passed as an assistant
|
||||
# suffix message, but this API seems less stable.
|
||||
messages += [{"labels": labels, "role": "assistant"}]
|
||||
return {"messages": messages}
|
||||
|
||||
def prep_msg(self, msg: Union[str, dict, BaseMessage]) -> dict:
|
||||
"""Helper Method: Ensures a message is a dictionary with a role and content."""
|
||||
if isinstance(msg, str):
|
||||
# (WFH) this shouldn't ever be reached but leaving this here bcs
|
||||
# it's a Chesterton's fence I'm unwilling to touch
|
||||
return dict(role="user", content=msg)
|
||||
if isinstance(msg, dict):
|
||||
if msg.get("content", None) is None:
|
||||
raise ValueError(f"Message {msg} has no content")
|
||||
return msg
|
||||
raise ValueError(f"Unknown message received: {msg} of type {type(msg)}")
|
||||
@@ -0,0 +1,207 @@
|
||||
"""Chat Model Components Derived from ChatModel/NVAIPlay"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import urllib.parse
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import SimpleChatModel
|
||||
from langchain_core.messages import BaseMessage, ChatMessage, ChatMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
|
||||
from langchain_nvidia_aiplay import _common as nv_aiplay
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_openai_parts_format(part: dict) -> bool:
|
||||
return "type" in part
|
||||
|
||||
|
||||
def _is_url(s: str) -> bool:
|
||||
try:
|
||||
result = urllib.parse.urlparse(s)
|
||||
return all([result.scheme, result.netloc])
|
||||
except Exception as e:
|
||||
logger.debug(f"Unable to parse URL: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _is_b64(s: str) -> bool:
|
||||
return s.startswith("data:image")
|
||||
|
||||
|
||||
def _url_to_b64_string(image_source: str) -> str:
|
||||
b64_template = "data:image/png;base64,{b64_string}"
|
||||
try:
|
||||
if _is_url(image_source):
|
||||
response = requests.get(image_source)
|
||||
response.raise_for_status()
|
||||
encoded = base64.b64encode(response.content).decode("utf-8")
|
||||
return b64_template.format(b64_string=encoded)
|
||||
elif _is_b64(image_source):
|
||||
return image_source
|
||||
elif os.path.exists(image_source):
|
||||
with open(image_source, "rb") as f:
|
||||
encoded = base64.b64encode(f.read()).decode("utf-8")
|
||||
return b64_template.format(b64_string=encoded)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The provided string is not a valid URL, base64, or file path."
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to process the provided image source: {e}")
|
||||
|
||||
|
||||
class ChatNVAIPlay(nv_aiplay._NVAIPlayClient, SimpleChatModel):
|
||||
"""NVAIPlay chat model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_nvidia_aiplay import ChatNVAIPlay
|
||||
|
||||
|
||||
model = ChatNVAIPlay(model="llama2_13b")
|
||||
response = model.invoke("Hello")
|
||||
"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of NVIDIA AI Playground Interface."""
|
||||
return "chat-nvidia-ai-playground"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
labels: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Invoke on a single list of chat messages."""
|
||||
inputs = self.custom_preprocess(messages)
|
||||
responses = self.get_generation(
|
||||
inputs=inputs, stop=stop, labels=labels, **kwargs
|
||||
)
|
||||
outputs = self.custom_postprocess(responses)
|
||||
return outputs
|
||||
|
||||
def _get_filled_chunk(
|
||||
self, text: str, role: Optional[str] = "assistant"
|
||||
) -> ChatGenerationChunk:
|
||||
"""Fill the generation chunk."""
|
||||
return ChatGenerationChunk(message=ChatMessageChunk(content=text, role=role))
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
labels: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
"""Allows streaming to model!"""
|
||||
inputs = self.custom_preprocess(messages)
|
||||
for response in self.get_stream(
|
||||
inputs=inputs, stop=stop, labels=labels, **kwargs
|
||||
):
|
||||
chunk = self._get_filled_chunk(self.custom_postprocess(response))
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
labels: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
inputs = self.custom_preprocess(messages)
|
||||
async for response in self.get_astream(
|
||||
inputs=inputs, stop=stop, labels=labels, **kwargs
|
||||
):
|
||||
chunk = self._get_filled_chunk(self.custom_postprocess(response))
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
|
||||
def custom_preprocess(
|
||||
self, msg_list: Sequence[BaseMessage]
|
||||
) -> List[Dict[str, str]]:
|
||||
# The previous author had a lot of custom preprocessing here
|
||||
# but I'm just going to assume it's a list
|
||||
return [self.preprocess_msg(m) for m in msg_list]
|
||||
|
||||
def _process_content(self, content: Union[str, List[Union[dict, str]]]) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
string_array: list = []
|
||||
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
string_array.append(part)
|
||||
elif isinstance(part, Mapping):
|
||||
# OpenAI Format
|
||||
if _is_openai_parts_format(part):
|
||||
if part["type"] == "text":
|
||||
string_array.append(str(part["text"]))
|
||||
elif part["type"] == "image_url":
|
||||
img_url = part["image_url"]
|
||||
if isinstance(img_url, dict):
|
||||
if "url" not in img_url:
|
||||
raise ValueError(
|
||||
f"Unrecognized message image format: {img_url}"
|
||||
)
|
||||
img_url = img_url["url"]
|
||||
b64_string = _url_to_b64_string(img_url)
|
||||
string_array.append(f'<img src="{b64_string}" />')
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized message part type: {part['type']}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized message part format: {part}")
|
||||
return "".join(string_array)
|
||||
|
||||
def preprocess_msg(self, msg: BaseMessage) -> Dict[str, str]:
|
||||
## (WFH): Previous author added a bunch of
|
||||
# custom processing here, but I'm just going to support
|
||||
# the LCEL api.
|
||||
if isinstance(msg, BaseMessage):
|
||||
role_convert = {"ai": "assistant", "human": "user"}
|
||||
if isinstance(msg, ChatMessage):
|
||||
role = msg.role
|
||||
else:
|
||||
role = msg.type
|
||||
role = role_convert.get(role, role)
|
||||
content = self._process_content(msg.content)
|
||||
return {"role": role, "content": content}
|
||||
raise ValueError(f"Invalid message: {repr(msg)} of type {type(msg)}")
|
||||
|
||||
def custom_postprocess(self, msg: dict) -> str:
|
||||
if "content" in msg:
|
||||
return msg["content"]
|
||||
logger.warning(
|
||||
f"Got ambiguous message in postprocessing; returning as-is: msg = {msg}"
|
||||
)
|
||||
return str(msg)
|
||||
@@ -0,0 +1,74 @@
|
||||
"""Embeddings Components Derived from ChatModel/NVAIPlay"""
|
||||
from typing import Any, List, Literal, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
|
||||
import langchain_nvidia_aiplay._common as nvaiplay_common
|
||||
|
||||
|
||||
class NVAIPlayEmbeddings(BaseModel, Embeddings):
|
||||
"""NVIDIA's AI Playground NVOLVE Question-Answer Asymmetric Model."""
|
||||
|
||||
client: nvaiplay_common.NVCRModel = Field(nvaiplay_common.NVCRModel)
|
||||
model: str = Field(
|
||||
..., description="The embedding model to use. Example: nvolveqa_40k"
|
||||
)
|
||||
max_length: int = Field(2048, ge=1, le=2048)
|
||||
max_batch_size: int = Field(default=50)
|
||||
model_type: Optional[Literal["passage", "query"]] = Field(
|
||||
"passage", description="The type of text to be embedded."
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def _validate_client(cls, values: Any) -> Any:
|
||||
if "client" not in values:
|
||||
values["client"] = nvaiplay_common.NVCRModel()
|
||||
return values
|
||||
|
||||
@property
|
||||
def available_models(self) -> dict:
|
||||
"""Map the available models that can be invoked."""
|
||||
return self.client.available_models
|
||||
|
||||
def _embed(
|
||||
self, texts: List[str], model_type: Literal["passage", "query"]
|
||||
) -> List[List[float]]:
|
||||
"""Embed a single text entry to either passage or query type"""
|
||||
response = self.client.get_req(
|
||||
model_name=self.model,
|
||||
payload={
|
||||
"input": texts,
|
||||
"model": model_type,
|
||||
"encoding_format": "float",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
data = result["data"]
|
||||
if not isinstance(data, list):
|
||||
raise ValueError(f"Expected a list of embeddings. Got: {data}")
|
||||
embedding_list = [(res["embedding"], res["index"]) for res in data]
|
||||
return [x[0] for x in sorted(embedding_list, key=lambda x: x[1])]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Input pathway for query embeddings."""
|
||||
return self._embed([text], model_type=self.model_type or "query")[0]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Input pathway for document embeddings."""
|
||||
# From https://catalog.ngc.nvidia.com/orgs/nvidia/teams/ai-foundation/models/nvolve-40k/documentation
|
||||
# The input must not exceed the 2048 max input characters and inputs above 512
|
||||
# model tokens will be truncated. The input array must not exceed 50 input
|
||||
# strings.
|
||||
all_embeddings = []
|
||||
for i in range(0, len(texts), self.max_batch_size):
|
||||
batch = texts[i : i + self.max_batch_size]
|
||||
truncated = [
|
||||
text[: self.max_length] if len(text) > self.max_length else text
|
||||
for text in batch
|
||||
]
|
||||
all_embeddings.extend(
|
||||
self._embed(truncated, model_type=self.model_type or "passage")
|
||||
)
|
||||
return all_embeddings
|
||||
1235
libs/partners/nvidia-aiplay/poetry.lock
generated
Normal file
1235
libs/partners/nvidia-aiplay/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
92
libs/partners/nvidia-aiplay/pyproject.toml
Normal file
92
libs/partners/nvidia-aiplay/pyproject.toml
Normal file
@@ -0,0 +1,92 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-nvidia-aiplay"
|
||||
version = "0.0.1"
|
||||
description = "An integration package connecting NVidia AIPlay and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/nvidia-aiplay"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.1.0"
|
||||
aiohttp = "^3.9.1"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
types-requests = "^2.31.0.10"
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
exclude = ["notebooks", "examples", "example_data", "langchain_core/pydantic"]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"tests/*",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
||||
17
libs/partners/nvidia-aiplay/scripts/check_imports.py
Normal file
17
libs/partners/nvidia-aiplay/scripts/check_imports.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_faillure = True
|
||||
print(file)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
||||
27
libs/partners/nvidia-aiplay/scripts/check_pydantic.sh
Executable file
27
libs/partners/nvidia-aiplay/scripts/check_pydantic.sh
Executable file
@@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
||||
17
libs/partners/nvidia-aiplay/scripts/lint_imports.sh
Executable file
17
libs/partners/nvidia-aiplay/scripts/lint_imports.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
||||
0
libs/partners/nvidia-aiplay/tests/__init__.py
Normal file
0
libs/partners/nvidia-aiplay/tests/__init__.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Test ChatNVAIPlay chat model."""
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
from langchain_nvidia_aiplay.chat_models import ChatNVAIPlay
|
||||
|
||||
|
||||
def test_chat_aiplay() -> None:
|
||||
"""Test ChatNVAIPlay wrapper."""
|
||||
chat = ChatNVAIPlay(
|
||||
model="llama2_13b",
|
||||
temperature=0.7,
|
||||
)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat([message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_aiplay_model() -> None:
|
||||
"""Test GeneralChat wrapper handles model."""
|
||||
chat = ChatNVAIPlay(model="mistral")
|
||||
assert chat.model == "mistral"
|
||||
|
||||
|
||||
def test_chat_aiplay_system_message() -> None:
|
||||
"""Test GeneralChat wrapper with system message."""
|
||||
chat = ChatNVAIPlay(model="llama2_13b", max_tokens=36)
|
||||
system_message = SystemMessage(content="You are to chat with the user.")
|
||||
human_message = HumanMessage(content="Hello")
|
||||
response = chat([system_message, human_message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
## TODO: Not sure if we want to support the n syntax. Trash or keep test
|
||||
|
||||
|
||||
def test_aiplay_streaming() -> None:
|
||||
"""Test streaming tokens from aiplay."""
|
||||
llm = ChatNVAIPlay(model="llama2_13b", max_tokens=36)
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_aiplay_astream() -> None:
|
||||
"""Test streaming tokens from aiplay."""
|
||||
llm = ChatNVAIPlay(model="llama2_13b", max_tokens=35)
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_aiplay_abatch() -> None:
|
||||
"""Test streaming tokens from GeneralChat."""
|
||||
llm = ChatNVAIPlay(model="llama2_13b", max_tokens=36)
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_aiplay_abatch_tags() -> None:
|
||||
"""Test batch tokens from GeneralChat."""
|
||||
llm = ChatNVAIPlay(model="llama2_13b", max_tokens=55)
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
def test_aiplay_batch() -> None:
|
||||
"""Test batch tokens from GeneralChat."""
|
||||
llm = ChatNVAIPlay(model="llama2_13b", max_tokens=60)
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_aiplay_ainvoke() -> None:
|
||||
"""Test invoke tokens from GeneralChat."""
|
||||
llm = ChatNVAIPlay(model="llama2_13b", max_tokens=60)
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def test_aiplay_invoke() -> None:
|
||||
"""Test invoke tokens from GeneralChat."""
|
||||
llm = ChatNVAIPlay(model="llama2_13b", max_tokens=60)
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
||||
@@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Test NVIDIA AI Playground Embeddings.
|
||||
|
||||
Note: These tests are designed to validate the functionality of NVAIPlayEmbeddings.
|
||||
"""
|
||||
from langchain_nvidia_aiplay import NVAIPlayEmbeddings
|
||||
|
||||
|
||||
def test_nvai_play_embedding_documents() -> None:
|
||||
"""Test NVAIPlay embeddings for documents."""
|
||||
documents = ["foo bar"]
|
||||
embedding = NVAIPlayEmbeddings(model="nvolveqa_40k")
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 1024 # Assuming embedding size is 2048
|
||||
|
||||
|
||||
def test_nvai_play_embedding_documents_multiple() -> None:
|
||||
"""Test NVAIPlay embeddings for multiple documents."""
|
||||
documents = ["foo bar", "bar foo", "foo"]
|
||||
embedding = NVAIPlayEmbeddings(model="nvolveqa_40k")
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 3
|
||||
assert all(len(doc) == 1024 for doc in output)
|
||||
|
||||
|
||||
def test_nvai_play_embedding_query() -> None:
|
||||
"""Test NVAIPlay embeddings for a single query."""
|
||||
query = "foo bar"
|
||||
embedding = NVAIPlayEmbeddings(model="nvolveqa_40k")
|
||||
output = embedding.embed_query(query)
|
||||
assert len(output) == 1024
|
||||
|
||||
|
||||
async def test_nvai_play_embedding_async_query() -> None:
|
||||
"""Test NVAIPlay async embeddings for a single query."""
|
||||
query = "foo bar"
|
||||
embedding = NVAIPlayEmbeddings(model="nvolveqa_40k")
|
||||
output = await embedding.aembed_query(query)
|
||||
assert len(output) == 1024
|
||||
|
||||
|
||||
async def test_nvai_play_embedding_async_documents() -> None:
|
||||
"""Test NVAIPlay async embeddings for multiple documents."""
|
||||
documents = ["foo bar", "bar foo", "foo"]
|
||||
embedding = NVAIPlayEmbeddings(model="nvolveqa_40k")
|
||||
output = await embedding.aembed_documents(documents)
|
||||
assert len(output) == 3
|
||||
assert all(len(doc) == 1024 for doc in output)
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Test chat model integration."""
|
||||
|
||||
|
||||
from langchain_nvidia_aiplay.chat_models import ChatNVAIPlay
|
||||
|
||||
|
||||
def test_integration_initialization() -> None:
|
||||
"""Test chat model initialization."""
|
||||
ChatNVAIPlay(
|
||||
model="llama2_13b",
|
||||
nvidia_api_key="nvapi-...",
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
max_tokens=50,
|
||||
)
|
||||
ChatNVAIPlay(model="mistral", nvidia_api_key="nvapi-...")
|
||||
@@ -0,0 +1,7 @@
|
||||
from langchain_nvidia_aiplay import __all__
|
||||
|
||||
EXPECTED_ALL = ["ChatNVAIPlay", "NVAIPlayEmbeddings"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
||||
10
poetry.lock
generated
10
poetry.lock
generated
@@ -2119,13 +2119,13 @@ test = ["flaky", "ipykernel", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "p
|
||||
|
||||
[[package]]
|
||||
name = "nbconvert"
|
||||
version = "7.8.0"
|
||||
version = "7.12.0"
|
||||
description = "Converting Jupyter Notebooks"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "nbconvert-7.8.0-py3-none-any.whl", hash = "sha256:aec605e051fa682ccc7934ccc338ba1e8b626cfadbab0db592106b630f63f0f2"},
|
||||
{file = "nbconvert-7.8.0.tar.gz", hash = "sha256:f5bc15a1247e14dd41ceef0c0a3bc70020e016576eb0578da62f1c5b4f950479"},
|
||||
{file = "nbconvert-7.12.0-py3-none-any.whl", hash = "sha256:5b6c848194d270cc55fb691169202620d7b52a12fec259508d142ecbe4219310"},
|
||||
{file = "nbconvert-7.12.0.tar.gz", hash = "sha256:b1564bd89f69a74cd6398b0362da94db07aafb991b7857216a766204a71612c0"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2152,7 +2152,7 @@ docs = ["ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sp
|
||||
qtpdf = ["nbconvert[qtpng]"]
|
||||
qtpng = ["pyqtwebengine (>=5.15)"]
|
||||
serve = ["tornado (>=6.1)"]
|
||||
test = ["flaky", "ipykernel", "ipywidgets (>=7)", "pre-commit", "pytest", "pytest-dependency"]
|
||||
test = ["flaky", "ipykernel", "ipywidgets (>=7)", "pytest"]
|
||||
webpdf = ["playwright"]
|
||||
|
||||
[[package]]
|
||||
@@ -3971,4 +3971,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "581c178796dbb76589632e687d353a336ca23b3cdda7075720660b479dc85fa2"
|
||||
content-hash = "01838f8ac4fb3d5ac59517aa5b24e55f7167736fba952ebaa1216991b3972512"
|
||||
|
||||
@@ -27,6 +27,7 @@ myst-nb = "^0.17.1"
|
||||
linkchecker = "^10.2.1"
|
||||
sphinx-copybutton = "^0.5.1"
|
||||
nbdoc = "^0.0.82"
|
||||
nbconvert = "^7.12.0"
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1.5"
|
||||
|
||||
3
templates/propositional-retrieval/.gitignore
vendored
Normal file
3
templates/propositional-retrieval/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
docs/img_*.jpg
|
||||
chroma_db_proposals
|
||||
multi_vector_retriever_metadata
|
||||
21
templates/propositional-retrieval/LICENSE
Normal file
21
templates/propositional-retrieval/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
81
templates/propositional-retrieval/README.md
Normal file
81
templates/propositional-retrieval/README.md
Normal file
@@ -0,0 +1,81 @@
|
||||
# propositional-retrieval
|
||||
|
||||
This template demonstrates the multi-vector indexing strategy proposed by Chen, et. al.'s [Dense X Retrieval: What Retrieval Granularity Should We Use?](https://arxiv.org/abs/2312.06648). The prompt, which you can [try out on the hub](https://smith.langchain.com/hub/wfh/proposal-indexing), directs an LLM to generate de-contextualized "propositions" which can be vectorized to increase the retrieval accuracy. You can see the full definition in `proposal_chain.py`.
|
||||
|
||||

|
||||
|
||||
## Storage
|
||||
|
||||
For this demo, we index a simple academic paper using the RecursiveUrlLoader, and store all retriever information locally (using chroma and a bytestore stored on the local filesystem). You can modify the storage layer in `storage.py`.
|
||||
|
||||
## Environment Setup
|
||||
|
||||
Set the `OPENAI_API_KEY` environment variable to access `gpt-3.5` and the OpenAI Embeddings classes.
|
||||
|
||||
## Indexing
|
||||
|
||||
Create the index by running the following:
|
||||
|
||||
```python
|
||||
poetry install
|
||||
poetry run python propositional_retrieval/ingest.py
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
To use this package, you should first have the LangChain CLI installed:
|
||||
|
||||
```shell
|
||||
pip install -U langchain-cli
|
||||
```
|
||||
|
||||
To create a new LangChain project and install this as the only package, you can do:
|
||||
|
||||
```shell
|
||||
langchain app new my-app --package propositional-retrieval
|
||||
```
|
||||
|
||||
If you want to add this to an existing project, you can just run:
|
||||
|
||||
```shell
|
||||
langchain app add propositional-retrieval
|
||||
```
|
||||
|
||||
And add the following code to your `server.py` file:
|
||||
|
||||
```python
|
||||
from propositional_retrieval import chain
|
||||
|
||||
add_routes(app, chain, path="/propositional-retrieval")
|
||||
```
|
||||
|
||||
(Optional) Let's now configure LangSmith.
|
||||
LangSmith will help us trace, monitor and debug LangChain applications.
|
||||
LangSmith is currently in private beta, you can sign up [here](https://smith.langchain.com/).
|
||||
If you don't have access, you can skip this section
|
||||
|
||||
```shell
|
||||
export LANGCHAIN_TRACING_V2=true
|
||||
export LANGCHAIN_API_KEY=<your-api-key>
|
||||
export LANGCHAIN_PROJECT=<your-project> # if not specified, defaults to "default"
|
||||
```
|
||||
|
||||
If you are inside this directory, then you can spin up a LangServe instance directly by:
|
||||
|
||||
```shell
|
||||
langchain serve
|
||||
```
|
||||
|
||||
This will start the FastAPI app with a server is running locally at
|
||||
[http://localhost:8000](http://localhost:8000)
|
||||
|
||||
We can see all templates at [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs)
|
||||
We can access the playground at [http://127.0.0.1:8000/propositional-retrieval/playground](http://127.0.0.1:8000/propositional-retrieval/playground)
|
||||
|
||||
We can access the template from code with:
|
||||
|
||||
```python
|
||||
from langserve.client import RemoteRunnable
|
||||
|
||||
runnable = RemoteRunnable("http://localhost:8000/propositional-retrieval")
|
||||
```
|
||||
BIN
templates/propositional-retrieval/_images/retriever_diagram.png
Normal file
BIN
templates/propositional-retrieval/_images/retriever_diagram.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 375 KiB |
2859
templates/propositional-retrieval/poetry.lock
generated
Normal file
2859
templates/propositional-retrieval/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,68 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "681a5d1e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Run Template\n",
|
||||
"\n",
|
||||
"In `server.py`, set -\n",
|
||||
"```\n",
|
||||
"from fastapi import FastAPI\n",
|
||||
"from langserve import add_routes\n",
|
||||
"from propositional_retrieval import chain\n",
|
||||
"\n",
|
||||
"app = FastAPI(\n",
|
||||
" title=\"LangChain Server\",\n",
|
||||
" version=\"1.0\",\n",
|
||||
" description=\"Retriever and Generator for RAG Chroma Dense Retrieval\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"add_routes(app, chain, path=\"/propositional-retrieval\")\n",
|
||||
"\n",
|
||||
"if __name__ == \"__main__\":\n",
|
||||
" import uvicorn\n",
|
||||
"\n",
|
||||
" uvicorn.run(app, host=\"localhost\", port=8000)\n",
|
||||
"\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d774be2a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langserve.client import RemoteRunnable\n",
|
||||
"\n",
|
||||
"rag_app = RemoteRunnable(\"http://localhost:8001/propositional-retrieval\")\n",
|
||||
"rag_app.invoke(\"How are transformers related to convolutional neural networks?\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
from propositional_retrieval.chain import chain
|
||||
from propositional_retrieval.proposal_chain import proposition_chain
|
||||
|
||||
__all__ = ["chain", "proposition_chain"]
|
||||
@@ -0,0 +1,67 @@
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.load import load
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
|
||||
from propositional_retrieval.constants import DOCSTORE_ID_KEY
|
||||
from propositional_retrieval.storage import get_multi_vector_retriever
|
||||
|
||||
|
||||
def format_docs(docs: list) -> str:
|
||||
loaded_docs = [load(doc) for doc in docs]
|
||||
return "\n".join(
|
||||
[
|
||||
f"<Document id={i}>\n{doc.page_content}\n</Document>"
|
||||
for i, doc in enumerate(loaded_docs)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def rag_chain(retriever):
|
||||
"""
|
||||
The RAG chain
|
||||
|
||||
:param retriever: A function that retrieves the necessary context for the model.
|
||||
:return: A chain of functions representing the multi-modal RAG process.
|
||||
"""
|
||||
model = ChatOpenAI(temperature=0, model="gpt-4-1106-preview", max_tokens=1024)
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are an AI assistant. Answer based on the retrieved documents:"
|
||||
"\n<Documents>\n{context}\n</Documents>",
|
||||
),
|
||||
("user", "{question}?"),
|
||||
]
|
||||
)
|
||||
|
||||
# Define the RAG pipeline
|
||||
chain = (
|
||||
{
|
||||
"context": retriever | format_docs,
|
||||
"question": RunnablePassthrough(),
|
||||
}
|
||||
| prompt
|
||||
| model
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
return chain
|
||||
|
||||
|
||||
# Create the multi-vector retriever
|
||||
retriever = get_multi_vector_retriever(DOCSTORE_ID_KEY)
|
||||
|
||||
# Create RAG chain
|
||||
chain = rag_chain(retriever)
|
||||
|
||||
|
||||
# Add typing for input
|
||||
class Question(BaseModel):
|
||||
__root__: str
|
||||
|
||||
|
||||
chain = chain.with_types(input_type=Question)
|
||||
@@ -0,0 +1 @@
|
||||
DOCSTORE_ID_KEY = "doc_id"
|
||||
@@ -0,0 +1,93 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Sequence
|
||||
|
||||
from bs4 import BeautifulSoup as Soup
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
from propositional_retrieval.constants import DOCSTORE_ID_KEY
|
||||
from propositional_retrieval.proposal_chain import proposition_chain
|
||||
from propositional_retrieval.storage import get_multi_vector_retriever
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def add_documents(
|
||||
retriever,
|
||||
propositions: Sequence[Sequence[str]],
|
||||
docs: Sequence[Document],
|
||||
id_key: str = DOCSTORE_ID_KEY,
|
||||
):
|
||||
doc_ids = [
|
||||
str(uuid.uuid5(uuid.NAMESPACE_DNS, doc.metadata["source"])) for doc in docs
|
||||
]
|
||||
prop_docs = [
|
||||
Document(page_content=prop, metadata={id_key: doc_ids[i]})
|
||||
for i, props in enumerate(propositions)
|
||||
for prop in props
|
||||
if prop
|
||||
]
|
||||
retriever.vectorstore.add_documents(prop_docs)
|
||||
retriever.docstore.mset(list(zip(doc_ids, docs)))
|
||||
|
||||
|
||||
def create_index(
|
||||
docs: Sequence[Document],
|
||||
indexer: Runnable,
|
||||
docstore_id_key: str = DOCSTORE_ID_KEY,
|
||||
):
|
||||
"""
|
||||
Create retriever that indexes docs and their propositions
|
||||
|
||||
:param docs: Documents to index
|
||||
:param indexer: Runnable creates additional propositions per doc
|
||||
:param docstore_id_key: Key to use to store the docstore id
|
||||
:return: Retriever
|
||||
"""
|
||||
logger.info("Creating multi-vector retriever")
|
||||
retriever = get_multi_vector_retriever(docstore_id_key)
|
||||
propositions = indexer.batch(
|
||||
[{"input": doc.page_content} for doc in docs], {"max_concurrency": 10}
|
||||
)
|
||||
|
||||
add_documents(
|
||||
retriever,
|
||||
propositions,
|
||||
docs,
|
||||
id_key=docstore_id_key,
|
||||
)
|
||||
|
||||
return retriever
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# For our example, we'll load docs from the web
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter # noqa
|
||||
from langchain_community.document_loaders.recursive_url_loader import (
|
||||
RecursiveUrlLoader,
|
||||
) # noqa
|
||||
|
||||
# The attention is all you need paper
|
||||
# Could add more parsing here, as it's very raw.
|
||||
loader = RecursiveUrlLoader(
|
||||
"https://ar5iv.labs.arxiv.org/html/1706.03762",
|
||||
max_depth=2,
|
||||
extractor=lambda x: Soup(x, "html.parser").text,
|
||||
)
|
||||
data = loader.load()
|
||||
logger.info(f"Loaded {len(data)} documents")
|
||||
|
||||
# Split
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=8000, chunk_overlap=0)
|
||||
all_splits = text_splitter.split_documents(data)
|
||||
logger.info(f"Split into {len(all_splits)} documents")
|
||||
|
||||
# Create retriever
|
||||
retriever_multi_vector_img = create_index(
|
||||
all_splits,
|
||||
proposition_chain,
|
||||
DOCSTORE_ID_KEY,
|
||||
)
|
||||
@@ -0,0 +1,107 @@
|
||||
import logging
|
||||
|
||||
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Modified from the paper to be more robust to benign prompt injection
|
||||
# https://arxiv.org/abs/2312.06648
|
||||
# @misc{chen2023dense,
|
||||
# title={Dense X Retrieval: What Retrieval Granularity Should We Use?},
|
||||
# author={Tong Chen and Hongwei Wang and Sihao Chen and Wenhao Yu and Kaixin Ma
|
||||
# and Xinran Zhao and Hongming Zhang and Dong Yu},
|
||||
# year={2023},
|
||||
# eprint={2312.06648},
|
||||
# archivePrefix={arXiv},
|
||||
# primaryClass={cs.CL}
|
||||
# }
|
||||
PROMPT = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"""Decompose the "Content" into clear and simple propositions, ensuring they are interpretable out of
|
||||
context.
|
||||
1. Split compound sentence into simple sentences. Maintain the original phrasing from the input
|
||||
whenever possible.
|
||||
2. For any named entity that is accompanied by additional descriptive information, separate this
|
||||
information into its own distinct proposition.
|
||||
3. Decontextualize the proposition by adding necessary modifier to nouns or entire sentences
|
||||
and replacing pronouns (e.g., "it", "he", "she", "they", "this", "that") with the full name of the
|
||||
entities they refer to.
|
||||
4. Present the results as a list of strings, formatted in JSON.
|
||||
|
||||
Example:
|
||||
|
||||
Input: Title: ¯Eostre. Section: Theories and interpretations, Connection to Easter Hares. Content:
|
||||
The earliest evidence for the Easter Hare (Osterhase) was recorded in south-west Germany in
|
||||
1678 by the professor of medicine Georg Franck von Franckenau, but it remained unknown in
|
||||
other parts of Germany until the 18th century. Scholar Richard Sermon writes that "hares were
|
||||
frequently seen in gardens in spring, and thus may have served as a convenient explanation for the
|
||||
origin of the colored eggs hidden there for children. Alternatively, there is a European tradition
|
||||
that hares laid eggs, since a hare’s scratch or form and a lapwing’s nest look very similar, and
|
||||
both occur on grassland and are first seen in the spring. In the nineteenth century the influence
|
||||
of Easter cards, toys, and books was to make the Easter Hare/Rabbit popular throughout Europe.
|
||||
German immigrants then exported the custom to Britain and America where it evolved into the
|
||||
Easter Bunny."
|
||||
Output: [ "The earliest evidence for the Easter Hare was recorded in south-west Germany in
|
||||
1678 by Georg Franck von Franckenau.", "Georg Franck von Franckenau was a professor of
|
||||
medicine.", "The evidence for the Easter Hare remained unknown in other parts of Germany until
|
||||
the 18th century.", "Richard Sermon was a scholar.", "Richard Sermon writes a hypothesis about
|
||||
the possible explanation for the connection between hares and the tradition during Easter", "Hares
|
||||
were frequently seen in gardens in spring.", "Hares may have served as a convenient explanation
|
||||
for the origin of the colored eggs hidden in gardens for children.", "There is a European tradition
|
||||
that hares laid eggs.", "A hare’s scratch or form and a lapwing’s nest look very similar.", "Both
|
||||
hares and lapwing’s nests occur on grassland and are first seen in the spring.", "In the nineteenth
|
||||
century the influence of Easter cards, toys, and books was to make the Easter Hare/Rabbit popular
|
||||
throughout Europe.", "German immigrants exported the custom of the Easter Hare/Rabbit to
|
||||
Britain and America.", "The custom of the Easter Hare/Rabbit evolved into the Easter Bunny in
|
||||
Britain and America."]""", # noqa
|
||||
),
|
||||
("user", "Decompose the following:\n{input}"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_propositions(tool_calls: list) -> list:
|
||||
if not tool_calls:
|
||||
raise ValueError("No tool calls found")
|
||||
return tool_calls[0]["args"]["propositions"]
|
||||
|
||||
|
||||
def empty_proposals(x):
|
||||
# Model couldn't generate proposals
|
||||
return []
|
||||
|
||||
|
||||
proposition_chain = (
|
||||
PROMPT
|
||||
| ChatOpenAI(model="gpt-3.5-turbo-16k").bind(
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "decompose_content",
|
||||
"description": "Return the decomposed propositions",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"propositions": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
},
|
||||
"required": ["propositions"],
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
tool_choice={"type": "function", "function": {"name": "decompose_content"}},
|
||||
)
|
||||
| JsonOutputToolsParser()
|
||||
| get_propositions
|
||||
).with_fallbacks([RunnableLambda(empty_proposals)])
|
||||
@@ -0,0 +1,38 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.retrievers.multi_vector import MultiVectorRetriever
|
||||
from langchain.storage import LocalFileStore
|
||||
from langchain_community.vectorstores import Chroma
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_multi_vector_retriever(docstore_id_key: str):
|
||||
"""Create the composed retriever object."""
|
||||
vectorstore = get_vectorstore()
|
||||
store = get_docstore()
|
||||
return MultiVectorRetriever(
|
||||
vectorstore=vectorstore,
|
||||
byte_store=store,
|
||||
id_key=docstore_id_key,
|
||||
)
|
||||
|
||||
|
||||
def get_vectorstore(collection_name: str = "proposals"):
|
||||
"""Get the vectorstore used for this example."""
|
||||
return Chroma(
|
||||
collection_name=collection_name,
|
||||
persist_directory=str(Path(__file__).parent.parent / "chroma_db_proposals"),
|
||||
embedding_function=OpenAIEmbeddings(),
|
||||
)
|
||||
|
||||
|
||||
def get_docstore():
|
||||
"""Get the metadata store used for this example."""
|
||||
return LocalFileStore(
|
||||
str(Path(__file__).parent.parent / "multi_vector_retriever_metadata")
|
||||
)
|
||||
35
templates/propositional-retrieval/pyproject.toml
Normal file
35
templates/propositional-retrieval/pyproject.toml
Normal file
@@ -0,0 +1,35 @@
|
||||
[tool.poetry]
|
||||
name = "propositional-retrieval"
|
||||
version = "0.1.0"
|
||||
description = "Dense retrieval using vectorized propositions."
|
||||
authors = [
|
||||
"William Fu-Hinthorn <will@langchain.dev>",
|
||||
]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain = ">=0.0.350"
|
||||
openai = "<2"
|
||||
tiktoken = ">=0.5.1"
|
||||
chromadb = ">=0.4.14"
|
||||
bs4 = "^0.0.1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-cli = ">=0.0.15"
|
||||
|
||||
[tool.langserve]
|
||||
export_module = "rag_chroma_multi_modal_multi_vector"
|
||||
export_attr = "chain"
|
||||
|
||||
[tool.templates-hub]
|
||||
use-case = "rag"
|
||||
author = "LangChain"
|
||||
integrations = ["OpenAI", "Chroma"]
|
||||
tags = ["vectordbs"]
|
||||
|
||||
[build-system]
|
||||
requires = [
|
||||
"poetry-core",
|
||||
]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
0
templates/propositional-retrieval/tests/__init__.py
Normal file
0
templates/propositional-retrieval/tests/__init__.py
Normal file
1
templates/rag-chroma-multi-modal-multi-vector/.gitignore
vendored
Normal file
1
templates/rag-chroma-multi-modal-multi-vector/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
docs/img_*.jpg
|
||||
21
templates/rag-chroma-multi-modal-multi-vector/LICENSE
Normal file
21
templates/rag-chroma-multi-modal-multi-vector/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
108
templates/rag-chroma-multi-modal-multi-vector/README.md
Normal file
108
templates/rag-chroma-multi-modal-multi-vector/README.md
Normal file
@@ -0,0 +1,108 @@
|
||||
|
||||
# rag-chroma-multi-modal-multi-vector
|
||||
|
||||
Presentations (slide decks, etc) contain visual content that challenges conventional RAG.
|
||||
|
||||
Multi-modal LLMs unlock new ways to build apps over visual content like presentations.
|
||||
|
||||
This template performs multi-modal RAG using Chroma with the multi-vector retriever (see [blog](https://blog.langchain.dev/multi-modal-rag-template/)):
|
||||
|
||||
* Extracts the slides as images
|
||||
* Uses GPT-4V to summarize each image
|
||||
* Embeds the image summaries with a link to the original images
|
||||
* Retrieves relevant image based on similarity between the image summary and the user input
|
||||
* Finally pass those images to GPT-4V for answer synthesis
|
||||
|
||||
## Storage
|
||||
|
||||
We will use Upstash to store the images, which offers Redis with a REST API.
|
||||
|
||||
Simply login [here](https://upstash.com/) and create a database.
|
||||
|
||||
This will give you a REST API with:
|
||||
|
||||
* UPSTASH_URL
|
||||
* UPSTASH_TOKEN
|
||||
|
||||
Set `UPSTASH_URL` and `UPSTASH_TOKEN` as environment variables to access your database.
|
||||
|
||||
We will use Chroma to store and index the image summaries, which will be created locally in the template directory.
|
||||
|
||||
## Input
|
||||
|
||||
Supply a slide deck as pdf in the `/docs` directory.
|
||||
|
||||
Create your vectorstore (Chroma) and populae Upstash with:
|
||||
|
||||
```
|
||||
poetry install
|
||||
python ingest.py
|
||||
```
|
||||
|
||||
## LLM
|
||||
|
||||
The app will retrieve images using multi-modal embeddings, and pass them to GPT-4V.
|
||||
|
||||
## Environment Setup
|
||||
|
||||
Set the `OPENAI_API_KEY` environment variable to access the OpenAI GPT-4V.
|
||||
|
||||
Set `UPSTASH_URL` and `UPSTASH_TOKEN` as environment variables to access your database.
|
||||
|
||||
## Usage
|
||||
|
||||
To use this package, you should first have the LangChain CLI installed:
|
||||
|
||||
```shell
|
||||
pip install -U langchain-cli
|
||||
```
|
||||
|
||||
To create a new LangChain project and install this as the only package, you can do:
|
||||
|
||||
```shell
|
||||
langchain app new my-app --package rag-chroma-multi-modal-multi-vector
|
||||
```
|
||||
|
||||
If you want to add this to an existing project, you can just run:
|
||||
|
||||
```shell
|
||||
langchain app add rag-chroma-multi-modal-multi-vector
|
||||
```
|
||||
|
||||
And add the following code to your `server.py` file:
|
||||
```python
|
||||
from rag_chroma_multi_modal_multi_vector import chain as rag_chroma_multi_modal_chain_mv
|
||||
|
||||
add_routes(app, rag_chroma_multi_modal_chain_mv, path="/rag-chroma-multi-modal-multi-vector")
|
||||
```
|
||||
|
||||
(Optional) Let's now configure LangSmith.
|
||||
LangSmith will help us trace, monitor and debug LangChain applications.
|
||||
LangSmith is currently in private beta, you can sign up [here](https://smith.langchain.com/).
|
||||
If you don't have access, you can skip this section
|
||||
|
||||
```shell
|
||||
export LANGCHAIN_TRACING_V2=true
|
||||
export LANGCHAIN_API_KEY=<your-api-key>
|
||||
export LANGCHAIN_PROJECT=<your-project> # if not specified, defaults to "default"
|
||||
```
|
||||
|
||||
If you are inside this directory, then you can spin up a LangServe instance directly by:
|
||||
|
||||
```shell
|
||||
langchain serve
|
||||
```
|
||||
|
||||
This will start the FastAPI app with a server is running locally at
|
||||
[http://localhost:8000](http://localhost:8000)
|
||||
|
||||
We can see all templates at [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs)
|
||||
We can access the playground at [http://127.0.0.1:8000/rag-chroma-multi-modal-multi-vector/playground](http://127.0.0.1:8000/rag-chroma-multi-modal-multi-vector/playground)
|
||||
|
||||
We can access the template from code with:
|
||||
|
||||
```python
|
||||
from langserve.client import RemoteRunnable
|
||||
|
||||
runnable = RemoteRunnable("http://localhost:8000/rag-chroma-multi-modal-multi-vector")
|
||||
```
|
||||
Binary file not shown.
197
templates/rag-chroma-multi-modal-multi-vector/ingest.py
Normal file
197
templates/rag-chroma-multi-modal-multi-vector/ingest.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import pypdfium2 as pdfium
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.retrievers.multi_vector import MultiVectorRetriever
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.storage import UpstashRedisByteStore
|
||||
from langchain.vectorstores import Chroma
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def image_summarize(img_base64, prompt):
|
||||
"""
|
||||
Make image summary
|
||||
|
||||
:param img_base64: Base64 encoded string for image
|
||||
:param prompt: Text prompt for summarizatiomn
|
||||
:return: Image summarization prompt
|
||||
|
||||
"""
|
||||
chat = ChatOpenAI(model="gpt-4-vision-preview", max_tokens=1024)
|
||||
|
||||
msg = chat.invoke(
|
||||
[
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
|
||||
},
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
return msg.content
|
||||
|
||||
|
||||
def generate_img_summaries(img_base64_list):
|
||||
"""
|
||||
Generate summaries for images
|
||||
|
||||
:param img_base64_list: Base64 encoded images
|
||||
:return: List of image summaries and processed images
|
||||
"""
|
||||
|
||||
# Store image summaries
|
||||
image_summaries = []
|
||||
processed_images = []
|
||||
|
||||
# Prompt
|
||||
prompt = """You are an assistant tasked with summarizing images for retrieval. \
|
||||
These summaries will be embedded and used to retrieve the raw image. \
|
||||
Give a concise summary of the image that is well optimized for retrieval."""
|
||||
|
||||
# Apply summarization to images
|
||||
for i, base64_image in enumerate(img_base64_list):
|
||||
try:
|
||||
image_summaries.append(image_summarize(base64_image, prompt))
|
||||
processed_images.append(base64_image)
|
||||
except Exception as e:
|
||||
print(f"Error with image {i+1}: {e}")
|
||||
|
||||
return image_summaries, processed_images
|
||||
|
||||
|
||||
def get_images_from_pdf(pdf_path):
|
||||
"""
|
||||
Extract images from each page of a PDF document and save as JPEG files.
|
||||
|
||||
:param pdf_path: A string representing the path to the PDF file.
|
||||
"""
|
||||
pdf = pdfium.PdfDocument(pdf_path)
|
||||
n_pages = len(pdf)
|
||||
pil_images = []
|
||||
for page_number in range(n_pages):
|
||||
page = pdf.get_page(page_number)
|
||||
bitmap = page.render(scale=1, rotation=0, crop=(0, 0, 0, 0))
|
||||
pil_image = bitmap.to_pil()
|
||||
pil_images.append(pil_image)
|
||||
return pil_images
|
||||
|
||||
|
||||
def resize_base64_image(base64_string, size=(128, 128)):
|
||||
"""
|
||||
Resize an image encoded as a Base64 string
|
||||
|
||||
:param base64_string: Base64 string
|
||||
:param size: Image size
|
||||
:return: Re-sized Base64 string
|
||||
"""
|
||||
# Decode the Base64 string
|
||||
img_data = base64.b64decode(base64_string)
|
||||
img = Image.open(io.BytesIO(img_data))
|
||||
|
||||
# Resize the image
|
||||
resized_img = img.resize(size, Image.LANCZOS)
|
||||
|
||||
# Save the resized image to a bytes buffer
|
||||
buffered = io.BytesIO()
|
||||
resized_img.save(buffered, format=img.format)
|
||||
|
||||
# Encode the resized image to Base64
|
||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def convert_to_base64(pil_image):
|
||||
"""
|
||||
Convert PIL images to Base64 encoded strings
|
||||
|
||||
:param pil_image: PIL image
|
||||
:return: Re-sized Base64 string
|
||||
"""
|
||||
|
||||
buffered = BytesIO()
|
||||
pil_image.save(buffered, format="JPEG") # You can change the format if needed
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
img_str = resize_base64_image(img_str, size=(960, 540))
|
||||
return img_str
|
||||
|
||||
|
||||
def create_multi_vector_retriever(vectorstore, image_summaries, images):
|
||||
"""
|
||||
Create retriever that indexes summaries, but returns raw images or texts
|
||||
|
||||
:param vectorstore: Vectorstore to store embedded image sumamries
|
||||
:param image_summaries: Image summaries
|
||||
:param images: Base64 encoded images
|
||||
:return: Retriever
|
||||
"""
|
||||
|
||||
# Initialize the storage layer for images
|
||||
UPSTASH_URL = os.getenv("UPSTASH_URL")
|
||||
UPSTASH_TOKEN = os.getenv("UPSTASH_TOKEN")
|
||||
store = UpstashRedisByteStore(url=UPSTASH_URL, token=UPSTASH_TOKEN)
|
||||
id_key = "doc_id"
|
||||
|
||||
# Create the multi-vector retriever
|
||||
retriever = MultiVectorRetriever(
|
||||
vectorstore=vectorstore,
|
||||
byte_store=store,
|
||||
id_key=id_key,
|
||||
)
|
||||
|
||||
# Helper function to add documents to the vectorstore and docstore
|
||||
def add_documents(retriever, doc_summaries, doc_contents):
|
||||
doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
|
||||
summary_docs = [
|
||||
Document(page_content=s, metadata={id_key: doc_ids[i]})
|
||||
for i, s in enumerate(doc_summaries)
|
||||
]
|
||||
retriever.vectorstore.add_documents(summary_docs)
|
||||
retriever.docstore.mset(list(zip(doc_ids, doc_contents)))
|
||||
|
||||
add_documents(retriever, image_summaries, images)
|
||||
|
||||
return retriever
|
||||
|
||||
|
||||
# Load PDF
|
||||
doc_path = Path(__file__).parent / "docs/DDOG_Q3_earnings_deck.pdf"
|
||||
rel_doc_path = doc_path.relative_to(Path.cwd())
|
||||
print("Extract slides as images")
|
||||
pil_images = get_images_from_pdf(rel_doc_path)
|
||||
|
||||
# Convert to b64
|
||||
images_base_64 = [convert_to_base64(i) for i in pil_images]
|
||||
|
||||
# Image summaries
|
||||
print("Generate image summaries")
|
||||
image_summaries, images_base_64_processed = generate_img_summaries(images_base_64)
|
||||
|
||||
# The vectorstore to use to index the images summaries
|
||||
vectorstore_mvr = Chroma(
|
||||
collection_name="image_summaries",
|
||||
persist_directory=str(Path(__file__).parent / "chroma_db_multi_modal"),
|
||||
embedding_function=OpenAIEmbeddings(),
|
||||
)
|
||||
|
||||
# Create documents
|
||||
images_base_64_processed_documents = [
|
||||
Document(page_content=i) for i in images_base_64_processed
|
||||
]
|
||||
|
||||
# Create retriever
|
||||
retriever_multi_vector_img = create_multi_vector_retriever(
|
||||
vectorstore_mvr,
|
||||
image_summaries,
|
||||
images_base_64_processed_documents,
|
||||
)
|
||||
2949
templates/rag-chroma-multi-modal-multi-vector/poetry.lock
generated
Normal file
2949
templates/rag-chroma-multi-modal-multi-vector/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
38
templates/rag-chroma-multi-modal-multi-vector/pyproject.toml
Normal file
38
templates/rag-chroma-multi-modal-multi-vector/pyproject.toml
Normal file
@@ -0,0 +1,38 @@
|
||||
[tool.poetry]
|
||||
name = "rag-chroma-multi-modal-multi-vector"
|
||||
version = "0.1.0"
|
||||
description = "Multi-modal RAG using Chroma and multi-vector retriever"
|
||||
authors = [
|
||||
"Lance Martin <lance@langchain.dev>",
|
||||
]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain = ">=0.0.350"
|
||||
openai = "<2"
|
||||
tiktoken = ">=0.5.1"
|
||||
chromadb = ">=0.4.14"
|
||||
pypdfium2 = ">=4.20.0"
|
||||
langchain-experimental = "^0.0.43"
|
||||
upstash-redis = ">=1.0.0"
|
||||
pillow = ">=10.1.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-cli = ">=0.0.15"
|
||||
|
||||
[tool.langserve]
|
||||
export_module = "rag_chroma_multi_modal_multi_vector"
|
||||
export_attr = "chain"
|
||||
|
||||
[tool.templates-hub]
|
||||
use-case = "rag"
|
||||
author = "LangChain"
|
||||
integrations = ["OpenAI", "Chroma"]
|
||||
tags = ["vectordbs"]
|
||||
|
||||
[build-system]
|
||||
requires = [
|
||||
"poetry-core",
|
||||
]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -0,0 +1,52 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "681a5d1e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Run Template\n",
|
||||
"\n",
|
||||
"In `server.py`, set -\n",
|
||||
"```\n",
|
||||
"add_routes(app, chain_rag_conv, path=\"/rag-chroma-multi-modal-multi-vector\")\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d774be2a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langserve.client import RemoteRunnable\n",
|
||||
"\n",
|
||||
"rag_app = RemoteRunnable(\"http://localhost:8001/rag-chroma-multi-modal-multi-vector\")\n",
|
||||
"rag_app.invoke(\"What is the projected TAM for observability expected for each year through 2026?\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
from rag_chroma_multi_modal_multi_vector.chain import chain
|
||||
|
||||
__all__ = ["chain"]
|
||||
@@ -0,0 +1,133 @@
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.retrievers.multi_vector import MultiVectorRetriever
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
|
||||
from langchain.storage import UpstashRedisByteStore
|
||||
from langchain.vectorstores import Chroma
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def resize_base64_image(base64_string, size=(128, 128)):
|
||||
"""
|
||||
Resize an image encoded as a Base64 string.
|
||||
|
||||
:param base64_string: A Base64 encoded string of the image to be resized.
|
||||
:param size: A tuple representing the new size (width, height) for the image.
|
||||
:return: A Base64 encoded string of the resized image.
|
||||
"""
|
||||
img_data = base64.b64decode(base64_string)
|
||||
img = Image.open(io.BytesIO(img_data))
|
||||
resized_img = img.resize(size, Image.LANCZOS)
|
||||
buffered = io.BytesIO()
|
||||
resized_img.save(buffered, format=img.format)
|
||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def get_resized_images(docs):
|
||||
"""
|
||||
Resize images from base64-encoded strings.
|
||||
|
||||
:param docs: A list of base64-encoded image to be resized.
|
||||
:return: Dict containing a list of resized base64-encoded strings.
|
||||
"""
|
||||
b64_images = []
|
||||
for doc in docs:
|
||||
if isinstance(doc, Document):
|
||||
doc = doc.page_content
|
||||
resized_image = resize_base64_image(doc, size=(1280, 720))
|
||||
b64_images.append(resized_image)
|
||||
return {"images": b64_images}
|
||||
|
||||
|
||||
def img_prompt_func(data_dict, num_images=2):
|
||||
"""
|
||||
GPT-4V prompt for image analysis.
|
||||
|
||||
:param data_dict: A dict with images and a user-provided question.
|
||||
:param num_images: Number of images to include in the prompt.
|
||||
:return: A list containing message objects for each image and the text prompt.
|
||||
"""
|
||||
messages = []
|
||||
if data_dict["context"]["images"]:
|
||||
for image in data_dict["context"]["images"][:num_images]:
|
||||
image_message = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{image}"},
|
||||
}
|
||||
messages.append(image_message)
|
||||
text_message = {
|
||||
"type": "text",
|
||||
"text": (
|
||||
"You are an analyst tasked with answering questions about visual content.\n"
|
||||
"You will be give a set of image(s) from a slide deck / presentation.\n"
|
||||
"Use this information to answer the user question. \n"
|
||||
f"User-provided question: {data_dict['question']}\n\n"
|
||||
),
|
||||
}
|
||||
messages.append(text_message)
|
||||
return [HumanMessage(content=messages)]
|
||||
|
||||
|
||||
def multi_modal_rag_chain(retriever):
|
||||
"""
|
||||
Multi-modal RAG chain,
|
||||
|
||||
:param retriever: A function that retrieves the necessary context for the model.
|
||||
:return: A chain of functions representing the multi-modal RAG process.
|
||||
"""
|
||||
# Initialize the multi-modal Large Language Model with specific parameters
|
||||
model = ChatOpenAI(temperature=0, model="gpt-4-vision-preview", max_tokens=1024)
|
||||
|
||||
# Define the RAG pipeline
|
||||
chain = (
|
||||
{
|
||||
"context": retriever | RunnableLambda(get_resized_images),
|
||||
"question": RunnablePassthrough(),
|
||||
}
|
||||
| RunnableLambda(img_prompt_func)
|
||||
| model
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
return chain
|
||||
|
||||
|
||||
# Load chroma
|
||||
vectorstore_mvr = Chroma(
|
||||
collection_name="image_summaries",
|
||||
persist_directory=str(Path(__file__).parent.parent / "chroma_db_multi_modal"),
|
||||
embedding_function=OpenAIEmbeddings(),
|
||||
)
|
||||
|
||||
# Load redis
|
||||
UPSTASH_URL = os.getenv("UPSTASH_URL")
|
||||
UPSTASH_TOKEN = os.getenv("UPSTASH_TOKEN")
|
||||
store = UpstashRedisByteStore(url=UPSTASH_URL, token=UPSTASH_TOKEN)
|
||||
id_key = "doc_id"
|
||||
|
||||
# Create the multi-vector retriever
|
||||
retriever = MultiVectorRetriever(
|
||||
vectorstore=vectorstore_mvr,
|
||||
byte_store=store,
|
||||
id_key=id_key,
|
||||
)
|
||||
|
||||
# Create RAG chain
|
||||
chain = multi_modal_rag_chain(retriever)
|
||||
|
||||
|
||||
# Add typing for input
|
||||
class Question(BaseModel):
|
||||
__root__: str
|
||||
|
||||
|
||||
chain = chain.with_types(input_type=Question)
|
||||
1
templates/rag-gemini-multi-modal/.gitignore
vendored
Normal file
1
templates/rag-gemini-multi-modal/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
docs/img_*.jpg
|
||||
21
templates/rag-gemini-multi-modal/LICENSE
Normal file
21
templates/rag-gemini-multi-modal/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
106
templates/rag-gemini-multi-modal/README.md
Normal file
106
templates/rag-gemini-multi-modal/README.md
Normal file
@@ -0,0 +1,106 @@
|
||||
|
||||
# rag-gemini-multi-modal
|
||||
|
||||
Presentations (slide decks, etc) contain visual content that challenges conventional RAG.
|
||||
|
||||
Multi-modal LLMs unlock new ways to build apps over visual content like presentations.
|
||||
|
||||
This template performs multi-modal RAG using Chroma with multi-modal OpenCLIP embeddings and [Google Gemini](https://deepmind.google/technologies/gemini/#introduction).
|
||||
|
||||
## Input
|
||||
|
||||
Supply a slide deck as pdf in the `/docs` directory.
|
||||
|
||||
Create your vectorstore with:
|
||||
|
||||
```
|
||||
poetry install
|
||||
python ingest.py
|
||||
```
|
||||
|
||||
## Embeddings
|
||||
|
||||
This template will use [OpenCLIP](https://github.com/mlfoundations/open_clip) multi-modal embeddings.
|
||||
|
||||
You can select different options (see results [here](https://github.com/mlfoundations/open_clip/blob/main/docs/openclip_results.csv)).
|
||||
|
||||
The first time you run the app, it will automatically download the multimodal embedding model.
|
||||
|
||||
By default, LangChain will use an embedding model with reasonably strong performance, `ViT-H-14`.
|
||||
|
||||
You can choose alternative `OpenCLIPEmbeddings` models in `ingest.py`:
|
||||
```
|
||||
vectorstore_mmembd = Chroma(
|
||||
collection_name="multi-modal-rag",
|
||||
persist_directory=str(re_vectorstore_path),
|
||||
embedding_function=OpenCLIPEmbeddings(
|
||||
model_name="ViT-H-14", checkpoint="laion2b_s32b_b79k"
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
## LLM
|
||||
|
||||
The app will retrieve images using multi-modal embeddings, and pass them to Google Gemini.
|
||||
|
||||
## Environment Setup
|
||||
|
||||
Set the `GOOGLE_API_KEY` environment variable to access Gemini.
|
||||
|
||||
## Usage
|
||||
|
||||
To use this package, you should first have the LangChain CLI installed:
|
||||
|
||||
```shell
|
||||
pip install -U langchain-cli
|
||||
```
|
||||
|
||||
To create a new LangChain project and install this as the only package, you can do:
|
||||
|
||||
```shell
|
||||
langchain app new my-app --package rag-gemini-multi-modal
|
||||
```
|
||||
|
||||
If you want to add this to an existing project, you can just run:
|
||||
|
||||
```shell
|
||||
langchain app add rag-gemini-multi-modal
|
||||
```
|
||||
|
||||
And add the following code to your `server.py` file:
|
||||
```python
|
||||
from rag_gemini_multi_modal import chain as rag_gemini_multi_modal_chain
|
||||
|
||||
add_routes(app, rag_gemini_multi_modal_chain, path="/rag-gemini-multi-modal")
|
||||
```
|
||||
|
||||
(Optional) Let's now configure LangSmith.
|
||||
LangSmith will help us trace, monitor and debug LangChain applications.
|
||||
LangSmith is currently in private beta, you can sign up [here](https://smith.langchain.com/).
|
||||
If you don't have access, you can skip this section
|
||||
|
||||
```shell
|
||||
export LANGCHAIN_TRACING_V2=true
|
||||
export LANGCHAIN_API_KEY=<your-api-key>
|
||||
export LANGCHAIN_PROJECT=<your-project> # if not specified, defaults to "default"
|
||||
```
|
||||
|
||||
If you are inside this directory, then you can spin up a LangServe instance directly by:
|
||||
|
||||
```shell
|
||||
langchain serve
|
||||
```
|
||||
|
||||
This will start the FastAPI app with a server is running locally at
|
||||
[http://localhost:8000](http://localhost:8000)
|
||||
|
||||
We can see all templates at [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs)
|
||||
We can access the playground at [http://127.0.0.1:8000/rag-gemini-multi-modal/playground](http://127.0.0.1:8000/rag-gemini-multi-modal/playground)
|
||||
|
||||
We can access the template from code with:
|
||||
|
||||
```python
|
||||
from langserve.client import RemoteRunnable
|
||||
|
||||
runnable = RemoteRunnable("http://localhost:8000/rag-gemini-multi-modal")
|
||||
```
|
||||
BIN
templates/rag-gemini-multi-modal/docs/DDOG_Q3_earnings_deck.pdf
Normal file
BIN
templates/rag-gemini-multi-modal/docs/DDOG_Q3_earnings_deck.pdf
Normal file
Binary file not shown.
58
templates/rag-gemini-multi-modal/ingest.py
Normal file
58
templates/rag-gemini-multi-modal/ingest.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pypdfium2 as pdfium
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain_experimental.open_clip import OpenCLIPEmbeddings
|
||||
|
||||
|
||||
def get_images_from_pdf(pdf_path, img_dump_path):
|
||||
"""
|
||||
Extract images from each page of a PDF document and save as JPEG files.
|
||||
|
||||
:param pdf_path: A string representing the path to the PDF file.
|
||||
:param img_dump_path: A string representing the path to dummp images.
|
||||
"""
|
||||
pdf = pdfium.PdfDocument(pdf_path)
|
||||
n_pages = len(pdf)
|
||||
for page_number in range(n_pages):
|
||||
page = pdf.get_page(page_number)
|
||||
bitmap = page.render(scale=1, rotation=0, crop=(0, 0, 0, 0))
|
||||
pil_image = bitmap.to_pil()
|
||||
pil_image.save(f"{img_dump_path}/img_{page_number + 1}.jpg", format="JPEG")
|
||||
|
||||
|
||||
# Load PDF
|
||||
doc_path = Path(__file__).parent / "docs/DDOG_Q3_earnings_deck.pdf"
|
||||
img_dump_path = Path(__file__).parent / "docs/"
|
||||
rel_doc_path = doc_path.relative_to(Path.cwd())
|
||||
rel_img_dump_path = img_dump_path.relative_to(Path.cwd())
|
||||
print("pdf index")
|
||||
pil_images = get_images_from_pdf(rel_doc_path, rel_img_dump_path)
|
||||
print("done")
|
||||
vectorstore = Path(__file__).parent / "chroma_db_multi_modal"
|
||||
re_vectorstore_path = vectorstore.relative_to(Path.cwd())
|
||||
|
||||
# Load embedding function
|
||||
print("Loading embedding function")
|
||||
embedding = OpenCLIPEmbeddings(model_name="ViT-H-14", checkpoint="laion2b_s32b_b79k")
|
||||
|
||||
# Create chroma
|
||||
vectorstore_mmembd = Chroma(
|
||||
collection_name="multi-modal-rag",
|
||||
persist_directory=str(Path(__file__).parent / "chroma_db_multi_modal"),
|
||||
embedding_function=embedding,
|
||||
)
|
||||
|
||||
# Get image URIs
|
||||
image_uris = sorted(
|
||||
[
|
||||
os.path.join(rel_img_dump_path, image_name)
|
||||
for image_name in os.listdir(rel_img_dump_path)
|
||||
if image_name.endswith(".jpg")
|
||||
]
|
||||
)
|
||||
|
||||
# Add images
|
||||
print("Embedding images")
|
||||
vectorstore_mmembd.add_images(uris=image_uris)
|
||||
3714
templates/rag-gemini-multi-modal/poetry.lock
generated
Normal file
3714
templates/rag-gemini-multi-modal/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
39
templates/rag-gemini-multi-modal/pyproject.toml
Normal file
39
templates/rag-gemini-multi-modal/pyproject.toml
Normal file
@@ -0,0 +1,39 @@
|
||||
[tool.poetry]
|
||||
name = "rag-gemini-multi-modal"
|
||||
version = "0.1.0"
|
||||
description = "Multi-modal RAG using Gemini and OpenCLIP embeddings"
|
||||
authors = [
|
||||
"Lance Martin <lance@langchain.dev>",
|
||||
]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9,<4.0"
|
||||
langchain = ">=0.0.350"
|
||||
openai = "<2"
|
||||
tiktoken = ">=0.5.1"
|
||||
chromadb = ">=0.4.14"
|
||||
open-clip-torch = ">=2.23.0"
|
||||
torch = ">=2.1.0"
|
||||
pypdfium2 = ">=4.20.0"
|
||||
langchain-experimental = "^0.0.43"
|
||||
langchain-google-genai = ">=0.0.1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-cli = ">=0.0.15"
|
||||
|
||||
[tool.langserve]
|
||||
export_module = "rag_gemini_multi_modal"
|
||||
export_attr = "chain"
|
||||
|
||||
[tool.templates-hub]
|
||||
use-case = "rag"
|
||||
author = "LangChain"
|
||||
integrations = ["OpenAI", "Chroma"]
|
||||
tags = ["vectordbs"]
|
||||
|
||||
[build-system]
|
||||
requires = [
|
||||
"poetry-core",
|
||||
]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -0,0 +1,52 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "681a5d1e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Run Template\n",
|
||||
"\n",
|
||||
"In `server.py`, set -\n",
|
||||
"```\n",
|
||||
"add_routes(app, chain_rag_conv, path=\"/rag-gemini-multi-modal\")\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d774be2a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langserve.client import RemoteRunnable\n",
|
||||
"\n",
|
||||
"rag_app = RemoteRunnable(\"http://localhost:8001/rag-gemini-multi-modal\")\n",
|
||||
"rag_app.invoke(\"What is the projected TAM for observability expected for each year through 2026?\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
from rag_gemini_multi_modal.chain import chain
|
||||
|
||||
__all__ = ["chain"]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user