mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
Harrison/serialize llm chain (#671)
This commit is contained in:
parent
499e54edda
commit
0ffeabd14f
13
docs/modules/chains/generic/llm.json
Normal file
13
docs/modules/chains/generic/llm.json
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"model_name": "text-davinci-003",
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_tokens": 256,
|
||||||
|
"top_p": 1,
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"presence_penalty": 0,
|
||||||
|
"n": 1,
|
||||||
|
"best_of": 1,
|
||||||
|
"request_timeout": null,
|
||||||
|
"logit_bias": {},
|
||||||
|
"_type": "openai"
|
||||||
|
}
|
27
docs/modules/chains/generic/llm_chain.json
Normal file
27
docs/modules/chains/generic/llm_chain.json
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
{
|
||||||
|
"memory": null,
|
||||||
|
"verbose": true,
|
||||||
|
"prompt": {
|
||||||
|
"input_variables": [
|
||||||
|
"question"
|
||||||
|
],
|
||||||
|
"output_parser": null,
|
||||||
|
"template": "Question: {question}\n\nAnswer: Let's think step by step.",
|
||||||
|
"template_format": "f-string"
|
||||||
|
},
|
||||||
|
"llm": {
|
||||||
|
"model_name": "text-davinci-003",
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_tokens": 256,
|
||||||
|
"top_p": 1,
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"presence_penalty": 0,
|
||||||
|
"n": 1,
|
||||||
|
"best_of": 1,
|
||||||
|
"request_timeout": null,
|
||||||
|
"logit_bias": {},
|
||||||
|
"_type": "openai"
|
||||||
|
},
|
||||||
|
"output_key": "text",
|
||||||
|
"_type": "llm_chain"
|
||||||
|
}
|
8
docs/modules/chains/generic/llm_chain_separate.json
Normal file
8
docs/modules/chains/generic/llm_chain_separate.json
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"memory": null,
|
||||||
|
"verbose": true,
|
||||||
|
"prompt_path": "prompt.json",
|
||||||
|
"llm_path": "llm.json",
|
||||||
|
"output_key": "text",
|
||||||
|
"_type": "llm_chain"
|
||||||
|
}
|
8
docs/modules/chains/generic/prompt.json
Normal file
8
docs/modules/chains/generic/prompt.json
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"input_variables": [
|
||||||
|
"question"
|
||||||
|
],
|
||||||
|
"output_parser": null,
|
||||||
|
"template": "Question: {question}\n\nAnswer: Let's think step by step.",
|
||||||
|
"template_format": "f-string"
|
||||||
|
}
|
376
docs/modules/chains/generic/serialization.ipynb
Normal file
376
docs/modules/chains/generic/serialization.ipynb
Normal file
@ -0,0 +1,376 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "cbe47c3a",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Serialization\n",
|
||||||
|
"This notebook covers how to serialize chains to and from disk. The serialization format we use is json or yaml. Currently, only some chains support this type of serialization. We will grow the number of supported chains over time.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e4a8a447",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Saving a chain to disk\n",
|
||||||
|
"First, let's go over how to save a chain to disk. This can be done with the `.save` method, and specifying a file path with a json or yaml extension."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "26e28451",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain import PromptTemplate, OpenAI, LLMChain\n",
|
||||||
|
"template = \"\"\"Question: {question}\n",
|
||||||
|
"\n",
|
||||||
|
"Answer: Let's think step by step.\"\"\"\n",
|
||||||
|
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n",
|
||||||
|
"llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=0), verbose=True)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "bfa18e1f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm_chain.save(\"llm_chain.json\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ea82665d",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Let's now take a look at what's inside this saved file"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "0fd33328",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{\r\n",
|
||||||
|
" \"memory\": null,\r\n",
|
||||||
|
" \"verbose\": true,\r\n",
|
||||||
|
" \"prompt\": {\r\n",
|
||||||
|
" \"input_variables\": [\r\n",
|
||||||
|
" \"question\"\r\n",
|
||||||
|
" ],\r\n",
|
||||||
|
" \"output_parser\": null,\r\n",
|
||||||
|
" \"template\": \"Question: {question}\\n\\nAnswer: Let's think step by step.\",\r\n",
|
||||||
|
" \"template_format\": \"f-string\"\r\n",
|
||||||
|
" },\r\n",
|
||||||
|
" \"llm\": {\r\n",
|
||||||
|
" \"model_name\": \"text-davinci-003\",\r\n",
|
||||||
|
" \"temperature\": 0.0,\r\n",
|
||||||
|
" \"max_tokens\": 256,\r\n",
|
||||||
|
" \"top_p\": 1,\r\n",
|
||||||
|
" \"frequency_penalty\": 0,\r\n",
|
||||||
|
" \"presence_penalty\": 0,\r\n",
|
||||||
|
" \"n\": 1,\r\n",
|
||||||
|
" \"best_of\": 1,\r\n",
|
||||||
|
" \"request_timeout\": null,\r\n",
|
||||||
|
" \"logit_bias\": {},\r\n",
|
||||||
|
" \"_type\": \"openai\"\r\n",
|
||||||
|
" },\r\n",
|
||||||
|
" \"output_key\": \"text\",\r\n",
|
||||||
|
" \"_type\": \"llm_chain\"\r\n",
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"!cat llm_chain.json"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "2012c724",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Loading a chain from disk\n",
|
||||||
|
"We can load a chain from disk by using the `load_chain` method."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "342a1974",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chains import load_chain"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "394b7da8",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chain = load_chain(\"llm_chain.json\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "20d99787",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
|
||||||
|
"Prompt after formatting:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mQuestion: whats 2 + 2\n",
|
||||||
|
"\n",
|
||||||
|
"Answer: Let's think step by step.\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"' 2 + 2 = 4'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chain.run(\"whats 2 + 2\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "14449679",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Saving components separately\n",
|
||||||
|
"In the above example, we can see that the prompt and llm configuration information is saved in the same json as the overall chain. Alternatively, we can split them up and save them separately. This is often useful to make the saved components more modular. In order to do this, we just need to specify `llm_path` instead of the `llm` component, and `prompt_path` instead of the `prompt` component."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "50ec35ab",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm_chain.prompt.save(\"prompt.json\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "c48b39aa",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{\r\n",
|
||||||
|
" \"input_variables\": [\r\n",
|
||||||
|
" \"question\"\r\n",
|
||||||
|
" ],\r\n",
|
||||||
|
" \"output_parser\": null,\r\n",
|
||||||
|
" \"template\": \"Question: {question}\\n\\nAnswer: Let's think step by step.\",\r\n",
|
||||||
|
" \"template_format\": \"f-string\"\r\n",
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"!cat prompt.json"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "13c92944",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm_chain.llm.save(\"llm.json\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "1b815f89",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{\r\n",
|
||||||
|
" \"model_name\": \"text-davinci-003\",\r\n",
|
||||||
|
" \"temperature\": 0.0,\r\n",
|
||||||
|
" \"max_tokens\": 256,\r\n",
|
||||||
|
" \"top_p\": 1,\r\n",
|
||||||
|
" \"frequency_penalty\": 0,\r\n",
|
||||||
|
" \"presence_penalty\": 0,\r\n",
|
||||||
|
" \"n\": 1,\r\n",
|
||||||
|
" \"best_of\": 1,\r\n",
|
||||||
|
" \"request_timeout\": null,\r\n",
|
||||||
|
" \"logit_bias\": {},\r\n",
|
||||||
|
" \"_type\": \"openai\"\r\n",
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"!cat llm.json"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"id": "7e6aa9ab",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"config = {\n",
|
||||||
|
" \"memory\": None,\n",
|
||||||
|
" \"verbose\": True,\n",
|
||||||
|
" \"prompt_path\": \"prompt.json\",\n",
|
||||||
|
" \"llm_path\": \"llm.json\",\n",
|
||||||
|
" \"output_key\": \"text\",\n",
|
||||||
|
" \"_type\": \"llm_chain\"\n",
|
||||||
|
"}\n",
|
||||||
|
"import json\n",
|
||||||
|
"with open(\"llm_chain_separate.json\", \"w\") as f:\n",
|
||||||
|
" json.dump(config, f, indent=2)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"id": "8e959ca6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{\r\n",
|
||||||
|
" \"memory\": null,\r\n",
|
||||||
|
" \"verbose\": true,\r\n",
|
||||||
|
" \"prompt_path\": \"prompt.json\",\r\n",
|
||||||
|
" \"llm_path\": \"llm.json\",\r\n",
|
||||||
|
" \"output_key\": \"text\",\r\n",
|
||||||
|
" \"_type\": \"llm_chain\"\r\n",
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"!cat llm_chain_separate.json"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "662731c0",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"We can then load it in the same way"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"id": "d69ceb93",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chain = load_chain(\"llm_chain_separate.json\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"id": "a99d61b9",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
|
||||||
|
"Prompt after formatting:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mQuestion: whats 2 + 2\n",
|
||||||
|
"\n",
|
||||||
|
"Answer: Let's think step by step.\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"' 2 + 2 = 4'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chain.run(\"whats 2 + 2\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "822b7c12",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.9"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -6,6 +6,7 @@ from langchain.chains.llm_bash.base import LLMBashChain
|
|||||||
from langchain.chains.llm_checker.base import LLMCheckerChain
|
from langchain.chains.llm_checker.base import LLMCheckerChain
|
||||||
from langchain.chains.llm_math.base import LLMMathChain
|
from langchain.chains.llm_math.base import LLMMathChain
|
||||||
from langchain.chains.llm_requests import LLMRequestsChain
|
from langchain.chains.llm_requests import LLMRequestsChain
|
||||||
|
from langchain.chains.loading import load_chain
|
||||||
from langchain.chains.mapreduce import MapReduceChain
|
from langchain.chains.mapreduce import MapReduceChain
|
||||||
from langchain.chains.moderation import OpenAIModerationChain
|
from langchain.chains.moderation import OpenAIModerationChain
|
||||||
from langchain.chains.pal.base import PALChain
|
from langchain.chains.pal.base import PALChain
|
||||||
@ -39,4 +40,5 @@ __all__ = [
|
|||||||
"MapReduceChain",
|
"MapReduceChain",
|
||||||
"OpenAIModerationChain",
|
"OpenAIModerationChain",
|
||||||
"SQLDatabaseSequentialChain",
|
"SQLDatabaseSequentialChain",
|
||||||
|
"load_chain",
|
||||||
]
|
]
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
"""Base interface that all chains should implement."""
|
"""Base interface that all chains should implement."""
|
||||||
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import yaml
|
||||||
from pydantic import BaseModel, Extra, Field, validator
|
from pydantic import BaseModel, Extra, Field, validator
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
@ -44,7 +47,9 @@ class Chain(BaseModel, ABC):
|
|||||||
"""Base interface that all chains should implement."""
|
"""Base interface that all chains should implement."""
|
||||||
|
|
||||||
memory: Optional[Memory] = None
|
memory: Optional[Memory] = None
|
||||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
callback_manager: BaseCallbackManager = Field(
|
||||||
|
default_factory=get_callback_manager, exclude=True
|
||||||
|
)
|
||||||
verbose: bool = Field(
|
verbose: bool = Field(
|
||||||
default_factory=_get_verbosity
|
default_factory=_get_verbosity
|
||||||
) # Whether to print the response text
|
) # Whether to print the response text
|
||||||
@ -54,6 +59,10 @@ class Chain(BaseModel, ABC):
|
|||||||
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _chain_type(self) -> str:
|
||||||
|
raise NotImplementedError("Saving not supported for this chain type.")
|
||||||
|
|
||||||
@validator("callback_manager", pre=True, always=True)
|
@validator("callback_manager", pre=True, always=True)
|
||||||
def set_callback_manager(
|
def set_callback_manager(
|
||||||
cls, callback_manager: Optional[BaseCallbackManager]
|
cls, callback_manager: Optional[BaseCallbackManager]
|
||||||
@ -177,3 +186,43 @@ class Chain(BaseModel, ABC):
|
|||||||
f"`run` supported with either positional arguments or keyword arguments"
|
f"`run` supported with either positional arguments or keyword arguments"
|
||||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
|
"""Return dictionary representation of chain."""
|
||||||
|
if self.memory is not None:
|
||||||
|
raise ValueError("Saving of memory is not yet supported.")
|
||||||
|
_dict = super().dict()
|
||||||
|
_dict["_type"] = self._chain_type
|
||||||
|
return _dict
|
||||||
|
|
||||||
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
|
"""Save the chain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to file to save the chain to.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
chain.save(file_path="path/chain.yaml")
|
||||||
|
"""
|
||||||
|
# Convert file to Path object.
|
||||||
|
if isinstance(file_path, str):
|
||||||
|
save_path = Path(file_path)
|
||||||
|
else:
|
||||||
|
save_path = file_path
|
||||||
|
|
||||||
|
directory_path = save_path.parent
|
||||||
|
directory_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Fetch dictionary to save
|
||||||
|
chain_dict = self.dict()
|
||||||
|
|
||||||
|
if save_path.suffix == ".json":
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
json.dump(chain_dict, f, indent=4)
|
||||||
|
elif save_path.suffix == ".yaml":
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
yaml.dump(chain_dict, f, default_flow_style=False)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{save_path} must be json or yaml")
|
||||||
|
@ -122,3 +122,7 @@ class LLMChain(Chain, BaseModel):
|
|||||||
return new_result
|
return new_result
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _chain_type(self) -> str:
|
||||||
|
return "llm_chain"
|
||||||
|
68
langchain/chains/loading.py
Normal file
68
langchain/chains/loading.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
"""Functionality for loading chains."""
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.llms.loading import load_llm, load_llm_from_config
|
||||||
|
from langchain.prompts.loading import load_prompt, load_prompt_from_config
|
||||||
|
|
||||||
|
|
||||||
|
def _load_llm_chain(config: dict) -> LLMChain:
|
||||||
|
"""Load LLM chain from config dict."""
|
||||||
|
if "llm" in config:
|
||||||
|
llm_config = config.pop("llm")
|
||||||
|
llm = load_llm_from_config(llm_config)
|
||||||
|
elif "llm_path" in config:
|
||||||
|
llm = load_llm(config.pop("llm_path"))
|
||||||
|
else:
|
||||||
|
raise ValueError("One of `llm` or `llm_config` must be present.")
|
||||||
|
|
||||||
|
if "prompt" in config:
|
||||||
|
prompt_config = config.pop("prompt")
|
||||||
|
prompt = load_prompt_from_config(prompt_config)
|
||||||
|
elif "prompt_path" in config:
|
||||||
|
prompt = load_prompt(config.pop("prompt_path"))
|
||||||
|
else:
|
||||||
|
raise ValueError("One of `prompt` or `prompt_path` must be present.")
|
||||||
|
|
||||||
|
return LLMChain(llm=llm, prompt=prompt, **config)
|
||||||
|
|
||||||
|
|
||||||
|
type_to_loader_dict = {"llm_chain": _load_llm_chain}
|
||||||
|
|
||||||
|
|
||||||
|
def load_chain_from_config(config: dict) -> Chain:
|
||||||
|
"""Load chain from Config Dict."""
|
||||||
|
if "_type" not in config:
|
||||||
|
raise ValueError("Must specify an chain Type in config")
|
||||||
|
config_type = config.pop("_type")
|
||||||
|
|
||||||
|
if config_type not in type_to_loader_dict:
|
||||||
|
raise ValueError(f"Loading {config_type} chain not supported")
|
||||||
|
|
||||||
|
chain_loader = type_to_loader_dict[config_type]
|
||||||
|
return chain_loader(config)
|
||||||
|
|
||||||
|
|
||||||
|
def load_chain(file: Union[str, Path]) -> Chain:
|
||||||
|
"""Load chain from file."""
|
||||||
|
# Convert file to Path object.
|
||||||
|
if isinstance(file, str):
|
||||||
|
file_path = Path(file)
|
||||||
|
else:
|
||||||
|
file_path = file
|
||||||
|
# Load from either json or yaml.
|
||||||
|
if file_path.suffix == ".json":
|
||||||
|
with open(file_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
elif file_path.suffix == ".yaml":
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
else:
|
||||||
|
raise ValueError("File type must be json or yaml")
|
||||||
|
# Load the chain from the config now.
|
||||||
|
return load_chain_from_config(config)
|
@ -79,7 +79,7 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
raise e
|
raise e
|
||||||
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||||
return output
|
return output
|
||||||
params = self._llm_dict()
|
params = self.dict()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
missing_prompts = []
|
missing_prompts = []
|
||||||
@ -148,8 +148,8 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
|
|
||||||
def _llm_dict(self) -> Dict:
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
"""Return a dictionary of the prompt."""
|
"""Return a dictionary of the LLM."""
|
||||||
starter_dict = dict(self._identifying_params)
|
starter_dict = dict(self._identifying_params)
|
||||||
starter_dict["_type"] = self._llm_type
|
starter_dict["_type"] = self._llm_type
|
||||||
return starter_dict
|
return starter_dict
|
||||||
@ -175,7 +175,7 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
directory_path.mkdir(parents=True, exist_ok=True)
|
directory_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Fetch dictionary to save
|
# Fetch dictionary to save
|
||||||
prompt_dict = self._llm_dict()
|
prompt_dict = self.dict()
|
||||||
|
|
||||||
if save_path.suffix == ".json":
|
if save_path.suffix == ".json":
|
||||||
with open(file_path, "w") as f:
|
with open(file_path, "w") as f:
|
||||||
|
@ -135,10 +135,6 @@ class BasePromptTemplate(BaseModel, ABC):
|
|||||||
prompt.format(variable1="foo")
|
prompt.format(variable1="foo")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _prompt_dict(self) -> Dict:
|
|
||||||
"""Return a dictionary of the prompt."""
|
|
||||||
return self.dict()
|
|
||||||
|
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
"""Save the prompt.
|
"""Save the prompt.
|
||||||
|
|
||||||
@ -160,7 +156,7 @@ class BasePromptTemplate(BaseModel, ABC):
|
|||||||
directory_path.mkdir(parents=True, exist_ok=True)
|
directory_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Fetch dictionary to save
|
# Fetch dictionary to save
|
||||||
prompt_dict = self._prompt_dict()
|
prompt_dict = self.dict()
|
||||||
|
|
||||||
if save_path.suffix == ".json":
|
if save_path.suffix == ".json":
|
||||||
with open(file_path, "w") as f:
|
with open(file_path, "w") as f:
|
||||||
|
@ -109,11 +109,11 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel):
|
|||||||
# Format the template with the input variables.
|
# Format the template with the input variables.
|
||||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
|
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
|
||||||
|
|
||||||
def _prompt_dict(self) -> Dict:
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
"""Return a dictionary of the prompt."""
|
"""Return a dictionary of the prompt."""
|
||||||
if self.example_selector:
|
if self.example_selector:
|
||||||
raise ValueError("Saving an example selector is not currently supported")
|
raise ValueError("Saving an example selector is not currently supported")
|
||||||
|
|
||||||
prompt_dict = self.dict()
|
prompt_dict = super().dict()
|
||||||
prompt_dict["_type"] = "few_shot"
|
prompt_dict["_type"] = "few_shot"
|
||||||
return prompt_dict
|
return prompt_dict
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
"""Test LLM chain."""
|
"""Test LLM chain."""
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.chains.loading import load_chain
|
||||||
from langchain.prompts.base import BaseOutputParser
|
from langchain.prompts.base import BaseOutputParser
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
@ -24,6 +27,16 @@ def fake_llm_chain() -> LLMChain:
|
|||||||
return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1")
|
return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("langchain.llms.loading.type_to_cls_dict", {"fake": FakeLLM})
|
||||||
|
def test_serialization(fake_llm_chain: LLMChain) -> None:
|
||||||
|
"""Test serialization."""
|
||||||
|
with TemporaryDirectory() as temp_dir:
|
||||||
|
file = temp_dir + "/llm.json"
|
||||||
|
fake_llm_chain.save(file)
|
||||||
|
loaded_chain = load_chain(file)
|
||||||
|
assert loaded_chain == fake_llm_chain
|
||||||
|
|
||||||
|
|
||||||
def test_missing_inputs(fake_llm_chain: LLMChain) -> None:
|
def test_missing_inputs(fake_llm_chain: LLMChain) -> None:
|
||||||
"""Test error is raised if inputs are missing."""
|
"""Test error is raised if inputs are missing."""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
@ -12,7 +12,7 @@ def test_caching() -> None:
|
|||||||
"""Test caching behavior."""
|
"""Test caching behavior."""
|
||||||
langchain.llm_cache = InMemoryCache()
|
langchain.llm_cache = InMemoryCache()
|
||||||
llm = FakeLLM()
|
llm = FakeLLM()
|
||||||
params = llm._llm_dict()
|
params = llm.dict()
|
||||||
params["stop"] = None
|
params["stop"] = None
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||||
@ -50,7 +50,7 @@ def test_custom_caching() -> None:
|
|||||||
engine = create_engine("sqlite://")
|
engine = create_engine("sqlite://")
|
||||||
langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache)
|
langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache)
|
||||||
llm = FakeLLM()
|
llm = FakeLLM()
|
||||||
params = llm._llm_dict()
|
params = llm.dict()
|
||||||
params["stop"] = None
|
params["stop"] = None
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||||
|
Loading…
Reference in New Issue
Block a user