mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 20:58:25 +00:00
Harrison/improve usability of api chain (#247)
improve usability of api chain
This commit is contained in:
parent
c897bd6cbd
commit
a9ce04201f
@ -1,4 +1,5 @@
|
||||
"""Chains are easily reusable components which can be linked together."""
|
||||
from langchain.chains.api.base import APIChain
|
||||
from langchain.chains.conversation.base import ConversationChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_math.base import LLMMathChain
|
||||
@ -22,4 +23,5 @@ __all__ = [
|
||||
"QAWithSourcesChain",
|
||||
"VectorDBQAWithSourcesChain",
|
||||
"PALChain",
|
||||
"APIChain",
|
||||
]
|
||||
|
@ -6,9 +6,10 @@ from typing import Any, Dict, List, Optional
|
||||
import requests
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import print_text
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
|
||||
@ -27,7 +28,7 @@ class APIChain(Chain, BaseModel):
|
||||
|
||||
api_request_chain: LLMChain
|
||||
api_answer_chain: LLMChain
|
||||
requests_chain: RequestsWrapper
|
||||
requests_wrapper: RequestsWrapper
|
||||
api_docs: str
|
||||
question_key: str = "question" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
@ -75,7 +76,11 @@ class APIChain(Chain, BaseModel):
|
||||
api_url = self.api_request_chain.predict(
|
||||
question=question, api_docs=self.api_docs
|
||||
)
|
||||
api_response = self.requests_chain.run(api_url)
|
||||
if self.verbose:
|
||||
print_text(api_url, color="green", end="\n")
|
||||
api_response = self.requests_wrapper.run(api_url)
|
||||
if self.verbose:
|
||||
print_text(api_url, color="yellow", end="\n")
|
||||
answer = self.api_answer_chain.predict(
|
||||
question=question,
|
||||
api_docs=self.api_docs,
|
||||
@ -93,8 +98,8 @@ class APIChain(Chain, BaseModel):
|
||||
requests_wrapper = RequestsWrapper(headers=headers)
|
||||
get_answer_chain = LLMChain(llm=llm, prompt=API_RESPONSE_PROMPT)
|
||||
return cls(
|
||||
request_chain=get_request_chain,
|
||||
answer_chain=get_answer_chain,
|
||||
api_request_chain=get_request_chain,
|
||||
api_answer_chain=get_answer_chain,
|
||||
requests_wrapper=requests_wrapper,
|
||||
api_docs=api_docs,
|
||||
**kwargs,
|
||||
|
@ -68,11 +68,11 @@ def fake_llm_api_chain(test_api_data: dict) -> APIChain:
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
api_request_chain = LLMChain(llm=fake_llm, prompt=API_URL_PROMPT)
|
||||
api_answer_chain = LLMChain(llm=fake_llm, prompt=API_RESPONSE_PROMPT)
|
||||
requests_chain = FakeRequestsChain(output=TEST_API_RESPONSE)
|
||||
requests_wrapper = FakeRequestsChain(output=TEST_API_RESPONSE)
|
||||
return APIChain(
|
||||
api_request_chain=api_request_chain,
|
||||
api_answer_chain=api_answer_chain,
|
||||
requests_chain=requests_chain,
|
||||
requests_wrapper=requests_wrapper,
|
||||
api_docs=TEST_API_DOCS,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user