mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-11 19:49:54 +00:00
Compare commits
52 Commits
harrison/a
...
harrison/r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
26d1cf468c | ||
|
|
ccc7e97513 | ||
|
|
d6e245e2e1 | ||
|
|
d3aee0a181 | ||
|
|
e9483bfc88 | ||
|
|
2f718319bb | ||
|
|
40ad26f030 | ||
|
|
4eb649da19 | ||
|
|
71b80e3ba5 | ||
|
|
3d2250820c | ||
|
|
bcb64527ae | ||
|
|
6a0898f89d | ||
|
|
3f991afa15 | ||
|
|
eb726fa333 | ||
|
|
9951a6cc75 | ||
|
|
447066163e | ||
|
|
d908818b32 | ||
|
|
0591854007 | ||
|
|
f893a82efa | ||
|
|
fc9e1bcf9d | ||
|
|
3b8f58edfa | ||
|
|
baafcbacd9 | ||
|
|
641183a137 | ||
|
|
c8982b1a82 | ||
|
|
262b45522d | ||
|
|
5083e7b1e2 | ||
|
|
d52a853518 | ||
|
|
3ead58d1c1 | ||
|
|
620be7322f | ||
|
|
54a0900093 | ||
|
|
447517e125 | ||
|
|
563d7806f5 | ||
|
|
1055a74dc4 | ||
|
|
d6663aecec | ||
|
|
7b692d22bf | ||
|
|
2bf326b77d | ||
|
|
e602c5ea95 | ||
|
|
6a7c07010d | ||
|
|
6609750538 | ||
|
|
c9432e2055 | ||
|
|
865cd63f81 | ||
|
|
2d3064e6fd | ||
|
|
87838d0fcd | ||
|
|
0beb465adb | ||
|
|
ffa69882c0 | ||
|
|
c61a05d6b0 | ||
|
|
f2a56d733c | ||
|
|
10e419d84e | ||
|
|
284b9ae850 | ||
|
|
bc7535a9c7 | ||
|
|
50e9bfce9f | ||
|
|
f726b428ad |
14
.github/CONTRIBUTING.md
vendored
14
.github/CONTRIBUTING.md
vendored
@@ -9,7 +9,7 @@ to contributions, whether they be in the form of new features, improved infra, b
|
||||
### 👩💻 Contributing Code
|
||||
|
||||
To contribute to this project, please follow a ["fork and pull request"](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow.
|
||||
Please do not try to push directly to this repo unless you are a maintainer.
|
||||
Please do not try to push directly to this repo unless you are maintainer.
|
||||
|
||||
Please follow the checked-in pull request template when opening pull requests. Note related issues and tag relevant
|
||||
maintainers.
|
||||
@@ -21,7 +21,7 @@ It's essential that we maintain great documentation and testing. If you:
|
||||
- Fix a bug
|
||||
- Add a relevant unit or integration test when possible. These live in `tests/unit_tests` and `tests/integration_tests`.
|
||||
- Make an improvement
|
||||
- Update any affected example notebooks and documentation. These live in `docs`.
|
||||
- Update any affected example notebooks and documentation. These lives in `docs`.
|
||||
- Update unit and integration tests when relevant.
|
||||
- Add a feature
|
||||
- Add a demo notebook in `docs/modules`.
|
||||
@@ -43,7 +43,7 @@ If you start working on an issue, please assign it to yourself.
|
||||
If you are adding an issue, please try to keep it focused on a single, modular bug/improvement/feature.
|
||||
If two issues are related, or blocking, please link them rather than combining them.
|
||||
|
||||
We will try to keep these issues as up-to-date as possible, though
|
||||
We will try to keep these issues as up to date as possible, though
|
||||
with the rapid rate of development in this field some may get out of date.
|
||||
If you notice this happening, please let us know.
|
||||
|
||||
@@ -63,7 +63,7 @@ we do not want these to get in the way of getting good code into the codebase.
|
||||
|
||||
This project uses [Poetry](https://python-poetry.org/) v1.5.1 as a dependency manager. Check out Poetry's [documentation on how to install it](https://python-poetry.org/docs/#installation) on your system before proceeding.
|
||||
|
||||
❗Note: If you use `Conda` or `Pyenv` as your environment/package manager, avoid dependency conflicts by doing the following first:
|
||||
❗Note: If you use `Conda` or `Pyenv` as your environment / package manager, avoid dependency conflicts by doing the following first:
|
||||
1. *Before installing Poetry*, create and activate a new Conda env (e.g. `conda create -n langchain python=3.9`)
|
||||
2. Install Poetry v1.5.1 (see above)
|
||||
3. Tell Poetry to use the virtualenv python environment (`poetry config virtualenvs.prefer-active-python true`)
|
||||
@@ -174,7 +174,7 @@ Langchain relies heavily on optional dependencies to keep the Langchain package
|
||||
If you're adding a new dependency to Langchain, assume that it will be an optional dependency, and
|
||||
that most users won't have it installed.
|
||||
|
||||
Users who do not have the dependency installed should be able to **import** your code without
|
||||
Users that do not have the dependency installed should be able to **import** your code without
|
||||
any side effects (no warnings, no errors, no exceptions).
|
||||
|
||||
To introduce the dependency to the pyproject.toml file correctly, please do the following:
|
||||
@@ -188,7 +188,7 @@ To introduce the dependency to the pyproject.toml file correctly, please do the
|
||||
```bash
|
||||
poetry lock --no-update
|
||||
```
|
||||
4. Add a unit test that the very least attempts to import the new code. Ideally, the unit
|
||||
4. Add a unit test that the very least attempts to import the new code. Ideally the unit
|
||||
test makes use of lightweight fixtures to test the logic of the code.
|
||||
5. Please use the `@pytest.mark.requires(package_name)` decorator for any tests that require the dependency.
|
||||
|
||||
@@ -238,7 +238,7 @@ If you add support for a new external API, please add a new integration test.
|
||||
|
||||
### Adding a Jupyter Notebook
|
||||
|
||||
If you are adding a Jupyter Notebook example, you'll want to install the optional `dev` dependencies.
|
||||
If you are adding a Jupyter notebook example, you'll want to install the optional `dev` dependencies.
|
||||
|
||||
To install dev dependencies:
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
name: Documentation Lint
|
||||
name: Imports
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -18,5 +18,6 @@ jobs:
|
||||
- name: Run import check
|
||||
run: |
|
||||
# We should not encourage imports directly from main init file
|
||||
# Expect for hub
|
||||
git grep 'from langchain import' docs | grep -vE 'from langchain import (hub)' && exit 1 || exit 0
|
||||
# Expect for __version__ and hub
|
||||
# And of course expect for this file
|
||||
git grep 'from langchain import' | grep -vE 'from langchain import (__version__|hub)' | grep -v '.github/workflows/check-imports.yml' && exit 1 || exit 0
|
||||
@@ -10,8 +10,5 @@ Any chain constructed this way will automatically have full sync, async, and str
|
||||
#### [Interface](/docs/expression_language/interface)
|
||||
The base interface shared by all LCEL objects
|
||||
|
||||
#### [How to](/docs/expression_language/how_to)
|
||||
How to use core features of LCEL
|
||||
|
||||
#### [Cookbook](/docs/expression_language/cookbook)
|
||||
Examples of common LCEL usage patterns
|
||||
|
||||
@@ -5,14 +5,14 @@ sidebar_position: 0
|
||||
# Introduction
|
||||
|
||||
**LangChain** is a framework for developing applications powered by language models. It enables applications that:
|
||||
- **Are context-aware**: connect a language model to sources of context (prompt instructions, few shot examples, content to ground its response in, etc.)
|
||||
- **Reason**: rely on a language model to reason (about how to answer based on provided context, what actions to take, etc.)
|
||||
- **Are context-aware**: connect a language model to other sources of context (prompt instructions, few shot examples, content to ground it's response in)
|
||||
- **Reason**: rely on a language model to reason (about how to answer based on provided context, what actions to take, etc)
|
||||
|
||||
The main value props of LangChain are:
|
||||
1. **Components**: abstractions for working with language models, along with a collection of implementations for each abstraction. Components are modular and easy-to-use, whether you are using the rest of the LangChain framework or not
|
||||
2. **Off-the-shelf chains**: a structured assembly of components for accomplishing specific higher-level tasks
|
||||
|
||||
Off-the-shelf chains make it easy to get started. For complex applications, components make it easy to customize existing chains and build new ones.
|
||||
Off-the-shelf chains make it easy to get started. For more complex applications and nuanced use-cases, components make it easy to customize existing chains or build new ones.
|
||||
|
||||
## Get started
|
||||
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
{
|
||||
"redirects": [
|
||||
{
|
||||
"source": "/docs/expression_language/cookbook/routing",
|
||||
"destination": "/docs/expression_language/how_to/routing"
|
||||
},
|
||||
{
|
||||
"source": "/docs/integrations/providers/amazon_api_gateway",
|
||||
"destination": "/docs/integrations/platform/aws"
|
||||
|
||||
@@ -453,7 +453,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
"id": "4b47436a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Route between multiple Runnables\n",
|
||||
"# Routing\n",
|
||||
"\n",
|
||||
"This notebook covers how to do routing in the LangChain Expression Language\n",
|
||||
"\n",
|
||||
@@ -224,7 +224,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.10.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
@@ -1,194 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "711752cb-4f15-42a3-9838-a0c67f397771",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Bind runtime args\n",
|
||||
"\n",
|
||||
"Sometimes we want to invoke a Runnable within a Runnable sequence with constant arguments that are not part of the output of the preceding Runnable in the sequence, and which are not part of the user input. We can use `Runnable.bind()` to easily pass these arguments in.\n",
|
||||
"\n",
|
||||
"Suppose we have a simple prompt + model sequence:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "f3fdf86d-155f-4587-b7cd-52d363970c1d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"EQUATION: x^3 + 7 = 12\n",
|
||||
"\n",
|
||||
"SOLUTION:\n",
|
||||
"Subtracting 7 from both sides of the equation, we get:\n",
|
||||
"x^3 = 12 - 7\n",
|
||||
"x^3 = 5\n",
|
||||
"\n",
|
||||
"Taking the cube root of both sides, we get:\n",
|
||||
"x = ∛5\n",
|
||||
"\n",
|
||||
"Therefore, the solution to the equation x^3 + 7 = 12 is x = ∛5.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.prompts import ChatPromptTemplate\n",
|
||||
"from langchain.schema import StrOutputParser\n",
|
||||
"from langchain.schema.runnable import RunnablePassthrough\n",
|
||||
"\n",
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\"system\", \"Write out the following equation using algebraic symbols then solve it. Use the format\\n\\nEQUATION:...\\nSOLUTION:...\\n\\n\"),\n",
|
||||
" (\"human\", \"{equation_statement}\")\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"model = ChatOpenAI(temperature=0)\n",
|
||||
"runnable = {\"equation_statement\": RunnablePassthrough()} | prompt | model | StrOutputParser()\n",
|
||||
"\n",
|
||||
"print(runnable.invoke(\"x raised to the third plus seven equals 12\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "929c9aba-a4a0-462c-adac-2cfc2156e117",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"and want to call the model with certain `stop` words:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "32e0484a-78c5-4570-a00b-20d597245a96",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"EQUATION: x^3 + 7 = 12\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"runnable = (\n",
|
||||
" {\"equation_statement\": RunnablePassthrough()} \n",
|
||||
" | prompt \n",
|
||||
" | model.bind(stop=\"SOLUTION\") \n",
|
||||
" | StrOutputParser()\n",
|
||||
")\n",
|
||||
"print(runnable.invoke(\"x raised to the third plus seven equals 12\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f4bd641f-6b58-4ca9-a544-f69095428f16",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Attaching OpenAI functions\n",
|
||||
"\n",
|
||||
"One particularly useful application of binding is to attach OpenAI functions to a compatible OpenAI model:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "f66a0fe4-fde0-4706-8863-d60253f211c7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"functions = [\n",
|
||||
" {\n",
|
||||
" \"name\": \"solver\",\n",
|
||||
" \"description\": \"Formulates and solves an equation\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"equation\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The algebraic expression of the equation\"\n",
|
||||
" },\n",
|
||||
" \"solution\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The solution to the equation\"\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" \"required\": [\"equation\", \"solution\"]\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" ]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"id": "f381f969-df8e-48a3-bf5c-d0397cfecde0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='', additional_kwargs={'function_call': {'name': 'solver', 'arguments': '{\\n\"equation\": \"x^3 + 7 = 12\",\\n\"solution\": \"x = ∛5\"\\n}'}}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Need gpt-4 to solve this one correctly\n",
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\"system\", \"Write out the following equation using algebraic symbols then solve it.\"),\n",
|
||||
" (\"human\", \"{equation_statement}\")\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"model = ChatOpenAI(model=\"gpt-4\", temperature=0).bind(function_call={\"name\": \"solver\"}, functions=functions)\n",
|
||||
"runnable = (\n",
|
||||
" {\"equation_statement\": RunnablePassthrough()} \n",
|
||||
" | prompt \n",
|
||||
" | model\n",
|
||||
")\n",
|
||||
"runnable.invoke(\"x raised to the third plus seven equals 12\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2cdeeb4c-0c1f-43da-bd58-4f591d9e0671",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "poetry-venv",
|
||||
"language": "python",
|
||||
"name": "poetry-venv"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,285 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "19c9cbd6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Add fallbacks\n",
|
||||
"\n",
|
||||
"There are many possible points of failure in an LLM application, whether that be issues with LLM API's, poor model outputs, issues with other integrations, etc. Fallbacks help you gracefully handle and isolate these issues.\n",
|
||||
"\n",
|
||||
"Crucially, fallbacks can be applied not only on the LLM level but on the whole runnable level."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a6bb9ba9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Handling LLM API Errors\n",
|
||||
"\n",
|
||||
"This is maybe the most common use case for fallbacks. A request to an LLM API can fail for a variety of reasons - the API could be down, you could have hit rate limits, any number of things. Therefore, using fallbacks can help protect against these types of things.\n",
|
||||
"\n",
|
||||
"IMPORTANT: By default, a lot of the LLM wrappers catch errors and retry. You will most likely want to turn those off when working with fallbacks. Otherwise the first wrapper will keep on retrying and not failing."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "d3e893bf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI, ChatAnthropic"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4847c82d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First, let's mock out what happens if we hit a RateLimitError from OpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "dfdd8bf5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from unittest.mock import patch\n",
|
||||
"from openai.error import RateLimitError"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "e6fdffc1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Note that we set max_retries = 0 to avoid retrying on RateLimits, etc\n",
|
||||
"openai_llm = ChatOpenAI(max_retries=0)\n",
|
||||
"anthropic_llm = ChatAnthropic()\n",
|
||||
"llm = openai_llm.with_fallbacks([anthropic_llm])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"id": "584461ab",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Hit error\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Let's use just the OpenAI LLm first, to show that we run into an error\n",
|
||||
"with patch('openai.ChatCompletion.create', side_effect=RateLimitError()):\n",
|
||||
" try:\n",
|
||||
" print(openai_llm.invoke(\"Why did the chicken cross the road?\"))\n",
|
||||
" except:\n",
|
||||
" print(\"Hit error\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"id": "4fc1e673",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"content=' I don\\'t actually know why the chicken crossed the road, but here are some possible humorous answers:\\n\\n- To get to the other side!\\n\\n- It was too chicken to just stand there. \\n\\n- It wanted a change of scenery.\\n\\n- It wanted to show the possum it could be done.\\n\\n- It was on its way to a poultry farmers\\' convention.\\n\\nThe joke plays on the double meaning of \"the other side\" - literally crossing the road to the other side, or the \"other side\" meaning the afterlife. So it\\'s an anti-joke, with a silly or unexpected pun as the answer.' additional_kwargs={} example=False\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Now let's try with fallbacks to Anthropic\n",
|
||||
"with patch('openai.ChatCompletion.create', side_effect=RateLimitError()):\n",
|
||||
" try:\n",
|
||||
" print(llm.invoke(\"Why did the the chicken cross the road?\"))\n",
|
||||
" except:\n",
|
||||
" print(\"Hit error\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f00bea25",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can use our \"LLM with Fallbacks\" as we would a normal LLM."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "4f8eaaa0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"content=\" I don't actually know why the kangaroo crossed the road, but I'm happy to take a guess! Maybe the kangaroo was trying to get to the other side to find some tasty grass to eat. Or maybe it was trying to get away from a predator or other danger. Kangaroos do need to cross roads and other open areas sometimes as part of their normal activities. Whatever the reason, I'm sure the kangaroo looked both ways before hopping across!\" additional_kwargs={} example=False\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.prompts import ChatPromptTemplate\n",
|
||||
"\n",
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\"system\", \"You're a nice assistant who always includes a compliment in your response\"),\n",
|
||||
" (\"human\", \"Why did the {animal} cross the road\"),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"chain = prompt | llm\n",
|
||||
"with patch('openai.ChatCompletion.create', side_effect=RateLimitError()):\n",
|
||||
" try:\n",
|
||||
" print(chain.invoke({\"animal\": \"kangaroo\"}))\n",
|
||||
" except:\n",
|
||||
" print(\"Hit error\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ef9f0f39-0b9f-4723-a394-f61c98c75d41",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Specifying errors to handle\n",
|
||||
"\n",
|
||||
"We can also specify the errors to handle if we want to be more specific about when the fallback is invoked:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "e4069ca4-1c16-4915-9a8c-b2732869ae27",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Hit error\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm = openai_llm.with_fallbacks([anthropic_llm], exceptions_to_handle=(KeyboardInterrupt,))\n",
|
||||
"\n",
|
||||
"chain = prompt | llm\n",
|
||||
"with patch('openai.ChatCompletion.create', side_effect=RateLimitError()):\n",
|
||||
" try:\n",
|
||||
" print(chain.invoke({\"animal\": \"kangaroo\"}))\n",
|
||||
" except:\n",
|
||||
" print(\"Hit error\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8d62241b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Fallbacks for Sequences\n",
|
||||
"\n",
|
||||
"We can also create fallbacks for sequences, that are sequences themselves. Here we do that with two different models: ChatOpenAI and then normal OpenAI (which does not use a chat model). Because OpenAI is NOT a chat model, you likely want a different prompt."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"id": "6d0b8056",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# First let's create a chain with a ChatModel\n",
|
||||
"# We add in a string output parser here so the outputs between the two are the same type\n",
|
||||
"from langchain.schema.output_parser import StrOutputParser\n",
|
||||
"\n",
|
||||
"chat_prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\"system\", \"You're a nice assistant who always includes a compliment in your response\"),\n",
|
||||
" (\"human\", \"Why did the {animal} cross the road\"),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"# Here we're going to use a bad model name to easily create a chain that will error\n",
|
||||
"chat_model = ChatOpenAI(model_name=\"gpt-fake\")\n",
|
||||
"bad_chain = chat_prompt | chat_model | StrOutputParser()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"id": "8d1fc2a5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Now lets create a chain with the normal OpenAI model\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"\n",
|
||||
"prompt_template = \"\"\"Instructions: You should always include a compliment in your response.\n",
|
||||
"\n",
|
||||
"Question: Why did the {animal} cross the road?\"\"\"\n",
|
||||
"prompt = PromptTemplate.from_template(prompt_template)\n",
|
||||
"llm = OpenAI()\n",
|
||||
"good_chain = prompt | llm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"id": "283bfa44",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'\\n\\nAnswer: The turtle crossed the road to get to the other side, and I have to say he had some impressive determination.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# We can now create a final chain which combines the two\n",
|
||||
"chain = bad_chain.with_fallbacks([good_chain])\n",
|
||||
"chain.invoke({\"animal\": \"turtle\"})"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,199 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b022ab74-794d-4c54-ad47-ff9549ddb9d2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Use RunnableMaps\n",
|
||||
"\n",
|
||||
"RunnableMaps make it easy to execute multiple Runnables in parallel, and to return the output of these Runnables as a map."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "7e1873d6-d4b6-43ac-96a1-edcf178201e0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'joke': AIMessage(content=\"Why don't bears wear shoes? \\nBecause they have bear feet!\", additional_kwargs={}, example=False),\n",
|
||||
" 'poem': AIMessage(content=\"In twilight's embrace, a bear's gentle lumber,\\nSilent strength, nature's awe, a humble slumber.\", additional_kwargs={}, example=False)}"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.prompts import ChatPromptTemplate\n",
|
||||
"from langchain.schema.runnable import RunnableMap\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"model = ChatOpenAI()\n",
|
||||
"joke_chain = ChatPromptTemplate.from_template(\"tell me a joke about {topic}\") | model\n",
|
||||
"poem_chain = ChatPromptTemplate.from_template(\"write a 2-line poem about {topic}\") | model\n",
|
||||
"\n",
|
||||
"map_chain = RunnableMap({\"joke\": chain1, \"poem\": chain2,})\n",
|
||||
"\n",
|
||||
"map_chain.invoke({\"topic\": \"bear\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "df867ae9-1cec-4c9e-9fef-21969b206af5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Manipulating outputs/inputs\n",
|
||||
"Maps can be useful for manipulating the output of one Runnable to match the input format of the next Runnable in a sequence."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "267d1460-53c1-4fdb-b2c3-b6a1eb7fccff",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Harrison worked at Kensho.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"from langchain.schema.output_parser import StrOutputParser\n",
|
||||
"from langchain.schema.runnable import RunnablePassthrough\n",
|
||||
"from langchain.vectorstores import FAISS\n",
|
||||
"\n",
|
||||
"vectorstore = FAISS.from_texts([\"harrison worked at kensho\"], embedding=OpenAIEmbeddings())\n",
|
||||
"retriever = vectorstore.as_retriever()\n",
|
||||
"template = \"\"\"Answer the question based only on the following context:\n",
|
||||
"{context}\n",
|
||||
"\n",
|
||||
"Question: {question}\n",
|
||||
"\"\"\"\n",
|
||||
"prompt = ChatPromptTemplate.from_template(template)\n",
|
||||
"\n",
|
||||
"retrieval_chain = (\n",
|
||||
" {\"context\": retriever, \"question\": RunnablePassthrough()} \n",
|
||||
" | prompt \n",
|
||||
" | model \n",
|
||||
" | StrOutputParser()\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"retrieval_chain.invoke(\"where did harrison work?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "392cd4c4-e7ed-4ab8-934d-f7a4eca55ee1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here the input to prompt is expected to be a map with keys \"context\" and \"question\". The user input is just the question. So we need to get the context using our retriever and passthrough the user input under the \"question\" key.\n",
|
||||
"\n",
|
||||
"Note that when composing a RunnableMap when another Runnable we don't even need to wrap our dictuionary in the RunnableMap class — the type conversion is handled for us."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "833da249-c0d4-4e5b-b3f8-cab549f0f7e1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Parallelism\n",
|
||||
"\n",
|
||||
"RunnableMaps are also useful for running independent processes in parallel, since each Runnable in the map is executed in parallel. For example, we can see our earlier `joke_chain`, `poem_chain` and `map_chain` all have about the same runtime, even though `map_chain` executes both of the other two."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "38e47834-45af-4281-991f-86f150001510",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"958 ms ± 402 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%timeit\n",
|
||||
"\n",
|
||||
"joke_chain.invoke({\"topic\": \"bear\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "d0cd40de-b37e-41fa-a2f6-8aaa49f368d6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1.22 s ± 508 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%timeit\n",
|
||||
"\n",
|
||||
"poem_chain.invoke({\"topic\": \"bear\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "799894e1-8e18-4a73-b466-f6aea6af3920",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1.15 s ± 119 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%timeit\n",
|
||||
"\n",
|
||||
"map_chain.invoke({\"topic\": \"bear\"})"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,165 +0,0 @@
|
||||
# Anthropic
|
||||
|
||||
All functionality related to Anthropic models.
|
||||
|
||||
[Anthropic](https://www.anthropic.com/) is an AI safety and research company, and is the creator of Claude.
|
||||
This page covers all integrations between Anthropic models and LangChain.
|
||||
|
||||
## Prompting Overview
|
||||
|
||||
Claude is chat-based model, meaning it is trained on conversation data.
|
||||
However, it is a text based API, meaning it takes in single string.
|
||||
It expects this string to be in a particular format.
|
||||
This means that it is up the user to ensure that is the case.
|
||||
LangChain provides several utilities and helper functions to make sure prompts that you write -
|
||||
whether formatted as a string or as a list of messages - end up formatted correctly.
|
||||
|
||||
Specifically, Claude is trained to fill in text for the Assistant role as part of an ongoing dialogue
|
||||
between a human user (`Human:`) and an AI assistant (`Assistant:`). Prompts sent via the API must contain
|
||||
`\n\nHuman:` and `\n\nAssistant:` as the signals of who's speaking.
|
||||
The final turn must always be `\n\nAssistant:` - the input string cannot have `\n\nHuman:` as the final role.
|
||||
|
||||
Because Claude is chat-based but accepts a string as input, it can be treated as either a LangChain `ChatModel` or `LLM`.
|
||||
This means there are two wrappers in LangChain - `ChatAnthropic` and `Anthropic`.
|
||||
It is generally recommended to use the `ChatAnthropic` wrapper, and format your prompts as `ChatMessage`s (we will show examples of this below).
|
||||
This is because it keeps your prompt in a general format that you can easily then also use with other models (should you want to).
|
||||
However, if you want more fine-grained control over the prompt, you can use the `Anthropic` wrapper - we will show and example of this as well.
|
||||
The `Anthropic` wrapper however is deprecated, as all functionality can be achieved in a more generic way using `ChatAnthropic`.
|
||||
|
||||
## Prompting Best Practices
|
||||
|
||||
Anthropic models have several prompting best practices compared to OpenAI models.
|
||||
|
||||
**No System Messages**
|
||||
|
||||
Anthropic models are not trained on the concept of a "system message".
|
||||
We have worked with the Anthropic team to handle them somewhat appropriately (a Human message with an `admin` tag)
|
||||
but this is largely a hack and it is recommended that you do not use system messages.
|
||||
|
||||
**AI Messages Can Continue**
|
||||
|
||||
A completion from Claude is a continuation of the last text in the string which allows you further control over Claude's output.
|
||||
For example, putting words in Claude's mouth in a prompt like this:
|
||||
|
||||
`\n\nHuman: Tell me a joke about bears\n\nAssistant: What do you call a bear with no teeth?`
|
||||
|
||||
This will return a completion like this `A gummy bear!` instead of a whole new assistant message with a different random bear joke.
|
||||
|
||||
|
||||
## `ChatAnthropic`
|
||||
|
||||
`ChatAnthropic` is a subclass of LangChain's `ChatModel`, meaning it works best with `ChatPromptTemplate`.
|
||||
You can import this wrapper with the following code:
|
||||
|
||||
```
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
model = ChatAnthropic()
|
||||
```
|
||||
|
||||
When working with ChatModels, it is preferred that you design your prompts as `ChatPromptTemplate`s.
|
||||
Here is an example below of doing that:
|
||||
|
||||
```
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", "You are a helpful chatbot"),
|
||||
("human", "Tell me a joke about {topic}"),
|
||||
])
|
||||
```
|
||||
|
||||
You can then use this in a chain as follows:
|
||||
|
||||
```
|
||||
chain = prompt | model
|
||||
chain.invoke({"topic": "bears"})
|
||||
```
|
||||
|
||||
How is the prompt actually being formatted under the hood? We can see that by running the following code
|
||||
|
||||
```
|
||||
prompt_value = prompt.format_prompt(topic="bears")
|
||||
model.convert_prompt(prompt_value)
|
||||
```
|
||||
|
||||
This produces the following formatted string:
|
||||
|
||||
```
|
||||
'\n\nHuman: <admin>You are a helpful chatbot</admin>\n\nHuman: Tell me a joke about bears\n\nAssistant:'
|
||||
```
|
||||
|
||||
We can see that under the hood LangChain is representing `SystemMessage`s with `Human: <admin>...</admin>`,
|
||||
and is appending an assistant message to the end IF the last message is NOT already an assistant message.
|
||||
|
||||
If you decide instead to use a normal PromptTemplate (one that just works on a single string) let's take a look at
|
||||
what happens:
|
||||
|
||||
```
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
prompt = PromptTemplate.from_template("Tell me a joke about {topic}")
|
||||
prompt_value = prompt.format_prompt(topic="bears")
|
||||
model.convert_prompt(prompt_value)
|
||||
```
|
||||
|
||||
This produces the following formatted string:
|
||||
|
||||
```
|
||||
'\n\nHuman: Tell me a joke about bears\n\nAssistant:'
|
||||
```
|
||||
|
||||
We can see that it automatically adds the Human and Assistant tags.
|
||||
What is happening under the hood?
|
||||
First: the string gets converted to a single human message. This happens generically (because we are using a subclass of `ChatModel`).
|
||||
Then, similarly to the above example, an empty Assistant message is getting appended.
|
||||
This is Anthropic specific.
|
||||
|
||||
## [Deprecated] `Anthropic`
|
||||
|
||||
This `Anthropic` wrapper is subclassed from `LLM`.
|
||||
We can import it with:
|
||||
|
||||
```
|
||||
from langchain.llms import Anthropic
|
||||
model = Anthropic()
|
||||
```
|
||||
|
||||
This model class is designed to work with normal PromptTemplates. An example of that is below:
|
||||
|
||||
```
|
||||
prompt = PromptTemplate.from_template("Tell me a joke about {topic}")
|
||||
chain = prompt | model
|
||||
chain.invoke({"topic": "bears"})
|
||||
```
|
||||
|
||||
Let's see what is going on with the prompt templating under the hood!
|
||||
|
||||
```
|
||||
prompt_value = prompt.format_prompt(topic="bears")
|
||||
model.convert_prompt(prompt_value)
|
||||
```
|
||||
|
||||
This outputs the following
|
||||
|
||||
```
|
||||
'\n\nHuman: Tell me a joke about bears\n\nAssistant: Sure, here you go:\n'
|
||||
```
|
||||
|
||||
Notice that it adds the Human tag at the start of the string, and then finishes it with `\n\nAssistant: Sure, here you go:`.
|
||||
The extra `Sure, here you go` was added on purpose by the Anthropic team.
|
||||
|
||||
What happens if we have those symbols in the prompt directly?
|
||||
|
||||
```
|
||||
prompt = PromptTemplate.from_template("Human: Tell me a joke about {topic}")
|
||||
prompt_value = prompt.format_prompt(topic="bears")
|
||||
model.convert_prompt(prompt_value)
|
||||
```
|
||||
|
||||
This outputs:
|
||||
|
||||
```
|
||||
'\n\nHuman: Tell me a joke about bears'
|
||||
```
|
||||
|
||||
We can see that we detect that the user is trying to use the special tokens, and so we don't do any formatting.
|
||||
@@ -93,10 +93,10 @@ llm(
|
||||
### Usage
|
||||
|
||||
For more information and detailed examples, refer to the
|
||||
[example for xinference LLMs](/docs/integrations/llms/xinference.html)
|
||||
[example notebook for xinference](../modules/models/llms/integrations/xinference.ipynb)
|
||||
|
||||
### Embeddings
|
||||
|
||||
Xinference also supports embedding queries and documents. See
|
||||
[example for xinference embeddings](/docs/integrations/text_embedding/xinference.html)
|
||||
[example notebook for xinference embeddings](../modules/data_connection/text_embedding/integrations/xinference.ipynb)
|
||||
for a more detailed demo.
|
||||
@@ -6,9 +6,9 @@ from typing import Any, Dict, List, Optional, Union
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseOutputParser, BasePromptTemplate
|
||||
from langchain.schema.base import Embeddings
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
from langchain.utilities.sql_database import SQLDatabase
|
||||
|
||||
@@ -77,7 +77,6 @@ lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/langchain -
|
||||
|
||||
lint lint_diff:
|
||||
./scripts/check_pydantic.sh .
|
||||
./scripts/check_imports.sh
|
||||
poetry run ruff .
|
||||
poetry run black $(PYTHON_FILES) --check
|
||||
poetry run mypy $(PYTHON_FILES)
|
||||
|
||||
@@ -1,57 +1,7 @@
|
||||
# ruff: noqa: E402
|
||||
"""Main entrypoint into package."""
|
||||
from importlib import metadata
|
||||
from typing import Optional
|
||||
|
||||
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
|
||||
from langchain.chains import (
|
||||
ConversationChain,
|
||||
LLMBashChain,
|
||||
LLMChain,
|
||||
LLMCheckerChain,
|
||||
LLMMathChain,
|
||||
QAWithSourcesChain,
|
||||
VectorDBQA,
|
||||
VectorDBQAWithSourcesChain,
|
||||
)
|
||||
from langchain.docstore import InMemoryDocstore, Wikipedia
|
||||
from langchain.llms import (
|
||||
Anthropic,
|
||||
Banana,
|
||||
CerebriumAI,
|
||||
Cohere,
|
||||
ForefrontAI,
|
||||
GooseAI,
|
||||
HuggingFaceHub,
|
||||
HuggingFaceTextGenInference,
|
||||
LlamaCpp,
|
||||
Modal,
|
||||
OpenAI,
|
||||
Petals,
|
||||
PipelineAI,
|
||||
SagemakerEndpoint,
|
||||
StochasticAI,
|
||||
Writer,
|
||||
)
|
||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
from langchain.prompts import (
|
||||
FewShotPromptTemplate,
|
||||
Prompt,
|
||||
PromptTemplate,
|
||||
)
|
||||
from langchain.schema.cache import BaseCache
|
||||
from langchain.schema.prompt_template import BasePromptTemplate
|
||||
from langchain.utilities.arxiv import ArxivAPIWrapper
|
||||
from langchain.utilities.golden_query import GoldenQueryAPIWrapper
|
||||
from langchain.utilities.google_search import GoogleSearchAPIWrapper
|
||||
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
|
||||
from langchain.utilities.powerbi import PowerBIDataset
|
||||
from langchain.utilities.searx_search import SearxSearchWrapper
|
||||
from langchain.utilities.serpapi import SerpAPIWrapper
|
||||
from langchain.utilities.sql_database import SQLDatabase
|
||||
from langchain.utilities.wikipedia import WikipediaAPIWrapper
|
||||
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
|
||||
from langchain.vectorstores import FAISS, ElasticVectorSearch
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
__version__ = metadata.version(__package__)
|
||||
@@ -62,58 +12,4 @@ del metadata # optional, avoids polluting the results of dir(__package__)
|
||||
|
||||
verbose: bool = False
|
||||
debug: bool = False
|
||||
llm_cache: Optional[BaseCache] = None
|
||||
|
||||
# For backwards compatibility
|
||||
SerpAPIChain = SerpAPIWrapper
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LLMChain",
|
||||
"LLMBashChain",
|
||||
"LLMCheckerChain",
|
||||
"LLMMathChain",
|
||||
"ArxivAPIWrapper",
|
||||
"GoldenQueryAPIWrapper",
|
||||
"SelfAskWithSearchChain",
|
||||
"SerpAPIWrapper",
|
||||
"SerpAPIChain",
|
||||
"SearxSearchWrapper",
|
||||
"GoogleSearchAPIWrapper",
|
||||
"GoogleSerperAPIWrapper",
|
||||
"WolframAlphaAPIWrapper",
|
||||
"WikipediaAPIWrapper",
|
||||
"Anthropic",
|
||||
"Banana",
|
||||
"CerebriumAI",
|
||||
"Cohere",
|
||||
"ForefrontAI",
|
||||
"GooseAI",
|
||||
"Modal",
|
||||
"OpenAI",
|
||||
"Petals",
|
||||
"PipelineAI",
|
||||
"StochasticAI",
|
||||
"Writer",
|
||||
"BasePromptTemplate",
|
||||
"Prompt",
|
||||
"FewShotPromptTemplate",
|
||||
"PromptTemplate",
|
||||
"ReActChain",
|
||||
"Wikipedia",
|
||||
"HuggingFaceHub",
|
||||
"SagemakerEndpoint",
|
||||
"HuggingFacePipeline",
|
||||
"SQLDatabase",
|
||||
"PowerBIDataset",
|
||||
"FAISS",
|
||||
"MRKLChain",
|
||||
"VectorDBQA",
|
||||
"ElasticVectorSearch",
|
||||
"InMemoryDocstore",
|
||||
"ConversationChain",
|
||||
"VectorDBQAWithSourcesChain",
|
||||
"QAWithSourcesChain",
|
||||
"LlamaCpp",
|
||||
"HuggingFaceTextGenInference",
|
||||
]
|
||||
llm_cache: Optional[Any] = None
|
||||
|
||||
@@ -282,7 +282,7 @@ class AgentExecutorIterator(BaseAgentExecutorIterator):
|
||||
return self._call_next()
|
||||
except StopIteration:
|
||||
raise
|
||||
except BaseException as e:
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.run_manager:
|
||||
self.run_manager.on_chain_error(e)
|
||||
raise
|
||||
@@ -304,7 +304,7 @@ class AgentExecutorIterator(BaseAgentExecutorIterator):
|
||||
await self.timeout_manager.__aexit__(None, None, None)
|
||||
self.timeout_manager = None
|
||||
return await self._astop()
|
||||
except BaseException as e:
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.run_manager:
|
||||
assert isinstance(self.run_manager, AsyncCallbackManagerForChainRun)
|
||||
await self.run_manager.on_chain_error(e)
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
def format_react_single_input(intermediate_steps, observation_prefix:str = "Observation", llm_prefix: str = "Thought"):
|
||||
thoughts = ""
|
||||
for action, observation in intermediate_steps:
|
||||
thoughts += action.log
|
||||
thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}"
|
||||
return thoughts
|
||||
@@ -1,2 +0,0 @@
|
||||
def format_tools_with_description(tools):
|
||||
return "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
||||
@@ -1,13 +1,72 @@
|
||||
from langchain.agents.output_parsers.react_single_input import ReActSingleInputOutputParser
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
from langchain.agents.agent import AgentOutputParser
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = (
|
||||
"Invalid Format: Missing 'Action:' after 'Thought:"
|
||||
)
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE = (
|
||||
"Invalid Format: Missing 'Action Input:' after 'Action:'"
|
||||
)
|
||||
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = (
|
||||
"Parsing LLM output produced both a final answer and a parse-able action:"
|
||||
)
|
||||
|
||||
|
||||
class MRKLOutputParser(ReActSingleInputOutputParser):
|
||||
class MRKLOutputParser(AgentOutputParser):
|
||||
"""MRKL Output parser for the chat agent."""
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return FORMAT_INSTRUCTIONS
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
includes_answer = FINAL_ANSWER_ACTION in text
|
||||
regex = (
|
||||
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
|
||||
)
|
||||
action_match = re.search(regex, text, re.DOTALL)
|
||||
if action_match:
|
||||
if includes_answer:
|
||||
raise OutputParserException(
|
||||
f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
|
||||
)
|
||||
action = action_match.group(1).strip()
|
||||
action_input = action_match.group(2)
|
||||
tool_input = action_input.strip(" ")
|
||||
# ensure if its a well formed SQL query we don't remove any trailing " chars
|
||||
if tool_input.startswith("SELECT ") is False:
|
||||
tool_input = tool_input.strip('"')
|
||||
|
||||
return AgentAction(action, tool_input, text)
|
||||
|
||||
elif includes_answer:
|
||||
return AgentFinish(
|
||||
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
||||
)
|
||||
|
||||
if not re.search(r"Action\s*\d*\s*:[\s]*(.*?)", text, re.DOTALL):
|
||||
raise OutputParserException(
|
||||
f"Could not parse LLM output: `{text}`",
|
||||
observation=MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE,
|
||||
llm_output=text,
|
||||
send_to_llm=True,
|
||||
)
|
||||
elif not re.search(
|
||||
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
|
||||
):
|
||||
raise OutputParserException(
|
||||
f"Could not parse LLM output: `{text}`",
|
||||
observation=MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||
llm_output=text,
|
||||
send_to_llm=True,
|
||||
)
|
||||
else:
|
||||
raise OutputParserException(f"Could not parse LLM output: `{text}`")
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "mrkl"
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
from langchain.agents.agent import AgentOutputParser
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = (
|
||||
"Invalid Format: Missing 'Action:' after 'Thought:"
|
||||
)
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE = (
|
||||
"Invalid Format: Missing 'Action Input:' after 'Action:'"
|
||||
)
|
||||
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = (
|
||||
"Parsing LLM output produced both a final answer and a parse-able action:"
|
||||
)
|
||||
|
||||
|
||||
class ReActSingleInputOutputParser(AgentOutputParser):
|
||||
"""Parser for ReAct format ."""
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
pass
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
includes_answer = FINAL_ANSWER_ACTION in text
|
||||
regex = (
|
||||
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
|
||||
)
|
||||
action_match = re.search(regex, text, re.DOTALL)
|
||||
if action_match:
|
||||
if includes_answer:
|
||||
raise OutputParserException(
|
||||
f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
|
||||
)
|
||||
action = action_match.group(1).strip()
|
||||
action_input = action_match.group(2)
|
||||
tool_input = action_input.strip(" ")
|
||||
# ensure if it's a well-formed SQL query we don't remove any trailing " chars
|
||||
if tool_input.startswith("SELECT ") is False:
|
||||
tool_input = tool_input.strip('"')
|
||||
|
||||
return AgentAction(action, tool_input, text)
|
||||
|
||||
elif includes_answer:
|
||||
return AgentFinish(
|
||||
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
||||
)
|
||||
|
||||
if not re.search(r"Action\s*\d*\s*:[\s]*(.*?)", text, re.DOTALL):
|
||||
raise OutputParserException(
|
||||
f"Could not parse LLM output: `{text}`",
|
||||
observation=MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE,
|
||||
llm_output=text,
|
||||
send_to_llm=True,
|
||||
)
|
||||
elif not re.search(
|
||||
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
|
||||
):
|
||||
raise OutputParserException(
|
||||
f"Could not parse LLM output: `{text}`",
|
||||
observation=MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||
llm_output=text,
|
||||
send_to_llm=True,
|
||||
)
|
||||
else:
|
||||
raise OutputParserException(f"Could not parse LLM output: `{text}`")
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "react-single-input"
|
||||
|
||||
@@ -51,12 +51,12 @@ except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.llms.base import LLM, get_prompts
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.load.load import loads
|
||||
from langchain.schema import ChatGeneration, Generation
|
||||
from langchain.schema.cache import RETURN_VAL_TYPE, BaseCache
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_env
|
||||
from langchain.vectorstores.redis import Redis as RedisVectorstore
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
@@ -255,7 +255,9 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.step += 1
|
||||
self.llm_streams += 1
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
@@ -294,7 +296,9 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
aim.Text(outputs_res["output"]), name="on_chain_end", context=resp
|
||||
)
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
@@ -325,7 +329,9 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
self._run.track(aim.Text(output), name="on_tool_end", context=resp)
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from packaging.version import parse
|
||||
|
||||
@@ -236,7 +236,9 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
# Push the records to Argilla
|
||||
self.dataset.push_to_argilla()
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when LLM outputs an error."""
|
||||
pass
|
||||
|
||||
@@ -311,7 +313,9 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
# Push the records to Argilla
|
||||
self.dataset.push_to_argilla()
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when LLM chain outputs an error."""
|
||||
pass
|
||||
|
||||
@@ -338,7 +342,9 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
"""Do nothing when tool ends."""
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when tool outputs an error."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import import_pandas
|
||||
@@ -163,7 +163,9 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
||||
else:
|
||||
print(f'❌ Logging failed "{response_from_arize.text}"')
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
@@ -176,7 +178,9 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
@@ -201,7 +205,9 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
|
||||
@@ -6,7 +6,7 @@ import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -257,13 +257,17 @@ class ArthurCallbackHandler(BaseCallbackHandler):
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""On chain end, do nothing."""
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when LLM outputs an error."""
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""On new token, pass."""
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when LLM chain outputs an error."""
|
||||
|
||||
def on_tool_start(
|
||||
@@ -286,7 +290,9 @@ class ArthurCallbackHandler(BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""Do nothing when tool ends."""
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when tool outputs an error."""
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
|
||||
@@ -18,7 +18,7 @@ class RetrieverManagerMixin:
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -69,7 +69,7 @@ class LLMManagerMixin:
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -93,7 +93,7 @@ class ChainManagerMixin:
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -137,7 +137,7 @@ class ToolManagerMixin:
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -344,7 +344,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -379,7 +379,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -414,7 +414,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -492,7 +492,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_retriever_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
@@ -155,7 +155,9 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(generation_resp)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
@@ -208,7 +210,9 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(resp)
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
@@ -246,7 +250,9 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(resp)
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
@@ -223,7 +223,9 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self._log_text_metrics(output_complexity_metrics, step=self.step)
|
||||
self._log_text_metrics(output_custom_metrics, step=self.step)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
@@ -278,7 +280,9 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
f"Output Value for {chain_output_key} will not be logged"
|
||||
)
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
@@ -316,7 +320,9 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
resp.update({"output": output})
|
||||
self.action_records.append(resp)
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
@@ -128,7 +128,9 @@ class DeepEvalCallbackHandler(BaseCallbackHandler):
|
||||
callbacks."""
|
||||
)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when LLM outputs an error."""
|
||||
pass
|
||||
|
||||
@@ -142,7 +144,9 @@ class DeepEvalCallbackHandler(BaseCallbackHandler):
|
||||
"""Do nothing when chain ends."""
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when LLM chain outputs an error."""
|
||||
pass
|
||||
|
||||
@@ -169,7 +173,9 @@ class DeepEvalCallbackHandler(BaseCallbackHandler):
|
||||
"""Do nothing when tool ends."""
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when tool outputs an error."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
@@ -221,7 +221,9 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
)
|
||||
self.deck.append(self.markdown_renderer().to_html(generation.text))
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
@@ -264,7 +266,9 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||
)
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
@@ -302,7 +306,9 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||
)
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
@@ -113,7 +113,9 @@ class InfinoCallbackHandler(BaseCallbackHandler):
|
||||
for generation in generations:
|
||||
self._send_to_infino("prompt_response", generation.text, is_ts=False)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Set the error flag."""
|
||||
self.error = 1
|
||||
|
||||
@@ -127,7 +129,9 @@ class InfinoCallbackHandler(BaseCallbackHandler):
|
||||
"""Do nothing when LLM chain ends."""
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Need to log the error."""
|
||||
pass
|
||||
|
||||
@@ -154,7 +158,9 @@ class InfinoCallbackHandler(BaseCallbackHandler):
|
||||
"""Do nothing when tool ends."""
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when tool outputs an error."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -334,7 +334,9 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
|
||||
# Pop current run from `self.runs`
|
||||
self.payload.pop(run_id)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when LLM outputs an error."""
|
||||
pass
|
||||
|
||||
@@ -346,7 +348,9 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when LLM chain outputs an error."""
|
||||
pass
|
||||
|
||||
@@ -373,7 +377,9 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
|
||||
"""Do nothing when tool ends."""
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when tool outputs an error."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -406,7 +406,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
@@ -423,7 +423,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
@@ -440,7 +440,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
|
||||
@@ -213,20 +213,17 @@ def trace_as_chain_group(
|
||||
group_name: str,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
project_name: Optional[str] = None,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Generator[CallbackManagerForChainGroup, None, None]:
|
||||
) -> Generator[CallbackManager, None, None]:
|
||||
"""Get a callback manager for a chain group in a context manager.
|
||||
Useful for grouping different calls together as a single run even if
|
||||
they aren't composed in a single chain.
|
||||
|
||||
Args:
|
||||
group_name (str): The name of the chain group.
|
||||
callback_manager (CallbackManager, optional): The callback manager to use.
|
||||
inputs (Dict[str, Any], optional): The inputs to the chain group.
|
||||
project_name (str, optional): The name of the project.
|
||||
Defaults to None.
|
||||
example_id (str or UUID, optional): The ID of the example.
|
||||
@@ -236,17 +233,13 @@ def trace_as_chain_group(
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
CallbackManagerForChainGroup: The callback manager for the chain group.
|
||||
CallbackManager: The callback manager for the chain group.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
llm_input = "Foo"
|
||||
with trace_as_chain_group("group_name", inputs={"input": llm_input}) as manager:
|
||||
# Use the callback manager for the chain group
|
||||
res = llm.predict(llm_input, callbacks=manager)
|
||||
manager.on_chain_end({"output": res})
|
||||
""" # noqa: E501
|
||||
>>> with trace_as_chain_group("group_name") as manager:
|
||||
... # Use the callback manager for the chain group
|
||||
... llm.predict("Foo", callbacks=manager)
|
||||
"""
|
||||
cb = cast(
|
||||
Callbacks,
|
||||
[
|
||||
@@ -263,27 +256,9 @@ def trace_as_chain_group(
|
||||
inheritable_tags=tags,
|
||||
)
|
||||
|
||||
run_manager = cm.on_chain_start({"name": group_name}, inputs or {}, run_id=run_id)
|
||||
child_cm = run_manager.get_child()
|
||||
group_cm = CallbackManagerForChainGroup(
|
||||
child_cm.handlers,
|
||||
child_cm.inheritable_handlers,
|
||||
child_cm.parent_run_id,
|
||||
parent_run_manager=run_manager,
|
||||
tags=child_cm.tags,
|
||||
inheritable_tags=child_cm.inheritable_tags,
|
||||
metadata=child_cm.metadata,
|
||||
inheritable_metadata=child_cm.inheritable_metadata,
|
||||
)
|
||||
try:
|
||||
yield group_cm
|
||||
except Exception as e:
|
||||
if not group_cm.ended:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
if not group_cm.ended:
|
||||
run_manager.on_chain_end({})
|
||||
run_manager = cm.on_chain_start({"name": group_name}, {}, run_id=run_id)
|
||||
yield run_manager.get_child()
|
||||
run_manager.on_chain_end({})
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -291,20 +266,17 @@ async def atrace_as_chain_group(
|
||||
group_name: str,
|
||||
callback_manager: Optional[AsyncCallbackManager] = None,
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
project_name: Optional[str] = None,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]:
|
||||
) -> AsyncGenerator[AsyncCallbackManager, None]:
|
||||
"""Get an async callback manager for a chain group in a context manager.
|
||||
Useful for grouping different async calls together as a single run even if
|
||||
they aren't composed in a single chain.
|
||||
|
||||
Args:
|
||||
group_name (str): The name of the chain group.
|
||||
callback_manager (AsyncCallbackManager, optional): The async callback manager to use,
|
||||
which manages tracing and other callback behavior.
|
||||
project_name (str, optional): The name of the project.
|
||||
Defaults to None.
|
||||
example_id (str or UUID, optional): The ID of the example.
|
||||
@@ -316,14 +288,10 @@ async def atrace_as_chain_group(
|
||||
AsyncCallbackManager: The async callback manager for the chain group.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
llm_input = "Foo"
|
||||
async with atrace_as_chain_group("group_name", inputs={"input": llm_input}) as manager:
|
||||
# Use the async callback manager for the chain group
|
||||
res = await llm.apredict(llm_input, callbacks=manager)
|
||||
await manager.on_chain_end({"output": res})
|
||||
""" # noqa: E501
|
||||
>>> async with atrace_as_chain_group("group_name") as manager:
|
||||
... # Use the async callback manager for the chain group
|
||||
... await llm.apredict("Foo", callbacks=manager)
|
||||
"""
|
||||
cb = cast(
|
||||
Callbacks,
|
||||
[
|
||||
@@ -337,29 +305,11 @@ async def atrace_as_chain_group(
|
||||
)
|
||||
cm = AsyncCallbackManager.configure(inheritable_callbacks=cb, inheritable_tags=tags)
|
||||
|
||||
run_manager = await cm.on_chain_start(
|
||||
{"name": group_name}, inputs or {}, run_id=run_id
|
||||
)
|
||||
child_cm = run_manager.get_child()
|
||||
group_cm = AsyncCallbackManagerForChainGroup(
|
||||
child_cm.handlers,
|
||||
child_cm.inheritable_handlers,
|
||||
child_cm.parent_run_id,
|
||||
parent_run_manager=run_manager,
|
||||
tags=child_cm.tags,
|
||||
inheritable_tags=child_cm.inheritable_tags,
|
||||
metadata=child_cm.metadata,
|
||||
inheritable_metadata=child_cm.inheritable_metadata,
|
||||
)
|
||||
run_manager = await cm.on_chain_start({"name": group_name}, {}, run_id=run_id)
|
||||
try:
|
||||
yield group_cm
|
||||
except Exception as e:
|
||||
if not group_cm.ended:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
if not group_cm.ended:
|
||||
await run_manager.on_chain_end({})
|
||||
yield run_manager.get_child()
|
||||
finally:
|
||||
await run_manager.on_chain_end({})
|
||||
|
||||
|
||||
def _handle_event(
|
||||
@@ -707,7 +657,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors.
|
||||
@@ -773,7 +723,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors.
|
||||
@@ -985,7 +935,7 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors.
|
||||
@@ -1027,7 +977,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
|
||||
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors.
|
||||
@@ -1069,7 +1019,7 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when retriever errors."""
|
||||
@@ -1108,7 +1058,7 @@ class AsyncCallbackManagerForRetrieverRun(
|
||||
|
||||
async def on_retriever_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when retriever errors."""
|
||||
@@ -1392,48 +1342,6 @@ class CallbackManager(BaseCallbackManager):
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForChainGroup(CallbackManager):
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: List[BaseCallbackHandler] | None = None,
|
||||
parent_run_id: UUID | None = None,
|
||||
*,
|
||||
parent_run_manager: CallbackManagerForChainRun,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
handlers,
|
||||
inheritable_handlers,
|
||||
parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.parent_run_manager = parent_run_manager
|
||||
self.ended = False
|
||||
|
||||
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
|
||||
"""Run when traced chain group ends.
|
||||
|
||||
Args:
|
||||
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
|
||||
"""
|
||||
self.ended = True
|
||||
return self.parent_run_manager.on_chain_end(outputs, **kwargs)
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors.
|
||||
|
||||
Args:
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
"""
|
||||
self.ended = True
|
||||
return self.parent_run_manager.on_chain_error(error, **kwargs)
|
||||
|
||||
|
||||
class AsyncCallbackManager(BaseCallbackManager):
|
||||
"""Async callback manager that handles callbacks from LangChain."""
|
||||
|
||||
@@ -1726,50 +1634,6 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: List[BaseCallbackHandler] | None = None,
|
||||
parent_run_id: UUID | None = None,
|
||||
*,
|
||||
parent_run_manager: AsyncCallbackManagerForChainRun,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
handlers,
|
||||
inheritable_handlers,
|
||||
parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.parent_run_manager = parent_run_manager
|
||||
self.ended = False
|
||||
|
||||
async def on_chain_end(
|
||||
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when traced chain group ends.
|
||||
|
||||
Args:
|
||||
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
|
||||
"""
|
||||
self.ended = True
|
||||
await self.parent_run_manager.on_chain_end(outputs, **kwargs)
|
||||
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors.
|
||||
|
||||
Args:
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
"""
|
||||
self.ended = True
|
||||
await self.parent_run_manager.on_chain_error(error, **kwargs)
|
||||
|
||||
|
||||
T = TypeVar("T", CallbackManager, AsyncCallbackManager)
|
||||
|
||||
|
||||
|
||||
@@ -384,7 +384,9 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.mlflg.html(dependency_tree, "dep-" + hash_string(generation.text))
|
||||
self.mlflg.html(entities, "ent-" + hash_string(generation.text))
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
@@ -432,7 +434,9 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"chain_end_{chain_ends}")
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
@@ -476,7 +480,9 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"tool_end_{tool_ends}")
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
@@ -121,7 +121,9 @@ class SageMakerCallbackHandler(BaseCallbackHandler):
|
||||
f"llm_end_{llm_ends}_generation_{idx}",
|
||||
)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
@@ -162,7 +164,9 @@ class SageMakerCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}")
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
@@ -198,7 +202,9 @@ class SageMakerCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}")
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""Callback Handler that prints to std out."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
@@ -27,7 +27,9 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
@@ -42,7 +44,9 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Print out that we finished a chain."""
|
||||
print("\n\033[1m> Finished chain.\033[0m")
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
@@ -76,7 +80,9 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
if llm_prefix is not None:
|
||||
print_text(f"\n{llm_prefix}")
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -37,7 +37,9 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
self.done.set()
|
||||
|
||||
async def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
async def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
self.done.set()
|
||||
|
||||
# TODO implement the other methods
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Callback Handler streams to stdout on new llm token."""
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
@@ -31,7 +31,9 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
def on_chain_start(
|
||||
@@ -42,7 +44,9 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
|
||||
def on_tool_start(
|
||||
@@ -57,7 +61,9 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.streamlit.mutable_expander import MutableExpander
|
||||
@@ -163,7 +163,9 @@ class LLMThought:
|
||||
# data is redundant
|
||||
self._reset_llm_token_stream()
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
self._container.markdown("**LLM encountered an error...**")
|
||||
self._container.exception(error)
|
||||
|
||||
@@ -189,7 +191,9 @@ class LLMThought:
|
||||
) -> None:
|
||||
self._container.markdown(f"**{output}**")
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
self._container.markdown("**Tool encountered an error...**")
|
||||
self._container.exception(error)
|
||||
|
||||
@@ -349,7 +353,9 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||
self._require_current_thought().on_llm_end(response, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
self._require_current_thought().on_llm_error(error, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
@@ -372,7 +378,9 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||
)
|
||||
self._complete_current_thought()
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
self._require_current_thought().on_tool_error(error, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
@@ -393,7 +401,9 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_action(
|
||||
|
||||
@@ -211,7 +211,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
@@ -294,7 +294,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
run_id: UUID,
|
||||
@@ -365,7 +365,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
@@ -420,7 +420,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
|
||||
@@ -282,7 +282,9 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
if self.stream_logs:
|
||||
self.run.log(generation_resp)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
@@ -335,7 +337,9 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
@@ -373,7 +377,9 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
@@ -287,7 +287,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
if new_arg_supported
|
||||
else self._call(inputs)
|
||||
)
|
||||
except BaseException as e:
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
run_manager.on_chain_end(outputs)
|
||||
@@ -356,7 +356,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
if new_arg_supported
|
||||
else await self._acall(inputs)
|
||||
)
|
||||
except BaseException as e:
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
await run_manager.on_chain_end(outputs)
|
||||
|
||||
@@ -12,8 +12,8 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.hyde.prompts import PROMPT_MAP
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import Extra
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
|
||||
@@ -186,7 +186,7 @@ class LLMChain(Chain):
|
||||
)
|
||||
try:
|
||||
response = self.generate(input_list, run_manager=run_manager)
|
||||
except BaseException as e:
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
outputs = self.create_outputs(response)
|
||||
@@ -206,7 +206,7 @@ class LLMChain(Chain):
|
||||
)
|
||||
try:
|
||||
response = await self.agenerate(input_list, run_manager=run_manager)
|
||||
except BaseException as e:
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
outputs = self.create_outputs(response)
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.router.base import RouterChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import Extra
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from langchain.schema.messages import (
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain.schema.prompt import PromptValue
|
||||
|
||||
|
||||
def _convert_one_message_to_text(
|
||||
@@ -113,9 +112,6 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
prompt_params["ai_prompt"] = self.AI_PROMPT
|
||||
return convert_messages_to_prompt_anthropic(messages=messages, **prompt_params)
|
||||
|
||||
def convert_prompt(self, prompt: PromptValue) -> str:
|
||||
return self._convert_messages_to_prompt(prompt.to_messages())
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
|
||||
@@ -186,7 +186,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
except BaseException as e:
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
else:
|
||||
@@ -233,7 +233,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
except BaseException as e:
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
raise e
|
||||
else:
|
||||
@@ -303,7 +303,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
except BaseException as e:
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if run_managers:
|
||||
run_managers[i].on_llm_error(e)
|
||||
raise e
|
||||
@@ -364,7 +364,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
)
|
||||
exceptions = []
|
||||
for i, res in enumerate(results):
|
||||
if isinstance(res, BaseException):
|
||||
if isinstance(res, Exception):
|
||||
if run_managers:
|
||||
await run_managers[i].on_llm_error(res)
|
||||
exceptions.append(res)
|
||||
|
||||
@@ -3,9 +3,9 @@ from typing import Any, Callable, List, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from langchain.schema import BaseDocumentTransformer, Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils.math import cosine_similarity
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
|
||||
|
||||
class AwaEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -4,8 +4,8 @@ import os
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
|
||||
|
||||
class BedrockEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
@@ -14,8 +14,8 @@ import uuid
|
||||
from functools import partial
|
||||
from typing import Callable, List, Sequence, Union, cast
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseStore
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.storage.encoder_backed import EncoderBackedStore
|
||||
|
||||
NAMESPACE_UUID = uuid.UUID(int=1985)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
|
||||
@@ -18,8 +18,8 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -2,8 +2,8 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
DEFAULT_MODEL_ID = "sentence-transformers/clip-ViT-B-32"
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
from langchain.requests import Requests
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ if TYPE_CHECKING:
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch.client import MlClient
|
||||
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class ElasticsearchEmbeddings(Embeddings):
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
import requests
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
# Currently supported maximum batch size for embedding requests
|
||||
|
||||
@@ -6,8 +6,8 @@ from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings, BaseModel):
|
||||
|
||||
@@ -11,8 +11,8 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
|
||||
|
||||
class GPT4AllEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
@@ -2,8 +2,8 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, Field
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
|
||||
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
||||
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
DEFAULT_REPO_ID = "sentence-transformers/all-mpnet-base-v2"
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
|
||||
|
||||
class LlamaCppEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
@@ -24,8 +24,8 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -11,8 +11,8 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -2,8 +2,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterator, List, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
|
||||
|
||||
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
|
||||
|
||||
class ModelScopeEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
@@ -2,8 +2,8 @@ from typing import Any, Dict, List, Mapping, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
DEFAULT_EMBED_INSTRUCTION = "Represent this input: "
|
||||
|
||||
@@ -2,8 +2,8 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
|
||||
|
||||
class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
@@ -25,8 +25,8 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
|
||||
|
||||
class EmbeddingsContentHandler(ContentHandlerBase[List[str], List[List[float]]]):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Any, Callable, List
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.llms import SelfHostedPipeline
|
||||
from langchain.pydantic_v1 import Extra
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
|
||||
|
||||
def _embed_documents(pipeline: Any, *args: Any, **kwargs: Any) -> List[List[float]]:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import importlib.util
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
|
||||
|
||||
class SpacyEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
|
||||
DEFAULT_MODEL_URL = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.llms.vertexai import _VertexAICommon
|
||||
from langchain.pydantic_v1 import root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utilities.vertexai import raise_vertex_import_error
|
||||
|
||||
|
||||
|
||||
@@ -1,54 +1,44 @@
|
||||
"""Wrapper around Xinference embedding models."""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class XinferenceEmbeddings(Embeddings):
|
||||
|
||||
"""Wrapper around xinference embedding models.
|
||||
To use, you should have the xinference library installed:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install xinference
|
||||
|
||||
Check out: https://github.com/xorbitsai/inference
|
||||
To run, you need to start a Xinference supervisor on one server and Xinference workers on the other servers.
|
||||
|
||||
To run, you need to start a Xinference supervisor on one server and Xinference workers on the other servers
|
||||
Example:
|
||||
To start a local instance of Xinference, run
|
||||
.. code-block:: bash
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference
|
||||
|
||||
$ xinference
|
||||
You can also deploy Xinference in a distributed cluster. Here are the steps:
|
||||
|
||||
Starting the supervisor:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference-supervisor
|
||||
|
||||
$ xinference-supervisor
|
||||
Starting the worker:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference-worker
|
||||
$ xinference-worker
|
||||
|
||||
Then, launch a model using command line interface (CLI).
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference launch -n orca -s 3 -q q4_0
|
||||
$ xinference launch -n orca -s 3 -q q4_0
|
||||
|
||||
It will return a model UID. Then you can use Xinference Embedding with LangChain.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import XinferenceEmbeddings
|
||||
|
||||
@@ -10,11 +10,11 @@ from langchain.callbacks.manager import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema import RUN_KEY
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils.math import cosine_similarity
|
||||
|
||||
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import string
|
||||
from typing import Any, List, Optional, Sequence, Tuple
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.llm import LLMChain
|
||||
@@ -15,32 +14,13 @@ from langchain.schema import RUN_KEY
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
def _get_score(text: str) -> Optional[Tuple[str, int]]:
|
||||
match = re.search(r"grade:\s*(correct|incorrect)", text.strip(), re.IGNORECASE)
|
||||
def _get_score(verdict: str) -> Optional[int]:
|
||||
match = re.search(r"(?i)(?:grade:\s*)?(correct|incorrect)", verdict)
|
||||
if match:
|
||||
if match.group(1).upper() == "CORRECT":
|
||||
return "CORRECT", 1
|
||||
return 1
|
||||
elif match.group(1).upper() == "INCORRECT":
|
||||
return "INCORRECT", 0
|
||||
try:
|
||||
first_word = (
|
||||
text.strip().split()[0].translate(str.maketrans("", "", string.punctuation))
|
||||
)
|
||||
if first_word.upper() == "CORRECT":
|
||||
return "CORRECT", 1
|
||||
elif first_word.upper() == "INCORRECT":
|
||||
return "INCORRECT", 0
|
||||
last_word = (
|
||||
text.strip()
|
||||
.split()[-1]
|
||||
.translate(str.maketrans("", "", string.punctuation))
|
||||
)
|
||||
if last_word.upper() == "CORRECT":
|
||||
return "CORRECT", 1
|
||||
elif last_word.upper() == "INCORRECT":
|
||||
return "INCORRECT", 0
|
||||
except IndexError:
|
||||
pass
|
||||
return 0
|
||||
return None
|
||||
|
||||
|
||||
@@ -53,15 +33,17 @@ def _parse_string_eval_output(text: str) -> dict:
|
||||
Returns:
|
||||
Any: The parsed output.
|
||||
"""
|
||||
reasoning = text.strip()
|
||||
parsed_scores = _get_score(reasoning)
|
||||
if parsed_scores is None:
|
||||
value, score = None, None
|
||||
splits = text.strip().rsplit("\n", maxsplit=1)
|
||||
if len(splits) == 1:
|
||||
verdict = splits[0]
|
||||
reasoning = None
|
||||
else:
|
||||
value, score = parsed_scores
|
||||
reasoning, verdict = splits
|
||||
reasoning = reasoning.strip()
|
||||
score = _get_score(verdict)
|
||||
return {
|
||||
"reasoning": reasoning,
|
||||
"value": value,
|
||||
"value": verdict,
|
||||
"score": score,
|
||||
}
|
||||
|
||||
|
||||
@@ -3,11 +3,11 @@ from typing import Any, Dict, List, Optional, Type
|
||||
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
|
||||
from langchain.chains.retrieval_qa.base import RetrievalQA
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, Field
|
||||
from langchain.schema import Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
@@ -10,7 +10,6 @@ from langchain.llms.base import LLM
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.schema.prompt import PromptValue
|
||||
from langchain.utils import (
|
||||
check_package_version,
|
||||
get_from_dict_or_env,
|
||||
@@ -235,9 +234,6 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
)
|
||||
return response.completion
|
||||
|
||||
def convert_prompt(self, prompt: PromptValue) -> str:
|
||||
return self._wrap_prompt(prompt.to_string())
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
@@ -388,7 +388,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
except BaseException as e:
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
else:
|
||||
@@ -435,7 +435,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
except BaseException as e:
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
raise e
|
||||
else:
|
||||
@@ -523,7 +523,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
if new_arg_supported
|
||||
else self._generate(prompts, stop=stop)
|
||||
)
|
||||
except BaseException as e:
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
@@ -674,7 +674,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
if new_arg_supported
|
||||
else await self._agenerate(prompts, stop=stop)
|
||||
)
|
||||
except BaseException as e:
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await asyncio.gather(
|
||||
*[run_manager.on_llm_error(e) for run_manager in run_managers]
|
||||
)
|
||||
|
||||
@@ -196,9 +196,7 @@ class BaseOpenAI(BaseLLM):
|
||||
def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
|
||||
"""Initialize the OpenAI object."""
|
||||
model_name = data.get("model_name", "")
|
||||
if (
|
||||
model_name.startswith("gpt-3.5-turbo") or model_name.startswith("gpt-4")
|
||||
) and not model_name.endswith("-instruct"):
|
||||
if model_name.startswith("gpt-3.5-turbo") or model_name.startswith("gpt-4"):
|
||||
warnings.warn(
|
||||
"You are trying to use a chat model. This way of initializing it is "
|
||||
"no longer supported. Instead, please use: "
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Generator, List, Mapping, Optional, Union
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
@@ -11,65 +11,55 @@ if TYPE_CHECKING:
|
||||
class Xinference(LLM):
|
||||
"""Wrapper for accessing Xinference's large-scale model inference service.
|
||||
To use, you should have the xinference library installed:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "xinference[all]"
|
||||
pip install "xinference[all]"
|
||||
|
||||
Check out: https://github.com/xorbitsai/inference
|
||||
To run, you need to start a Xinference supervisor on one server and Xinference workers on the other servers
|
||||
|
||||
Example:
|
||||
To start a local instance of Xinference, run
|
||||
.. code-block:: bash
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference
|
||||
$ xinference
|
||||
|
||||
You can also deploy Xinference in a distributed cluster. Here are the steps:
|
||||
|
||||
Starting the supervisor:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference-supervisor
|
||||
$ xinference-supervisor
|
||||
|
||||
Starting the worker:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference-worker
|
||||
$ xinference-worker
|
||||
|
||||
Then, launch a model using command line interface (CLI).
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference launch -n orca -s 3 -q q4_0
|
||||
$ xinference launch -n orca -s 3 -q q4_0
|
||||
|
||||
It will return a model UID. Then, you can use Xinference with LangChain.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
.. code-block:: python
|
||||
from langchain.llms import Xinference
|
||||
|
||||
from langchain.llms import Xinference
|
||||
llm = Xinference(
|
||||
server_url="http://0.0.0.0:9997",
|
||||
model_uid = {model_uid} # replace model_uid with the model UID return from launching the model
|
||||
)
|
||||
|
||||
llm = Xinference(
|
||||
server_url="http://0.0.0.0:9997",
|
||||
model_uid = {model_uid} # replace model_uid with the model UID return from launching the model
|
||||
)
|
||||
|
||||
llm(
|
||||
prompt="Q: where can we visit in the capital of France? A:",
|
||||
generate_config={"max_tokens": 1024, "stream": True},
|
||||
)
|
||||
llm(
|
||||
prompt="Q: where can we visit in the capital of France? A:",
|
||||
generate_config={"max_tokens": 1024, "stream": True},
|
||||
)
|
||||
|
||||
To view all the supported builtin models, run:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ xinference list --all
|
||||
|
||||
""" # noqa: E501
|
||||
@@ -79,14 +69,9 @@ class Xinference(LLM):
|
||||
"""URL of the xinference server"""
|
||||
model_uid: Optional[str]
|
||||
"""UID of the launched model"""
|
||||
model_kwargs: Dict[str, Any]
|
||||
"""Key word arguments to be passed to xinference.LLM"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: Optional[str] = None,
|
||||
model_uid: Optional[str] = None,
|
||||
**model_kwargs: Any,
|
||||
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
|
||||
):
|
||||
try:
|
||||
from xinference.client import RESTfulClient
|
||||
@@ -96,13 +81,10 @@ class Xinference(LLM):
|
||||
" with `pip install xinference`."
|
||||
) from e
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
|
||||
super().__init__(
|
||||
**{
|
||||
"server_url": server_url,
|
||||
"model_uid": model_uid,
|
||||
"model_kwargs": model_kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -125,7 +107,6 @@ class Xinference(LLM):
|
||||
return {
|
||||
**{"server_url": self.server_url},
|
||||
**{"model_uid": self.model_uid},
|
||||
**{"model_kwargs": self.model_kwargs},
|
||||
}
|
||||
|
||||
def _call(
|
||||
@@ -150,8 +131,6 @@ class Xinference(LLM):
|
||||
|
||||
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
|
||||
|
||||
generate_config = {**self.model_kwargs, **generate_config}
|
||||
|
||||
if stop:
|
||||
generate_config["stop"] = stop
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.prompts.example_selector.base import BaseExampleSelector
|
||||
from langchain.pydantic_v1 import BaseModel, Extra
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
|
||||
|
||||
@@ -7,12 +7,12 @@ from langchain.document_transformers.embeddings_redundant_filter import (
|
||||
_get_embeddings_from_stateful_docs,
|
||||
get_stateful_documents,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import root_validator
|
||||
from langchain.retrievers.document_compressors.base import (
|
||||
BaseDocumentCompressor,
|
||||
)
|
||||
from langchain.schema import Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils.math import cosine_similarity
|
||||
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ from typing import Any, List, Optional
|
||||
import numpy as np
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
|
||||
|
||||
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
||||
|
||||
@@ -3,9 +3,9 @@ import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import root_validator
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.vectorstores.milvus import Milvus
|
||||
|
||||
# TODO: Update to MilvusClient + Hybrid Search when available
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user