Harrison/improve usability of api chain (#247)

improve usability of api chain
This commit is contained in:
Harrison Chase 2022-12-02 15:44:10 -08:00 committed by GitHub
parent c897bd6cbd
commit a9ce04201f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 7 deletions

View File

@ -1,4 +1,5 @@
"""Chains are easily reusable components which can be linked together.""" """Chains are easily reusable components which can be linked together."""
from langchain.chains.api.base import APIChain
from langchain.chains.conversation.base import ConversationChain from langchain.chains.conversation.base import ConversationChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.base import LLMMathChain from langchain.chains.llm_math.base import LLMMathChain
@ -22,4 +23,5 @@ __all__ = [
"QAWithSourcesChain", "QAWithSourcesChain",
"VectorDBQAWithSourcesChain", "VectorDBQAWithSourcesChain",
"PALChain", "PALChain",
"APIChain",
] ]

View File

@ -6,9 +6,10 @@ from typing import Any, Dict, List, Optional
import requests import requests
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
from langchain import LLMChain
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import print_text
from langchain.llms.base import LLM from langchain.llms.base import LLM
@ -27,7 +28,7 @@ class APIChain(Chain, BaseModel):
api_request_chain: LLMChain api_request_chain: LLMChain
api_answer_chain: LLMChain api_answer_chain: LLMChain
requests_chain: RequestsWrapper requests_wrapper: RequestsWrapper
api_docs: str api_docs: str
question_key: str = "question" #: :meta private: question_key: str = "question" #: :meta private:
output_key: str = "output" #: :meta private: output_key: str = "output" #: :meta private:
@ -75,7 +76,11 @@ class APIChain(Chain, BaseModel):
api_url = self.api_request_chain.predict( api_url = self.api_request_chain.predict(
question=question, api_docs=self.api_docs question=question, api_docs=self.api_docs
) )
api_response = self.requests_chain.run(api_url) if self.verbose:
print_text(api_url, color="green", end="\n")
api_response = self.requests_wrapper.run(api_url)
if self.verbose:
print_text(api_url, color="yellow", end="\n")
answer = self.api_answer_chain.predict( answer = self.api_answer_chain.predict(
question=question, question=question,
api_docs=self.api_docs, api_docs=self.api_docs,
@ -93,8 +98,8 @@ class APIChain(Chain, BaseModel):
requests_wrapper = RequestsWrapper(headers=headers) requests_wrapper = RequestsWrapper(headers=headers)
get_answer_chain = LLMChain(llm=llm, prompt=API_RESPONSE_PROMPT) get_answer_chain = LLMChain(llm=llm, prompt=API_RESPONSE_PROMPT)
return cls( return cls(
request_chain=get_request_chain, api_request_chain=get_request_chain,
answer_chain=get_answer_chain, api_answer_chain=get_answer_chain,
requests_wrapper=requests_wrapper, requests_wrapper=requests_wrapper,
api_docs=api_docs, api_docs=api_docs,
**kwargs, **kwargs,

View File

@ -68,11 +68,11 @@ def fake_llm_api_chain(test_api_data: dict) -> APIChain:
fake_llm = FakeLLM(queries=queries) fake_llm = FakeLLM(queries=queries)
api_request_chain = LLMChain(llm=fake_llm, prompt=API_URL_PROMPT) api_request_chain = LLMChain(llm=fake_llm, prompt=API_URL_PROMPT)
api_answer_chain = LLMChain(llm=fake_llm, prompt=API_RESPONSE_PROMPT) api_answer_chain = LLMChain(llm=fake_llm, prompt=API_RESPONSE_PROMPT)
requests_chain = FakeRequestsChain(output=TEST_API_RESPONSE) requests_wrapper = FakeRequestsChain(output=TEST_API_RESPONSE)
return APIChain( return APIChain(
api_request_chain=api_request_chain, api_request_chain=api_request_chain,
api_answer_chain=api_answer_chain, api_answer_chain=api_answer_chain,
requests_chain=requests_chain, requests_wrapper=requests_wrapper,
api_docs=TEST_API_DOCS, api_docs=TEST_API_DOCS,
) )