mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
serializing api chains
This commit is contained in:
parent
e3df8ab6dc
commit
5f7e8196c6
@ -11,6 +11,7 @@ from langchain.chains.llm import LLMChain
|
|||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts import BasePromptTemplate
|
from langchain.prompts import BasePromptTemplate
|
||||||
from langchain.requests import RequestsWrapper
|
from langchain.requests import RequestsWrapper
|
||||||
|
from langchain.chains import load_chain
|
||||||
|
|
||||||
|
|
||||||
class APIChain(Chain, BaseModel):
|
class APIChain(Chain, BaseModel):
|
||||||
@ -102,3 +103,35 @@ class APIChain(Chain, BaseModel):
|
|||||||
api_docs=api_docs,
|
api_docs=api_docs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _chain_type(self) -> str:
|
||||||
|
return "api_chain"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(config: Dict) -> APIChain:
|
||||||
|
try:
|
||||||
|
api_request_chain_cfg = config.get("api_request_chain")
|
||||||
|
api_request_chain = load_chain(api_request_chain_cfg)
|
||||||
|
|
||||||
|
api_answer_chain_cfg = config.get("api_answer_chain")
|
||||||
|
api_answer_chain = load_chain(api_answer_chain_cfg)
|
||||||
|
|
||||||
|
request_headers = config.get("requests_wrapper").get("headers")
|
||||||
|
requests_wrapper = RequestsWrapper(headers=request_headers)
|
||||||
|
|
||||||
|
api_docs = config.get("api_docs")
|
||||||
|
question_key = config.get("question_key")
|
||||||
|
output_key = config.get("output_key")
|
||||||
|
|
||||||
|
except:
|
||||||
|
raise ValueError("Could not load API answer chain.")
|
||||||
|
|
||||||
|
return APIChain(
|
||||||
|
api_request_chain=api_request_chain,
|
||||||
|
api_answer_chain=api_answer_chain,
|
||||||
|
requests_wrapper=requests_wrapper,
|
||||||
|
api_docs=api_docs,
|
||||||
|
question_key=question_key,
|
||||||
|
output_key=output_key,
|
||||||
|
)
|
||||||
|
@ -226,3 +226,7 @@ class Chain(BaseModel, ABC):
|
|||||||
yaml.dump(chain_dict, f, default_flow_style=False)
|
yaml.dump(chain_dict, f, default_flow_style=False)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{save_path} must be json or yaml")
|
raise ValueError(f"{save_path} must be json or yaml")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(config: Dict) -> "Chain":
|
||||||
|
raise NotImplementedError("Abstract method.")
|
||||||
|
Loading…
Reference in New Issue
Block a user