diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index bc8e42d4f4b..d63839bc70d 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -1,4 +1,5 @@ """Chains are easily reusable components which can be linked together.""" +from langchain.chains.api.base import APIChain from langchain.chains.conversation.base import ConversationChain from langchain.chains.llm import LLMChain from langchain.chains.llm_math.base import LLMMathChain @@ -22,4 +23,5 @@ __all__ = [ "QAWithSourcesChain", "VectorDBQAWithSourcesChain", "PALChain", + "APIChain", ] diff --git a/langchain/chains/api/base.py b/langchain/chains/api/base.py index f85c9b0aa30..b73c96bb6cf 100644 --- a/langchain/chains/api/base.py +++ b/langchain/chains/api/base.py @@ -6,9 +6,10 @@ from typing import Any, Dict, List, Optional import requests from pydantic import BaseModel, root_validator -from langchain import LLMChain from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.input import print_text from langchain.llms.base import LLM @@ -27,7 +28,7 @@ class APIChain(Chain, BaseModel): api_request_chain: LLMChain api_answer_chain: LLMChain - requests_chain: RequestsWrapper + requests_wrapper: RequestsWrapper api_docs: str question_key: str = "question" #: :meta private: output_key: str = "output" #: :meta private: @@ -75,7 +76,11 @@ class APIChain(Chain, BaseModel): api_url = self.api_request_chain.predict( question=question, api_docs=self.api_docs ) - api_response = self.requests_chain.run(api_url) + if self.verbose: + print_text(api_url, color="green", end="\n") + api_response = self.requests_wrapper.run(api_url) + if self.verbose: + print_text(api_url, color="yellow", end="\n") answer = self.api_answer_chain.predict( question=question, api_docs=self.api_docs, @@ -93,8 +98,8 @@ class APIChain(Chain, BaseModel): requests_wrapper = RequestsWrapper(headers=headers) get_answer_chain = LLMChain(llm=llm, prompt=API_RESPONSE_PROMPT) return cls( - request_chain=get_request_chain, - answer_chain=get_answer_chain, + api_request_chain=get_request_chain, + api_answer_chain=get_answer_chain, requests_wrapper=requests_wrapper, api_docs=api_docs, **kwargs, diff --git a/tests/unit_tests/chains/test_api.py b/tests/unit_tests/chains/test_api.py index e64b9e1086c..9915ef22a1f 100644 --- a/tests/unit_tests/chains/test_api.py +++ b/tests/unit_tests/chains/test_api.py @@ -68,11 +68,11 @@ def fake_llm_api_chain(test_api_data: dict) -> APIChain: fake_llm = FakeLLM(queries=queries) api_request_chain = LLMChain(llm=fake_llm, prompt=API_URL_PROMPT) api_answer_chain = LLMChain(llm=fake_llm, prompt=API_RESPONSE_PROMPT) - requests_chain = FakeRequestsChain(output=TEST_API_RESPONSE) + requests_wrapper = FakeRequestsChain(output=TEST_API_RESPONSE) return APIChain( api_request_chain=api_request_chain, api_answer_chain=api_answer_chain, - requests_chain=requests_chain, + requests_wrapper=requests_wrapper, api_docs=TEST_API_DOCS, )