mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 13:27:36 +00:00
serializing api chains
This commit is contained in:
parent
e3df8ab6dc
commit
5f7e8196c6
@ -11,6 +11,7 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.requests import RequestsWrapper
|
||||
from langchain.chains import load_chain
|
||||
|
||||
|
||||
class APIChain(Chain, BaseModel):
|
||||
@ -102,3 +103,35 @@ class APIChain(Chain, BaseModel):
|
||||
api_docs=api_docs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "api_chain"
|
||||
|
||||
@classmethod
|
||||
def from_config(config: Dict) -> APIChain:
|
||||
try:
|
||||
api_request_chain_cfg = config.get("api_request_chain")
|
||||
api_request_chain = load_chain(api_request_chain_cfg)
|
||||
|
||||
api_answer_chain_cfg = config.get("api_answer_chain")
|
||||
api_answer_chain = load_chain(api_answer_chain_cfg)
|
||||
|
||||
request_headers = config.get("requests_wrapper").get("headers")
|
||||
requests_wrapper = RequestsWrapper(headers=request_headers)
|
||||
|
||||
api_docs = config.get("api_docs")
|
||||
question_key = config.get("question_key")
|
||||
output_key = config.get("output_key")
|
||||
|
||||
except:
|
||||
raise ValueError("Could not load API answer chain.")
|
||||
|
||||
return APIChain(
|
||||
api_request_chain=api_request_chain,
|
||||
api_answer_chain=api_answer_chain,
|
||||
requests_wrapper=requests_wrapper,
|
||||
api_docs=api_docs,
|
||||
question_key=question_key,
|
||||
output_key=output_key,
|
||||
)
|
||||
|
@ -226,3 +226,7 @@ class Chain(BaseModel, ABC):
|
||||
yaml.dump(chain_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
@classmethod
|
||||
def from_config(config: Dict) -> "Chain":
|
||||
raise NotImplementedError("Abstract method.")
|
||||
|
Loading…
Reference in New Issue
Block a user