mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 00:49:25 +00:00
changes to llm chain (#6328)
- return raw and full output (but keep run shortcut method functional) - change output parser to take in generations (good for working with messages) - add output parser to base class, always run (default to same as current) --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
d3c2eab0b3
commit
6a4a950a3c
@ -17,7 +17,16 @@
|
||||
"execution_count": 1,
|
||||
"id": "34f04daf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.4) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.chains import create_extraction_chain, create_extraction_chain_pydantic\n",
|
||||
@ -71,7 +80,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 4,
|
||||
"id": "640bd005",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -84,7 +93,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 5,
|
||||
"id": "64313214",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -102,7 +111,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 6,
|
||||
"id": "cc5436ed",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -119,7 +128,7 @@
|
||||
" 'person_hair_color': 'brunette'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -150,7 +159,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 7,
|
||||
"id": "6792866b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -161,7 +170,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 8,
|
||||
"id": "36a63761",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -176,7 +185,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 9,
|
||||
"id": "8ffd1e57",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -186,7 +195,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 10,
|
||||
"id": "24baa954",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
@ -220,7 +229,7 @@
|
||||
" Properties(person_name='Claudia', person_height=6, person_hair_color='brunette', dog_breed=None, dog_name=None)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -228,13 +237,21 @@
|
||||
"source": [
|
||||
"chain.run(inp)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0df61283",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "general",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "general"
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
@ -246,7 +263,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
181
docs/extras/modules/chains/additional/qa_citations.ipynb
Normal file
181
docs/extras/modules/chains/additional/qa_citations.ipynb
Normal file
@ -0,0 +1,181 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9b5c258f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Question-Answering Citations\n",
|
||||
"\n",
|
||||
"This notebook shows how to use OpenAI functions ability to extract citations from text."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "eae4ca3e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.4) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chains import create_citation_fuzzy_match_chain\n",
|
||||
"from langchain.chat_models import ChatOpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "2c6e62ee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"question = \"What did the author do during college?\"\n",
|
||||
"context = \"\"\"\n",
|
||||
"My name is Jason Liu, and I grew up in Toronto Canada but I was born in China.\n",
|
||||
"I went to an arts highschool but in university I studied Computational Mathematics and physics. \n",
|
||||
"As part of coop I worked at many companies including Stitchfix, Facebook.\n",
|
||||
"I also started the Data Science club at the University of Waterloo and I was the president of the club for 2 years.\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "078e0300",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "02cad6d0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = create_citation_fuzzy_match_chain(llm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "e3c6e7ba",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result = chain.run(question=question, context=context)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "6f7615f2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"question='What did the author do during college?' answer=[FactWithEvidence(fact='The author studied Computational Mathematics and physics in university.', substring_quote=['in university I studied Computational Mathematics and physics']), FactWithEvidence(fact='The author started the Data Science club at the University of Waterloo.', substring_quote=['I also started the Data Science club at the University of Waterloo']), FactWithEvidence(fact='The author was the president of the Data Science club for 2 years.', substring_quote=['I was the president of the club for 2 years'])]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "3be6f366",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def highlight(text, span):\n",
|
||||
" return (\n",
|
||||
" \"...\"\n",
|
||||
" + text[span[0] - 20 : span[0]]\n",
|
||||
" + \"*\"\n",
|
||||
" + \"\\033[91m\"\n",
|
||||
" + text[span[0] : span[1]]\n",
|
||||
" + \"\\033[0m\"\n",
|
||||
" + \"*\"\n",
|
||||
" + text[span[1] : span[1] + 20]\n",
|
||||
" + \"...\"\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "636c4528",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Statement: The author studied Computational Mathematics and physics in university.\n",
|
||||
"Citation: ...arts highschool but *\u001b[91min university I studied Computational Mathematics and physics\u001b[0m*. \n",
|
||||
"As part of coop I...\n",
|
||||
"\n",
|
||||
"Statement: The author started the Data Science club at the University of Waterloo.\n",
|
||||
"Citation: ...titchfix, Facebook.\n",
|
||||
"*\u001b[91mI also started the Data Science club at the University of Waterloo\u001b[0m* and I was the presi...\n",
|
||||
"\n",
|
||||
"Statement: The author was the president of the Data Science club for 2 years.\n",
|
||||
"Citation: ...ity of Waterloo and *\u001b[91mI was the president of the club for 2 years\u001b[0m*.\n",
|
||||
"...\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for fact in result.answer:\n",
|
||||
" print(\"Statement:\", fact.fact)\n",
|
||||
" for span in fact.get_spans(context):\n",
|
||||
" print(\"Citation:\", highlight(context, span))\n",
|
||||
" print()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8409cab0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -17,7 +17,16 @@
|
||||
"execution_count": 1,
|
||||
"id": "bafb496a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.4) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.chains import create_tagging_chain, create_tagging_chain_pydantic\n",
|
||||
@ -52,7 +61,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 3,
|
||||
"id": "8329f943",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -68,7 +77,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 4,
|
||||
"id": "6146ae70",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -88,7 +97,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 59,
|
||||
"execution_count": 5,
|
||||
"id": "5509b6a6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -98,7 +107,7 @@
|
||||
"{'sentiment': 'positive', 'language': 'Spanish'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 59,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -110,17 +119,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 60,
|
||||
"execution_count": 6,
|
||||
"id": "9154474c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'sentiment': 'enojado', 'aggressiveness': 1, 'language': 'Spanish'}"
|
||||
"{'sentiment': 'enojado', 'aggressiveness': 1, 'language': 'es'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 60,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -132,7 +141,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 61,
|
||||
"execution_count": 7,
|
||||
"id": "aae85b27",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -142,7 +151,7 @@
|
||||
"{'sentiment': 'positive', 'aggressiveness': 0, 'language': 'English'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 61,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -176,7 +185,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 8,
|
||||
"id": "6a5f7961",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -200,7 +209,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 9,
|
||||
"id": "e5a5881f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -218,7 +227,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 10,
|
||||
"id": "d9b9d53d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -228,7 +237,7 @@
|
||||
"{'sentiment': 'happy', 'aggressiveness': 0, 'language': 'spanish'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -240,7 +249,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 11,
|
||||
"id": "1c12fa00",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -250,7 +259,7 @@
|
||||
"{'sentiment': 'sad', 'aggressiveness': 10, 'language': 'spanish'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -262,7 +271,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 12,
|
||||
"id": "0bdfcb05",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -272,7 +281,7 @@
|
||||
"{'sentiment': 'neutral', 'aggressiveness': 0, 'language': 'english'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -304,7 +313,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 13,
|
||||
"id": "bf1f367e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -315,7 +324,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 14,
|
||||
"id": "83a2e826",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -334,7 +343,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 15,
|
||||
"id": "6e404892",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -344,7 +353,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 16,
|
||||
"id": "b5fc43c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -355,7 +364,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 17,
|
||||
"id": "5074bcc3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -365,7 +374,7 @@
|
||||
"Tags(sentiment='sad', aggressiveness=10, language='spanish')"
|
||||
]
|
||||
},
|
||||
"execution_count": 26,
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -377,9 +386,9 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "general",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "general"
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
@ -391,7 +400,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -24,6 +24,7 @@ from langchain.chains.mapreduce import MapReduceChain
|
||||
from langchain.chains.moderation import OpenAIModerationChain
|
||||
from langchain.chains.natbot.base import NatBotChain
|
||||
from langchain.chains.openai_functions import (
|
||||
create_citation_fuzzy_match_chain,
|
||||
create_extraction_chain,
|
||||
create_extraction_chain_pydantic,
|
||||
create_tagging_chain,
|
||||
@ -93,4 +94,5 @@ __all__ = [
|
||||
"create_tagging_chain",
|
||||
"create_tagging_chain_pydantic",
|
||||
"load_chain",
|
||||
"create_citation_fuzzy_match_chain",
|
||||
]
|
||||
|
@ -247,6 +247,15 @@ class Chain(Serializable, ABC):
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
|
||||
@property
|
||||
def _run_output_key(self) -> str:
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
f"`run` not supported when there is not exactly "
|
||||
f"one output key. Got {self.output_keys}."
|
||||
)
|
||||
return self.output_keys[0]
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -255,19 +264,16 @@ class Chain(Serializable, ABC):
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
f"`run` not supported when there is not exactly "
|
||||
f"one output key. Got {self.output_keys}."
|
||||
)
|
||||
# Run at start to make sure this is possible/defined
|
||||
_output_key = self._run_output_key
|
||||
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError("`run` supports only one positional argument.")
|
||||
return self(args[0], callbacks=callbacks, tags=tags)[self.output_keys[0]]
|
||||
return self(args[0], callbacks=callbacks, tags=tags)[_output_key]
|
||||
|
||||
if kwargs and not args:
|
||||
return self(kwargs, callbacks=callbacks, tags=tags)[self.output_keys[0]]
|
||||
return self(kwargs, callbacks=callbacks, tags=tags)[_output_key]
|
||||
|
||||
if not kwargs and not args:
|
||||
raise ValueError(
|
||||
|
@ -1,9 +1,10 @@
|
||||
"""Chain that just formats a prompt and calls an LLM."""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from pydantic import Extra
|
||||
from pydantic import Extra, Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
@ -18,7 +19,12 @@ from langchain.input import get_colored_text
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import LLMResult, PromptValue
|
||||
from langchain.schema import (
|
||||
BaseLLMOutputParser,
|
||||
LLMResult,
|
||||
NoOpOutputParser,
|
||||
PromptValue,
|
||||
)
|
||||
|
||||
|
||||
class LLMChain(Chain):
|
||||
@ -42,7 +48,16 @@ class LLMChain(Chain):
|
||||
prompt: BasePromptTemplate
|
||||
"""Prompt object to use."""
|
||||
llm: BaseLanguageModel
|
||||
"""Language model to call."""
|
||||
output_key: str = "text" #: :meta private:
|
||||
output_parser: BaseLLMOutputParser = Field(default_factory=NoOpOutputParser)
|
||||
"""Output parser to use.
|
||||
Defaults to one that takes the most likely string but does not change it
|
||||
otherwise."""
|
||||
return_final_only: bool = True
|
||||
"""Whether to return only the final parsed result. Defaults to True.
|
||||
If false, will return a bunch of extra information about the generation."""
|
||||
llm_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -64,7 +79,10 @@ class LLMChain(Chain):
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
if self.return_final_only:
|
||||
return [self.output_key]
|
||||
else:
|
||||
return [self.output_key, "full_generation"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
@ -82,7 +100,10 @@ class LLMChain(Chain):
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||
return self.llm.generate_prompt(
|
||||
prompts, stop, callbacks=run_manager.get_child() if run_manager else None
|
||||
prompts,
|
||||
stop,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**self.llm_kwargs,
|
||||
)
|
||||
|
||||
async def agenerate(
|
||||
@ -93,7 +114,10 @@ class LLMChain(Chain):
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager)
|
||||
return await self.llm.agenerate_prompt(
|
||||
prompts, stop, callbacks=run_manager.get_child() if run_manager else None
|
||||
prompts,
|
||||
stop,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**self.llm_kwargs,
|
||||
)
|
||||
|
||||
def prep_prompts(
|
||||
@ -184,13 +208,23 @@ class LLMChain(Chain):
|
||||
await run_manager.on_chain_end({"outputs": outputs})
|
||||
return outputs
|
||||
|
||||
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
|
||||
@property
|
||||
def _run_output_key(self) -> str:
|
||||
return self.output_key
|
||||
|
||||
def create_outputs(self, llm_result: LLMResult) -> List[Dict[str, Any]]:
|
||||
"""Create outputs from response."""
|
||||
return [
|
||||
result = [
|
||||
# Get the text of the top generated string.
|
||||
{self.output_key: generation[0].text}
|
||||
for generation in response.generations
|
||||
{
|
||||
self.output_key: self.output_parser.parse_result(generation),
|
||||
"full_generation": generation,
|
||||
}
|
||||
for generation in llm_result.generations
|
||||
]
|
||||
if self.return_final_only:
|
||||
result = [{self.output_key: r[self.output_key]} for r in result]
|
||||
return result
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
@ -238,6 +272,10 @@ class LLMChain(Chain):
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, Any]]:
|
||||
"""Call predict and then parse the results."""
|
||||
warnings.warn(
|
||||
"The predict_and_parse method is deprecated, "
|
||||
"instead pass an output parser directly to LLMChain."
|
||||
)
|
||||
result = self.predict(callbacks=callbacks, **kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
@ -248,6 +286,10 @@ class LLMChain(Chain):
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Call apredict and then parse the results."""
|
||||
warnings.warn(
|
||||
"The apredict_and_parse method is deprecated, "
|
||||
"instead pass an output parser directly to LLMChain."
|
||||
)
|
||||
result = await self.apredict(callbacks=callbacks, **kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
@ -258,25 +300,34 @@ class LLMChain(Chain):
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
warnings.warn(
|
||||
"The apply_and_parse method is deprecated, "
|
||||
"instead pass an output parser directly to LLMChain."
|
||||
)
|
||||
result = self.apply(input_list, callbacks=callbacks)
|
||||
return self._parse_result(result)
|
||||
return self._parse_generation(result)
|
||||
|
||||
def _parse_result(
|
||||
self, result: List[Dict[str, str]]
|
||||
def _parse_generation(
|
||||
self, generation: List[Dict[str, str]]
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
if self.prompt.output_parser is not None:
|
||||
return [
|
||||
self.prompt.output_parser.parse(res[self.output_key]) for res in result
|
||||
self.prompt.output_parser.parse(res[self.output_key])
|
||||
for res in generation
|
||||
]
|
||||
else:
|
||||
return result
|
||||
return generation
|
||||
|
||||
async def aapply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
warnings.warn(
|
||||
"The aapply_and_parse method is deprecated, "
|
||||
"instead pass an output parser directly to LLMChain."
|
||||
)
|
||||
result = await self.aapply(input_list, callbacks=callbacks)
|
||||
return self._parse_result(result)
|
||||
return self._parse_generation(result)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
|
@ -24,7 +24,11 @@ from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChai
|
||||
from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA
|
||||
from langchain.chains.sql_database.base import SQLDatabaseChain
|
||||
from langchain.llms.loading import load_llm, load_llm_from_config
|
||||
from langchain.prompts.loading import load_prompt, load_prompt_from_config
|
||||
from langchain.prompts.loading import (
|
||||
_load_output_parser,
|
||||
load_prompt,
|
||||
load_prompt_from_config,
|
||||
)
|
||||
from langchain.utilities.loading import try_load_from_hub
|
||||
|
||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/"
|
||||
@ -47,6 +51,7 @@ def _load_llm_chain(config: dict, **kwargs: Any) -> LLMChain:
|
||||
prompt = load_prompt(config.pop("prompt_path"))
|
||||
else:
|
||||
raise ValueError("One of `prompt` or `prompt_path` must be present.")
|
||||
_load_output_parser(config)
|
||||
|
||||
return LLMChain(llm=llm, prompt=prompt, **config)
|
||||
|
||||
|
@ -1,233 +0,0 @@
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.sequential import SimpleSequentialChain
|
||||
from langchain.chains.transform import TransformChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
EXTRACTION_NAME = "information_extraction"
|
||||
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
|
||||
|
||||
|
||||
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Resolves the $ref keys in a JSON schema object using the provided definitions.
|
||||
"""
|
||||
if isinstance(schema, list):
|
||||
for i, item in enumerate(schema):
|
||||
schema[i] = _resolve_schema_references(item, definitions)
|
||||
elif isinstance(schema, dict):
|
||||
if "$ref" in schema:
|
||||
ref_key = schema.pop("$ref").split("/")[-1]
|
||||
ref = definitions.get(ref_key, {})
|
||||
schema.update(ref)
|
||||
else:
|
||||
for key, value in schema.items():
|
||||
schema[key] = _resolve_schema_references(value, definitions)
|
||||
return schema
|
||||
|
||||
|
||||
def _get_function_arguments(inputs: dict) -> str:
|
||||
message = inputs["input"]
|
||||
try:
|
||||
func_call = message.additional_kwargs["function_call"]
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Could not parse function call: {exc}")
|
||||
|
||||
return func_call["arguments"]
|
||||
|
||||
|
||||
def _parse_tag(inputs: dict) -> dict:
|
||||
args = _get_function_arguments(inputs)
|
||||
return {"output": json.loads(args)}
|
||||
|
||||
|
||||
def _parse_tag_pydantic(inputs: dict, pydantic_schema: Any) -> dict:
|
||||
args = _get_function_arguments(inputs)
|
||||
args = pydantic_schema.parse_raw(args)
|
||||
return {"output": args}
|
||||
|
||||
|
||||
def _parse_entities(inputs: dict) -> dict:
|
||||
args = _get_function_arguments(inputs)
|
||||
return {"output": json.loads(args)["info"]}
|
||||
|
||||
|
||||
def _parse_entities_pydantic(inputs: dict, pydantic_schema: Any) -> dict:
|
||||
args = _get_function_arguments(inputs)
|
||||
pydantic_args = pydantic_schema.parse_raw(args)
|
||||
return {"output": pydantic_args.info}
|
||||
|
||||
|
||||
class OpenAIFunctionsChain(Chain):
|
||||
prompt: BasePromptTemplate
|
||||
llm: BaseLanguageModel
|
||||
functions: List[Dict]
|
||||
kwargs: Dict = Field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
return self.prompt.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
return ["output"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_inputs = {k: v for k, v in inputs.items() if k in self.prompt.input_variables}
|
||||
prompt = self.prompt.format_prompt(**_inputs)
|
||||
messages = prompt.to_messages()
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks, **self.kwargs
|
||||
)
|
||||
return {"output": predicted_message}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_inputs = {k: v for k, v in inputs.items() if k in self.prompt.input_variables}
|
||||
prompt = self.prompt.format_prompt(**_inputs)
|
||||
messages = prompt.to_messages()
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
predicted_message = await self.llm.apredict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks, **self.kwargs
|
||||
)
|
||||
return {"output": predicted_message}
|
||||
|
||||
|
||||
def _convert_schema(schema: dict) -> dict:
|
||||
props = {k: {"title": k, **v} for k, v in schema["properties"].items()}
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": props,
|
||||
"required": schema.get("required", []),
|
||||
}
|
||||
|
||||
|
||||
def _get_extraction_functions(entity_schema: dict) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"name": EXTRACTION_NAME,
|
||||
"description": "Extracts the relevant information from the passage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"info": {"type": "array", "items": _convert_schema(entity_schema)}
|
||||
},
|
||||
"required": ["info"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def _get_tagging_functions(schema: dict) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"name": EXTRACTION_NAME,
|
||||
"description": "Extracts the relevant information from the passage.",
|
||||
"parameters": _convert_schema(schema),
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned\
|
||||
in the following passage together with their properties.
|
||||
|
||||
Passage:
|
||||
{input}
|
||||
"""
|
||||
|
||||
|
||||
def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
functions = _get_extraction_functions(schema)
|
||||
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||
chain = OpenAIFunctionsChain(
|
||||
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
||||
)
|
||||
parsing_chain = TransformChain(
|
||||
transform=_parse_entities,
|
||||
input_variables=["input"],
|
||||
output_variables=["output"],
|
||||
)
|
||||
return SimpleSequentialChain(chains=[chain, parsing_chain])
|
||||
|
||||
|
||||
def create_extraction_chain_pydantic(
|
||||
pydantic_schema: Any, llm: BaseLanguageModel
|
||||
) -> Chain:
|
||||
class PydanticSchema(BaseModel):
|
||||
info: List[pydantic_schema] # type: ignore
|
||||
|
||||
openai_schema = PydanticSchema.schema()
|
||||
openai_schema = _resolve_schema_references(
|
||||
openai_schema, openai_schema["definitions"]
|
||||
)
|
||||
|
||||
functions = _get_extraction_functions(openai_schema)
|
||||
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||
chain = OpenAIFunctionsChain(
|
||||
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
||||
)
|
||||
pydantic_parsing_chain = TransformChain(
|
||||
transform=partial(_parse_entities_pydantic, pydantic_schema=PydanticSchema),
|
||||
input_variables=["input"],
|
||||
output_variables=["output"],
|
||||
)
|
||||
return SimpleSequentialChain(chains=[chain, pydantic_parsing_chain])
|
||||
|
||||
|
||||
_TAGGING_TEMPLATE = """Extract the desired information from the following passage.
|
||||
|
||||
Passage:
|
||||
{input}
|
||||
"""
|
||||
|
||||
|
||||
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
functions = _get_tagging_functions(schema)
|
||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
chain = OpenAIFunctionsChain(
|
||||
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
||||
)
|
||||
parsing_chain = TransformChain(
|
||||
transform=_parse_tag, input_variables=["input"], output_variables=["output"]
|
||||
)
|
||||
return SimpleSequentialChain(chains=[chain, parsing_chain])
|
||||
|
||||
|
||||
def create_tagging_chain_pydantic(
|
||||
pydantic_schema: Any, llm: BaseLanguageModel
|
||||
) -> Chain:
|
||||
openai_schema = pydantic_schema.schema()
|
||||
|
||||
functions = _get_tagging_functions(openai_schema)
|
||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
chain = OpenAIFunctionsChain(
|
||||
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
||||
)
|
||||
pydantic_parsing_chain = TransformChain(
|
||||
transform=partial(_parse_tag_pydantic, pydantic_schema=pydantic_schema),
|
||||
input_variables=["input"],
|
||||
output_variables=["output"],
|
||||
)
|
||||
|
||||
return SimpleSequentialChain(chains=[chain, pydantic_parsing_chain])
|
19
langchain/chains/openai_functions/__init__.py
Normal file
19
langchain/chains/openai_functions/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
from langchain.chains.openai_functions.citation_fuzzy_match import (
|
||||
create_citation_fuzzy_match_chain,
|
||||
)
|
||||
from langchain.chains.openai_functions.extraction import (
|
||||
create_extraction_chain,
|
||||
create_extraction_chain_pydantic,
|
||||
)
|
||||
from langchain.chains.openai_functions.tagging import (
|
||||
create_tagging_chain,
|
||||
create_tagging_chain_pydantic,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_tagging_chain",
|
||||
"create_tagging_chain_pydantic",
|
||||
"create_extraction_chain_pydantic",
|
||||
"create_extraction_chain",
|
||||
"create_citation_fuzzy_match_chain",
|
||||
]
|
101
langchain/chains/openai_functions/citation_fuzzy_match.py
Normal file
101
langchain/chains/openai_functions/citation_fuzzy_match.py
Normal file
@ -0,0 +1,101 @@
|
||||
from typing import Iterator, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.output_parsers.openai_functions import (
|
||||
PydanticOutputFunctionsParser,
|
||||
)
|
||||
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
|
||||
class FactWithEvidence(BaseModel):
|
||||
"""Class representing single statement.
|
||||
|
||||
Each fact has a body and a list of sources.
|
||||
If there are multiple facts make sure to break them apart
|
||||
such that each one only uses a set of sources that are relevant to it.
|
||||
"""
|
||||
|
||||
fact: str = Field(..., description="Body of the sentence, as part of a response")
|
||||
substring_quote: List[str] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Each source should be a direct quote from the context, "
|
||||
"as a substring of the original content"
|
||||
),
|
||||
)
|
||||
|
||||
def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]:
|
||||
import regex
|
||||
|
||||
minor = quote
|
||||
major = context
|
||||
|
||||
errs_ = 0
|
||||
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||
while s is None and errs_ <= errs:
|
||||
errs_ += 1
|
||||
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||
|
||||
if s is not None:
|
||||
yield from s.spans()
|
||||
|
||||
def get_spans(self, context: str) -> Iterator[str]:
|
||||
for quote in self.substring_quote:
|
||||
yield from self._get_span(quote, context)
|
||||
|
||||
|
||||
class QuestionAnswer(BaseModel):
|
||||
"""A question and its answer as a list of facts each one should have a source.
|
||||
each sentence contains a body and a list of sources."""
|
||||
|
||||
question: str = Field(..., description="Question that was asked")
|
||||
answer: List[FactWithEvidence] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Body of the answer, each fact should be "
|
||||
"its separate object with a body and a list of sources"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
|
||||
output_parser = PydanticOutputFunctionsParser(pydantic_schema=QuestionAnswer)
|
||||
schema = QuestionAnswer.schema()
|
||||
functions = [
|
||||
{
|
||||
"name": schema["title"],
|
||||
"description": schema["description"],
|
||||
"parameters": schema,
|
||||
}
|
||||
]
|
||||
kwargs = {"function_call": {"name": schema["title"]}}
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content=(
|
||||
"You are a world class algorithm to answer "
|
||||
"questions with correct and exact citations."
|
||||
)
|
||||
),
|
||||
HumanMessage(content="Answer question using the following context"),
|
||||
HumanMessagePromptTemplate.from_template("{context}"),
|
||||
HumanMessagePromptTemplate.from_template("Question: {question}"),
|
||||
HumanMessage(
|
||||
content=(
|
||||
"Tips: Make sure to cite your sources, "
|
||||
"and use the exact words from the context."
|
||||
)
|
||||
),
|
||||
]
|
||||
prompt = ChatPromptTemplate(messages=messages)
|
||||
|
||||
chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs={**{"functions": functions}, **kwargs},
|
||||
output_parser=output_parser,
|
||||
)
|
||||
return chain
|
81
langchain/chains/openai_functions/extraction.py
Normal file
81
langchain/chains/openai_functions/extraction.py
Normal file
@ -0,0 +1,81 @@
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.openai_functions.utils import (
|
||||
_convert_schema,
|
||||
_resolve_schema_references,
|
||||
)
|
||||
from langchain.output_parsers.openai_functions import (
|
||||
JsonKeyOutputFunctionsParser,
|
||||
PydanticAttrOutputFunctionsParser,
|
||||
)
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
|
||||
EXTRACTION_NAME = "information_extraction"
|
||||
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
|
||||
|
||||
|
||||
def _get_extraction_functions(entity_schema: dict) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"name": EXTRACTION_NAME,
|
||||
"description": "Extracts the relevant information from the passage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"info": {"type": "array", "items": _convert_schema(entity_schema)}
|
||||
},
|
||||
"required": ["info"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned\
|
||||
in the following passage together with their properties.
|
||||
|
||||
Passage:
|
||||
{input}
|
||||
"""
|
||||
|
||||
|
||||
def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
functions = _get_extraction_functions(schema)
|
||||
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||
output_parser = JsonKeyOutputFunctionsParser(key_name="info")
|
||||
chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
|
||||
output_parser=output_parser,
|
||||
)
|
||||
return chain
|
||||
|
||||
|
||||
def create_extraction_chain_pydantic(
|
||||
pydantic_schema: Any, llm: BaseLanguageModel
|
||||
) -> Chain:
|
||||
class PydanticSchema(BaseModel):
|
||||
info: List[pydantic_schema] # type: ignore
|
||||
|
||||
openai_schema = PydanticSchema.schema()
|
||||
openai_schema = _resolve_schema_references(
|
||||
openai_schema, openai_schema["definitions"]
|
||||
)
|
||||
|
||||
functions = _get_extraction_functions(openai_schema)
|
||||
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||
output_parser = PydanticAttrOutputFunctionsParser(
|
||||
pydantic_schema=PydanticSchema, attr_name="info"
|
||||
)
|
||||
chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
|
||||
output_parser=output_parser,
|
||||
)
|
||||
return chain
|
61
langchain/chains/openai_functions/tagging.py
Normal file
61
langchain/chains/openai_functions/tagging.py
Normal file
@ -0,0 +1,61 @@
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.openai_functions.utils import _convert_schema
|
||||
from langchain.output_parsers.openai_functions import (
|
||||
JsonOutputFunctionsParser,
|
||||
PydanticOutputFunctionsParser,
|
||||
)
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
|
||||
EXTRACTION_NAME = "information_extraction"
|
||||
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
|
||||
|
||||
|
||||
def _get_tagging_functions(schema: dict) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"name": EXTRACTION_NAME,
|
||||
"description": "Extracts the relevant information from the passage.",
|
||||
"parameters": _convert_schema(schema),
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
_TAGGING_TEMPLATE = """Extract the desired information from the following passage.
|
||||
|
||||
Passage:
|
||||
{input}
|
||||
"""
|
||||
|
||||
|
||||
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
functions = _get_tagging_functions(schema)
|
||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
output_parser = JsonOutputFunctionsParser()
|
||||
chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
|
||||
output_parser=output_parser,
|
||||
)
|
||||
return chain
|
||||
|
||||
|
||||
def create_tagging_chain_pydantic(
|
||||
pydantic_schema: Any, llm: BaseLanguageModel
|
||||
) -> Chain:
|
||||
openai_schema = pydantic_schema.schema()
|
||||
|
||||
functions = _get_tagging_functions(openai_schema)
|
||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
|
||||
chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
|
||||
output_parser=output_parser,
|
||||
)
|
||||
return chain
|
28
langchain/chains/openai_functions/utils.py
Normal file
28
langchain/chains/openai_functions/utils.py
Normal file
@ -0,0 +1,28 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Resolves the $ref keys in a JSON schema object using the provided definitions.
|
||||
"""
|
||||
if isinstance(schema, list):
|
||||
for i, item in enumerate(schema):
|
||||
schema[i] = _resolve_schema_references(item, definitions)
|
||||
elif isinstance(schema, dict):
|
||||
if "$ref" in schema:
|
||||
ref_key = schema.pop("$ref").split("/")[-1]
|
||||
ref = definitions.get(ref_key, {})
|
||||
schema.update(ref)
|
||||
else:
|
||||
for key, value in schema.items():
|
||||
schema[key] = _resolve_schema_references(value, definitions)
|
||||
return schema
|
||||
|
||||
|
||||
def _convert_schema(schema: dict) -> dict:
|
||||
props = {k: {"title": k, **v} for k, v in schema["properties"].items()}
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": props,
|
||||
"required": schema.get("required", []),
|
||||
}
|
51
langchain/output_parsers/openai_functions.py
Normal file
51
langchain/output_parsers/openai_functions.py
Normal file
@ -0,0 +1,51 @@
|
||||
import json
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.schema import BaseLLMOutputParser, ChatGeneration, Generation
|
||||
|
||||
|
||||
class OutputFunctionsParser(BaseLLMOutputParser[Any]):
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
raise ValueError(
|
||||
"This output parser can only be used with a chat generation."
|
||||
)
|
||||
message = generation.message
|
||||
try:
|
||||
func_call = message.additional_kwargs["function_call"]
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Could not parse function call: {exc}")
|
||||
|
||||
return func_call["arguments"]
|
||||
|
||||
|
||||
class JsonOutputFunctionsParser(OutputFunctionsParser):
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
_args = super().parse_result(result)
|
||||
return json.loads(_args)
|
||||
|
||||
|
||||
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||
key_name: str
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
res = super().parse_result(result)
|
||||
return res[self.key_name]
|
||||
|
||||
|
||||
class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
pydantic_schema: Any
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
_args = super().parse_result(result)
|
||||
pydantic_args = self.pydantic_schema.parse_raw(_args)
|
||||
return pydantic_args
|
||||
|
||||
|
||||
class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
|
||||
attr_name: str
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
result = super().parse_result(result)
|
||||
return getattr(result, self.attr_name)
|
@ -11,6 +11,7 @@ from langchain.output_parsers.regex import RegexParser
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseLLMOutputParser, NoOpOutputParser
|
||||
from langchain.utilities.loading import try_load_from_hub
|
||||
|
||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
||||
@ -78,7 +79,9 @@ def _load_output_parser(config: dict) -> dict:
|
||||
_config = config.pop("output_parser")
|
||||
output_parser_type = _config.pop("_type")
|
||||
if output_parser_type == "regex_parser":
|
||||
output_parser = RegexParser(**_config)
|
||||
output_parser: BaseLLMOutputParser = RegexParser(**_config)
|
||||
elif output_parser_type == "default":
|
||||
output_parser = NoOpOutputParser(**_config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported output parser {output_parser_type}")
|
||||
config["output_parser"] = output_parser
|
||||
|
@ -339,12 +339,21 @@ Memory = BaseMemory
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseOutputParser(Serializable, ABC, Generic[T]):
|
||||
class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def parse_result(self, result: List[Generation]) -> T:
|
||||
"""Parse LLM Result."""
|
||||
|
||||
|
||||
class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
|
||||
"""Class to parse the output of an LLM call.
|
||||
|
||||
Output parsers help structure language model responses.
|
||||
"""
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> T:
|
||||
return self.parse(result[0].text)
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> T:
|
||||
"""Parse the output of an LLM call.
|
||||
@ -394,6 +403,21 @@ class BaseOutputParser(Serializable, ABC, Generic[T]):
|
||||
return output_parser_dict
|
||||
|
||||
|
||||
class NoOpOutputParser(BaseOutputParser[str]):
|
||||
"""Output parser that just returns the text as is."""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "default"
|
||||
|
||||
def parse(self, text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
class OutputParserException(ValueError):
|
||||
"""Exception that output parsers should raise to signify a parsing error.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user