serializing api chains

This commit is contained in:
scadEfUr 2023-01-25 00:00:36 -08:00
parent e3df8ab6dc
commit 5f7e8196c6
2 changed files with 37 additions and 0 deletions

View File

@ -11,6 +11,7 @@ from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.prompts import BasePromptTemplate from langchain.prompts import BasePromptTemplate
from langchain.requests import RequestsWrapper from langchain.requests import RequestsWrapper
from langchain.chains import load_chain
class APIChain(Chain, BaseModel): class APIChain(Chain, BaseModel):
@ -102,3 +103,35 @@ class APIChain(Chain, BaseModel):
api_docs=api_docs, api_docs=api_docs,
**kwargs, **kwargs,
) )
@property
def _chain_type(self) -> str:
return "api_chain"
@classmethod
def from_config(config: Dict) -> APIChain:
try:
api_request_chain_cfg = config.get("api_request_chain")
api_request_chain = load_chain(api_request_chain_cfg)
api_answer_chain_cfg = config.get("api_answer_chain")
api_answer_chain = load_chain(api_answer_chain_cfg)
request_headers = config.get("requests_wrapper").get("headers")
requests_wrapper = RequestsWrapper(headers=request_headers)
api_docs = config.get("api_docs")
question_key = config.get("question_key")
output_key = config.get("output_key")
except:
raise ValueError("Could not load API answer chain.")
return APIChain(
api_request_chain=api_request_chain,
api_answer_chain=api_answer_chain,
requests_wrapper=requests_wrapper,
api_docs=api_docs,
question_key=question_key,
output_key=output_key,
)

View File

@ -226,3 +226,7 @@ class Chain(BaseModel, ABC):
yaml.dump(chain_dict, f, default_flow_style=False) yaml.dump(chain_dict, f, default_flow_style=False)
else: else:
raise ValueError(f"{save_path} must be json or yaml") raise ValueError(f"{save_path} must be json or yaml")
@classmethod
def from_config(config: Dict) -> "Chain":
raise NotImplementedError("Abstract method.")