Compare commits

...

20 Commits

Author SHA1 Message Date
Harrison Chase
a41e91d650 cr 2023-03-02 14:44:07 -08:00
Harrison Chase
7e2ae5570a Harrison/new prompt abstraction (#1405) 2023-03-02 12:04:12 -08:00
Harrison Chase
0abf4d4c7d Harrison/new prompt abstraction (#1404) 2023-03-02 11:58:55 -08:00
Harrison Chase
6db04cfe65 Merge branch 'master' into harrison/memory-chat 2023-03-02 11:55:40 -08:00
Harrison Chase
34214f5fa2 Harrison/new prompt abstraction (#1399) 2023-03-02 11:41:18 -08:00
Harrison Chase
098a0ff568 cr 2023-03-01 23:08:09 -08:00
Harrison Chase
7d0502e964 cr 2023-03-01 22:47:15 -08:00
Harrison Chase
ae65e8c5f4 cr 2023-03-01 22:02:08 -08:00
Harrison Chase
d6584fde16 cr 2023-03-01 21:28:30 -08:00
Harrison Chase
6cfd0ca73a cr 2023-03-01 17:53:38 -08:00
Harrison Chase
522452adae cr 2023-03-01 17:46:43 -08:00
Harrison Chase
007278a358 cr 2023-03-01 17:44:25 -08:00
Harrison Chase
79964e6409 cr 2023-03-01 17:38:48 -08:00
Harrison Chase
c3046309fb cr 2023-03-01 17:06:23 -08:00
Harrison Chase
95cfd002a7 cr 2023-03-01 17:06:06 -08:00
Harrison Chase
acaa2d3ee4 cr 2023-03-01 16:44:45 -08:00
Harrison Chase
f635a31992 memory chat 2023-03-01 15:27:20 -08:00
Harrison Chase
156bdb6590 memory stuff 2023-03-01 14:11:07 -08:00
Harrison Chase
04220be616 stash 2023-03-01 14:04:28 -08:00
Harrison Chase
12aacdbfb4 stash 2023-03-01 13:24:36 -08:00
50 changed files with 3390 additions and 483 deletions

View File

@@ -63,6 +63,8 @@ These modules are, in increasing order of complexity:
- `Memory <./modules/memory.html>`_: Memory is the concept of persisting state between calls of a chain/agent. LangChain provides a standard interface for memory, a collection of memory implementations, and examples of chains/agents that use memory.
- `Chat <./modules/chat.html>`_: WIP: how to work with chat models.
.. toctree::
:maxdepth: 1
@@ -76,6 +78,7 @@ These modules are, in increasing order of complexity:
./modules/utils.md
./modules/indexes.md
./modules/chains.md
./modules/chat.md
./modules/agents.md
./modules/memory.md

View File

@@ -59,6 +59,20 @@
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a88b8e4f",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models.openai import ChatOpenAI\n",
"from langchain.agents.chat.base import ChatAgent\n",
"\n",
"agent = ChatAgent.from_chat_model_and_tools(ChatOpenAI(temperature=0), toolkit.get_tools())\n",
"agent_executor.agent = agent"
]
},
{
"cell_type": "markdown",
"id": "36ae48c7-cb08-4fef-977e-c7d4b96a464b",
@@ -69,7 +83,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "ff70e83d-5ad0-4fc7-bb96-27d82ac166d7",
"metadata": {
"tags": []
@@ -82,12 +96,26 @@
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mAction: list_tables_sql_db\n",
"Action Input: \"\"\u001b[0m\n",
"Observation: \u001b[38;5;200m\u001b[1;3mArtist, Invoice, Playlist, Genre, Album, PlaylistTrack, Track, InvoiceLine, MediaType, Employee, Customer\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I should look at the schema of the playlisttrack table\n",
"Action: schema_sql_db\n",
"Action Input: \"PlaylistTrack\"\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to check the schema of the database to see if there is a table called \"play list track\"\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"list_tables_sql_db\",\n",
" \"action_input\": \"\"\n",
"}\n",
"```\n",
"\u001b[0m\n",
"Observation: \u001b[38;5;200m\u001b[1;3mCustomer, Invoice, Track, Artist, Genre, Employee, MediaType, InvoiceLine, Playlist, PlaylistTrack, Album, sales_table\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mThere is a table called \"PlaylistTrack\" in the database. I need to use the schema_sql_db tool to get the schema and sample rows for this table.\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"schema_sql_db\",\n",
" \"action_input\": \"PlaylistTrack\"\n",
"}\n",
"```\n",
"\n",
"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m\n",
"CREATE TABLE \"PlaylistTrack\" (\n",
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
@@ -98,12 +126,12 @@
")\n",
"\n",
"SELECT * FROM 'PlaylistTrack' LIMIT 3;\n",
"PlaylistId TrackId\n",
"1 3402\n",
"1 3389\n",
"1 3390\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: The PlaylistTrack table has two columns, PlaylistId and TrackId, and is linked to the Playlist and Track tables.\u001b[0m\n",
"PlaylistId\tTrackId\n",
"1\t3402\n",
"1\t3389\n",
"1\t3390\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mThe PlaylistTrack table has two columns: PlaylistId and TrackId. It has a composite primary key consisting of both columns. There are foreign key constraints on both columns referencing the Playlist and Track tables respectively. The sample rows show the first three entries in the table.\n",
"Final Answer: The PlaylistTrack table has two columns: PlaylistId and TrackId. It has a composite primary key consisting of both columns. There are foreign key constraints on both columns referencing the Playlist and Track tables respectively. The sample rows show the first three entries in the table.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@@ -111,16 +139,16 @@
{
"data": {
"text/plain": [
"'The PlaylistTrack table has two columns, PlaylistId and TrackId, and is linked to the Playlist and Track tables.'"
"'The PlaylistTrack table has two columns: PlaylistId and TrackId. It has a composite primary key consisting of both columns. There are foreign key constraints on both columns referencing the Playlist and Track tables respectively. The sample rows show the first three entries in the table.'"
]
},
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.run(\"Describe the playlisttrack table\")"
"agent_executor.run(\"Describe the play list track table\")"
]
},
{
@@ -135,7 +163,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 5,
"id": "bea76658-a65b-47e2-b294-6d52c5556246",
"metadata": {
"tags": []
@@ -148,36 +176,69 @@
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mAction: list_tables_sql_db\n",
"Action Input: \"\"\u001b[0m\n",
"Observation: \u001b[38;5;200m\u001b[1;3mGenre, PlaylistTrack, MediaType, Invoice, InvoiceLine, Track, Playlist, Customer, Album, Employee, Artist\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I should look at the schema of the PlaylistSong table\n",
"Action: schema_sql_db\n",
"Action Input: \"PlaylistSong\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3mError: table_names {'PlaylistSong'} not found in database\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I should check the spelling of the table\n",
"Action: list_tables_sql_db\n",
"Action Input: \"\"\u001b[0m\n",
"Observation: \u001b[38;5;200m\u001b[1;3mGenre, PlaylistTrack, MediaType, Invoice, InvoiceLine, Track, Playlist, Customer, Album, Employee, Artist\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m The table is called PlaylistTrack\n",
"Action: schema_sql_db\n",
"Action Input: \"PlaylistTrack\"\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to use the `schema_sql_db` tool to get the schema and sample rows for the table that has song information. But first, I need to know the name of the table.\n",
"\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"list_tables_sql_db\",\n",
" \"action_input\": \"\"\n",
"}\n",
"```\n",
"\n",
"\u001b[0m\n",
"Observation: \u001b[38;5;200m\u001b[1;3mCustomer, Invoice, Track, Artist, Genre, Employee, MediaType, InvoiceLine, Playlist, PlaylistTrack, Album, sales_table\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mThe table that has song information is likely to be named \"Track\". I will use the `schema_sql_db` tool to get the schema and sample rows for this table.\n",
"\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"schema_sql_db\",\n",
" \"action_input\": \"Track\"\n",
"}\n",
"```\n",
"\n",
"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m\n",
"CREATE TABLE \"PlaylistTrack\" (\n",
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
"CREATE TABLE \"Track\" (\n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
"\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n",
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
"\t\"Name\" NVARCHAR(200) NOT NULL, \n",
"\t\"AlbumId\" INTEGER, \n",
"\t\"MediaTypeId\" INTEGER NOT NULL, \n",
"\t\"GenreId\" INTEGER, \n",
"\t\"Composer\" NVARCHAR(220), \n",
"\t\"Milliseconds\" INTEGER NOT NULL, \n",
"\t\"Bytes\" INTEGER, \n",
"\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \n",
"\tPRIMARY KEY (\"TrackId\"), \n",
"\tFOREIGN KEY(\"MediaTypeId\") REFERENCES \"MediaType\" (\"MediaTypeId\"), \n",
"\tFOREIGN KEY(\"GenreId\") REFERENCES \"Genre\" (\"GenreId\"), \n",
"\tFOREIGN KEY(\"AlbumId\") REFERENCES \"Album\" (\"AlbumId\")\n",
")\n",
"\n",
"SELECT * FROM 'PlaylistTrack' LIMIT 3;\n",
"PlaylistId TrackId\n",
"1 3402\n",
"1 3389\n",
"1 3390\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: The PlaylistTrack table contains two columns, PlaylistId and TrackId, which are both integers and are used to link Playlist and Track tables.\u001b[0m\n",
"SELECT * FROM 'Track' LIMIT 3;\n",
"TrackId\tName\tAlbumId\tMediaTypeId\tGenreId\tComposer\tMilliseconds\tBytes\tUnitPrice\n",
"1\tFor Those About To Rock (We Salute You)\t1\t1\t1\tAngus Young, Malcolm Young, Brian Johnson\t343719\t11170334\t0.99\n",
"2\tBalls to the Wall\t2\t2\t1\tNone\t342562\t5510424\t0.99\n",
"3\tFast As a Shark\t3\t2\t1\tF. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman\t230619\t3990994\t0.99\u001b[0m\n",
"Thought:"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/harrisonchase/workplace/langchain/langchain/sql_database.py:141: SAWarning: Dialect sqlite+pysqlite does *not* support Decimal objects natively, and SQLAlchemy must convert from floating point - rounding errors and other issues may occur. Please consider storing Decimal numbers as strings or integers on this platform for lossless storage.\n",
" sample_rows = connection.execute(command)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3mThe table \"Track\" has columns for TrackId, Name, AlbumId, MediaTypeId, GenreId, Composer, Milliseconds, Bytes, and UnitPrice. The sample rows show the first three tracks in the table. \n",
"\n",
"Final Answer: The table that has song information is named \"Track\" and has columns for TrackId, Name, AlbumId, MediaTypeId, GenreId, Composer, Milliseconds, Bytes, and UnitPrice.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@@ -185,16 +246,16 @@
{
"data": {
"text/plain": [
"'The PlaylistTrack table contains two columns, PlaylistId and TrackId, which are both integers and are used to link Playlist and Track tables.'"
"'The table that has song information is named \"Track\" and has columns for TrackId, Name, AlbumId, MediaTypeId, GenreId, Composer, Milliseconds, Bytes, and UnitPrice.'"
]
},
"execution_count": 15,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.run(\"Describe the playlistsong table\")"
"agent_executor.run(\"Describe the table for that has song information\")"
]
},
{
@@ -207,7 +268,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"id": "17bea710-4a23-4de0-b48e-21d57be48293",
"metadata": {
"tags": []
@@ -220,63 +281,57 @@
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mAction: list_tables_sql_db\n",
"Action Input: \"\"\u001b[0m\n",
"Observation: \u001b[38;5;200m\u001b[1;3mInvoice, MediaType, Artist, InvoiceLine, Genre, Playlist, Employee, Album, PlaylistTrack, Track, Customer\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I should look at the schema of the relevant tables to see what columns I can use.\n",
"Action: schema_sql_db\n",
"Action Input: \"Invoice, Customer\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m\n",
"CREATE TABLE \"Customer\" (\n",
"\t\"CustomerId\" INTEGER NOT NULL, \n",
"\t\"FirstName\" NVARCHAR(40) NOT NULL, \n",
"\t\"LastName\" NVARCHAR(20) NOT NULL, \n",
"\t\"Company\" NVARCHAR(80), \n",
"\t\"Address\" NVARCHAR(70), \n",
"\t\"City\" NVARCHAR(40), \n",
"\t\"State\" NVARCHAR(40), \n",
"\t\"Country\" NVARCHAR(40), \n",
"\t\"PostalCode\" NVARCHAR(10), \n",
"\t\"Phone\" NVARCHAR(24), \n",
"\t\"Fax\" NVARCHAR(24), \n",
"\t\"Email\" NVARCHAR(60) NOT NULL, \n",
"\t\"SupportRepId\" INTEGER, \n",
"\tPRIMARY KEY (\"CustomerId\"), \n",
"\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n",
")\n",
"\u001b[32;1m\u001b[1;3mQuestion: List the total sales per country. Which country's customers spent the most?\n",
"Thought: I need to retrieve the total sales per country and then find the country with the highest total sales.\n",
"Action:\n",
"```\n",
"{\"action\": \"query_sql_db\", \"action_input\": \"SELECT SUM(sales) AS total_sales, country FROM sales_table GROUP BY country ORDER BY total_sales DESC LIMIT 1\"}\n",
"```\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m[(900, 'Japan')]\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mI need to modify the query to retrieve the country with the highest total sales, not just the first one in the list.\n",
"Action:\n",
"```\n",
"{\"action\": \"query_sql_db\", \"action_input\": \"SELECT SUM(sales) AS total_sales, country FROM sales_table GROUP BY country ORDER BY total_sales DESC LIMIT 1\"}\n",
"```\n",
"\n",
"SELECT * FROM 'Customer' LIMIT 3;\n",
"CustomerId FirstName LastName Company Address City State Country PostalCode Phone Fax Email SupportRepId\n",
"1 Luís Gonçalves Embraer - Empresa Brasileira de Aeronáutica S.A. Av. Brigadeiro Faria Lima, 2170 São José dos Campos SP Brazil 12227-000 +55 (12) 3923-5555 +55 (12) 3923-5566 luisg@embraer.com.br 3\n",
"2 Leonie Köhler None Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 +49 0711 2842222 None leonekohler@surfeu.de 5\n",
"3 François Tremblay None 1498 rue Bélanger Montréal QC Canada H2G 1A7 +1 (514) 721-4711 None ftremblay@gmail.com 3\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m[(900, 'Japan')]\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mI need to modify the query to retrieve the country with the highest total sales, not just the first one in the list.\n",
"Action:\n",
"```\n",
"{\"action\": \"query_sql_db\", \"action_input\": \"SELECT SUM(sales) AS total_sales, country FROM sales_table GROUP BY country ORDER BY total_sales DESC LIMIT 1\"}\n",
"```\n",
"\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m[(900, 'Japan')]\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mI need to modify the query to retrieve the country with the highest total sales, not just the first one in the list.\n",
"Action:\n",
"```\n",
"{\"action\": \"query_sql_db\", \"action_input\": \"SELECT SUM(sales) AS total_sales, country FROM sales_table GROUP BY country ORDER BY total_sales DESC LIMIT 1\"}\n",
"```\n",
"\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m[(900, 'Japan')]\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mI need to modify the query to retrieve all the countries and their total sales, then find the country with the highest total sales.\n",
"Action:\n",
"```\n",
"{\"action\": \"query_sql_db\", \"action_input\": \"SELECT SUM(sales) AS total_sales, country FROM sales_table GROUP BY country ORDER BY total_sales DESC\"}\n",
"```\n",
"\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m[(900, 'Japan'), (800, 'Australia'), (700, 'UK'), (600, 'Germany'), (500, 'France'), (400, 'Brazil'), (300, 'Mexico'), (200, 'Canada'), (100, 'USA')]\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mI need to select the first row of the result to get the country with the highest total sales.\n",
"Action:\n",
"```\n",
"{\"action\": \"query_sql_db\", \"action_input\": \"SELECT SUM(sales) AS total_sales, country FROM sales_table GROUP BY country ORDER BY total_sales DESC LIMIT 1\"}\n",
"```\n",
"\n",
"\n",
"CREATE TABLE \"Invoice\" (\n",
"\t\"InvoiceId\" INTEGER NOT NULL, \n",
"\t\"CustomerId\" INTEGER NOT NULL, \n",
"\t\"InvoiceDate\" DATETIME NOT NULL, \n",
"\t\"BillingAddress\" NVARCHAR(70), \n",
"\t\"BillingCity\" NVARCHAR(40), \n",
"\t\"BillingState\" NVARCHAR(40), \n",
"\t\"BillingCountry\" NVARCHAR(40), \n",
"\t\"BillingPostalCode\" NVARCHAR(10), \n",
"\t\"Total\" NUMERIC(10, 2) NOT NULL, \n",
"\tPRIMARY KEY (\"InvoiceId\"), \n",
"\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n",
")\n",
"\n",
"SELECT * FROM 'Invoice' LIMIT 3;\n",
"InvoiceId CustomerId InvoiceDate BillingAddress BillingCity BillingState BillingCountry BillingPostalCode Total\n",
"1 2 2009-01-01 00:00:00 Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 1.98\n",
"2 4 2009-01-02 00:00:00 Ullevålsveien 14 Oslo None Norway 0171 3.96\n",
"3 8 2009-01-03 00:00:00 Grétrystraat 63 Brussels None Belgium 1000 5.94\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I should query the Invoice and Customer tables to get the total sales per country.\n",
"Action: query_sql_db\n",
"Action Input: SELECT c.Country, SUM(i.Total) AS TotalSales FROM Invoice i INNER JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC LIMIT 10\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62)]\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: The customers from the USA spent the most, with a total of $523.06.\u001b[0m\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m[(900, 'Japan')]\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mThe country with the highest total sales is Japan.\n",
"Final Answer: Japan\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@@ -284,10 +339,10 @@
{
"data": {
"text/plain": [
"'The customers from the USA spent the most, with a total of $523.06.'"
"'Japan'"
]
},
"execution_count": 8,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@@ -311,12 +366,17 @@
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mAction: list_tables_sql_db\n",
"Action Input: \"\"\u001b[0m\n",
"Observation: \u001b[38;5;200m\u001b[1;3mInvoice, MediaType, Artist, InvoiceLine, Genre, Playlist, Employee, Album, PlaylistTrack, Track, Customer\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I should look at the schema of the Playlist and PlaylistTrack tables to see what columns I can use.\n",
"Action: schema_sql_db\n",
"Action Input: \"Playlist, PlaylistTrack\"\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to retrieve data from multiple tables, so I need to check the schema of the tables and make sure I have the correct column names to join the tables. Then I can use a SQL query to get the total number of tracks in each playlist.\n",
"\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"schema_sql_db\",\n",
" \"action_input\": \"Playlist, PlaylistTrack, Track\"\n",
"}\n",
"```\n",
"\n",
"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m\n",
"CREATE TABLE \"Playlist\" (\n",
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
@@ -325,10 +385,33 @@
")\n",
"\n",
"SELECT * FROM 'Playlist' LIMIT 3;\n",
"PlaylistId Name\n",
"1 Music\n",
"2 Movies\n",
"3 TV Shows\n",
"PlaylistId\tName\n",
"1\tMusic\n",
"2\tMovies\n",
"3\tTV Shows\n",
"\n",
"\n",
"CREATE TABLE \"Track\" (\n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
"\t\"Name\" NVARCHAR(200) NOT NULL, \n",
"\t\"AlbumId\" INTEGER, \n",
"\t\"MediaTypeId\" INTEGER NOT NULL, \n",
"\t\"GenreId\" INTEGER, \n",
"\t\"Composer\" NVARCHAR(220), \n",
"\t\"Milliseconds\" INTEGER NOT NULL, \n",
"\t\"Bytes\" INTEGER, \n",
"\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \n",
"\tPRIMARY KEY (\"TrackId\"), \n",
"\tFOREIGN KEY(\"MediaTypeId\") REFERENCES \"MediaType\" (\"MediaTypeId\"), \n",
"\tFOREIGN KEY(\"GenreId\") REFERENCES \"Genre\" (\"GenreId\"), \n",
"\tFOREIGN KEY(\"AlbumId\") REFERENCES \"Album\" (\"AlbumId\")\n",
")\n",
"\n",
"SELECT * FROM 'Track' LIMIT 3;\n",
"TrackId\tName\tAlbumId\tMediaTypeId\tGenreId\tComposer\tMilliseconds\tBytes\tUnitPrice\n",
"1\tFor Those About To Rock (We Salute You)\t1\t1\t1\tAngus Young, Malcolm Young, Brian Johnson\t343719\t11170334\t0.99\n",
"2\tBalls to the Wall\t2\t2\t1\tNone\t342562\t5510424\t0.99\n",
"3\tFast As a Shark\t3\t2\t1\tF. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman\t230619\t3990994\t0.99\n",
"\n",
"\n",
"CREATE TABLE \"PlaylistTrack\" (\n",
@@ -340,22 +423,25 @@
")\n",
"\n",
"SELECT * FROM 'PlaylistTrack' LIMIT 3;\n",
"PlaylistId TrackId\n",
"1 3402\n",
"1 3389\n",
"1 3390\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I can use a SELECT statement to get the total number of tracks in each playlist.\n",
"Action: query_checker_sql_db\n",
"Action Input: SELECT Playlist.Name, COUNT(PlaylistTrack.TrackId) AS TotalTracks FROM Playlist INNER JOIN PlaylistTrack ON Playlist.PlaylistId = PlaylistTrack.PlaylistId GROUP BY Playlist.Name\u001b[0m\n",
"Observation: \u001b[31;1m\u001b[1;3m\n",
"PlaylistId\tTrackId\n",
"1\t3402\n",
"1\t3389\n",
"1\t3390\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mNow that I have checked the schema of the tables, I can use a SQL query to get the total number of tracks in each playlist.\n",
"\n",
"SELECT Playlist.Name, COUNT(PlaylistTrack.TrackId) AS TotalTracks FROM Playlist INNER JOIN PlaylistTrack ON Playlist.PlaylistId = PlaylistTrack.PlaylistId GROUP BY Playlist.Name\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m The query looks correct, I can now execute it.\n",
"Action: query_sql_db\n",
"Action Input: SELECT Playlist.Name, COUNT(PlaylistTrack.TrackId) AS TotalTracks FROM Playlist INNER JOIN PlaylistTrack ON Playlist.PlaylistId = PlaylistTrack.PlaylistId GROUP BY Playlist.Name LIMIT 10\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m[('90s Music', 1477), ('Brazilian Music', 39), ('Classical', 75), ('Classical 101 - Deep Cuts', 25), ('Classical 101 - Next Steps', 25), ('Classical 101 - The Basics', 25), ('Grunge', 15), ('Heavy Metal Classic', 26), ('Music', 6580), ('Music Videos', 1)]\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n",
"Final Answer: The total number of tracks in each playlist are: '90s Music' (1477), 'Brazilian Music' (39), 'Classical' (75), 'Classical 101 - Deep Cuts' (25), 'Classical 101 - Next Steps' (25), 'Classical 101 - The Basics' (25), 'Grunge' (15), 'Heavy Metal Classic' (26), 'Music' (6580), 'Music Videos' (1).\u001b[0m\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"query_sql_db\",\n",
" \"action_input\": \"SELECT Playlist.Name, COUNT(PlaylistTrack.TrackId) AS 'Total Tracks' FROM Playlist JOIN PlaylistTrack ON Playlist.PlaylistId = PlaylistTrack.PlaylistId GROUP BY Playlist.Name\"\n",
"}\n",
"```\n",
"\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m[('90s Music', 1477), ('Brazilian Music', 39), ('Classical', 75), ('Classical 101 - Deep Cuts', 25), ('Classical 101 - Next Steps', 25), ('Classical 101 - The Basics', 25), ('Grunge', 15), ('Heavy Metal Classic', 26), ('Music', 6580), ('Music Videos', 1), ('On-The-Go 1', 1), ('TV Shows', 426)]\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
"\n",
"Final Answer: The total number of tracks in each playlist is displayed in the following format: [('90s Music', 1477), ('Brazilian Music', 39), ('Classical', 75), ('Classical 101 - Deep Cuts', 25), ('Classical 101 - Next Steps', 25), ('Classical 101 - The Basics', 25), ('Grunge', 15), ('Heavy Metal Classic', 26), ('Music', 6580), ('Music Videos', 1), ('On-The-Go 1', 1), ('TV Shows', 426)].\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@@ -363,7 +449,7 @@
{
"data": {
"text/plain": [
"\"The total number of tracks in each playlist are: '90s Music' (1477), 'Brazilian Music' (39), 'Classical' (75), 'Classical 101 - Deep Cuts' (25), 'Classical 101 - Next Steps' (25), 'Classical 101 - The Basics' (25), 'Grunge' (15), 'Heavy Metal Classic' (26), 'Music' (6580), 'Music Videos' (1).\""
"\"The total number of tracks in each playlist is displayed in the following format: [('90s Music', 1477), ('Brazilian Music', 39), ('Classical', 75), ('Classical 101 - Deep Cuts', 25), ('Classical 101 - Next Steps', 25), ('Classical 101 - The Basics', 25), ('Grunge', 15), ('Heavy Metal Classic', 26), ('Music', 6580), ('Music Videos', 1), ('On-The-Go 1', 1), ('TV Shows', 426)].\""
]
},
"execution_count": 7,
@@ -387,7 +473,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 8,
"id": "9fe4901e-f9e1-4022-b6bc-80e2b2d6a3a4",
"metadata": {
"tags": []
@@ -400,89 +486,100 @@
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mAction: list_tables_sql_db\n",
"Action Input: \"\"\u001b[0m\n",
"Observation: \u001b[38;5;200m\u001b[1;3mMediaType, Track, Invoice, Album, Playlist, Customer, Employee, InvoiceLine, PlaylistTrack, Genre, Artist\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I should look at the schema of the Artist, InvoiceLine, and Track tables to see what columns I can use.\n",
"Action: schema_sql_db\n",
"Action Input: \"Artist, InvoiceLine, Track\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m\n",
"CREATE TABLE \"Artist\" (\n",
"\t\"ArtistId\" INTEGER NOT NULL, \n",
"\t\"Name\" NVARCHAR(120), \n",
"\tPRIMARY KEY (\"ArtistId\")\n",
")\n",
"\u001b[32;1m\u001b[1;3mThought: I need to query the database to get the best selling artists. I should use the `query_sql_db` tool for this.\n",
"\n",
"SELECT * FROM 'Artist' LIMIT 3;\n",
"ArtistId Name\n",
"1 AC/DC\n",
"2 Accept\n",
"3 Aerosmith\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"query_sql_db\",\n",
" \"action_input\": \"SELECT artist, SUM(sales) AS total_sales FROM sales GROUP BY artist ORDER BY total_sales DESC LIMIT 3\"\n",
"}\n",
"```\n",
"\n",
"\n",
"CREATE TABLE \"Track\" (\n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
"\t\"Name\" NVARCHAR(200) NOT NULL, \n",
"\t\"AlbumId\" INTEGER, \n",
"\t\"MediaTypeId\" INTEGER NOT NULL, \n",
"\t\"GenreId\" INTEGER, \n",
"\t\"Composer\" NVARCHAR(220), \n",
"\t\"Milliseconds\" INTEGER NOT NULL, \n",
"\t\"Bytes\" INTEGER, \n",
"\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \n",
"\tPRIMARY KEY (\"TrackId\"), \n",
"\tFOREIGN KEY(\"MediaTypeId\") REFERENCES \"MediaType\" (\"MediaTypeId\"), \n",
"\tFOREIGN KEY(\"GenreId\") REFERENCES \"Genre\" (\"GenreId\"), \n",
"\tFOREIGN KEY(\"AlbumId\") REFERENCES \"Album\" (\"AlbumId\")\n",
")\n",
"\n",
"SELECT * FROM 'Track' LIMIT 3;\n",
"TrackId Name AlbumId MediaTypeId GenreId Composer Milliseconds Bytes UnitPrice\n",
"1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99\n",
"2 Balls to the Wall 2 2 1 None 342562 5510424 0.99\n",
"3 Fast As a Shark 3 2 1 F. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman 230619 3990994 0.99\n",
"\n",
"\n",
"CREATE TABLE \"InvoiceLine\" (\n",
"\t\"InvoiceLineId\" INTEGER NOT NULL, \n",
"\t\"InvoiceId\" INTEGER NOT NULL, \n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
"\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \n",
"\t\"Quantity\" INTEGER NOT NULL, \n",
"\tPRIMARY KEY (\"InvoiceLineId\"), \n",
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
"\tFOREIGN KEY(\"InvoiceId\") REFERENCES \"Invoice\" (\"InvoiceId\")\n",
")\n",
"\n",
"SELECT * FROM 'InvoiceLine' LIMIT 3;\n",
"InvoiceLineId InvoiceId TrackId UnitPrice Quantity\n",
"1 1 2 0.99 1\n",
"2 1 4 0.99 1\n",
"3 2 6 0.99 1\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I should query the database to get the top 3 best selling artists.\n",
"Action: query_sql_db\n",
"Action Input: SELECT Artist.Name, SUM(InvoiceLine.Quantity) AS TotalQuantity FROM Artist INNER JOIN Track ON Artist.ArtistId = Track.ArtistId INNER JOIN InvoiceLine ON Track.TrackId = InvoiceLine.TrackId GROUP BY Artist.Name ORDER BY TotalQuantity DESC LIMIT 3\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mError: (sqlite3.OperationalError) no such column: Track.ArtistId\n",
"[SQL: SELECT Artist.Name, SUM(InvoiceLine.Quantity) AS TotalQuantity FROM Artist INNER JOIN Track ON Artist.ArtistId = Track.ArtistId INNER JOIN InvoiceLine ON Track.TrackId = InvoiceLine.TrackId GROUP BY Artist.Name ORDER BY TotalQuantity DESC LIMIT 3]\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mError: (sqlite3.OperationalError) no such table: sales\n",
"[SQL: SELECT artist, SUM(sales) AS total_sales FROM sales GROUP BY artist ORDER BY total_sales DESC LIMIT 3]\n",
"(Background on this error at: https://sqlalche.me/e/14/e3q8)\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I should double check my query before executing it.\n",
"Action: query_checker_sql_db\n",
"Action Input: SELECT Artist.Name, SUM(InvoiceLine.Quantity) AS TotalQuantity FROM Artist INNER JOIN Track ON Artist.ArtistId = Track.ArtistId INNER JOIN InvoiceLine ON Track.TrackId = InvoiceLine.TrackId GROUP BY Artist.Name ORDER BY TotalQuantity DESC LIMIT 3\u001b[0m\n",
"Observation: \u001b[31;1m\u001b[1;3m\n",
"Thought:\u001b[32;1m\u001b[1;3mI need to check if the table \"sales\" exists in the database before running the query. I should use the `list_tables_sql_db` tool for this.\n",
"\n",
"SELECT Artist.Name, SUM(InvoiceLine.Quantity) AS TotalQuantity \n",
"FROM Artist \n",
"INNER JOIN Track ON Artist.ArtistId = Track.ArtistId \n",
"INNER JOIN InvoiceLine ON Track.TrackId = InvoiceLine.TrackId \n",
"GROUP BY Artist.Name \n",
"ORDER BY TotalQuantity DESC \n",
"LIMIT 3;\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n",
"Action: query_sql_db\n",
"Action Input: SELECT Artist.Name, SUM(InvoiceLine.Quantity) AS TotalQuantity FROM Artist INNER JOIN Album ON Artist.ArtistId = Album.ArtistId INNER JOIN Track ON Album.AlbumId = Track.AlbumId INNER JOIN InvoiceLine ON Track.TrackId = InvoiceLine.TrackId GROUP BY Artist.Name ORDER BY TotalQuantity DESC LIMIT 3\u001b[0m\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"list_tables_sql_db\",\n",
" \"action_input\": \"\"\n",
"}\n",
"```\n",
"\n",
"\u001b[0m\n",
"Observation: \u001b[38;5;200m\u001b[1;3mCustomer, Invoice, Track, Artist, Genre, Employee, MediaType, InvoiceLine, Playlist, PlaylistTrack, Album, sales_table\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mThe table \"sales_table\" exists in the database. I can now use the `query_sql_db` tool to get the best selling artists.\n",
"\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"query_sql_db\",\n",
" \"action_input\": \"SELECT artist, SUM(sales) AS total_sales FROM sales_table GROUP BY artist ORDER BY total_sales DESC LIMIT 3\"\n",
"}\n",
"```\n",
"\n",
"\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mError: (sqlite3.OperationalError) no such column: artist\n",
"[SQL: SELECT artist, SUM(sales) AS total_sales FROM sales_table GROUP BY artist ORDER BY total_sales DESC LIMIT 3]\n",
"(Background on this error at: https://sqlalche.me/e/14/e3q8)\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mI need to check the schema of the `sales_table` to see what columns it has. I should use the `schema_sql_db` tool for this.\n",
"\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"schema_sql_db\",\n",
" \"action_input\": \"sales_table\"\n",
"}\n",
"```\n",
"\n",
"\n",
"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m\n",
"CREATE TABLE sales_table (\n",
"\tcountry TEXT, \n",
"\tsales INTEGER\n",
")\n",
"\n",
"SELECT * FROM 'sales_table' LIMIT 3;\n",
"country\tsales\n",
"USA\t100\n",
"Canada\t200\n",
"Mexico\t300\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mThe `sales_table` only has two columns: `country` and `sales`. I need to find the correct table that has the `artist` column. I should use the `list_tables_sql_db` tool for this.\n",
"\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"list_tables_sql_db\",\n",
" \"action_input\": \"\"\n",
"}\n",
"```\n",
"\n",
"\n",
"\n",
"\u001b[0m\n",
"Observation: \u001b[38;5;200m\u001b[1;3mCustomer, Invoice, Track, Artist, Genre, Employee, MediaType, InvoiceLine, Playlist, PlaylistTrack, Album, sales_table\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mThe `Artist` table has the `artist` column. I can now use the `query_sql_db` tool to get the best selling artists.\n",
"\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"query_sql_db\",\n",
" \"action_input\": \"SELECT Artist.Name, SUM(InvoiceLine.Quantity) AS total_sales FROM InvoiceLine JOIN Track ON InvoiceLine.TrackId = Track.TrackId JOIN Album ON Track.AlbumId = Album.AlbumId JOIN Artist ON Album.ArtistId = Artist.ArtistId GROUP BY Artist.Name ORDER BY total_sales DESC LIMIT 3\"\n",
"}\n",
"```\n",
"\n",
"\n",
"\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m[('Iron Maiden', 140), ('U2', 107), ('Metallica', 91)]\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n",
"Final Answer: The top 3 best selling artists are Iron Maiden, U2, and Metallica.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mThe top 3 best selling artists are Iron Maiden, U2, and Metallica.\n",
"Final Answer: Iron Maiden, U2, Metallica.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@@ -490,10 +587,10 @@
{
"data": {
"text/plain": [
"'The top 3 best selling artists are Iron Maiden, U2, and Metallica.'"
"'Iron Maiden, U2, Metallica.'"
]
},
"execution_count": 16,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -501,6 +598,14 @@
"source": [
"agent_executor.run(\"Who are the top 3 best selling artists?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "512180bf",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@@ -519,7 +624,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.9.1"
}
},
"nbformat": 4,

View File

@@ -0,0 +1,195 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ba5f8741",
"metadata": {},
"source": [
"# Chat Agent\n",
"\n",
"This notebook goes through how to create a ChatGPT based agent\n",
"\n",
"First, we set up the agent with tools as normal."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9af9734e",
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import ZeroShotAgent, Tool, AgentExecutor\n",
"from langchain import OpenAI, SerpAPIWrapper, LLMChain, LLMMathChain, SQLDatabase, SQLDatabaseChain"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "becda2a1",
"metadata": {},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)\n",
"search = SerpAPIWrapper()\n",
"llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n",
"db = SQLDatabase.from_uri(\"sqlite:///../../../../notebooks/Chinook.db\")\n",
"db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)\n",
"tools = [\n",
" Tool(\n",
" name = \"Search\",\n",
" func=search.run,\n",
" description=\"useful for when you need to answer questions about current events. You should ask targeted questions\"\n",
" ),\n",
" Tool(\n",
" name=\"Calculator\",\n",
" func=llm_math_chain.run,\n",
" description=\"useful for when you need to answer questions about math\"\n",
" ),\n",
" Tool(\n",
" name=\"FooBar DB\",\n",
" func=db_chain.run,\n",
" description=\"useful for when you need to answer questions about FooBar. Input should be in the form of a question containing full context\"\n",
" )\n",
"]"
]
},
{
"cell_type": "markdown",
"id": "1717e36a",
"metadata": {},
"source": [
"Now we create a ChatGPT based model and agent."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "b1f12d80",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models.openai import ChatOpenAI\n",
"from langchain.agents.chat.base import ChatAgent\n",
"\n",
"agent = ChatAgent.from_chat_model_and_tools(ChatOpenAI(temperature=0), tools)"
]
},
{
"cell_type": "markdown",
"id": "5f57b076",
"metadata": {},
"source": [
"We can now use this as normal"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "490604e9",
"metadata": {},
"outputs": [],
"source": [
"agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "653b1617",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mQuestion: Who is Dua Lipa's boyfriend? What is his current age raised to the 0.43 power?\n",
"Thought: We need to find out the name of Dua Lipa's boyfriend and his current age.\n",
"Action: Search\n",
"Action Input: \"Dua Lipa boyfriend\"\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mDua Lipa was seen heading home from the Saint Laurent Paris Fashion Week show with her new boyfriend Romain Gavras on Tuesday. Advertisement.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mWe found out the name of Dua Lipa's boyfriend, now we need to find out his age.\n",
"Action: Search\n",
"Action Input: \"Romain Gavras age\"\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m41 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mWe have the age of Romain Gavras, now we need to calculate his age raised to the 0.43 power.\n",
"Action: Calculator\n",
"Action Input: 41^(0.43)\n",
"\n",
"\u001b[0m\n",
"\n",
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
"41^(0.43)\n",
"\n",
"\u001b[32;1m\u001b[1;3m```python\n",
"print(41**(0.43))\n",
"```\n",
"\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m4.9373857399466665\n",
"\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"Observation: \u001b[33;1m\u001b[1;3mAnswer: 4.9373857399466665\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mWe have the answer to the second question, which is 4.9373857399466665.\n",
"Final Answer: Romain Gavras is Dua Lipa's boyfriend and his current age raised to the 0.43 power is 4.9373857399466665.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"\"Romain Gavras is Dua Lipa's boyfriend and his current age raised to the 0.43 power is 4.9373857399466665.\""
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.run(\"Who is Dua Lipa's boyfriend? What is his current age raised to the 0.43 power?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "adefb4c2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
},
"vscode": {
"interpreter": {
"hash": "18784188d7ecd866c0586ac068b02361a6896dc3a29b64f5cc957f09c590acef"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "1aaba18c",
"metadata": {},
"outputs": [],
@@ -71,7 +71,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "56ff7670",
"metadata": {},
"outputs": [],
@@ -287,7 +287,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "8f15307d",
"metadata": {},
"outputs": [],
@@ -302,17 +302,17 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "0a23b91b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1184e0cd0>, func=<function search_api at 0x1635f8700>, coroutine=None)"
"Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x105c0adc0>, func=<function search_api at 0x13ff17040>, coroutine=None)"
]
},
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@@ -331,7 +331,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "28cdf04d",
"metadata": {},
"outputs": [],
@@ -344,17 +344,17 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "1085a4bd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', return_direct=True, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1184e0cd0>, func=<function search_api at 0x1635f8670>, coroutine=None)"
"Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', return_direct=True, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x105c0adc0>, func=<function search_api at 0x13ff17160>, coroutine=None)"
]
},
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -375,7 +375,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "79213f40",
"metadata": {},
"outputs": [],
@@ -385,7 +385,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"id": "e1067dcb",
"metadata": {},
"outputs": [],
@@ -395,7 +395,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "6c66ffe8",
"metadata": {},
"outputs": [],

View File

@@ -205,7 +205,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.9.1"
}
},
"nbformat": 4,

15
docs/modules/chat.rst Normal file
View File

@@ -0,0 +1,15 @@
Chat
==========================
WARNING: extreme WIP
Chat models are new models that rather than being text-in and text-out send a list of dicitionaries, each dictionary representing a chat utterance including the text of the chat and the "speaker" of the chat.
.. toctree::
:maxdepth: 1
:caption: Chat
:name: Chat
:hidden:
./chat/how_to_guides.rst

View File

@@ -0,0 +1,272 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "b7657104",
"metadata": {},
"source": [
"# Chat Vector DB"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "4990086a",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import Chroma\n",
"from langchain.docstore.document import Document\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.indexes.vectorstore import VectorstoreIndexCreator\n",
"from langchain.llms import OpenAI"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3a982715",
"metadata": {},
"outputs": [],
"source": [
"index_creator = VectorstoreIndexCreator()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1b2a6568",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running Chroma using direct local API.\n",
"Using DuckDB in-memory for database. Data will be transient.\n"
]
}
],
"source": [
"from langchain.document_loaders import TextLoader\n",
"loader = TextLoader('../../state_of_the_union.txt')\n",
"docsearch = index_creator.from_loaders([loader]).vectorstore"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "7d410fd6",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat.chat_vector_db import ChatVectorDBChain\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "markdown",
"id": "e606d9e7",
"metadata": {},
"source": [
"## Memory outside the chain"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "71378b64",
"metadata": {},
"outputs": [],
"source": [
"chain = ChatVectorDBChain.from_llm(model = ChatOpenAI(temperature=0), llm=OpenAI(temperature=0), vectorstore=docsearch)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "aa1c1942",
"metadata": {},
"outputs": [],
"source": [
"from langchain.memory.chat_memory import ChatMemory"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "6cbafc97",
"metadata": {},
"outputs": [],
"source": [
"memory = ChatMemory()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7bac7b99",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'The President said that he nominated Circuit Court of Appeals Judge Ketanji Brown Jackson for the United States Supreme Court, and that she is one of our nations top legal minds who will continue Justice Breyers legacy of excellence. He also mentioned that she is a former top litigator in private practice, a former federal public defender, and comes from a family of public school educators and police officers.'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"answer = chain.run(question=query, chat_history=memory.messages)\n",
"answer"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d9f3a746",
"metadata": {},
"outputs": [],
"source": [
"memory.add_user_message(query)\n",
"memory.add_ai_message(answer)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "3c92b39a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Ketanji Brown Jackson has not yet been confirmed as a United States Supreme Court Justice. She has been nominated by President Biden to succeed Justice Stephen Breyer, who is retiring.'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"query = \"Did the president say who she suceeded\"\n",
"chain.run(question=query, chat_history=memory.messages)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "41bc7676",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'The President said that he nominated Circuit Court of Appeals Judge Ketanji Brown Jackson for the United States Supreme Court, and that she is one of our nations top legal minds who will continue Justice Breyers legacy of excellence. He also mentioned that she is a former top litigator in private practice, a former federal public defender, and comes from a family of public school educators and police officers.'"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"answer"
]
},
{
"cell_type": "markdown",
"id": "439cc4be",
"metadata": {},
"source": [
"## Memory in the chain"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "8a0a66ce",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat.memory import ChatHistoryMemory\n",
"chain = ChatVectorDBChain.from_llm(\n",
" model = ChatOpenAI(temperature=0), \n",
" llm=OpenAI(temperature=0), \n",
" vectorstore=docsearch,\n",
" memory=ChatHistoryMemory(input_key=\"question\")\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d1d3a995",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'The President said that he nominated Circuit Court of Appeals Judge Ketanji Brown Jackson for the United States Supreme Court, and that she is one of our nations top legal minds who will continue Justice Breyers legacy of excellence. He also mentioned that she is a former top litigator in private practice, a former federal public defender, and comes from a family of public school educators and police officers.'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"answer = chain.run(question=query)\n",
"answer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b65288b",
"metadata": {},
"outputs": [],
"source": [
"query = \"Did the president say who she suceeded\"\n",
"chain.run(question=query)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d2cb862",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,135 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "0b12a588",
"metadata": {},
"source": [
"# Conversation Chain"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "319061e6",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat.conversation import ConversationChain\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4dbbb98f",
"metadata": {},
"outputs": [],
"source": [
"model = ChatOpenAI()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "164af02b",
"metadata": {},
"outputs": [],
"source": [
"chain = ConversationChain.from_model(model=model)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "11830ac7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\\n\\nHello there! How can I assist you today?'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run(\"hi!\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8182f6c8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"I'm sorry, as an AI language model I don't have access to your location. Can you please tell me your current location?\""
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run('where am i?')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ecc2fbbb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"I apologize for the confusion. As an AI language model, I don't have access to your location data. How can I assist you today?\""
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run('what did you say?')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "614977d8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,141 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ca494454",
"metadata": {},
"source": [
"# Question Answering"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "bed0dfef",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import Chroma\n",
"from langchain.docstore.document import Document\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.indexes.vectorstore import VectorstoreIndexCreator"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "15a6d191",
"metadata": {},
"outputs": [],
"source": [
"index_creator = VectorstoreIndexCreator()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "483815ae",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running Chroma using direct local API.\n",
"Using DuckDB in-memory for database. Data will be transient.\n"
]
}
],
"source": [
"from langchain.document_loaders import TextLoader\n",
"loader = TextLoader('../../state_of_the_union.txt')\n",
"docsearch = index_creator.from_loaders([loader]).vectorstore"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "35fd98c0",
"metadata": {},
"outputs": [],
"source": [
"query = \"What did the president say about Justice Breyer\"\n",
"docs = docsearch.similarity_search(query)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "86116c78",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat.question_answering import QAChain\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "6b5d1a80",
"metadata": {},
"outputs": [],
"source": [
"chain = QAChain.from_model(model = ChatOpenAI(temperature=0))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "5ff56c1d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'The President honored Justice Stephen Breyer, an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court, for his dedicated service to the country. The President also mentioned that one of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court, and he nominated Circuit Court of Appeals Judge Ketanji Brown Jackson to continue Justice Breyers legacy of excellence.'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"query = \"What did the president say about Justice Breyer\"\n",
"chain.run(input_documents=docs, question=query)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cf32e6d6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,154 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "aa309a80",
"metadata": {},
"source": [
"# QA Eval"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9c01a3a5",
"metadata": {},
"outputs": [],
"source": [
"from langchain.evaluation.qa.chat_eval_chain import QAEvalChatChain\n",
"from langchain.chat_models import ChatOpenAI\n",
"\n",
"model = ChatOpenAI(temperature=0)\n",
"\n",
"eval_chain = QAEvalChatChain.from_model(model)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c12568a6",
"metadata": {},
"outputs": [],
"source": [
"examples = [\n",
" {\n",
" \"question\": \"Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?\",\n",
" \"answer\": \"11\"\n",
" },\n",
" {\n",
" \"question\": 'Is the following sentence plausible? \"Joao Moutinho caught the screen pass in the NFC championship.\"',\n",
" \"answer\": \"No\"\n",
" }\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "207bb5b6",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"from langchain.chains import LLMChain\n",
"from langchain.llms import OpenAI"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3d4b9cda",
"metadata": {},
"outputs": [],
"source": [
"prompt = PromptTemplate(template=\"Question: {question}\\nAnswer:\", input_variables=[\"question\"])\n",
"llm = OpenAI(model_name=\"text-davinci-003\", temperature=0)\n",
"chain = LLMChain(llm=llm, prompt=prompt)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c03c4047",
"metadata": {},
"outputs": [],
"source": [
"predictions = chain.apply(examples)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "3871729e",
"metadata": {},
"outputs": [],
"source": [
"graded_outputs = eval_chain.evaluate(examples, predictions, question_key=\"question\", prediction_key=\"text\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "788f841a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Example 0:\n",
"Question: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?\n",
"Real Answer: 11\n",
"Predicted Answer: 11 tennis balls\n",
"Predicted Grade: GRADE: CORRECT\n",
"\n",
"Example 1:\n",
"Question: Is the following sentence plausible? \"Joao Moutinho caught the screen pass in the NFC championship.\"\n",
"Real Answer: No\n",
"Predicted Answer: No, this sentence is not plausible. Joao Moutinho is a professional soccer player, not an American football player, so it is not likely that he would be catching a screen pass in the NFC championship.\n",
"Predicted Grade: GRADE: CORRECT\n",
"\n"
]
}
],
"source": [
"for i, eg in enumerate(examples):\n",
" print(f\"Example {i}:\")\n",
" print(\"Question: \" + eg['question'])\n",
" print(\"Real Answer: \" + eg['answer'])\n",
" print(\"Predicted Answer: \" + predictions[i]['text'])\n",
" print(\"Predicted Grade: \" + graded_outputs[i]['text'])\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2a8d822c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,130 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "06b00f2b",
"metadata": {},
"source": [
"# Vector DB Question Answering"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "4990086a",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import Chroma\n",
"from langchain.docstore.document import Document\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.indexes.vectorstore import VectorstoreIndexCreator"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3a982715",
"metadata": {},
"outputs": [],
"source": [
"index_creator = VectorstoreIndexCreator()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1b2a6568",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running Chroma using direct local API.\n",
"Using DuckDB in-memory for database. Data will be transient.\n"
]
}
],
"source": [
"from langchain.document_loaders import TextLoader\n",
"loader = TextLoader('../../state_of_the_union.txt')\n",
"docsearch = index_creator.from_loaders([loader]).vectorstore"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7d410fd6",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat.vector_db_qa import VectorDBQA\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "89f2b56c",
"metadata": {},
"outputs": [],
"source": [
"chain = VectorDBQA.from_model(model = ChatOpenAI(temperature=0), vectorstore=docsearch)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7bac7b99",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'The President honored Justice Stephen Breyer, an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court, for his dedicated service to the country. The President also mentioned that one of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court, and he nominated Circuit Court of Appeals Judge Ketanji Brown Jackson to continue Justice Breyers legacy of excellence.'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"query = \"What did the president say about Justice Breyer\"\n",
"chain.run(query=query)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9f3a746",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,14 @@
Chat Models
=============
WARNING: extreme WIP
Chat models are new models that rather than being text-in and text-out send a list of dicitionaries, each dictionary representing a chat utterance including the text of the chat and the "speaker" of the chat.
The examples here all highlight how to work with Chat Models.
.. toctree::
:maxdepth: 1
:glob:
./examples/*

View File

@@ -46,7 +46,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"id": "291f0117",
"metadata": {},
"outputs": [
@@ -62,12 +62,12 @@
"source": [
"from langchain.document_loaders import TextLoader\n",
"loader = TextLoader('../../state_of_the_union.txt')\n",
"docsearch = index_creator.from_loaders([loader])"
"docsearch = index_creator.from_loaders([loader]).vectorstore"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"id": "d1eaf6e6",
"metadata": {},
"outputs": [],

View File

@@ -52,6 +52,17 @@
{
"cell_type": "code",
"execution_count": 3,
"id": "43c7d116",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models.openai import ChatOpenAI\n",
"qa = VectorDBQA.from_chat_model(ChatOpenAI(temperature=0), vectorstore=docsearch)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3018f865",
"metadata": {},
"outputs": [],
@@ -61,17 +72,17 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"id": "032a47f8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice and federal public defender, from a family of public school educators and police officers, a consensus builder, and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\""
"\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a consensus builder, and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\""
]
},
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}

View File

@@ -0,0 +1,686 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "e78b7bb1",
"metadata": {},
"source": [
"# Data Augmented Question Answering Comparison\n",
"\n",
"This notebook uses some generic prompts/language models to evaluate an question answering system that uses other sources of data besides what is in the model. For example, this can be used to evaluate a question answering system over your propritary data.\n",
"\n",
"## Setup\n",
"Let's set up an example with our favorite example - the state of the union address."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "ab4a6931",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.vectorstores import Chroma\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain import OpenAI, VectorDBQA"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4fdc211d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running Chroma using direct local API.\n",
"Using DuckDB in-memory for database. Data will be transient.\n"
]
}
],
"source": [
"from langchain.document_loaders import TextLoader\n",
"loader = TextLoader('../../modules/state_of_the_union.txt')\n",
"documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"texts = text_splitter.split_documents(documents)\n",
"\n",
"embeddings = OpenAIEmbeddings()\n",
"docsearch = Chroma.from_documents(texts, embeddings)\n",
"qa = VectorDBQA.from_llm(llm=OpenAI(), vectorstore=docsearch)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6039aabb",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat.vector_db_qa import VectorDBQA as ChatVectorDBQA\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "cfbe541a",
"metadata": {},
"outputs": [],
"source": [
"chat_qa = ChatVectorDBQA.from_model(model = ChatOpenAI(temperature=0), vectorstore=docsearch)"
]
},
{
"cell_type": "markdown",
"id": "30fd72f2",
"metadata": {},
"source": [
"## Examples\n",
"Now we need some examples to evaluate. We can do this in two ways:\n",
"\n",
"1. Hard code some examples ourselves\n",
"2. Generate examples automatically, using a language model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3459b001",
"metadata": {},
"outputs": [],
"source": [
"# Hard-coded examples\n",
"examples = [\n",
" {\n",
" \"query\": \"What did the president say about Ketanji Brown Jackson\",\n",
" \"answer\": \"He praised her legal ability and said he nominated her for the supreme court.\"\n",
" },\n",
" {\n",
" \"query\": \"What did the president say about Michael Jackson\",\n",
" \"answer\": \"Nothing\"\n",
" }\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b9c3fa75",
"metadata": {},
"outputs": [],
"source": [
"# Generated examples\n",
"from langchain.evaluation.qa import QAGenerateChain\n",
"example_gen_chain = QAGenerateChain.from_llm(OpenAI())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c24543a9",
"metadata": {},
"outputs": [],
"source": [
"new_examples = example_gen_chain.apply_and_parse([{\"doc\": t} for t in texts[:5]])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a2d27560",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'query': 'Who did Vladimir Putin think he would meet when he entered Ukraine?',\n",
" 'answer': 'He thought he would meet the world and that it would roll over.'},\n",
" {'query': 'Who is the Ukrainian Ambassador to the United States?',\n",
" 'answer': 'The Ukrainian Ambassador to the United States is mentioned in the document.'},\n",
" {'query': 'How many countries have joined the coalition to confront Putin?',\n",
" 'answer': '27 members of the European Union, France, Germany, Italy, the United Kingdom, Canada, Japan, Korea, Australia, New Zealand, and many others, even Switzerland.'},\n",
" {'query': 'What is the U.S. Department of Justice doing to go after the crimes of Russian oligarchs?',\n",
" 'answer': 'The U.S. Department of Justice is assembling a dedicated task force to go after the crimes of Russian oligarchs and join with European allies to find and seize yachts, luxury apartments, and private jets.'},\n",
" {'query': 'What percentage of value has the Ruble lost due to the actions of the US and its allies?',\n",
" 'answer': '30%'}]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_examples"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "558da6f3",
"metadata": {},
"outputs": [],
"source": [
"# Combine examples\n",
"examples += new_examples"
]
},
{
"cell_type": "markdown",
"id": "443dc34e",
"metadata": {},
"source": [
"## Evaluate\n",
"Now that we have examples, we can use the question answering evaluator to evaluate our question answering chain."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "1bb77416",
"metadata": {},
"outputs": [],
"source": [
"predictions = qa.apply(examples)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d7109ee4",
"metadata": {},
"outputs": [],
"source": [
"chat_predictions = chat_qa.apply(examples)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "bcd0ad7f",
"metadata": {},
"outputs": [],
"source": [
"from langchain.evaluation.qa.chat_eval_chain import QAEvalChatChain\n",
"from langchain.chat_models import ChatOpenAI\n",
"\n",
"model = ChatOpenAI(temperature=0)\n",
"\n",
"eval_chain = QAEvalChatChain.from_model(model)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "2e6af79a",
"metadata": {},
"outputs": [],
"source": [
"graded_outputs = eval_chain.evaluate(examples, predictions)\n",
"graded_chat_outputs = eval_chain.evaluate(examples, chat_predictions)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "32fac2dc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Example 0:\n",
"Question: What did the president say about Ketanji Brown Jackson\n",
"Real Answer: He praised her legal ability and said he nominated her for the supreme court.\n",
"Predicted Answer: The president said that Ketanji Brown Jackson is \"one of our nation's top legal minds\" and that she will \"continue Justice Breyer's legacy of excellence.\"\n",
"Predicted Grade: GRADE: INCORRECT\n",
"\n",
"Example 1:\n",
"Question: What did the president say about Michael Jackson\n",
"Real Answer: Nothing\n",
"Predicted Answer: The president did not mention Michael Jackson in the given context.\n",
"Predicted Grade: GRADE: CORRECT\n",
"\n",
"Example 2:\n",
"Question: Who did Vladimir Putin think he would meet when he entered Ukraine?\n",
"Real Answer: He thought he would meet the world and that it would roll over.\n",
"Predicted Answer: Putin thought he would meet a divided West and NATO that would not respond to his aggression.\n",
"Predicted Grade: GRADE: INCORRECT\n",
"\n",
"Example 3:\n",
"Question: Who is the Ukrainian Ambassador to the United States?\n",
"Real Answer: The Ukrainian Ambassador to the United States is mentioned in the document.\n",
"Predicted Answer: I don't know.\n",
"Predicted Grade: GRADE: INCORRECT\n",
"\n",
"Example 4:\n",
"Question: How many countries have joined the coalition to confront Putin?\n",
"Real Answer: 27 members of the European Union, France, Germany, Italy, the United Kingdom, Canada, Japan, Korea, Australia, New Zealand, and many others, even Switzerland.\n",
"Predicted Answer: 27 members of the European Union, France, Germany, Italy, the United Kingdom, Canada, Japan, Korea, Australia, New Zealand, and many others, including Switzerland.\n",
"Predicted Grade: GRADE: CORRECT\n",
"\n",
"Example 5:\n",
"Question: What is the U.S. Department of Justice doing to go after the crimes of Russian oligarchs?\n",
"Real Answer: The U.S. Department of Justice is assembling a dedicated task force to go after the crimes of Russian oligarchs and join with European allies to find and seize yachts, luxury apartments, and private jets.\n",
"Predicted Answer: The U.S. Department of Justice is assembling a dedicated task force to go after the crimes of Russian oligarchs. They are joining with European allies to find and seize yachts, luxury apartments, and private jets, and are coming for the ill-begotten gains.\n",
"Predicted Grade: GRADE: CORRECT\n",
"\n",
"Example 6:\n",
"Question: What percentage of value has the Ruble lost due to the actions of the US and its allies?\n",
"Real Answer: 30%\n",
"Predicted Answer: The Ruble has lost 30% of its value.\n",
"Predicted Grade: GRADE: CORRECT\n",
"\n"
]
}
],
"source": [
"for i, eg in enumerate(examples):\n",
" print(f\"Example {i}:\")\n",
" print(\"Question: \" + predictions[i]['query'])\n",
" print(\"Real Answer: \" + predictions[i]['answer'])\n",
" print(\"Predicted Answer: \" + predictions[i]['result'])\n",
" print(\"Predicted Grade: \" + graded_outputs[i]['text'])\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "bd0b01dc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Example 0:\n",
"Question: What did the president say about Ketanji Brown Jackson\n",
"Real Answer: He praised her legal ability and said he nominated her for the supreme court.\n",
"Predicted Answer: The President said that he nominated Circuit Court of Appeals Judge Ketanji Brown Jackson for the United States Supreme Court, and that she is one of the nation's top legal minds who will continue Justice Breyer's legacy of excellence. He also mentioned that she has received broad support from various groups, including the Fraternal Order of Police and former judges appointed by Democrats and Republicans.\n",
"Predicted Grade: GRADE: INCORRECT\n",
"\n",
"Example 1:\n",
"Question: What did the president say about Michael Jackson\n",
"Real Answer: Nothing\n",
"Predicted Answer: I'm sorry, I cannot provide an answer to that question as there is no mention of Michael Jackson in the given speech.\n",
"Predicted Grade: GRADE: CORRECT\n",
"\n",
"Example 2:\n",
"Question: Who did Vladimir Putin think he would meet when he entered Ukraine?\n",
"Real Answer: He thought he would meet the world and that it would roll over.\n",
"Predicted Answer: Vladimir Putin likely thought he would meet a weak and divided Ukraine, and a world that would not stand up to his aggression. He may have also believed that he could exploit divisions within the West and NATO to achieve his goals. However, he was proven wrong on all counts. The Ukrainian people have shown remarkable courage and resilience in the face of Russian aggression, and the world has come together to condemn Putin's actions and impose consequences.\n",
"Predicted Grade: GRADE: INCORRECT\n",
"\n",
"Example 3:\n",
"Question: Who is the Ukrainian Ambassador to the United States?\n",
"Real Answer: The Ukrainian Ambassador to the United States is mentioned in the document.\n",
"Predicted Answer: I'm sorry, I do not have access to real-time information. As an AI language model, my database only contains general knowledge and not current events or specific details.\n",
"Predicted Grade: GRADE: INCORRECT\n",
"\n",
"Example 4:\n",
"Question: How many countries have joined the coalition to confront Putin?\n",
"Real Answer: 27 members of the European Union, France, Germany, Italy, the United Kingdom, Canada, Japan, Korea, Australia, New Zealand, and many others, even Switzerland.\n",
"Predicted Answer: According to the statement, \"We spent months building a coalition of other freedom-loving nations from Europe and the Americas to Asia and Africa to confront Putin.\" The statement also mentions that \"Along with twenty-seven members of the European Union including France, Germany, Italy, as well as countries like the United Kingdom, Canada, Japan, Korea, Australia, New Zealand, and many others, even Switzerland\" have joined the coalition. Therefore, it can be inferred that more than 27 countries have joined the coalition to confront Putin.\n",
"Predicted Grade: GRADE: CORRECT\n",
"\n",
"Example 5:\n",
"Question: What is the U.S. Department of Justice doing to go after the crimes of Russian oligarchs?\n",
"Real Answer: The U.S. Department of Justice is assembling a dedicated task force to go after the crimes of Russian oligarchs and join with European allies to find and seize yachts, luxury apartments, and private jets.\n",
"Predicted Answer: The U.S. Department of Justice is assembling a dedicated task force to go after the crimes of Russian oligarchs. This task force will work to find and seize the ill-gotten gains of corrupt Russian leaders who have bilked billions of dollars off this violent regime. The U.S. government is joining with its European allies to track down and seize the yachts, luxury apartments, and private jets of these oligarchs. The goal is to hold these individuals accountable for their crimes and to prevent them from benefiting from their corrupt activities.\n",
"Predicted Grade: GRADE: CORRECT\n",
"\n",
"Example 6:\n",
"Question: What percentage of value has the Ruble lost due to the actions of the US and its allies?\n",
"Real Answer: 30%\n",
"Predicted Answer: In the speech, it is mentioned that the Ruble has lost 30% of its value.\n",
"Predicted Grade: GRADE: CORRECT\n",
"\n"
]
}
],
"source": [
"for i, eg in enumerate(examples):\n",
" print(f\"Example {i}:\")\n",
" print(\"Question: \" + chat_predictions[i]['query'])\n",
" print(\"Real Answer: \" + chat_predictions[i]['answer'])\n",
" print(\"Predicted Answer: \" + chat_predictions[i]['result'])\n",
" print(\"Predicted Grade: \" + graded_chat_outputs[i]['text'])\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "d61ff13f",
"metadata": {},
"outputs": [],
"source": [
"from langchain.evaluation.qa import QAEvalChain\n",
"predictions = qa.apply(examples)\n",
"llm = OpenAI(temperature=0)\n",
"eval_chain = QAEvalChain.from_llm(llm)\n",
"normal_graded_outputs = eval_chain.evaluate(examples, predictions)\n",
"\n",
"normal_graded_chat_outputs = eval_chain.evaluate(examples, chat_predictions)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "4378937d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Example 0:\n",
"Question: What did the president say about Ketanji Brown Jackson\n",
"Real Answer: He praised her legal ability and said he nominated her for the supreme court.\n",
"Predicted Answer: The president said that she is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also mentioned that she has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\n",
"Predicted Grade: CORRECT\n",
"\n",
"Example 1:\n",
"Question: What did the president say about Michael Jackson\n",
"Real Answer: Nothing\n",
"Predicted Answer: The president did not mention Michael Jackson.\n",
"Predicted Grade: CORRECT\n",
"\n",
"Example 2:\n",
"Question: Who did Vladimir Putin think he would meet when he entered Ukraine?\n",
"Real Answer: He thought he would meet the world and that it would roll over.\n",
"Predicted Answer: Putin thought he would meet people who would roll over and accept his aggression.\n",
"Predicted Grade: INCORRECT\n",
"\n",
"Example 3:\n",
"Question: Who is the Ukrainian Ambassador to the United States?\n",
"Real Answer: The Ukrainian Ambassador to the United States is mentioned in the document.\n",
"Predicted Answer: I don't know.\n",
"Predicted Grade: INCORRECT\n",
"\n",
"Example 4:\n",
"Question: How many countries have joined the coalition to confront Putin?\n",
"Real Answer: 27 members of the European Union, France, Germany, Italy, the United Kingdom, Canada, Japan, Korea, Australia, New Zealand, and many others, even Switzerland.\n",
"Predicted Answer: 27 members of the European Union including France, Germany, Italy, the United Kingdom, Canada, Japan, Korea, Australia, New Zealand, and many others, even Switzerland.\n",
"Predicted Grade: CORRECT\n",
"\n",
"Example 5:\n",
"Question: What is the U.S. Department of Justice doing to go after the crimes of Russian oligarchs?\n",
"Real Answer: The U.S. Department of Justice is assembling a dedicated task force to go after the crimes of Russian oligarchs and join with European allies to find and seize yachts, luxury apartments, and private jets.\n",
"Predicted Answer: The U.S. Department of Justice is assembling a dedicated task force to go after the crimes of Russian oligarchs. They are joining with European allies to find and seize their yachts, luxury apartments, and private jets, and they are coming for their ill-begotten gains.\n",
"Predicted Grade: CORRECT\n",
"\n",
"Example 6:\n",
"Question: What percentage of value has the Ruble lost due to the actions of the US and its allies?\n",
"Real Answer: 30%\n",
"Predicted Answer: The Ruble has lost 30% of its value.\n",
"Predicted Grade: CORRECT\n",
"\n"
]
}
],
"source": [
"for i, eg in enumerate(examples):\n",
" print(f\"Example {i}:\")\n",
" print(\"Question: \" + predictions[i]['query'])\n",
" print(\"Real Answer: \" + predictions[i]['answer'])\n",
" print(\"Predicted Answer: \" + predictions[i]['result'])\n",
" print(\"Predicted Grade: \" + normal_graded_outputs[i]['text'])\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "203fdc73",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Example 0:\n",
"Question: What did the president say about Ketanji Brown Jackson\n",
"Real Answer: He praised her legal ability and said he nominated her for the supreme court.\n",
"Predicted Answer: The President said that he nominated Circuit Court of Appeals Judge Ketanji Brown Jackson for the United States Supreme Court, and that she is one of the nation's top legal minds who will continue Justice Breyer's legacy of excellence. He also mentioned that she has received broad support from various groups, including the Fraternal Order of Police and former judges appointed by Democrats and Republicans.\n",
"Predicted Grade: CORRECT\n",
"\n",
"Example 1:\n",
"Question: What did the president say about Michael Jackson\n",
"Real Answer: Nothing\n",
"Predicted Answer: I'm sorry, I cannot provide an answer to that question as there is no mention of Michael Jackson in the given speech.\n",
"Predicted Grade: CORRECT\n",
"\n",
"Example 2:\n",
"Question: Who did Vladimir Putin think he would meet when he entered Ukraine?\n",
"Real Answer: He thought he would meet the world and that it would roll over.\n",
"Predicted Answer: Vladimir Putin likely thought he would meet a weak and divided Ukraine, and a world that would not stand up to his aggression. He may have also believed that he could exploit divisions within the West and NATO to achieve his goals. However, he was proven wrong on all counts. The Ukrainian people have shown remarkable courage and resilience in the face of Russian aggression, and the world has come together to condemn Putin's actions and impose consequences.\n",
"Predicted Grade: INCORRECT\n",
"\n",
"Example 3:\n",
"Question: Who is the Ukrainian Ambassador to the United States?\n",
"Real Answer: The Ukrainian Ambassador to the United States is mentioned in the document.\n",
"Predicted Answer: I'm sorry, I do not have access to real-time information. As an AI language model, my database only contains general knowledge and not current events or specific details.\n",
"Predicted Grade: INCORRECT\n",
"\n",
"Example 4:\n",
"Question: How many countries have joined the coalition to confront Putin?\n",
"Real Answer: 27 members of the European Union, France, Germany, Italy, the United Kingdom, Canada, Japan, Korea, Australia, New Zealand, and many others, even Switzerland.\n",
"Predicted Answer: According to the statement, \"We spent months building a coalition of other freedom-loving nations from Europe and the Americas to Asia and Africa to confront Putin.\" The statement also mentions that \"Along with twenty-seven members of the European Union including France, Germany, Italy, as well as countries like the United Kingdom, Canada, Japan, Korea, Australia, New Zealand, and many others, even Switzerland\" have joined the coalition. Therefore, it can be inferred that more than 27 countries have joined the coalition to confront Putin.\n",
"Predicted Grade: INCORRECT\n",
"\n",
"Example 5:\n",
"Question: What is the U.S. Department of Justice doing to go after the crimes of Russian oligarchs?\n",
"Real Answer: The U.S. Department of Justice is assembling a dedicated task force to go after the crimes of Russian oligarchs and join with European allies to find and seize yachts, luxury apartments, and private jets.\n",
"Predicted Answer: The U.S. Department of Justice is assembling a dedicated task force to go after the crimes of Russian oligarchs. This task force will work to find and seize the ill-gotten gains of corrupt Russian leaders who have bilked billions of dollars off this violent regime. The U.S. government is joining with its European allies to track down and seize the yachts, luxury apartments, and private jets of these oligarchs. The goal is to hold these individuals accountable for their crimes and to prevent them from benefiting from their corrupt activities.\n",
"Predicted Grade: CORRECT\n",
"\n",
"Example 6:\n",
"Question: What percentage of value has the Ruble lost due to the actions of the US and its allies?\n",
"Real Answer: 30%\n",
"Predicted Answer: In the speech, it is mentioned that the Ruble has lost 30% of its value.\n",
"Predicted Grade: CORRECT\n",
"\n"
]
}
],
"source": [
"for i, eg in enumerate(examples):\n",
" print(f\"Example {i}:\")\n",
" print(\"Question: \" + chat_predictions[i]['query'])\n",
" print(\"Real Answer: \" + chat_predictions[i]['answer'])\n",
" print(\"Predicted Answer: \" + chat_predictions[i]['result'])\n",
" print(\"Predicted Grade: \" + normal_graded_chat_outputs[i]['text'])\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "7e3b0531",
"metadata": {},
"outputs": [],
"source": [
"from langchain.evaluation.qa.chat_comp_chain import QACompChatChain"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6ff307bf",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 20,
"id": "e96f871b",
"metadata": {},
"outputs": [],
"source": [
"comp_chain = QACompChatChain.from_model(ChatOpenAI(temperature=0))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "f2628a11",
"metadata": {},
"outputs": [],
"source": [
"comps = comp_chain.evaluate(examples, predictions, chat_predictions)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "a5953f3b",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"## Example 0:\n",
"\n",
"Question: What did the president say about Ketanji Brown Jackson\n",
"\n",
"Real Answer: He praised her legal ability and said he nominated her for the supreme court.\n",
"\n",
"Normal Answer: The president said that she is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also mentioned that she has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\n",
"\n",
"Chat Answer: The President said that he nominated Circuit Court of Appeals Judge Ketanji Brown Jackson for the United States Supreme Court, and that she is one of the nation's top legal minds who will continue Justice Breyer's legacy of excellence. He also mentioned that she has received broad support from various groups, including the Fraternal Order of Police and former judges appointed by Democrats and Republicans.\n",
"\n",
"Comparison: Student A's answer is more verbose, detailed, and includes more information than Student B's answer. However, Student B's answer is more concise and accurately summarizes the main points of the president's statement.\n",
"\n",
"## Example 1:\n",
"\n",
"Question: What did the president say about Michael Jackson\n",
"\n",
"Real Answer: Nothing\n",
"\n",
"Normal Answer: The president did not mention Michael Jackson.\n",
"\n",
"Chat Answer: I'm sorry, I cannot provide an answer to that question as there is no mention of Michael Jackson in the given speech.\n",
"\n",
"Comparison: Student A's answer is more verbose and less succint than Student B's answer. However, Student A's answer is still correct while Student B's answer is technically correct but not very helpful.\n",
"\n",
"## Example 2:\n",
"\n",
"Question: Who did Vladimir Putin think he would meet when he entered Ukraine?\n",
"\n",
"Real Answer: He thought he would meet the world and that it would roll over.\n",
"\n",
"Normal Answer: Putin thought he would meet people who would roll over and accept his aggression.\n",
"\n",
"Chat Answer: Vladimir Putin likely thought he would meet a weak and divided Ukraine, and a world that would not stand up to his aggression. He may have also believed that he could exploit divisions within the West and NATO to achieve his goals. However, he was proven wrong on all counts. The Ukrainian people have shown remarkable courage and resilience in the face of Russian aggression, and the world has come together to condemn Putin's actions and impose consequences.\n",
"\n",
"Comparison: Student A's answer is less precise and less accurate than Student B's answer. Student B's answer is more detailed, accurate, and informative than Student A's answer.\n",
"\n",
"## Example 3:\n",
"\n",
"Question: Who is the Ukrainian Ambassador to the United States?\n",
"\n",
"Real Answer: The Ukrainian Ambassador to the United States is mentioned in the document.\n",
"\n",
"Normal Answer: I don't know.\n",
"\n",
"Chat Answer: I'm sorry, I do not have access to real-time information. As an AI language model, my database only contains general knowledge and not current events or specific details.\n",
"\n",
"Comparison: Student A's answer is less informative, less correct. Student B's answer is more polite, more detailed, more informative.\n",
"\n",
"## Example 4:\n",
"\n",
"Question: How many countries have joined the coalition to confront Putin?\n",
"\n",
"Real Answer: 27 members of the European Union, France, Germany, Italy, the United Kingdom, Canada, Japan, Korea, Australia, New Zealand, and many others, even Switzerland.\n",
"\n",
"Normal Answer: 27 members of the European Union including France, Germany, Italy, the United Kingdom, Canada, Japan, Korea, Australia, New Zealand, and many others, even Switzerland.\n",
"\n",
"Chat Answer: According to the statement, \"We spent months building a coalition of other freedom-loving nations from Europe and the Americas to Asia and Africa to confront Putin.\" The statement also mentions that \"Along with twenty-seven members of the European Union including France, Germany, Italy, as well as countries like the United Kingdom, Canada, Japan, Korea, Australia, New Zealand, and many others, even Switzerland\" have joined the coalition. Therefore, it can be inferred that more than 27 countries have joined the coalition to confront Putin.\n",
"\n",
"Comparison: Student A's answer is less verbose and equally correct compared to Student B's answer.\n",
"\n",
"## Example 5:\n",
"\n",
"Question: What is the U.S. Department of Justice doing to go after the crimes of Russian oligarchs?\n",
"\n",
"Real Answer: The U.S. Department of Justice is assembling a dedicated task force to go after the crimes of Russian oligarchs and join with European allies to find and seize yachts, luxury apartments, and private jets.\n",
"\n",
"Normal Answer: The U.S. Department of Justice is assembling a dedicated task force to go after the crimes of Russian oligarchs. They are joining with European allies to find and seize their yachts, luxury apartments, and private jets, and they are coming for their ill-begotten gains.\n",
"\n",
"Chat Answer: The U.S. Department of Justice is assembling a dedicated task force to go after the crimes of Russian oligarchs. This task force will work to find and seize the ill-gotten gains of corrupt Russian leaders who have bilked billions of dollars off this violent regime. The U.S. government is joining with its European allies to track down and seize the yachts, luxury apartments, and private jets of these oligarchs. The goal is to hold these individuals accountable for their crimes and to prevent them from benefiting from their corrupt activities.\n",
"\n",
"Comparison: Student A's answer is more succinct and less detailed than Student B's answer. Student B's answer is more descriptive and provides more information about the crimes of Russian oligarchs and the U.S. Department of Justice's efforts to go after them.\n",
"\n",
"## Example 6:\n",
"\n",
"Question: What percentage of value has the Ruble lost due to the actions of the US and its allies?\n",
"\n",
"Real Answer: 30%\n",
"\n",
"Normal Answer: The Ruble has lost 30% of its value.\n",
"\n",
"Chat Answer: In the speech, it is mentioned that the Ruble has lost 30% of its value.\n",
"\n",
"Comparison: Student A's answer is more concise and equally correct compared to Student B's answer, which is more verbose but still correct.\n",
"\n"
]
}
],
"source": [
"for i, eg in enumerate(examples):\n",
" print(f\"## Example {i}:\")\n",
" print()\n",
" print(\"Question: \" + chat_predictions[i]['query'])\n",
" print()\n",
" print(\"Real Answer: \" + chat_predictions[i]['answer'])\n",
" print()\n",
" print(\"Normal Answer: \" + predictions[i]['result'])\n",
" print()\n",
" print(\"Chat Answer: \" + chat_predictions[i]['result'])\n",
" print()\n",
" print(\"Comparison: \" + comps[i]['text'])\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "862420bc",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -191,7 +191,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "782ae8c8",
"metadata": {},
@@ -316,7 +315,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -330,7 +329,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7 (default, Sep 16 2021, 08:50:36) \n[Clang 10.0.0 ]"
"version": "3.9.1"
},
"vscode": {
"interpreter": {

View File

@@ -13,7 +13,8 @@ from pydantic import BaseModel, root_validator
from langchain.agents.tools import InvalidTool
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm import BaseLLMChain, ChatModelChain, LLMChain
from langchain.chat_models.base import BaseChatModel
from langchain.input import get_color_mapping
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
@@ -33,7 +34,7 @@ class Agent(BaseModel):
intermediary work.
"""
llm_chain: LLMChain
llm_chain: BaseLLMChain
allowed_tools: Optional[List[str]] = None
return_values: List[str] = ["output"]
@@ -205,6 +206,24 @@ class Agent(BaseModel):
tool_names = [tool.name for tool in tools]
return cls(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
@classmethod
def from_chat_model_and_tools(
cls,
model: BaseChatModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
llm_chain = ChatModelChain(
llm=model,
prompt=cls.create_prompt(tools),
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
return cls(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
def return_stopped_response(
self,
early_stopping_method: str,

View File

View File

@@ -0,0 +1,64 @@
from typing import List, Optional, Sequence, Tuple
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import AgentAction
from langchain.tools import BaseTool
FINAL_ANSWER_ACTION = "Final Answer:"
class ChatAgent(ZeroShotAgent):
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> str:
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
if agent_scratchpad:
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}"
)
else:
return agent_scratchpad
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
if FINAL_ANSWER_ACTION in text:
return "Final Answer", text.split(FINAL_ANSWER_ACTION)[-1].strip()
_, action, _ = text.split("```")
import json
foo = json.loads(action.strip())
return foo["action"], foo["action_input"]
@property
def _stop(self) -> List[str]:
return ["Observation:"]
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
) -> BasePromptTemplate:
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
messages = [
("user", PromptTemplate.from_template(template)),
("user", PromptTemplate.from_template("{input}\n\n{agent_scratchpad}")),
]
return ChatPromptTemplate(
input_variables=["input", "agent_scratchpad"], messages=messages
)
@property
def _agent_type(self) -> str:
raise ValueError

View File

@@ -0,0 +1,18 @@
# flake8: noqa
PREFIX = """Answer the following questions as best you can. You have access to the following tools:"""
FORMAT_INSTRUCTIONS = """The way you use the tools is by specifying a json blob.
Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here).
ALWAYS use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action:
```
$JSON_BLOB
```
Observation: the result of the action
... (this Thought/Action/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question"""
SUFFIX = """Begin! Reminder to always use the exact characters `Final Answer` when responding."""

View File

@@ -10,7 +10,7 @@ from langchain.agents.tools import Tool
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate
from langchain.prompts import BasePromptTemplate, PromptTemplate
from langchain.tools.base import BaseTool
FINAL_ANSWER_ACTION = "Final Answer:"
@@ -46,7 +46,7 @@ def get_action_and_input(llm_output: str) -> Tuple[str, str]:
raise ValueError(f"Could not parse LLM output: `{llm_output}`")
action = match.group(1).strip()
action_input = match.group(2)
return action, action_input.strip(" ").strip('"')
return action, action_input.strip().strip('"')
class ZeroShotAgent(Agent):
@@ -75,7 +75,7 @@ class ZeroShotAgent(Agent):
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
) -> PromptTemplate:
) -> BasePromptTemplate:
"""Create prompt in the style of the zero shot agent.
Args:

View File

@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.llm import BaseLLMChain
from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
@@ -18,7 +18,7 @@ def _get_default_document_prompt() -> PromptTemplate:
class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel):
"""Chain that combines documents by stuffing into context."""
llm_chain: LLMChain
llm_chain: BaseLLMChain
"""LLM wrapper to use after formatting documents."""
document_prompt: BasePromptTemplate = Field(
default_factory=_get_default_document_prompt
@@ -80,7 +80,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel):
"""Get the prompt length by formatting the prompt."""
inputs = self._get_inputs(docs, **kwargs)
prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain.llm.get_num_tokens(prompt)
return self.llm_chain.get_num_tokens(prompt)
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM."""

View File

@@ -17,7 +17,10 @@ from langchain.graphs.networkx_graph import (
parse_triples,
)
from langchain.llms.base import BaseLLM
from langchain.memory.chat_memory import ChatMemory
from langchain.memory.utils import get_buffer_string
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import ChatMessage
def _get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
@@ -73,17 +76,53 @@ class CombinedMemory(Memory, BaseModel):
memory.clear()
class ConversationBufferMemory(Memory, BaseModel):
"""Buffer for storing conversation memory."""
class ChatMemoryMixin(Memory):
chat_memory: ChatMemory
human_prefix: str = "Human"
ai_prefix: str = "AI"
"""Prefix to use for AI generated responses."""
buffer: str = ""
output_key: Optional[str] = None
input_key: Optional[str] = None
@root_validator(pre=True)
def add_chat_memory(cls, values: Dict) -> Dict:
"""Add chat memory data structure."""
human_prefix = values.get("human_prefix", "Human")
ai_prefix = values.get("ai_prefix", "AI")
values["chat_memory"] = ChatMemory(
human_prefix=human_prefix, ai_prefix=ai_prefix
)
return values
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
if self.input_key is None:
prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
self.chat_memory.add_user_message(inputs[prompt_input_key])
self.chat_memory.add_ai_message(outputs[output_key])
def clear(self) -> None:
"""Clear memory contents."""
self.chat_memory.clear()
class ConversationBufferMemory(ChatMemoryMixin, BaseModel):
"""Buffer for storing conversation memory."""
memory_key: str = "history" #: :meta private:
@property
def buffer(self) -> str:
"""String buffer of memory."""
return get_buffer_string(self.chat_memory.messages)
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
@@ -96,39 +135,18 @@ class ConversationBufferMemory(Memory, BaseModel):
"""Return history buffer."""
return {self.memory_key: self.buffer}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
if self.input_key is None:
prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
human = f"{self.human_prefix}: " + inputs[prompt_input_key]
ai = f"{self.ai_prefix}: " + outputs[output_key]
self.buffer += "\n" + "\n".join([human, ai])
def clear(self) -> None:
"""Clear memory contents."""
self.buffer = ""
class ConversationBufferWindowMemory(Memory, BaseModel):
class ConversationBufferWindowMemory(ChatMemoryMixin, BaseModel):
"""Buffer for storing conversation memory."""
human_prefix: str = "Human"
ai_prefix: str = "AI"
"""Prefix to use for AI generated responses."""
buffer: List[str] = Field(default_factory=list)
memory_key: str = "history" #: :meta private:
output_key: Optional[str] = None
input_key: Optional[str] = None
k: int = 5
@property
def buffer(self) -> List[ChatMessage]:
"""String buffer of memory."""
return self.chat_memory.messages
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
@@ -139,45 +157,21 @@ class ConversationBufferWindowMemory(Memory, BaseModel):
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return history buffer."""
return {self.memory_key: "\n".join(self.buffer[-self.k :])}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
if self.input_key is None:
prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
human = f"{self.human_prefix}: " + inputs[prompt_input_key]
ai = f"{self.ai_prefix}: " + outputs[output_key]
self.buffer.append("\n".join([human, ai]))
def clear(self) -> None:
"""Clear memory contents."""
self.buffer = []
return {self.memory_key: get_buffer_string(self.buffer[-self.k * 2 :])}
# For legacy naming reasons
ConversationalBufferWindowMemory = ConversationBufferWindowMemory
class ConversationSummaryMemory(Memory, BaseModel):
class ConversationSummaryMemory(ChatMemoryMixin, BaseModel):
"""Conversation summarizer to memory."""
buffer: str = ""
human_prefix: str = "Human"
ai_prefix: str = "AI"
"""Prefix to use for AI generated responses."""
llm: BaseLLM
prompt: BasePromptTemplate = SUMMARY_PROMPT
memory_key: str = "history" #: :meta private:
output_key: Optional[str] = None
input_key: Optional[str] = None
@property
def memory_variables(self) -> List[str]:
@@ -205,44 +199,33 @@ class ConversationSummaryMemory(Memory, BaseModel):
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
if self.input_key is None:
prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
human = f"{self.human_prefix}: {inputs[prompt_input_key]}"
ai = f"{self.ai_prefix}: {outputs[output_key]}"
new_lines = "\n".join([human, ai])
super().save_context(inputs, outputs)
new_lines = get_buffer_string(self.chat_memory.messages[-2:])
chain = LLMChain(llm=self.llm, prompt=self.prompt)
self.buffer = chain.predict(summary=self.buffer, new_lines=new_lines)
def clear(self) -> None:
"""Clear memory contents."""
super().clear()
self.buffer = ""
class ConversationEntityMemory(Memory, BaseModel):
class ConversationEntityMemory(ChatMemoryMixin, BaseModel):
"""Entity extractor & summarizer to memory."""
buffer: List[str] = []
human_prefix: str = "Human"
ai_prefix: str = "AI"
"""Prefix to use for AI generated responses."""
llm: BaseLLM
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT
output_key: Optional[str] = None
input_key: Optional[str] = None
store: Dict[str, Optional[str]] = {}
entity_cache: List[str] = []
k: int = 3
chat_history_key: str = "history"
@property
def buffer(self) -> List[ChatMessage]:
return self.chat_memory.messages
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
@@ -259,7 +242,7 @@ class ConversationEntityMemory(Memory, BaseModel):
else:
prompt_input_key = self.input_key
output = chain.predict(
history="\n".join(self.buffer[-self.k :]),
history=get_buffer_string(self.buffer[-self.k * 2 :]),
input=inputs[prompt_input_key],
)
if output.strip() == "NONE":
@@ -271,58 +254,47 @@ class ConversationEntityMemory(Memory, BaseModel):
entity_summaries[entity] = self.store.get(entity, "")
self.entity_cache = entities
return {
self.chat_history_key: "\n".join(self.buffer[-self.k :]),
self.chat_history_key: get_buffer_string(self.buffer[-self.k * 2 :]),
"entities": entity_summaries,
}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
super().save_context(inputs, outputs)
if self.input_key is None:
prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
human = f"{self.human_prefix}: " + inputs[prompt_input_key]
ai = f"{self.ai_prefix}: " + outputs[output_key]
for entity in self.entity_cache:
chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt)
# key value store for entity
existing_summary = self.store.get(entity, "")
output = chain.predict(
summary=existing_summary,
history="\n".join(self.buffer[-self.k :]),
history=get_buffer_string(self.buffer[-self.k * 2 :]),
input=inputs[prompt_input_key],
entity=entity,
)
self.store[entity] = output.strip()
new_lines = "\n".join([human, ai])
self.buffer.append(new_lines)
def clear(self) -> None:
"""Clear memory contents."""
self.buffer = []
self.chat_memory.clear()
self.store = {}
class ConversationSummaryBufferMemory(Memory, BaseModel):
class ConversationSummaryBufferMemory(ChatMemoryMixin, BaseModel):
"""Buffer with summarizer for storing conversation memory."""
buffer: List[str] = Field(default_factory=list)
max_token_limit: int = 2000
moving_summary_buffer: str = ""
llm: BaseLLM
prompt: BasePromptTemplate = SUMMARY_PROMPT
memory_key: str = "history"
human_prefix: str = "Human"
ai_prefix: str = "AI"
"""Prefix to use for AI generated responses."""
output_key: Optional[str] = None
input_key: Optional[str] = None
@property
def buffer(self) -> List[ChatMessage]:
return self.chat_memory.messages
@property
def memory_variables(self) -> List[str]:
@@ -335,8 +307,8 @@ class ConversationSummaryBufferMemory(Memory, BaseModel):
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return history buffer."""
if self.moving_summary_buffer == "":
return {self.memory_key: "\n".join(self.buffer)}
memory_val = self.moving_summary_buffer + "\n" + "\n".join(self.buffer)
return {self.memory_key: get_buffer_string(self.buffer)}
memory_val = self.moving_summary_buffer + "\n" + get_buffer_string(self.buffer)
return {self.memory_key: memory_val}
@root_validator()
@@ -351,45 +323,34 @@ class ConversationSummaryBufferMemory(Memory, BaseModel):
)
return values
def get_num_tokens_list(self, arr: List[str]) -> List[int]:
def get_num_tokens_list(self, arr: List[ChatMessage]) -> List[int]:
"""Get list of number of tokens in each string in the input array."""
return [self.llm.get_num_tokens(x) for x in arr]
return [self.llm.get_num_tokens(get_buffer_string([x])) for x in arr]
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
if self.input_key is None:
prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
human = f"{self.human_prefix}: {inputs[prompt_input_key]}"
ai = f"{self.ai_prefix}: {outputs[output_key]}"
new_lines = "\n".join([human, ai])
self.buffer.append(new_lines)
super().save_context(inputs, outputs)
# Prune buffer if it exceeds max token limit
curr_buffer_length = sum(self.get_num_tokens_list(self.buffer))
buffer = self.chat_memory.messages
curr_buffer_length = sum(self.get_num_tokens_list(buffer))
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit:
pruned_memory.append(self.buffer.pop(0))
curr_buffer_length = sum(self.get_num_tokens_list(self.buffer))
pruned_memory.append(buffer.pop(0))
curr_buffer_length = sum(self.get_num_tokens_list(buffer))
chain = LLMChain(llm=self.llm, prompt=self.prompt)
self.moving_summary_buffer = chain.predict(
summary=self.moving_summary_buffer, new_lines=("\n".join(pruned_memory))
summary=self.moving_summary_buffer,
new_lines=(get_buffer_string(pruned_memory)),
)
def clear(self) -> None:
"""Clear memory contents."""
self.buffer = []
super().clear()
self.moving_summary_buffer = ""
class ConversationKGMemory(Memory, BaseModel):
class ConversationKGMemory(ChatMemoryMixin, BaseModel):
"""Knowledge graph memory for storing conversation memory.
Integrates with external knowledge graph to store and retrieve
@@ -397,17 +358,11 @@ class ConversationKGMemory(Memory, BaseModel):
"""
k: int = 2
buffer: List[str] = Field(default_factory=list)
kg: NetworkxEntityGraph = Field(default_factory=NetworkxEntityGraph)
knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
llm: BaseLLM
"""Number of previous utterances to include in the context."""
human_prefix: str = "Human"
ai_prefix: str = "AI"
"""Prefix to use for AI generated responses."""
output_key: Optional[str] = None
input_key: Optional[str] = None
memory_key: str = "history" #: :meta private:
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
@@ -454,7 +409,7 @@ class ConversationKGMemory(Memory, BaseModel):
prompt_input_key = self._get_prompt_input_key(inputs)
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
output = chain.predict(
history="\n".join(self.buffer[-self.k :]),
history=get_buffer_string(self.chat_memory.messages[-self.k :]),
input=inputs[prompt_input_key],
)
return get_entities(output)
@@ -464,7 +419,7 @@ class ConversationKGMemory(Memory, BaseModel):
chain = LLMChain(llm=self.llm, prompt=self.knowledge_extraction_prompt)
prompt_input_key = self._get_prompt_input_key(inputs)
output = chain.predict(
history="\n".join(self.buffer[-self.k :]),
history=get_buffer_string(self.chat_memory.messages[-self.k :]),
input=inputs[prompt_input_key],
verbose=True,
)
@@ -474,14 +429,10 @@ class ConversationKGMemory(Memory, BaseModel):
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
super().save_context(inputs, outputs)
self._get_and_update_kg(inputs)
prompt_input_key = self._get_prompt_input_key(inputs)
output_key = self._get_prompt_output_key(outputs)
human = f"{self.human_prefix}: {inputs[prompt_input_key]}"
ai = f"{self.ai_prefix}: {outputs[output_key]}"
new_lines = "\n".join([human.strip(), ai.strip()])
self.buffer.append(new_lines)
def clear(self) -> None:
"""Clear memory contents."""
return self.kg.clear()
super().clear()
self.kg.clear()

View File

@@ -4,14 +4,15 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chat_models.base import BaseChatModel
from langchain.input import get_colored_text
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import LLMResult
from langchain.schema import ChatMessage, ChatResult, LLMResult
class LLMChain(Chain, BaseModel):
class BaseLLMChain(Chain, BaseModel):
"""Chain to run queries against LLMs.
Example:
@@ -27,8 +28,6 @@ class LLMChain(Chain, BaseModel):
prompt: BasePromptTemplate
"""Prompt object to use."""
llm: BaseLLM
"""LLM wrapper to use."""
output_key: str = "text" #: :meta private:
class Config:
@@ -53,6 +52,101 @@ class LLMChain(Chain, BaseModel):
"""
return [self.output_key]
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.apply([inputs])[0]
async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
raise NotImplementedError
def get_num_tokens(self, prompt: str) -> int:
raise NotImplementedError
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return (await self.aapply([inputs]))[0]
def predict(self, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
**kwargs: Keys to pass to prompt template.
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return self(kwargs)[self.output_key]
async def apredict(self, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
**kwargs: Keys to pass to prompt template.
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return (await self.acall(kwargs))[self.output_key]
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
"""Call predict and then parse the results."""
result = self.predict(**kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
def apply_and_parse(
self, input_list: List[Dict[str, Any]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
result = self.apply(input_list)
return self._parse_result(result)
def _parse_result(
self, result: List[Dict[str, str]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
if self.prompt.output_parser is not None:
return [
self.prompt.output_parser.parse(res[self.output_key]) for res in result
]
else:
return result
async def aapply_and_parse(
self, input_list: List[Dict[str, Any]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
result = await self.aapply(input_list)
return self._parse_result(result)
@property
def _chain_type(self) -> str:
return "llm_chain"
@classmethod
def from_string(cls, llm: BaseLLM, template: str) -> Chain:
"""Create LLMChain from LLM and template."""
prompt_template = PromptTemplate.from_template(template)
return cls(llm=llm, prompt=prompt_template)
class LLMChain(BaseLLMChain):
llm: BaseLLM
"""LLM wrapper to use."""
def get_num_tokens(self, prompt: str) -> int:
return self.llm.get_num_tokens(prompt)
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list)
@@ -130,82 +224,88 @@ class LLMChain(Chain, BaseModel):
for generation in response.generations
]
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.apply([inputs])[0]
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return (await self.aapply([inputs]))[0]
class ChatModelChain(BaseLLMChain):
llm: BaseChatModel
"""LLM wrapper to use."""
def predict(self, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
def generate(self, input_list: List[Dict[str, Any]]) -> List[ChatResult]:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list)
results = []
for prompt in prompts:
results.append(self.llm.generate(prompt, stop=stop))
return results
Args:
**kwargs: Keys to pass to prompt template.
async def agenerate(self, input_list: List[Dict[str, Any]]) -> List[ChatResult]:
"""Generate LLM result from inputs."""
prompts, stop = await self.aprep_prompts(input_list)
results = []
for prompt in prompts:
results.append(await self.llm.agenerate(prompt, stop=stop))
return results
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return self(kwargs)[self.output_key]
async def apredict(self, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
**kwargs: Keys to pass to prompt template.
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return (await self.acall(kwargs))[self.output_key]
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
"""Call predict and then parse the results."""
result = self.predict(**kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
def apply_and_parse(
def prep_prompts(
self, input_list: List[Dict[str, Any]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
result = self.apply(input_list)
return self._parse_result(result)
) -> Tuple[List[List[ChatMessage]], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
prompts = []
for inputs in input_list:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format_chat(**selected_inputs)
_colored_text = get_colored_text(str(prompt), "green")
_text = "Prompt after formatting:\n" + _colored_text
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
return prompts, stop
def _parse_result(
self, result: List[Dict[str, str]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
if self.prompt.output_parser is not None:
return [
self.prompt.output_parser.parse(res[self.output_key]) for res in result
]
else:
return result
async def aapply_and_parse(
async def aprep_prompts(
self, input_list: List[Dict[str, Any]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
result = await self.aapply(input_list)
return self._parse_result(result)
) -> Tuple[List[List[ChatMessage]], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
prompts = []
for inputs in input_list:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format_chat(**selected_inputs)
_colored_text = get_colored_text(str(prompt), "green")
_text = "Prompt after formatting:\n" + _colored_text
if self.callback_manager.is_async:
await self.callback_manager.on_text(
_text, end="\n", verbose=self.verbose
)
else:
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
return prompts, stop
@property
def _chain_type(self) -> str:
return "llm_chain"
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
response = self.generate(input_list)
return self.create_outputs(response)
@classmethod
def from_string(cls, llm: BaseLLM, template: str) -> Chain:
"""Create LLMChain from LLM and template."""
prompt_template = PromptTemplate.from_template(template)
return cls(llm=llm, prompt=prompt_template)
async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
response = await self.agenerate(input_list)
return self.create_outputs(response)
def create_outputs(self, response: List[ChatResult]) -> List[Dict[str, str]]:
"""Create outputs from response."""
return [
# Get the text of the top generated string.
{self.output_key: res.generations[0].message.text}
for res in response
]

View File

@@ -32,9 +32,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
self.combine_documents_chain, StuffDocumentsChain
):
tokens = [
self.combine_documents_chain.llm_chain.llm.get_num_tokens(
doc.page_content
)
self.combine_documents_chain.llm_chain.get_num_tokens(doc.page_content)
for doc in docs
]
token_count = sum(tokens[:num_docs])

View File

@@ -8,11 +8,12 @@ from pydantic import BaseModel, Extra, Field, root_validator
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.llm import ChatModelChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.vector_db_qa.prompt import PROMPT
from langchain.chains.vector_db_qa.prompt import CHAT_PROMPT, PROMPT
from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate
from langchain.prompts import BasePromptTemplate, PromptTemplate
from langchain.vectorstores.base import VectorStore
@@ -118,6 +119,23 @@ class VectorDBQA(Chain, BaseModel):
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
@classmethod
def from_chat_model(
cls, llm: BaseChatModel, prompt: BasePromptTemplate = CHAT_PROMPT, **kwargs: Any
) -> VectorDBQA:
"""Initialize from LLM."""
llm_chain = ChatModelChain(llm=llm, prompt=prompt)
document_prompt = PromptTemplate(
input_variables=["page_content"], template="Context:\n{page_content}"
)
combine_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name="context",
document_prompt=document_prompt,
)
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
@classmethod
def from_chain_type(
cls,

View File

@@ -1,5 +1,9 @@
# flake8: noqa
from typing import Any, List
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from langchain.schema import ChatMessage
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
@@ -10,3 +14,15 @@ Helpful Answer:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
chat_template_system = """Use the following pieces of context to answer any user questions. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}"""
chat_prompt = PromptTemplate.from_template(chat_template_system)
question_prompt = PromptTemplate.from_template("{question}")
CHAT_PROMPT = ChatPromptTemplate(
messages=[("system", chat_prompt), ("user", question_prompt)],
input_variables=["context", "question"],
)

View File

30
langchain/chat/base.py Normal file
View File

@@ -0,0 +1,30 @@
from abc import ABC
from typing import Dict
from pydantic import root_validator
from langchain.chains.base import Chain
from langchain.memory.chat_memory import ChatMemory
class BaseChatChain(Chain, ABC):
human_prefix: str = "user"
ai_prefix: str = "assistant"
@root_validator()
def validate_memory_keys(cls, values: Dict) -> Dict:
"""Validate that the human and ai prefixes line up."""
if "memory" in values:
memory = values["memory"]
if isinstance(memory, ChatMemory):
if memory.human_prefix != values["human_prefix"]:
raise ValueError(
f"Memory human_prefix ({memory.human_prefix}) must "
f"match chain human_prefix ({values['human_prefix']})"
)
if memory.ai_prefix != values["ai_prefix"]:
raise ValueError(
f"Memory ai_prefix ({memory.ai_prefix}) must "
f"match chain ai_prefix ({values['ai_prefix']})"
)
return values

View File

@@ -0,0 +1,95 @@
"""Chain for chatting with a vector database."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.llm import LLMChain
from langchain.chat.base import BaseChatChain
from langchain.chat.question_answering import QAChain
from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
from langchain.memory.utils import get_buffer_string
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import ChatMessage
from langchain.vectorstores.base import VectorStore
class ChatVectorDBChain(BaseChatChain, BaseModel):
"""Chain for chatting with a vector database."""
vectorstore: VectorStore
qa_chain: QAChain
question_generator: LLMChain
output_key: str = "answer"
return_source_documents: bool = False
top_k_docs_for_context: int = 4
"""Return the source documents."""
@property
def _chain_type(self) -> str:
return "chat-vector-db"
@property
def input_keys(self) -> List[str]:
"""Input keys."""
return ["question", "chat_history"]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
if self.return_source_documents:
_output_keys = _output_keys + ["source_documents"]
return _output_keys
@classmethod
def from_llm(
cls,
*,
llm: BaseLLM,
model: BaseChatModel,
vectorstore: VectorStore,
starter_messages: Optional[List[ChatMessage]] = None,
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
**kwargs: Any,
) -> ChatVectorDBChain:
"""Load chain from LLM."""
qa_chain = QAChain.from_model(model, starter_messages=starter_messages)
condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt)
return cls(
vectorstore=vectorstore,
qa_chain=qa_chain,
question_generator=condense_question_chain,
**kwargs,
)
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
question = inputs["question"]
chat_history_str = get_buffer_string(inputs["chat_history"])
vectordbkwargs = inputs.get("vectordbkwargs", {})
if chat_history_str:
new_question = self.question_generator.run(
question=question, chat_history=chat_history_str
)
else:
new_question = question
docs = self.vectorstore.similarity_search(
new_question, k=self.top_k_docs_for_context, **vectordbkwargs
)
args = {
self.qa_chain.documents_key: docs,
self.qa_chain.question_key: new_question,
}
result = self.qa_chain(args)
answer = result[self.qa_chain.output_key]
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}

View File

@@ -0,0 +1,59 @@
"""Chain that carries on a conversation and calls an LLM."""
from __future__ import annotations
from typing import Any, Dict, List
from pydantic import BaseModel, Extra, Field
from langchain.chains.conversation.prompt import PROMPT
from langchain.chat.base import BaseChatChain
from langchain.chat.memory import SimpleChatMemory
from langchain.chat_models.base import BaseChatModel
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import ChatMessage
class ConversationChain(BaseChatChain, BaseModel):
"""Chain to have a conversation and load context from memory.
Example:
.. code-block:: python
from langchain import ConversationChain, OpenAI
conversation = ConversationChain(llm=OpenAI())
"""
model: BaseChatModel
memory: SimpleChatMemory = Field(default_factory=SimpleChatMemory)
"""Default memory store."""
prompt: BasePromptTemplate = PROMPT
"""Default conversation prompt to use."""
input_key: str = "input" #: :meta private:
output_key: str = "response" #: :meta private:
starter_messages: List[ChatMessage] = Field(default_factory=list)
@classmethod
def from_model(cls, model: BaseModel, **kwargs: Any) -> ConversationChain:
"""From model. Future proofing."""
return cls(model=model, **kwargs)
@property
def output_keys(self) -> List[str]:
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
new_message = ChatMessage(text=inputs[self.input_key], role=self.human_prefix)
messages = self.starter_messages + self.memory.messages + [new_message]
output = self.model.run(messages)
return {self.output_key: output.text}
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Use this since so some prompt vars come from history."""
return [self.input_key]

45
langchain/chat/memory.py Normal file
View File

@@ -0,0 +1,45 @@
from typing import Any, Dict, List, Optional
from langchain.chains.base import Memory
from langchain.memory.chat_memory import ChatMemory
def _get_prompt_input_key(inputs: Dict[str, Any], key: Optional[str]) -> str:
if key is not None:
return key
# "stop" is a special key that can be passed as input but is not used to
# format the prompt.
prompt_input_keys = list(set(inputs).difference(["stop"]))
if len(prompt_input_keys) != 1:
raise ValueError(f"One input key expected got {prompt_input_keys}")
return prompt_input_keys[0]
class SimpleChatMemory(Memory, ChatMemory):
input_key: Optional[str] = None
output_key: Optional[str] = None
def clear(self) -> None:
self.clear()
@property
def memory_variables(self) -> List[str]:
return []
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return {}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
self.add_user_message(inputs[_get_prompt_input_key(inputs, self.input_key)])
self.add_ai_message(outputs[_get_prompt_input_key(outputs, self.output_key)])
class ChatHistoryMemory(SimpleChatMemory):
chat_history_key: str = "chat_history"
@property
def memory_variables(self) -> List[str]:
return [self.chat_history_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return {self.chat_history_key: self.messages}

View File

@@ -0,0 +1,84 @@
"""Question Answering."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, Field
from langchain.chains.conversation.prompt import PROMPT
from langchain.chat.base import BaseChatChain
from langchain.chat_models.base import BaseChatModel
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import ChatMessage
def _get_default_starter_messages() -> List[ChatMessage]:
prompt = (
"You are chatbot optimized for question answering. "
"Your job is to answer the most recent user question based "
"ONLY on the information they have told you before. "
"Do NOT use other information than what is provided in previous "
"messages by the human."
)
return [ChatMessage(text=prompt, role="system")]
class QAChain(BaseChatChain, BaseModel):
"""Chain to have a conversation and load context from memory.
Example:
.. code-block:: python
from langchain import ConversationChain, OpenAI
conversation = ConversationChain(llm=OpenAI())
"""
model: BaseChatModel
"""Default memory store."""
prompt: BasePromptTemplate = PROMPT
"""Default conversation prompt to use."""
question_key: str = "question" #: :meta private:
documents_key: str = "input_documents" #: :meta private:
output_key: str = "response" #: :meta private:
starter_messages: List[ChatMessage] = Field(default_factory=list)
@classmethod
def from_model(
cls,
model: BaseChatModel,
starter_messages: Optional[List[ChatMessage]] = None,
**kwargs: Any,
) -> QAChain:
"""From model. Future proofing."""
if starter_messages is not None:
_starter_messages = starter_messages
else:
_starter_messages = _get_default_starter_messages()
return cls(model=model, **kwargs)
@property
def output_keys(self) -> List[str]:
return [self.output_key]
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
new_message = ChatMessage(
text=inputs[self.question_key], role=self.human_prefix
)
docs = inputs[self.documents_key]
doc_messages = [
ChatMessage(text=doc.page_content, role=self.human_prefix) for doc in docs
]
messages = self.starter_messages + doc_messages + [new_message]
output = self.model.run(messages)
return {self.output_key: output.text}
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Use this since so some prompt vars come from history."""
return [self.question_key, self.documents_key]

View File

@@ -0,0 +1,121 @@
"""Chain for question-answering against a vector database."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.chat.base import BaseChatChain
from langchain.chat.question_answering import QAChain
from langchain.chat_models.base import BaseChatModel
from langchain.schema import ChatMessage
from langchain.vectorstores.base import VectorStore
class VectorDBQA(BaseChatChain, BaseModel):
"""Chain for question-answering against a vector database.
Example:
.. code-block:: python
from langchain import OpenAI, VectorDBQA
from langchain.faiss import FAISS
vectordb = FAISS(...)
vectordbQA = VectorDBQA(llm=OpenAI(), vectorstore=vectordb)
"""
vectorstore: VectorStore = Field(exclude=True)
"""Vector Database to connect to."""
k: int = 4
"""Number of documents to query for."""
qa_chain: QAChain
"""Chain to use to combine the documents."""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
return_source_documents: bool = False
"""Return the source documents."""
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Extra search args."""
search_type: str = "similarity"
"""Search type to use over vectorstore. `similarity` or `mmr`."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
if self.return_source_documents:
_output_keys = _output_keys + ["source_documents"]
return _output_keys
@root_validator()
def validate_search_type(cls, values: Dict) -> Dict:
"""Validate search type."""
if "search_type" in values:
search_type = values["search_type"]
if search_type not in ("similarity", "mmr"):
raise ValueError(f"search_type of {search_type} not allowed.")
return values
@classmethod
def from_model(
cls,
model: BaseChatModel,
starter_messages: Optional[List[ChatMessage]] = None,
**kwargs: Any,
) -> VectorDBQA:
"""Initialize from LLM."""
qa_chain = QAChain.from_model(model, starter_messages=starter_messages)
return cls(qa_chain=qa_chain, **kwargs)
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
"""Run similarity search and llm on input query.
If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.
Example:
.. code-block:: python
res = vectordbqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
question = inputs[self.input_key]
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(
question, k=self.k, **self.search_kwargs
)
elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search(
question, k=self.k, **self.search_kwargs
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
args = {self.qa_chain.documents_key: docs, self.qa_chain.question_key: question}
result = self.qa_chain(args)
answer = result[self.qa_chain.output_key]
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}

View File

@@ -0,0 +1,3 @@
from langchain.chat_models.openai import ChatOpenAI
__all__ = ["ChatOpenAI"]

View File

@@ -0,0 +1,48 @@
from abc import ABC, abstractmethod
from typing import List, Optional
from langchain.schema import ChatGeneration, ChatMessage, ChatResult
class BaseChatModel(ABC):
def generate(
self, messages: List[ChatMessage], stop: Optional[List[str]] = None
) -> ChatResult:
"""Top Level call"""
# Nothing here now, but future proofing.
return self._generate(messages, stop=stop)
async def agenerate(
self, messages: List[ChatMessage], stop: Optional[List[str]] = None
) -> ChatResult:
raise NotImplementedError
@abstractmethod
def _generate(
self, messages: List[ChatMessage], stop: Optional[List[str]] = None
) -> ChatResult:
"""Top Level call"""
def run(
self, messages: List[ChatMessage], stop: Optional[List[str]] = None
) -> ChatMessage:
res = self.generate(messages, stop=stop)
return res.generations[0].message
class SimpleChatModel(BaseChatModel):
role: str = "assistant"
def _generate(
self, messages: List[ChatMessage], stop: Optional[List[str]] = None
) -> ChatResult:
output_str = self._call(messages, stop=stop)
message = ChatMessage(text=output_str, role=self.role)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@abstractmethod
def _call(
self, messages: List[ChatMessage], stop: Optional[List[str]] = None
) -> str:
"""Simpler interface."""

View File

@@ -0,0 +1,148 @@
"""OpenAI chat wrapper."""
import logging
from typing import Any, Callable, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain.chat_models.base import BaseChatModel
from langchain.schema import ChatGeneration, ChatMessage, ChatResult
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__file__)
class ChatOpenAI(BaseChatModel, BaseModel):
"""Wrapper around OpenAI Chat large language models.
To use, you should have the ``openai`` python package installed, and the
environment variable ``OPENAI_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the openai.create call can be passed
in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain.chat import OpenAI
openai = OpenAI(model_name="gpt-3.5-turbo")
"""
client: Any #: :meta private:
model_name: str = "gpt-3.5-turbo"
"""Model name to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[str] = None
max_retries: int = 6
"""Maximum number of retries to make when generating."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.ignore
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name not in all_required_field_names:
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
extra[field_name] = values.pop(field_name)
values["model_kwargs"] = extra
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
try:
import openai
openai.api_key = openai_api_key
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please it install it with `pip install openai`."
)
try:
values["client"] = openai.ChatCompletion
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return self.model_kwargs
def _create_retry_decorator(self) -> Callable[[Any], Any]:
import openai
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = self._create_retry_decorator()
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
def _generate(
self, messages: List[ChatMessage], stop: Optional[List[str]] = None
) -> ChatResult:
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [{"role": m.role, "content": m.text} for m in messages]
response = self.completion_with_retry(messages=message_dicts, **params)
generations = []
for res in response["choices"]:
message = ChatMessage(
text=res["message"]["content"], role=res["message"]["role"]
)
gen = ChatGeneration(message=message)
generations.append(gen)
return ChatResult(generations=generations)
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params}

View File

@@ -0,0 +1,50 @@
"""LLM Chain specifically for evaluating question answering."""
from __future__ import annotations
from typing import Any, List
from langchain.chains.llm import ChatModelChain
from langchain.chat_models.base import BaseChatModel
from langchain.evaluation.qa.eval_prompt import CHAT_COMP_PROMPT
from langchain.prompts.base import BasePromptTemplate
class QACompChatChain(ChatModelChain):
"""LLM Chain specifically for evaluating question answering."""
@classmethod
def from_llm(
cls,
llm: BaseChatModel,
prompt: BasePromptTemplate = CHAT_COMP_PROMPT,
**kwargs: Any,
) -> QACompChatChain:
expected_input_vars = {"query", "answer", "result"}
if expected_input_vars != set(prompt.input_variables):
raise ValueError(
f"Input variables should be {expected_input_vars}, "
f"but got {prompt.input_variables}"
)
return cls(llm=llm, prompt=prompt, **kwargs)
def evaluate(
self,
examples: List[dict],
predictions_a: List[dict],
predictions_b: List[dict],
question_key: str = "query",
answer_key: str = "answer",
prediction_key: str = "result",
) -> List[dict]:
"""Evaluate question answering examples and predictions."""
inputs = [
{
"query": example[question_key],
"answer": example[answer_key],
"student_a": predictions_a[i][prediction_key],
"student_b": predictions_b[i][prediction_key],
}
for i, example in enumerate(examples)
]
results = [self(inp) for inp in inputs]
return results

View File

@@ -0,0 +1,45 @@
"""LLM Chain specifically for evaluating question answering."""
from __future__ import annotations
from typing import Any, List
from langchain.chains.llm import ChatModelChain
from langchain.chat_models.base import BaseChatModel
from langchain.evaluation.qa.eval_prompt import CHAT_PROMPT
from langchain.prompts.base import BasePromptTemplate
class QAEvalChain(ChatModelChain):
"""LLM Chain specifically for evaluating question answering."""
@classmethod
def from_llm(
cls, llm: BaseChatModel, prompt: BasePromptTemplate = CHAT_PROMPT, **kwargs: Any
) -> QAEvalChain:
expected_input_vars = {"query", "answer", "result"}
if expected_input_vars != set(prompt.input_variables):
raise ValueError(
f"Input variables should be {expected_input_vars}, "
f"but got {prompt.input_variables}"
)
return cls(llm=llm, prompt=prompt, **kwargs)
def evaluate(
self,
examples: List[dict],
predictions: List[dict],
question_key: str = "query",
answer_key: str = "answer",
prediction_key: str = "result",
) -> List[dict]:
"""Evaluate question answering examples and predictions."""
inputs = [
{
"query": example[question_key],
"answer": example[answer_key],
"result": predictions[i][prediction_key],
}
for i, example in enumerate(examples)
]
return self.apply(inputs)

View File

@@ -1,5 +1,6 @@
# flake8: noqa
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
template = """You are a teacher grading a quiz.
You are given a question, the student's answer, and the true answer, and are asked to score it as either CORRECT or INCORRECT.
@@ -19,3 +20,34 @@ GRADE:"""
PROMPT = PromptTemplate(
input_variables=["query", "result", "answer"], template=template
)
CHAT_INSTRUCTIONS = """You are a teacher grading a quiz.
You are given a question, the student's answer, and the true answer, and are asked to score it as either CORRECT or INCORRECT.
The format of your response should be `GRADE: ${grade}` with ${grade} being either CORRECT or INCORRECT and nothing more."""
CHAT_RESPONSE_TEMPLATE = """QUESTION: {query}
STUDENT ANSWER: {result}
TRUE ANSWER: {answer}"""
CHAT_COMPARISON_INSTRUCTIONS = """You are a teacher grading a quiz.
You are given a question, the correct answer, Student A's answer and then Student B's answer.
Please describe how Student A's answer compares to Student B's.
Describe them with a comma separated list of adjectives. Example adjectives may include: more verbose, less correct, more succint, etc."""
CHAT_COMPARISON_RESPONSE_TEMPLATE = """QUESTION: {query}
TRUE ANSWER: {answer}
STUDENT A ANSWER: {student_a}
STUDENT B ANSWER: {student_b}"""
CHAT_PROMPT = ChatPromptTemplate.from_strings(
[("system", CHAT_INSTRUCTIONS), ("user", CHAT_RESPONSE_TEMPLATE)]
)
CHAT_COMP_PROMPT = ChatPromptTemplate.from_strings(
[
("system", CHAT_COMPARISON_INSTRUCTIONS),
("user", CHAT_COMPARISON_RESPONSE_TEMPLATE),
]
)

View File

View File

@@ -0,0 +1,22 @@
from typing import List
from pydantic import BaseModel, Field
from langchain.schema import ChatMessage
class ChatMemory(BaseModel):
human_prefix: str = "user"
ai_prefix: str = "assistant"
messages: List[ChatMessage] = Field(default_factory=list)
def add_user_message(self, message: str) -> None:
gen = ChatMessage(text=message, role=self.human_prefix)
self.messages.append(gen)
def add_ai_message(self, message: str) -> None:
gen = ChatMessage(text=message, role=self.ai_prefix)
self.messages.append(gen)
def clear(self) -> None:
self.messages = []

View File

@@ -0,0 +1,8 @@
from typing import List
from langchain.schema import ChatMessage
def get_buffer_string(messages: List[ChatMessage]) -> str:
"""Get buffer string of messages."""
return "\n".join([f"{gen.role}: {gen.text}" for gen in messages])

View File

@@ -11,6 +11,7 @@ import yaml
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.formatting import formatter
from langchain.schema import ChatMessage
def jinja2_formatter(template: str, **kwargs: Any) -> str:
@@ -189,6 +190,10 @@ class BasePromptTemplate(BaseModel, ABC):
prompt.format(variable1="foo")
"""
def format_chat(self, **kwargs: Any) -> List[ChatMessage]:
"""Create Chat Messages."""
raise NotImplementedError
@property
@abstractmethod
def _prompt_type(self) -> str:

43
langchain/prompts/chat.py Normal file
View File

@@ -0,0 +1,43 @@
from __future__ import annotations
from abc import ABC
from typing import Any, Callable, List, Tuple, Union
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import ChatMessage
class ChatPromptTemplate(BasePromptTemplate, ABC):
input_variables: List[str]
messages: List[Tuple[str, BasePromptTemplate]]
@classmethod
def from_strings(cls, string_messages: List[Tuple[str, str]]) -> ChatPromptTemplate:
messages = [
(role, PromptTemplate.from_template(template))
for role, template in string_messages
]
input_vars = set([m.input_variables] for _, m in messages)
return cls(input_variables=list(input_vars), messages=messages)
def format(self, **kwargs: Any) -> str:
return str(self.format_chat(**kwargs))
def format_chat(self, **kwargs: Any) -> List[ChatMessage]:
"""Format message templates."""
result = []
for role, prompt in self.messages:
rel_params = {
k: v for k, v in kwargs.items() if k in prompt.input_variables
}
message = prompt.format(**rel_params)
result.append(ChatMessage(text=message, role=role))
return result
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
raise NotImplementedError
@property
def _prompt_type(self) -> str:
raise NotImplementedError

View File

@@ -1,9 +1,8 @@
"""Common schema objects."""
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional
from dataclasses_json import dataclass_json
from pydantic import BaseModel
class AgentAction(NamedTuple):
@@ -21,9 +20,7 @@ class AgentFinish(NamedTuple):
log: str
@dataclass_json
@dataclass
class Generation:
class Generation(BaseModel):
"""Output of a single generation."""
text: str
@@ -35,9 +32,7 @@ class Generation:
# TODO: add log probs
@dataclass_json
@dataclass
class LLMResult:
class LLMResult(BaseModel):
"""Class that contains all relevant information for an LLM Result."""
generations: List[List[Generation]]
@@ -45,3 +40,33 @@ class LLMResult:
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
class ChatMessage(BaseModel):
"""Message object."""
text: str
"""Generated text output."""
role: str
"""Role of the chatter."""
class ChatGeneration(BaseModel):
"""Output of a single generation."""
message: ChatMessage
generation_info: Optional[Dict[str, Any]] = None
"""Raw generation info response from the provider"""
"""May include things like reason for finishing (e.g. in OpenAI)"""
# TODO: add log probs
class ChatResult(BaseModel):
"""Class that contains all relevant information for a Chat Result."""
generations: List[ChatGeneration]
"""List of the things generated."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain"
version = "0.0.100"
version = "0.0.101rc0"
description = "Building applications with LLMs through composability"
authors = []
license = "MIT"

View File

@@ -60,7 +60,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
execution_order=3,
serialized={},
prompts=[],
response=LLMResult([[]]),
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
)
],
@@ -74,7 +74,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
execution_order=4,
serialized={},
prompts=[],
response=LLMResult([[]]),
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
),
],
@@ -86,10 +86,10 @@ def _perform_nested_run(tracer: BaseTracer) -> None:
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_tool_end("test")
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_chain_end(outputs={})
@@ -209,7 +209,7 @@ def test_tracer_llm_run() -> None:
execution_order=1,
serialized={},
prompts=[],
response=LLMResult([[]]),
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
@@ -217,7 +217,7 @@ def test_tracer_llm_run() -> None:
tracer.new_session()
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@@ -237,7 +237,7 @@ def test_tracer_llm_run_errors_no_start() -> None:
tracer.new_session()
with pytest.raises(TracerException):
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_llm_end(response=LLMResult(generations=[[]]))
@freeze_time("2023-01-01")
@@ -251,7 +251,7 @@ def test_tracer_multiple_llm_runs() -> None:
execution_order=1,
serialized={},
prompts=[],
response=LLMResult([[]]),
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
@@ -261,7 +261,7 @@ def test_tracer_multiple_llm_runs() -> None:
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run] * num_runs
@@ -409,9 +409,9 @@ def test_tracer_nested_runs_on_error() -> None:
for _ in range(3):
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_error(exception)

View File

@@ -16,14 +16,14 @@ def test_memory_ai_prefix() -> None:
"""Test that ai_prefix in the memory component works."""
memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant")
memory.save_context({"input": "bar"}, {"output": "foo"})
assert memory.buffer == "\nHuman: bar\nAssistant: foo"
assert memory.buffer == "Human: bar\nAssistant: foo"
def test_memory_human_prefix() -> None:
"""Test that human_prefix in the memory component works."""
memory = ConversationBufferMemory(memory_key="foo", human_prefix="Friend")
memory.save_context({"input": "bar"}, {"output": "foo"})
assert memory.buffer == "\nFriend: bar\nAI: foo"
assert memory.buffer == "Friend: bar\nAI: foo"
def test_conversation_chain_works() -> None:

View File

@@ -31,7 +31,7 @@ def test_caching() -> None:
[Generation(text="fizz")],
]
expected_output = LLMResult(
expected_generations,
generations=expected_generations,
llm_output=None,
)
assert output == expected_output
@@ -69,7 +69,7 @@ def test_custom_caching() -> None:
[Generation(text="fizz")],
]
expected_output = LLMResult(
expected_generations,
generations=expected_generations,
llm_output=None,
)
assert output == expected_output