mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 08:27:03 +00:00
langchain: Add ruff rule RET (#31875)
All auto-fixes See https://docs.astral.sh/ruff/rules/#flake8-return-ret --------- Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
parent
fceebbb387
commit
56bbfd9723
@ -48,25 +48,25 @@ def __getattr__(name: str) -> Any:
|
|||||||
_warn_on_import(name, replacement="langchain.agents.MRKLChain")
|
_warn_on_import(name, replacement="langchain.agents.MRKLChain")
|
||||||
|
|
||||||
return MRKLChain
|
return MRKLChain
|
||||||
elif name == "ReActChain":
|
if name == "ReActChain":
|
||||||
from langchain.agents import ReActChain
|
from langchain.agents import ReActChain
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.agents.ReActChain")
|
_warn_on_import(name, replacement="langchain.agents.ReActChain")
|
||||||
|
|
||||||
return ReActChain
|
return ReActChain
|
||||||
elif name == "SelfAskWithSearchChain":
|
if name == "SelfAskWithSearchChain":
|
||||||
from langchain.agents import SelfAskWithSearchChain
|
from langchain.agents import SelfAskWithSearchChain
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.agents.SelfAskWithSearchChain")
|
_warn_on_import(name, replacement="langchain.agents.SelfAskWithSearchChain")
|
||||||
|
|
||||||
return SelfAskWithSearchChain
|
return SelfAskWithSearchChain
|
||||||
elif name == "ConversationChain":
|
if name == "ConversationChain":
|
||||||
from langchain.chains import ConversationChain
|
from langchain.chains import ConversationChain
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.chains.ConversationChain")
|
_warn_on_import(name, replacement="langchain.chains.ConversationChain")
|
||||||
|
|
||||||
return ConversationChain
|
return ConversationChain
|
||||||
elif name == "LLMBashChain":
|
if name == "LLMBashChain":
|
||||||
msg = (
|
msg = (
|
||||||
"This module has been moved to langchain-experimental. "
|
"This module has been moved to langchain-experimental. "
|
||||||
"For more details: "
|
"For more details: "
|
||||||
@ -77,97 +77,97 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
raise ImportError(msg)
|
raise ImportError(msg)
|
||||||
|
|
||||||
elif name == "LLMChain":
|
if name == "LLMChain":
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.chains.LLMChain")
|
_warn_on_import(name, replacement="langchain.chains.LLMChain")
|
||||||
|
|
||||||
return LLMChain
|
return LLMChain
|
||||||
elif name == "LLMCheckerChain":
|
if name == "LLMCheckerChain":
|
||||||
from langchain.chains import LLMCheckerChain
|
from langchain.chains import LLMCheckerChain
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.chains.LLMCheckerChain")
|
_warn_on_import(name, replacement="langchain.chains.LLMCheckerChain")
|
||||||
|
|
||||||
return LLMCheckerChain
|
return LLMCheckerChain
|
||||||
elif name == "LLMMathChain":
|
if name == "LLMMathChain":
|
||||||
from langchain.chains import LLMMathChain
|
from langchain.chains import LLMMathChain
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.chains.LLMMathChain")
|
_warn_on_import(name, replacement="langchain.chains.LLMMathChain")
|
||||||
|
|
||||||
return LLMMathChain
|
return LLMMathChain
|
||||||
elif name == "QAWithSourcesChain":
|
if name == "QAWithSourcesChain":
|
||||||
from langchain.chains import QAWithSourcesChain
|
from langchain.chains import QAWithSourcesChain
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.chains.QAWithSourcesChain")
|
_warn_on_import(name, replacement="langchain.chains.QAWithSourcesChain")
|
||||||
|
|
||||||
return QAWithSourcesChain
|
return QAWithSourcesChain
|
||||||
elif name == "VectorDBQA":
|
if name == "VectorDBQA":
|
||||||
from langchain.chains import VectorDBQA
|
from langchain.chains import VectorDBQA
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.chains.VectorDBQA")
|
_warn_on_import(name, replacement="langchain.chains.VectorDBQA")
|
||||||
|
|
||||||
return VectorDBQA
|
return VectorDBQA
|
||||||
elif name == "VectorDBQAWithSourcesChain":
|
if name == "VectorDBQAWithSourcesChain":
|
||||||
from langchain.chains import VectorDBQAWithSourcesChain
|
from langchain.chains import VectorDBQAWithSourcesChain
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.chains.VectorDBQAWithSourcesChain")
|
_warn_on_import(name, replacement="langchain.chains.VectorDBQAWithSourcesChain")
|
||||||
|
|
||||||
return VectorDBQAWithSourcesChain
|
return VectorDBQAWithSourcesChain
|
||||||
elif name == "InMemoryDocstore":
|
if name == "InMemoryDocstore":
|
||||||
from langchain_community.docstore import InMemoryDocstore
|
from langchain_community.docstore import InMemoryDocstore
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.docstore.InMemoryDocstore")
|
_warn_on_import(name, replacement="langchain.docstore.InMemoryDocstore")
|
||||||
|
|
||||||
return InMemoryDocstore
|
return InMemoryDocstore
|
||||||
elif name == "Wikipedia":
|
if name == "Wikipedia":
|
||||||
from langchain_community.docstore import Wikipedia
|
from langchain_community.docstore import Wikipedia
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.docstore.Wikipedia")
|
_warn_on_import(name, replacement="langchain.docstore.Wikipedia")
|
||||||
|
|
||||||
return Wikipedia
|
return Wikipedia
|
||||||
elif name == "Anthropic":
|
if name == "Anthropic":
|
||||||
from langchain_community.llms import Anthropic
|
from langchain_community.llms import Anthropic
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.Anthropic")
|
_warn_on_import(name, replacement="langchain_community.llms.Anthropic")
|
||||||
|
|
||||||
return Anthropic
|
return Anthropic
|
||||||
elif name == "Banana":
|
if name == "Banana":
|
||||||
from langchain_community.llms import Banana
|
from langchain_community.llms import Banana
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.Banana")
|
_warn_on_import(name, replacement="langchain_community.llms.Banana")
|
||||||
|
|
||||||
return Banana
|
return Banana
|
||||||
elif name == "CerebriumAI":
|
if name == "CerebriumAI":
|
||||||
from langchain_community.llms import CerebriumAI
|
from langchain_community.llms import CerebriumAI
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.CerebriumAI")
|
_warn_on_import(name, replacement="langchain_community.llms.CerebriumAI")
|
||||||
|
|
||||||
return CerebriumAI
|
return CerebriumAI
|
||||||
elif name == "Cohere":
|
if name == "Cohere":
|
||||||
from langchain_community.llms import Cohere
|
from langchain_community.llms import Cohere
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.Cohere")
|
_warn_on_import(name, replacement="langchain_community.llms.Cohere")
|
||||||
|
|
||||||
return Cohere
|
return Cohere
|
||||||
elif name == "ForefrontAI":
|
if name == "ForefrontAI":
|
||||||
from langchain_community.llms import ForefrontAI
|
from langchain_community.llms import ForefrontAI
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.ForefrontAI")
|
_warn_on_import(name, replacement="langchain_community.llms.ForefrontAI")
|
||||||
|
|
||||||
return ForefrontAI
|
return ForefrontAI
|
||||||
elif name == "GooseAI":
|
if name == "GooseAI":
|
||||||
from langchain_community.llms import GooseAI
|
from langchain_community.llms import GooseAI
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.GooseAI")
|
_warn_on_import(name, replacement="langchain_community.llms.GooseAI")
|
||||||
|
|
||||||
return GooseAI
|
return GooseAI
|
||||||
elif name == "HuggingFaceHub":
|
if name == "HuggingFaceHub":
|
||||||
from langchain_community.llms import HuggingFaceHub
|
from langchain_community.llms import HuggingFaceHub
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.HuggingFaceHub")
|
_warn_on_import(name, replacement="langchain_community.llms.HuggingFaceHub")
|
||||||
|
|
||||||
return HuggingFaceHub
|
return HuggingFaceHub
|
||||||
elif name == "HuggingFaceTextGenInference":
|
if name == "HuggingFaceTextGenInference":
|
||||||
from langchain_community.llms import HuggingFaceTextGenInference
|
from langchain_community.llms import HuggingFaceTextGenInference
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -175,55 +175,55 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return HuggingFaceTextGenInference
|
return HuggingFaceTextGenInference
|
||||||
elif name == "LlamaCpp":
|
if name == "LlamaCpp":
|
||||||
from langchain_community.llms import LlamaCpp
|
from langchain_community.llms import LlamaCpp
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.LlamaCpp")
|
_warn_on_import(name, replacement="langchain_community.llms.LlamaCpp")
|
||||||
|
|
||||||
return LlamaCpp
|
return LlamaCpp
|
||||||
elif name == "Modal":
|
if name == "Modal":
|
||||||
from langchain_community.llms import Modal
|
from langchain_community.llms import Modal
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.Modal")
|
_warn_on_import(name, replacement="langchain_community.llms.Modal")
|
||||||
|
|
||||||
return Modal
|
return Modal
|
||||||
elif name == "OpenAI":
|
if name == "OpenAI":
|
||||||
from langchain_community.llms import OpenAI
|
from langchain_community.llms import OpenAI
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.OpenAI")
|
_warn_on_import(name, replacement="langchain_community.llms.OpenAI")
|
||||||
|
|
||||||
return OpenAI
|
return OpenAI
|
||||||
elif name == "Petals":
|
if name == "Petals":
|
||||||
from langchain_community.llms import Petals
|
from langchain_community.llms import Petals
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.Petals")
|
_warn_on_import(name, replacement="langchain_community.llms.Petals")
|
||||||
|
|
||||||
return Petals
|
return Petals
|
||||||
elif name == "PipelineAI":
|
if name == "PipelineAI":
|
||||||
from langchain_community.llms import PipelineAI
|
from langchain_community.llms import PipelineAI
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.PipelineAI")
|
_warn_on_import(name, replacement="langchain_community.llms.PipelineAI")
|
||||||
|
|
||||||
return PipelineAI
|
return PipelineAI
|
||||||
elif name == "SagemakerEndpoint":
|
if name == "SagemakerEndpoint":
|
||||||
from langchain_community.llms import SagemakerEndpoint
|
from langchain_community.llms import SagemakerEndpoint
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.SagemakerEndpoint")
|
_warn_on_import(name, replacement="langchain_community.llms.SagemakerEndpoint")
|
||||||
|
|
||||||
return SagemakerEndpoint
|
return SagemakerEndpoint
|
||||||
elif name == "StochasticAI":
|
if name == "StochasticAI":
|
||||||
from langchain_community.llms import StochasticAI
|
from langchain_community.llms import StochasticAI
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.StochasticAI")
|
_warn_on_import(name, replacement="langchain_community.llms.StochasticAI")
|
||||||
|
|
||||||
return StochasticAI
|
return StochasticAI
|
||||||
elif name == "Writer":
|
if name == "Writer":
|
||||||
from langchain_community.llms import Writer
|
from langchain_community.llms import Writer
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.Writer")
|
_warn_on_import(name, replacement="langchain_community.llms.Writer")
|
||||||
|
|
||||||
return Writer
|
return Writer
|
||||||
elif name == "HuggingFacePipeline":
|
if name == "HuggingFacePipeline":
|
||||||
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -232,7 +232,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return HuggingFacePipeline
|
return HuggingFacePipeline
|
||||||
elif name == "FewShotPromptTemplate":
|
if name == "FewShotPromptTemplate":
|
||||||
from langchain_core.prompts import FewShotPromptTemplate
|
from langchain_core.prompts import FewShotPromptTemplate
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -240,7 +240,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return FewShotPromptTemplate
|
return FewShotPromptTemplate
|
||||||
elif name == "Prompt":
|
if name == "Prompt":
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate")
|
_warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate")
|
||||||
@ -248,19 +248,19 @@ def __getattr__(name: str) -> Any:
|
|||||||
# it's renamed as prompt template anyways
|
# it's renamed as prompt template anyways
|
||||||
# this is just for backwards compat
|
# this is just for backwards compat
|
||||||
return PromptTemplate
|
return PromptTemplate
|
||||||
elif name == "PromptTemplate":
|
if name == "PromptTemplate":
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate")
|
_warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate")
|
||||||
|
|
||||||
return PromptTemplate
|
return PromptTemplate
|
||||||
elif name == "BasePromptTemplate":
|
if name == "BasePromptTemplate":
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_core.prompts.BasePromptTemplate")
|
_warn_on_import(name, replacement="langchain_core.prompts.BasePromptTemplate")
|
||||||
|
|
||||||
return BasePromptTemplate
|
return BasePromptTemplate
|
||||||
elif name == "ArxivAPIWrapper":
|
if name == "ArxivAPIWrapper":
|
||||||
from langchain_community.utilities import ArxivAPIWrapper
|
from langchain_community.utilities import ArxivAPIWrapper
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -268,7 +268,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ArxivAPIWrapper
|
return ArxivAPIWrapper
|
||||||
elif name == "GoldenQueryAPIWrapper":
|
if name == "GoldenQueryAPIWrapper":
|
||||||
from langchain_community.utilities import GoldenQueryAPIWrapper
|
from langchain_community.utilities import GoldenQueryAPIWrapper
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -276,7 +276,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return GoldenQueryAPIWrapper
|
return GoldenQueryAPIWrapper
|
||||||
elif name == "GoogleSearchAPIWrapper":
|
if name == "GoogleSearchAPIWrapper":
|
||||||
from langchain_community.utilities import GoogleSearchAPIWrapper
|
from langchain_community.utilities import GoogleSearchAPIWrapper
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -284,7 +284,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return GoogleSearchAPIWrapper
|
return GoogleSearchAPIWrapper
|
||||||
elif name == "GoogleSerperAPIWrapper":
|
if name == "GoogleSerperAPIWrapper":
|
||||||
from langchain_community.utilities import GoogleSerperAPIWrapper
|
from langchain_community.utilities import GoogleSerperAPIWrapper
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -292,7 +292,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return GoogleSerperAPIWrapper
|
return GoogleSerperAPIWrapper
|
||||||
elif name == "PowerBIDataset":
|
if name == "PowerBIDataset":
|
||||||
from langchain_community.utilities import PowerBIDataset
|
from langchain_community.utilities import PowerBIDataset
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -300,7 +300,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return PowerBIDataset
|
return PowerBIDataset
|
||||||
elif name == "SearxSearchWrapper":
|
if name == "SearxSearchWrapper":
|
||||||
from langchain_community.utilities import SearxSearchWrapper
|
from langchain_community.utilities import SearxSearchWrapper
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -308,7 +308,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return SearxSearchWrapper
|
return SearxSearchWrapper
|
||||||
elif name == "WikipediaAPIWrapper":
|
if name == "WikipediaAPIWrapper":
|
||||||
from langchain_community.utilities import WikipediaAPIWrapper
|
from langchain_community.utilities import WikipediaAPIWrapper
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -316,7 +316,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return WikipediaAPIWrapper
|
return WikipediaAPIWrapper
|
||||||
elif name == "WolframAlphaAPIWrapper":
|
if name == "WolframAlphaAPIWrapper":
|
||||||
from langchain_community.utilities import WolframAlphaAPIWrapper
|
from langchain_community.utilities import WolframAlphaAPIWrapper
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -324,19 +324,19 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return WolframAlphaAPIWrapper
|
return WolframAlphaAPIWrapper
|
||||||
elif name == "SQLDatabase":
|
if name == "SQLDatabase":
|
||||||
from langchain_community.utilities import SQLDatabase
|
from langchain_community.utilities import SQLDatabase
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.utilities.SQLDatabase")
|
_warn_on_import(name, replacement="langchain_community.utilities.SQLDatabase")
|
||||||
|
|
||||||
return SQLDatabase
|
return SQLDatabase
|
||||||
elif name == "FAISS":
|
if name == "FAISS":
|
||||||
from langchain_community.vectorstores import FAISS
|
from langchain_community.vectorstores import FAISS
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.vectorstores.FAISS")
|
_warn_on_import(name, replacement="langchain_community.vectorstores.FAISS")
|
||||||
|
|
||||||
return FAISS
|
return FAISS
|
||||||
elif name == "ElasticVectorSearch":
|
if name == "ElasticVectorSearch":
|
||||||
from langchain_community.vectorstores import ElasticVectorSearch
|
from langchain_community.vectorstores import ElasticVectorSearch
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -345,7 +345,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
|
|
||||||
return ElasticVectorSearch
|
return ElasticVectorSearch
|
||||||
# For backwards compatibility
|
# For backwards compatibility
|
||||||
elif name == "SerpAPIChain" or name == "SerpAPIWrapper":
|
if name == "SerpAPIChain" or name == "SerpAPIWrapper":
|
||||||
from langchain_community.utilities import SerpAPIWrapper
|
from langchain_community.utilities import SerpAPIWrapper
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -353,7 +353,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return SerpAPIWrapper
|
return SerpAPIWrapper
|
||||||
elif name == "verbose":
|
if name == "verbose":
|
||||||
from langchain.globals import _verbose
|
from langchain.globals import _verbose
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -364,7 +364,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return _verbose
|
return _verbose
|
||||||
elif name == "debug":
|
if name == "debug":
|
||||||
from langchain.globals import _debug
|
from langchain.globals import _debug
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -375,7 +375,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return _debug
|
return _debug
|
||||||
elif name == "llm_cache":
|
if name == "llm_cache":
|
||||||
from langchain.globals import _llm_cache
|
from langchain.globals import _llm_cache
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -386,9 +386,8 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return _llm_cache
|
return _llm_cache
|
||||||
else:
|
msg = f"Could not find: {name}"
|
||||||
msg = f"Could not find: {name}"
|
raise AttributeError(msg)
|
||||||
raise AttributeError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -137,9 +137,8 @@ class BaseSingleActionAgent(BaseModel):
|
|||||||
return AgentFinish(
|
return AgentFinish(
|
||||||
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
||||||
)
|
)
|
||||||
else:
|
msg = f"Got unsupported early_stopping_method `{early_stopping_method}`"
|
||||||
msg = f"Got unsupported early_stopping_method `{early_stopping_method}`"
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm_and_tools(
|
def from_llm_and_tools(
|
||||||
@ -308,9 +307,8 @@ class BaseMultiActionAgent(BaseModel):
|
|||||||
if early_stopping_method == "force":
|
if early_stopping_method == "force":
|
||||||
# `force` just returns a constant string
|
# `force` just returns a constant string
|
||||||
return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
|
return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
|
||||||
else:
|
msg = f"Got unsupported early_stopping_method `{early_stopping_method}`"
|
||||||
msg = f"Got unsupported early_stopping_method `{early_stopping_method}`"
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _agent_type(self) -> str:
|
def _agent_type(self) -> str:
|
||||||
@ -815,8 +813,7 @@ class Agent(BaseSingleActionAgent):
|
|||||||
"""
|
"""
|
||||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||||
full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs)
|
full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs)
|
||||||
agent_output = await self.output_parser.aparse(full_output)
|
return await self.output_parser.aparse(full_output)
|
||||||
return agent_output
|
|
||||||
|
|
||||||
def get_full_inputs(
|
def get_full_inputs(
|
||||||
self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs: Any
|
self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs: Any
|
||||||
@ -833,8 +830,7 @@ class Agent(BaseSingleActionAgent):
|
|||||||
"""
|
"""
|
||||||
thoughts = self._construct_scratchpad(intermediate_steps)
|
thoughts = self._construct_scratchpad(intermediate_steps)
|
||||||
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
|
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
|
||||||
full_inputs = {**kwargs, **new_inputs}
|
return {**kwargs, **new_inputs}
|
||||||
return full_inputs
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> list[str]:
|
def input_keys(self) -> list[str]:
|
||||||
@ -970,7 +966,7 @@ class Agent(BaseSingleActionAgent):
|
|||||||
return AgentFinish(
|
return AgentFinish(
|
||||||
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
||||||
)
|
)
|
||||||
elif early_stopping_method == "generate":
|
if early_stopping_method == "generate":
|
||||||
# Generate does one final forward pass
|
# Generate does one final forward pass
|
||||||
thoughts = ""
|
thoughts = ""
|
||||||
for action, observation in intermediate_steps:
|
for action, observation in intermediate_steps:
|
||||||
@ -990,16 +986,14 @@ class Agent(BaseSingleActionAgent):
|
|||||||
if isinstance(parsed_output, AgentFinish):
|
if isinstance(parsed_output, AgentFinish):
|
||||||
# If we can extract, we send the correct stuff
|
# If we can extract, we send the correct stuff
|
||||||
return parsed_output
|
return parsed_output
|
||||||
else:
|
# If we can extract, but the tool is not the final tool,
|
||||||
# If we can extract, but the tool is not the final tool,
|
# we just return the full output
|
||||||
# we just return the full output
|
return AgentFinish({"output": full_output}, full_output)
|
||||||
return AgentFinish({"output": full_output}, full_output)
|
msg = (
|
||||||
else:
|
"early_stopping_method should be one of `force` or `generate`, "
|
||||||
msg = (
|
f"got {early_stopping_method}"
|
||||||
"early_stopping_method should be one of `force` or `generate`, "
|
)
|
||||||
f"got {early_stopping_method}"
|
raise ValueError(msg)
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
def tool_run_logging_kwargs(self) -> builtins.dict:
|
def tool_run_logging_kwargs(self) -> builtins.dict:
|
||||||
"""Return logging kwargs for tool run."""
|
"""Return logging kwargs for tool run."""
|
||||||
@ -1179,8 +1173,7 @@ class AgentExecutor(Chain):
|
|||||||
"""
|
"""
|
||||||
if isinstance(self.agent, Runnable):
|
if isinstance(self.agent, Runnable):
|
||||||
return cast(RunnableAgentType, self.agent)
|
return cast(RunnableAgentType, self.agent)
|
||||||
else:
|
return self.agent
|
||||||
return self.agent
|
|
||||||
|
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
"""Raise error - saving not supported for Agent Executors.
|
"""Raise error - saving not supported for Agent Executors.
|
||||||
@ -1249,8 +1242,7 @@ class AgentExecutor(Chain):
|
|||||||
"""
|
"""
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
return self._action_agent.return_values + ["intermediate_steps"]
|
return self._action_agent.return_values + ["intermediate_steps"]
|
||||||
else:
|
return self._action_agent.return_values
|
||||||
return self._action_agent.return_values
|
|
||||||
|
|
||||||
def lookup_tool(self, name: str) -> BaseTool:
|
def lookup_tool(self, name: str) -> BaseTool:
|
||||||
"""Lookup tool by name.
|
"""Lookup tool by name.
|
||||||
@ -1304,10 +1296,7 @@ class AgentExecutor(Chain):
|
|||||||
msg = "Expected a single AgentFinish output, but got multiple values."
|
msg = "Expected a single AgentFinish output, but got multiple values."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return values[-1]
|
return values[-1]
|
||||||
else:
|
return [(a.action, a.observation) for a in values if isinstance(a, AgentStep)]
|
||||||
return [
|
|
||||||
(a.action, a.observation) for a in values if isinstance(a, AgentStep)
|
|
||||||
]
|
|
||||||
|
|
||||||
def _take_next_step(
|
def _take_next_step(
|
||||||
self,
|
self,
|
||||||
@ -1727,10 +1716,9 @@ class AgentExecutor(Chain):
|
|||||||
and self.trim_intermediate_steps > 0
|
and self.trim_intermediate_steps > 0
|
||||||
):
|
):
|
||||||
return intermediate_steps[-self.trim_intermediate_steps :]
|
return intermediate_steps[-self.trim_intermediate_steps :]
|
||||||
elif callable(self.trim_intermediate_steps):
|
if callable(self.trim_intermediate_steps):
|
||||||
return self.trim_intermediate_steps(intermediate_steps)
|
return self.trim_intermediate_steps(intermediate_steps)
|
||||||
else:
|
return intermediate_steps
|
||||||
return intermediate_steps
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def stream(
|
def stream(
|
||||||
|
@ -61,8 +61,7 @@ class ChatAgent(Agent):
|
|||||||
f"(but I haven't seen any of it! I only see what "
|
f"(but I haven't seen any of it! I only see what "
|
||||||
f"you return as final answer):\n{agent_scratchpad}"
|
f"you return as final answer):\n{agent_scratchpad}"
|
||||||
)
|
)
|
||||||
else:
|
return agent_scratchpad
|
||||||
return agent_scratchpad
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
||||||
|
@ -39,15 +39,13 @@ class ConvoOutputParser(AgentOutputParser):
|
|||||||
# If the action indicates a final answer, return an AgentFinish
|
# If the action indicates a final answer, return an AgentFinish
|
||||||
if action == "Final Answer":
|
if action == "Final Answer":
|
||||||
return AgentFinish({"output": action_input}, text)
|
return AgentFinish({"output": action_input}, text)
|
||||||
else:
|
# Otherwise, return an AgentAction with the specified action and
|
||||||
# Otherwise, return an AgentAction with the specified action and
|
# input
|
||||||
# input
|
return AgentAction(action, action_input, text)
|
||||||
return AgentAction(action, action_input, text)
|
# If the necessary keys aren't present in the response, raise an
|
||||||
else:
|
# exception
|
||||||
# If the necessary keys aren't present in the response, raise an
|
msg = f"Missing 'action' or 'action_input' in LLM output: {text}"
|
||||||
# exception
|
raise OutputParserException(msg)
|
||||||
msg = f"Missing 'action' or 'action_input' in LLM output: {text}"
|
|
||||||
raise OutputParserException(msg)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If any other exception is raised during parsing, also raise an
|
# If any other exception is raised during parsing, also raise an
|
||||||
# OutputParserException
|
# OutputParserException
|
||||||
|
@ -23,8 +23,7 @@ def _convert_agent_action_to_messages(
|
|||||||
return list(agent_action.message_log) + [
|
return list(agent_action.message_log) + [
|
||||||
_create_function_message(agent_action, observation)
|
_create_function_message(agent_action, observation)
|
||||||
]
|
]
|
||||||
else:
|
return [AIMessage(content=agent_action.log)]
|
||||||
return [AIMessage(content=agent_action.log)]
|
|
||||||
|
|
||||||
|
|
||||||
def _create_function_message(
|
def _create_function_message(
|
||||||
|
@ -182,7 +182,7 @@ def create_json_chat_agent(
|
|||||||
else:
|
else:
|
||||||
llm_to_use = llm
|
llm_to_use = llm
|
||||||
|
|
||||||
agent = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
agent_scratchpad=lambda x: format_log_to_messages(
|
agent_scratchpad=lambda x: format_log_to_messages(
|
||||||
x["intermediate_steps"], template_tool_response=template_tool_response
|
x["intermediate_steps"], template_tool_response=template_tool_response
|
||||||
@ -192,4 +192,3 @@ def create_json_chat_agent(
|
|||||||
| llm_to_use
|
| llm_to_use
|
||||||
| JSONAgentOutputParser()
|
| JSONAgentOutputParser()
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
@ -55,9 +55,8 @@ class MRKLOutputParser(AgentOutputParser):
|
|||||||
return AgentFinish(
|
return AgentFinish(
|
||||||
{"output": text[start_index:end_index].strip()}, text[:end_index]
|
{"output": text[start_index:end_index].strip()}, text[:end_index]
|
||||||
)
|
)
|
||||||
else:
|
msg = f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
|
||||||
msg = f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
|
raise OutputParserException(msg)
|
||||||
raise OutputParserException(msg)
|
|
||||||
|
|
||||||
if action_match:
|
if action_match:
|
||||||
action = action_match.group(1).strip()
|
action = action_match.group(1).strip()
|
||||||
@ -69,7 +68,7 @@ class MRKLOutputParser(AgentOutputParser):
|
|||||||
|
|
||||||
return AgentAction(action, tool_input, text)
|
return AgentAction(action, tool_input, text)
|
||||||
|
|
||||||
elif includes_answer:
|
if includes_answer:
|
||||||
return AgentFinish(
|
return AgentFinish(
|
||||||
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
||||||
)
|
)
|
||||||
@ -82,7 +81,7 @@ class MRKLOutputParser(AgentOutputParser):
|
|||||||
llm_output=text,
|
llm_output=text,
|
||||||
send_to_llm=True,
|
send_to_llm=True,
|
||||||
)
|
)
|
||||||
elif not re.search(
|
if not re.search(
|
||||||
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
|
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
|
||||||
):
|
):
|
||||||
msg = f"Could not parse LLM output: `{text}`"
|
msg = f"Could not parse LLM output: `{text}`"
|
||||||
@ -92,9 +91,8 @@ class MRKLOutputParser(AgentOutputParser):
|
|||||||
llm_output=text,
|
llm_output=text,
|
||||||
send_to_llm=True,
|
send_to_llm=True,
|
||||||
)
|
)
|
||||||
else:
|
msg = f"Could not parse LLM output: `{text}`"
|
||||||
msg = f"Could not parse LLM output: `{text}`"
|
raise OutputParserException(msg)
|
||||||
raise OutputParserException(msg)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _type(self) -> str:
|
def _type(self) -> str:
|
||||||
|
@ -128,8 +128,7 @@ def _get_assistants_tool(
|
|||||||
"""
|
"""
|
||||||
if _is_assistants_builtin_tool(tool):
|
if _is_assistants_builtin_tool(tool):
|
||||||
return tool # type: ignore[return-value]
|
return tool # type: ignore[return-value]
|
||||||
else:
|
return convert_to_openai_tool(tool)
|
||||||
return convert_to_openai_tool(tool)
|
|
||||||
|
|
||||||
|
|
||||||
OutputType = Union[
|
OutputType = Union[
|
||||||
@ -510,12 +509,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
for action, output in intermediate_steps
|
for action, output in intermediate_steps
|
||||||
if action.tool_call_id in required_tool_call_ids
|
if action.tool_call_id in required_tool_call_ids
|
||||||
]
|
]
|
||||||
submit_tool_outputs = {
|
return {
|
||||||
"tool_outputs": tool_outputs,
|
"tool_outputs": tool_outputs,
|
||||||
"run_id": last_action.run_id,
|
"run_id": last_action.run_id,
|
||||||
"thread_id": last_action.thread_id,
|
"thread_id": last_action.thread_id,
|
||||||
}
|
}
|
||||||
return submit_tool_outputs
|
|
||||||
|
|
||||||
def _create_run(self, input_dict: dict) -> Any:
|
def _create_run(self, input_dict: dict) -> Any:
|
||||||
params = {
|
params = {
|
||||||
@ -558,12 +556,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
"run_metadata",
|
"run_metadata",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
run = self.client.beta.threads.create_and_run(
|
return self.client.beta.threads.create_and_run(
|
||||||
assistant_id=self.assistant_id,
|
assistant_id=self.assistant_id,
|
||||||
thread=thread,
|
thread=thread,
|
||||||
**params,
|
**params,
|
||||||
)
|
)
|
||||||
return run
|
|
||||||
|
|
||||||
def _get_response(self, run: Any) -> Any:
|
def _get_response(self, run: Any) -> Any:
|
||||||
# TODO: Pagination
|
# TODO: Pagination
|
||||||
@ -612,7 +609,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
run_id=run.id,
|
run_id=run.id,
|
||||||
thread_id=run.thread_id,
|
thread_id=run.thread_id,
|
||||||
)
|
)
|
||||||
elif run.status == "requires_action":
|
if run.status == "requires_action":
|
||||||
if not self.as_agent:
|
if not self.as_agent:
|
||||||
return run.required_action.submit_tool_outputs.tool_calls
|
return run.required_action.submit_tool_outputs.tool_calls
|
||||||
actions = []
|
actions = []
|
||||||
@ -639,10 +636,9 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
return actions
|
return actions
|
||||||
else:
|
run_info = json.dumps(run.dict(), indent=2)
|
||||||
run_info = json.dumps(run.dict(), indent=2)
|
msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})"
|
||||||
msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})"
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
def _wait_for_run(self, run_id: str, thread_id: str) -> Any:
|
def _wait_for_run(self, run_id: str, thread_id: str) -> Any:
|
||||||
in_progress = True
|
in_progress = True
|
||||||
@ -668,12 +664,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
for action, output in intermediate_steps
|
for action, output in intermediate_steps
|
||||||
if action.tool_call_id in required_tool_call_ids
|
if action.tool_call_id in required_tool_call_ids
|
||||||
]
|
]
|
||||||
submit_tool_outputs = {
|
return {
|
||||||
"tool_outputs": tool_outputs,
|
"tool_outputs": tool_outputs,
|
||||||
"run_id": last_action.run_id,
|
"run_id": last_action.run_id,
|
||||||
"thread_id": last_action.thread_id,
|
"thread_id": last_action.thread_id,
|
||||||
}
|
}
|
||||||
return submit_tool_outputs
|
|
||||||
|
|
||||||
async def _acreate_run(self, input_dict: dict) -> Any:
|
async def _acreate_run(self, input_dict: dict) -> Any:
|
||||||
params = {
|
params = {
|
||||||
@ -716,12 +711,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
"run_metadata",
|
"run_metadata",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
run = await self.async_client.beta.threads.create_and_run(
|
return await self.async_client.beta.threads.create_and_run(
|
||||||
assistant_id=self.assistant_id,
|
assistant_id=self.assistant_id,
|
||||||
thread=thread,
|
thread=thread,
|
||||||
**params,
|
**params,
|
||||||
)
|
)
|
||||||
return run
|
|
||||||
|
|
||||||
async def _aget_response(self, run: Any) -> Any:
|
async def _aget_response(self, run: Any) -> Any:
|
||||||
# TODO: Pagination
|
# TODO: Pagination
|
||||||
@ -766,7 +760,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
run_id=run.id,
|
run_id=run.id,
|
||||||
thread_id=run.thread_id,
|
thread_id=run.thread_id,
|
||||||
)
|
)
|
||||||
elif run.status == "requires_action":
|
if run.status == "requires_action":
|
||||||
if not self.as_agent:
|
if not self.as_agent:
|
||||||
return run.required_action.submit_tool_outputs.tool_calls
|
return run.required_action.submit_tool_outputs.tool_calls
|
||||||
actions = []
|
actions = []
|
||||||
@ -793,10 +787,9 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
return actions
|
return actions
|
||||||
else:
|
run_info = json.dumps(run.dict(), indent=2)
|
||||||
run_info = json.dumps(run.dict(), indent=2)
|
msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})"
|
||||||
msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})"
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
async def _await_for_run(self, run_id: str, thread_id: str) -> Any:
|
async def _await_for_run(self, run_id: str, thread_id: str) -> Any:
|
||||||
in_progress = True
|
in_progress = True
|
||||||
|
@ -132,8 +132,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
messages,
|
messages,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
agent_decision = self.output_parser._parse_ai_message(predicted_message)
|
return self.output_parser._parse_ai_message(predicted_message)
|
||||||
return agent_decision
|
|
||||||
|
|
||||||
async def aplan(
|
async def aplan(
|
||||||
self,
|
self,
|
||||||
@ -164,8 +163,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
predicted_message = await self.llm.apredict_messages(
|
predicted_message = await self.llm.apredict_messages(
|
||||||
messages, functions=self.functions, callbacks=callbacks
|
messages, functions=self.functions, callbacks=callbacks
|
||||||
)
|
)
|
||||||
agent_decision = self.output_parser._parse_ai_message(predicted_message)
|
return self.output_parser._parse_ai_message(predicted_message)
|
||||||
return agent_decision
|
|
||||||
|
|
||||||
def return_stopped_response(
|
def return_stopped_response(
|
||||||
self,
|
self,
|
||||||
@ -192,22 +190,20 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
return AgentFinish(
|
return AgentFinish(
|
||||||
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
||||||
)
|
)
|
||||||
elif early_stopping_method == "generate":
|
if early_stopping_method == "generate":
|
||||||
# Generate does one final forward pass
|
# Generate does one final forward pass
|
||||||
agent_decision = self.plan(
|
agent_decision = self.plan(
|
||||||
intermediate_steps, with_functions=False, **kwargs
|
intermediate_steps, with_functions=False, **kwargs
|
||||||
)
|
)
|
||||||
if isinstance(agent_decision, AgentFinish):
|
if isinstance(agent_decision, AgentFinish):
|
||||||
return agent_decision
|
return agent_decision
|
||||||
else:
|
msg = f"got AgentAction with no functions provided: {agent_decision}"
|
||||||
msg = f"got AgentAction with no functions provided: {agent_decision}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
else:
|
|
||||||
msg = (
|
|
||||||
"early_stopping_method should be one of `force` or `generate`, "
|
|
||||||
f"got {early_stopping_method}"
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
msg = (
|
||||||
|
"early_stopping_method should be one of `force` or `generate`, "
|
||||||
|
f"got {early_stopping_method}"
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_prompt(
|
def create_prompt(
|
||||||
@ -358,7 +354,7 @@ def create_openai_functions_agent(
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
llm_with_tools = llm.bind(functions=[convert_to_openai_function(t) for t in tools])
|
llm_with_tools = llm.bind(functions=[convert_to_openai_function(t) for t in tools])
|
||||||
agent = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
agent_scratchpad=lambda x: format_to_openai_function_messages(
|
agent_scratchpad=lambda x: format_to_openai_function_messages(
|
||||||
x["intermediate_steps"]
|
x["intermediate_steps"]
|
||||||
@ -368,4 +364,3 @@ def create_openai_functions_agent(
|
|||||||
| llm_with_tools
|
| llm_with_tools
|
||||||
| OpenAIFunctionsAgentOutputParser()
|
| OpenAIFunctionsAgentOutputParser()
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
@ -224,8 +224,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
|||||||
predicted_message = self.llm.predict_messages(
|
predicted_message = self.llm.predict_messages(
|
||||||
messages, functions=self.functions, callbacks=callbacks
|
messages, functions=self.functions, callbacks=callbacks
|
||||||
)
|
)
|
||||||
agent_decision = _parse_ai_message(predicted_message)
|
return _parse_ai_message(predicted_message)
|
||||||
return agent_decision
|
|
||||||
|
|
||||||
async def aplan(
|
async def aplan(
|
||||||
self,
|
self,
|
||||||
@ -254,8 +253,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
|||||||
predicted_message = await self.llm.apredict_messages(
|
predicted_message = await self.llm.apredict_messages(
|
||||||
messages, functions=self.functions, callbacks=callbacks
|
messages, functions=self.functions, callbacks=callbacks
|
||||||
)
|
)
|
||||||
agent_decision = _parse_ai_message(predicted_message)
|
return _parse_ai_message(predicted_message)
|
||||||
return agent_decision
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_prompt(
|
def create_prompt(
|
||||||
|
@ -96,7 +96,7 @@ def create_openai_tools_agent(
|
|||||||
tools=[convert_to_openai_tool(tool, strict=strict) for tool in tools]
|
tools=[convert_to_openai_tool(tool, strict=strict) for tool in tools]
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
agent_scratchpad=lambda x: format_to_openai_tool_messages(
|
agent_scratchpad=lambda x: format_to_openai_tool_messages(
|
||||||
x["intermediate_steps"]
|
x["intermediate_steps"]
|
||||||
@ -106,4 +106,3 @@ def create_openai_tools_agent(
|
|||||||
| llm_with_tools
|
| llm_with_tools
|
||||||
| OpenAIToolsAgentOutputParser()
|
| OpenAIToolsAgentOutputParser()
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
@ -49,11 +49,10 @@ class JSONAgentOutputParser(AgentOutputParser):
|
|||||||
response = response[0]
|
response = response[0]
|
||||||
if response["action"] == "Final Answer":
|
if response["action"] == "Final Answer":
|
||||||
return AgentFinish({"output": response["action_input"]}, text)
|
return AgentFinish({"output": response["action_input"]}, text)
|
||||||
else:
|
action_input = response.get("action_input", {})
|
||||||
action_input = response.get("action_input", {})
|
if action_input is None:
|
||||||
if action_input is None:
|
action_input = {}
|
||||||
action_input = {}
|
return AgentAction(response["action"], action_input, text)
|
||||||
return AgentAction(response["action"], action_input, text)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"Could not parse LLM output: {text}"
|
msg = f"Could not parse LLM output: {text}"
|
||||||
raise OutputParserException(msg) from e
|
raise OutputParserException(msg) from e
|
||||||
|
@ -65,7 +65,7 @@ class ReActSingleInputOutputParser(AgentOutputParser):
|
|||||||
|
|
||||||
return AgentAction(action, tool_input, text)
|
return AgentAction(action, tool_input, text)
|
||||||
|
|
||||||
elif includes_answer:
|
if includes_answer:
|
||||||
return AgentFinish(
|
return AgentFinish(
|
||||||
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
||||||
)
|
)
|
||||||
@ -78,7 +78,7 @@ class ReActSingleInputOutputParser(AgentOutputParser):
|
|||||||
llm_output=text,
|
llm_output=text,
|
||||||
send_to_llm=True,
|
send_to_llm=True,
|
||||||
)
|
)
|
||||||
elif not re.search(
|
if not re.search(
|
||||||
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
|
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
|
||||||
):
|
):
|
||||||
msg = f"Could not parse LLM output: `{text}`"
|
msg = f"Could not parse LLM output: `{text}`"
|
||||||
@ -88,9 +88,8 @@ class ReActSingleInputOutputParser(AgentOutputParser):
|
|||||||
llm_output=text,
|
llm_output=text,
|
||||||
send_to_llm=True,
|
send_to_llm=True,
|
||||||
)
|
)
|
||||||
else:
|
msg = f"Could not parse LLM output: `{text}`"
|
||||||
msg = f"Could not parse LLM output: `{text}`"
|
raise OutputParserException(msg)
|
||||||
raise OutputParserException(msg)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _type(self) -> str:
|
def _type(self) -> str:
|
||||||
|
@ -36,13 +36,12 @@ class XMLAgentOutputParser(AgentOutputParser):
|
|||||||
if "</tool_input>" in _tool_input:
|
if "</tool_input>" in _tool_input:
|
||||||
_tool_input = _tool_input.split("</tool_input>")[0]
|
_tool_input = _tool_input.split("</tool_input>")[0]
|
||||||
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
|
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
|
||||||
elif "<final_answer>" in text:
|
if "<final_answer>" in text:
|
||||||
_, answer = text.split("<final_answer>")
|
_, answer = text.split("<final_answer>")
|
||||||
if "</final_answer>" in answer:
|
if "</final_answer>" in answer:
|
||||||
answer = answer.split("</final_answer>")[0]
|
answer = answer.split("</final_answer>")[0]
|
||||||
return AgentFinish(return_values={"output": answer}, log=text)
|
return AgentFinish(return_values={"output": answer}, log=text)
|
||||||
else:
|
raise ValueError
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
def get_format_instructions(self) -> str:
|
def get_format_instructions(self) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -134,7 +134,7 @@ def create_react_agent(
|
|||||||
else:
|
else:
|
||||||
llm_with_stop = llm
|
llm_with_stop = llm
|
||||||
output_parser = output_parser or ReActSingleInputOutputParser()
|
output_parser = output_parser or ReActSingleInputOutputParser()
|
||||||
agent = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
|
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
|
||||||
)
|
)
|
||||||
@ -142,4 +142,3 @@ def create_react_agent(
|
|||||||
| llm_with_stop
|
| llm_with_stop
|
||||||
| output_parser
|
| output_parser
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
@ -96,9 +96,8 @@ class DocstoreExplorer:
|
|||||||
if isinstance(result, Document):
|
if isinstance(result, Document):
|
||||||
self.document = result
|
self.document = result
|
||||||
return self._summary
|
return self._summary
|
||||||
else:
|
self.document = None
|
||||||
self.document = None
|
return result
|
||||||
return result
|
|
||||||
|
|
||||||
def lookup(self, term: str) -> str:
|
def lookup(self, term: str) -> str:
|
||||||
"""Lookup a term in document (if saved)."""
|
"""Lookup a term in document (if saved)."""
|
||||||
@ -113,11 +112,10 @@ class DocstoreExplorer:
|
|||||||
lookups = [p for p in self._paragraphs if self.lookup_str in p.lower()]
|
lookups = [p for p in self._paragraphs if self.lookup_str in p.lower()]
|
||||||
if len(lookups) == 0:
|
if len(lookups) == 0:
|
||||||
return "No Results"
|
return "No Results"
|
||||||
elif self.lookup_index >= len(lookups):
|
if self.lookup_index >= len(lookups):
|
||||||
return "No More Results"
|
return "No More Results"
|
||||||
else:
|
result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})"
|
||||||
result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})"
|
return f"{result_prefix} {lookups[self.lookup_index]}"
|
||||||
return f"{result_prefix} {lookups[self.lookup_index]}"
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _summary(self) -> str:
|
def _summary(self) -> str:
|
||||||
|
@ -26,8 +26,7 @@ class ReActOutputParser(AgentOutputParser):
|
|||||||
action, action_input = re_matches.group(1), re_matches.group(2)
|
action, action_input = re_matches.group(1), re_matches.group(2)
|
||||||
if action == "Finish":
|
if action == "Finish":
|
||||||
return AgentFinish({"output": action_input}, text)
|
return AgentFinish({"output": action_input}, text)
|
||||||
else:
|
return AgentAction(action, action_input, text)
|
||||||
return AgentAction(action, action_input, text)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _type(self) -> str:
|
def _type(self) -> str:
|
||||||
|
@ -195,7 +195,7 @@ def create_self_ask_with_search_agent(
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
llm_with_stop = llm.bind(stop=["\nIntermediate answer:"])
|
llm_with_stop = llm.bind(stop=["\nIntermediate answer:"])
|
||||||
agent = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
agent_scratchpad=lambda x: format_log_to_str(
|
agent_scratchpad=lambda x: format_log_to_str(
|
||||||
x["intermediate_steps"],
|
x["intermediate_steps"],
|
||||||
@ -209,4 +209,3 @@ def create_self_ask_with_search_agent(
|
|||||||
| llm_with_stop
|
| llm_with_stop
|
||||||
| SelfAskOutputParser()
|
| SelfAskOutputParser()
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
@ -62,8 +62,7 @@ class StructuredChatAgent(Agent):
|
|||||||
f"(but I haven't seen any of it! I only see what "
|
f"(but I haven't seen any of it! I only see what "
|
||||||
f"you return as final answer):\n{agent_scratchpad}"
|
f"you return as final answer):\n{agent_scratchpad}"
|
||||||
)
|
)
|
||||||
else:
|
return agent_scratchpad
|
||||||
return agent_scratchpad
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||||
@ -292,7 +291,7 @@ def create_structured_chat_agent(
|
|||||||
else:
|
else:
|
||||||
llm_with_stop = llm
|
llm_with_stop = llm
|
||||||
|
|
||||||
agent = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
|
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
|
||||||
)
|
)
|
||||||
@ -300,4 +299,3 @@ def create_structured_chat_agent(
|
|||||||
| llm_with_stop
|
| llm_with_stop
|
||||||
| JSONAgentOutputParser()
|
| JSONAgentOutputParser()
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
@ -42,12 +42,10 @@ class StructuredChatOutputParser(AgentOutputParser):
|
|||||||
response = response[0]
|
response = response[0]
|
||||||
if response["action"] == "Final Answer":
|
if response["action"] == "Final Answer":
|
||||||
return AgentFinish({"output": response["action_input"]}, text)
|
return AgentFinish({"output": response["action_input"]}, text)
|
||||||
else:
|
return AgentAction(
|
||||||
return AgentAction(
|
response["action"], response.get("action_input", {}), text
|
||||||
response["action"], response.get("action_input", {}), text
|
)
|
||||||
)
|
return AgentFinish({"output": text}, text)
|
||||||
else:
|
|
||||||
return AgentFinish({"output": text}, text)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"Could not parse LLM output: {text}"
|
msg = f"Could not parse LLM output: {text}"
|
||||||
raise OutputParserException(msg) from e
|
raise OutputParserException(msg) from e
|
||||||
@ -93,10 +91,9 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
|||||||
llm=llm, parser=base_parser
|
llm=llm, parser=base_parser
|
||||||
)
|
)
|
||||||
return cls(output_fixing_parser=output_fixing_parser)
|
return cls(output_fixing_parser=output_fixing_parser)
|
||||||
elif base_parser is not None:
|
if base_parser is not None:
|
||||||
return cls(base_parser=base_parser)
|
return cls(base_parser=base_parser)
|
||||||
else:
|
return cls()
|
||||||
return cls()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _type(self) -> str:
|
def _type(self) -> str:
|
||||||
|
@ -100,7 +100,7 @@ def create_tool_calling_agent(
|
|||||||
)
|
)
|
||||||
llm_with_tools = llm.bind_tools(tools)
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
|
||||||
agent = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
agent_scratchpad=lambda x: message_formatter(x["intermediate_steps"])
|
agent_scratchpad=lambda x: message_formatter(x["intermediate_steps"])
|
||||||
)
|
)
|
||||||
@ -108,4 +108,3 @@ def create_tool_calling_agent(
|
|||||||
| llm_with_tools
|
| llm_with_tools
|
||||||
| ToolsAgentOutputParser()
|
| ToolsAgentOutputParser()
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
@ -221,7 +221,7 @@ def create_xml_agent(
|
|||||||
else:
|
else:
|
||||||
llm_with_stop = llm
|
llm_with_stop = llm
|
||||||
|
|
||||||
agent = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
agent_scratchpad=lambda x: format_xml(x["intermediate_steps"]),
|
agent_scratchpad=lambda x: format_xml(x["intermediate_steps"]),
|
||||||
)
|
)
|
||||||
@ -229,4 +229,3 @@ def create_xml_agent(
|
|||||||
| llm_with_stop
|
| llm_with_stop
|
||||||
| XMLAgentOutputParser()
|
| XMLAgentOutputParser()
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
@ -24,8 +24,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
def check_if_answer_reached(self) -> bool:
|
def check_if_answer_reached(self) -> bool:
|
||||||
if self.strip_tokens:
|
if self.strip_tokens:
|
||||||
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
|
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
|
||||||
else:
|
return self.last_tokens == self.answer_prefix_tokens
|
||||||
return self.last_tokens == self.answer_prefix_tokens
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -25,8 +25,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
|||||||
def check_if_answer_reached(self) -> bool:
|
def check_if_answer_reached(self) -> bool:
|
||||||
if self.strip_tokens:
|
if self.strip_tokens:
|
||||||
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
|
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
|
||||||
else:
|
return self.last_tokens == self.answer_prefix_tokens
|
||||||
return self.last_tokens == self.answer_prefix_tokens
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -261,8 +261,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|||||||
"""
|
"""
|
||||||
if verbose is None:
|
if verbose is None:
|
||||||
return _get_verbosity()
|
return _get_verbosity()
|
||||||
else:
|
return verbose
|
||||||
return verbose
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -474,8 +473,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|||||||
self.memory.save_context(inputs, outputs)
|
self.memory.save_context(inputs, outputs)
|
||||||
if return_only_outputs:
|
if return_only_outputs:
|
||||||
return outputs
|
return outputs
|
||||||
else:
|
return {**inputs, **outputs}
|
||||||
return {**inputs, **outputs}
|
|
||||||
|
|
||||||
async def aprep_outputs(
|
async def aprep_outputs(
|
||||||
self,
|
self,
|
||||||
@ -500,8 +498,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|||||||
await self.memory.asave_context(inputs, outputs)
|
await self.memory.asave_context(inputs, outputs)
|
||||||
if return_only_outputs:
|
if return_only_outputs:
|
||||||
return outputs
|
return outputs
|
||||||
else:
|
return {**inputs, **outputs}
|
||||||
return {**inputs, **outputs}
|
|
||||||
|
|
||||||
def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
|
def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
|
||||||
"""Prepare chain inputs, including adding inputs from memory.
|
"""Prepare chain inputs, including adding inputs from memory.
|
||||||
@ -628,12 +625,11 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|||||||
" but none were provided."
|
" but none were provided."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
msg = (
|
||||||
msg = (
|
f"`run` supported with either positional arguments or keyword arguments"
|
||||||
f"`run` supported with either positional arguments or keyword arguments"
|
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
)
|
||||||
)
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
@deprecated("0.1.0", alternative="ainvoke", removal="1.0")
|
@deprecated("0.1.0", alternative="ainvoke", removal="1.0")
|
||||||
async def arun(
|
async def arun(
|
||||||
@ -687,7 +683,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|||||||
f"one output key. Got {self.output_keys}."
|
f"one output key. Got {self.output_keys}."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif args and not kwargs:
|
if args and not kwargs:
|
||||||
if len(args) != 1:
|
if len(args) != 1:
|
||||||
msg = "`run` supports only one positional argument."
|
msg = "`run` supports only one positional argument."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
@ -208,28 +208,25 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
|
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
|
||||||
if self.reduce_documents_chain.collapse_documents_chain:
|
if self.reduce_documents_chain.collapse_documents_chain:
|
||||||
return self.reduce_documents_chain.collapse_documents_chain
|
return self.reduce_documents_chain.collapse_documents_chain
|
||||||
else:
|
return self.reduce_documents_chain.combine_documents_chain
|
||||||
return self.reduce_documents_chain.combine_documents_chain
|
msg = (
|
||||||
else:
|
f"`reduce_documents_chain` is of type "
|
||||||
msg = (
|
f"{type(self.reduce_documents_chain)} so it does not have "
|
||||||
f"`reduce_documents_chain` is of type "
|
f"this attribute."
|
||||||
f"{type(self.reduce_documents_chain)} so it does not have "
|
)
|
||||||
f"this attribute."
|
raise ValueError(msg)
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def combine_document_chain(self) -> BaseCombineDocumentsChain:
|
def combine_document_chain(self) -> BaseCombineDocumentsChain:
|
||||||
"""Kept for backward compatibility."""
|
"""Kept for backward compatibility."""
|
||||||
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
|
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
|
||||||
return self.reduce_documents_chain.combine_documents_chain
|
return self.reduce_documents_chain.combine_documents_chain
|
||||||
else:
|
msg = (
|
||||||
msg = (
|
f"`reduce_documents_chain` is of type "
|
||||||
f"`reduce_documents_chain` is of type "
|
f"{type(self.reduce_documents_chain)} so it does not have "
|
||||||
f"{type(self.reduce_documents_chain)} so it does not have "
|
f"this attribute."
|
||||||
f"this attribute."
|
)
|
||||||
)
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
def combine_docs(
|
def combine_docs(
|
||||||
self,
|
self,
|
||||||
|
@ -225,8 +225,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
def _collapse_chain(self) -> BaseCombineDocumentsChain:
|
def _collapse_chain(self) -> BaseCombineDocumentsChain:
|
||||||
if self.collapse_documents_chain is not None:
|
if self.collapse_documents_chain is not None:
|
||||||
return self.collapse_documents_chain
|
return self.collapse_documents_chain
|
||||||
else:
|
return self.combine_documents_chain
|
||||||
return self.combine_documents_chain
|
|
||||||
|
|
||||||
def combine_docs(
|
def combine_docs(
|
||||||
self,
|
self,
|
||||||
|
@ -222,8 +222,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
base_inputs: dict = {
|
base_inputs: dict = {
|
||||||
self.document_variable_name: self.document_prompt.format(**document_info)
|
self.document_variable_name: self.document_prompt.format(**document_info)
|
||||||
}
|
}
|
||||||
inputs = {**base_inputs, **kwargs}
|
return {**base_inputs, **kwargs}
|
||||||
return inputs
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _chain_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
|
@ -201,8 +201,7 @@ class ConstitutionalChain(Chain):
|
|||||||
) -> list[ConstitutionalPrinciple]:
|
) -> list[ConstitutionalPrinciple]:
|
||||||
if names is None:
|
if names is None:
|
||||||
return list(PRINCIPLES.values())
|
return list(PRINCIPLES.values())
|
||||||
else:
|
return [PRINCIPLES[name] for name in names]
|
||||||
return [PRINCIPLES[name] for name in names]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
|
@ -80,8 +80,7 @@ class ElasticsearchDatabaseChain(Chain):
|
|||||||
"""
|
"""
|
||||||
if not self.return_intermediate_steps:
|
if not self.return_intermediate_steps:
|
||||||
return [self.output_key]
|
return [self.output_key]
|
||||||
else:
|
return [self.output_key, INTERMEDIATE_STEPS_KEY]
|
||||||
return [self.output_key, INTERMEDIATE_STEPS_KEY]
|
|
||||||
|
|
||||||
def _list_indices(self) -> list[str]:
|
def _list_indices(self) -> list[str]:
|
||||||
all_indices = [
|
all_indices = [
|
||||||
|
@ -47,8 +47,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|||||||
"""Output keys for Hyde's LLM chain."""
|
"""Output keys for Hyde's LLM chain."""
|
||||||
if isinstance(self.llm_chain, LLMChain):
|
if isinstance(self.llm_chain, LLMChain):
|
||||||
return self.llm_chain.output_keys
|
return self.llm_chain.output_keys
|
||||||
else:
|
return ["text"]
|
||||||
return ["text"]
|
|
||||||
|
|
||||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
"""Call the base embeddings."""
|
"""Call the base embeddings."""
|
||||||
|
@ -116,8 +116,7 @@ class LLMChain(Chain):
|
|||||||
"""
|
"""
|
||||||
if self.return_final_only:
|
if self.return_final_only:
|
||||||
return [self.output_key]
|
return [self.output_key]
|
||||||
else:
|
return [self.output_key, "full_generation"]
|
||||||
return [self.output_key, "full_generation"]
|
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
@ -142,17 +141,16 @@ class LLMChain(Chain):
|
|||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**self.llm_kwargs,
|
**self.llm_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
|
||||||
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
|
cast(list, prompts), {"callbacks": callbacks}
|
||||||
cast(list, prompts), {"callbacks": callbacks}
|
)
|
||||||
)
|
generations: list[list[Generation]] = []
|
||||||
generations: list[list[Generation]] = []
|
for res in results:
|
||||||
for res in results:
|
if isinstance(res, BaseMessage):
|
||||||
if isinstance(res, BaseMessage):
|
generations.append([ChatGeneration(message=res)])
|
||||||
generations.append([ChatGeneration(message=res)])
|
else:
|
||||||
else:
|
generations.append([Generation(text=res)])
|
||||||
generations.append([Generation(text=res)])
|
return LLMResult(generations=generations)
|
||||||
return LLMResult(generations=generations)
|
|
||||||
|
|
||||||
async def agenerate(
|
async def agenerate(
|
||||||
self,
|
self,
|
||||||
@ -169,17 +167,16 @@ class LLMChain(Chain):
|
|||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**self.llm_kwargs,
|
**self.llm_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
|
||||||
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
|
cast(list, prompts), {"callbacks": callbacks}
|
||||||
cast(list, prompts), {"callbacks": callbacks}
|
)
|
||||||
)
|
generations: list[list[Generation]] = []
|
||||||
generations: list[list[Generation]] = []
|
for res in results:
|
||||||
for res in results:
|
if isinstance(res, BaseMessage):
|
||||||
if isinstance(res, BaseMessage):
|
generations.append([ChatGeneration(message=res)])
|
||||||
generations.append([ChatGeneration(message=res)])
|
else:
|
||||||
else:
|
generations.append([Generation(text=res)])
|
||||||
generations.append([Generation(text=res)])
|
return LLMResult(generations=generations)
|
||||||
return LLMResult(generations=generations)
|
|
||||||
|
|
||||||
def prep_prompts(
|
def prep_prompts(
|
||||||
self,
|
self,
|
||||||
@ -344,8 +341,7 @@ class LLMChain(Chain):
|
|||||||
result = self.predict(callbacks=callbacks, **kwargs)
|
result = self.predict(callbacks=callbacks, **kwargs)
|
||||||
if self.prompt.output_parser is not None:
|
if self.prompt.output_parser is not None:
|
||||||
return self.prompt.output_parser.parse(result)
|
return self.prompt.output_parser.parse(result)
|
||||||
else:
|
return result
|
||||||
return result
|
|
||||||
|
|
||||||
async def apredict_and_parse(
|
async def apredict_and_parse(
|
||||||
self, callbacks: Callbacks = None, **kwargs: Any
|
self, callbacks: Callbacks = None, **kwargs: Any
|
||||||
@ -358,8 +354,7 @@ class LLMChain(Chain):
|
|||||||
result = await self.apredict(callbacks=callbacks, **kwargs)
|
result = await self.apredict(callbacks=callbacks, **kwargs)
|
||||||
if self.prompt.output_parser is not None:
|
if self.prompt.output_parser is not None:
|
||||||
return self.prompt.output_parser.parse(result)
|
return self.prompt.output_parser.parse(result)
|
||||||
else:
|
return result
|
||||||
return result
|
|
||||||
|
|
||||||
def apply_and_parse(
|
def apply_and_parse(
|
||||||
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
|
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
|
||||||
@ -380,8 +375,7 @@ class LLMChain(Chain):
|
|||||||
self.prompt.output_parser.parse(res[self.output_key])
|
self.prompt.output_parser.parse(res[self.output_key])
|
||||||
for res in generation
|
for res in generation
|
||||||
]
|
]
|
||||||
else:
|
return generation
|
||||||
return generation
|
|
||||||
|
|
||||||
async def aapply_and_parse(
|
async def aapply_and_parse(
|
||||||
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
|
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
|
||||||
@ -411,15 +405,14 @@ class LLMChain(Chain):
|
|||||||
def _get_language_model(llm_like: Runnable) -> BaseLanguageModel:
|
def _get_language_model(llm_like: Runnable) -> BaseLanguageModel:
|
||||||
if isinstance(llm_like, BaseLanguageModel):
|
if isinstance(llm_like, BaseLanguageModel):
|
||||||
return llm_like
|
return llm_like
|
||||||
elif isinstance(llm_like, RunnableBinding):
|
if isinstance(llm_like, RunnableBinding):
|
||||||
return _get_language_model(llm_like.bound)
|
return _get_language_model(llm_like.bound)
|
||||||
elif isinstance(llm_like, RunnableWithFallbacks):
|
if isinstance(llm_like, RunnableWithFallbacks):
|
||||||
return _get_language_model(llm_like.runnable)
|
return _get_language_model(llm_like.runnable)
|
||||||
elif isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
|
if isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
|
||||||
return _get_language_model(llm_like.default)
|
return _get_language_model(llm_like.default)
|
||||||
else:
|
msg = (
|
||||||
msg = (
|
f"Unable to extract BaseLanguageModel from llm_like object of type "
|
||||||
f"Unable to extract BaseLanguageModel from llm_like object of type "
|
f"{type(llm_like)}"
|
||||||
f"{type(llm_like)}"
|
)
|
||||||
)
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
|
||||||
|
@ -55,13 +55,12 @@ def _load_question_to_checked_assertions_chain(
|
|||||||
check_assertions_chain,
|
check_assertions_chain,
|
||||||
revised_answer_chain,
|
revised_answer_chain,
|
||||||
]
|
]
|
||||||
question_to_checked_assertions_chain = SequentialChain(
|
return SequentialChain(
|
||||||
chains=chains, # type: ignore[arg-type]
|
chains=chains, # type: ignore[arg-type]
|
||||||
input_variables=["question"],
|
input_variables=["question"],
|
||||||
output_variables=["revised_statement"],
|
output_variables=["revised_statement"],
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
return question_to_checked_assertions_chain
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
|
@ -32,7 +32,7 @@ def _load_sequential_chain(
|
|||||||
are_all_true_prompt: PromptTemplate,
|
are_all_true_prompt: PromptTemplate,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> SequentialChain:
|
) -> SequentialChain:
|
||||||
chain = SequentialChain(
|
return SequentialChain(
|
||||||
chains=[
|
chains=[
|
||||||
LLMChain(
|
LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
@ -63,7 +63,6 @@ def _load_sequential_chain(
|
|||||||
output_variables=["all_true", "revised_summary"],
|
output_variables=["all_true", "revised_summary"],
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
return chain
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
|
@ -311,8 +311,7 @@ def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain:
|
|||||||
prompt = load_prompt(config.pop("prompt_path"))
|
prompt = load_prompt(config.pop("prompt_path"))
|
||||||
if llm_chain:
|
if llm_chain:
|
||||||
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config) # type: ignore[arg-type]
|
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config) # type: ignore[arg-type]
|
||||||
else:
|
return LLMMathChain(llm=llm, prompt=prompt, **config)
|
||||||
return LLMMathChain(llm=llm, prompt=prompt, **config)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_map_rerank_documents_chain(
|
def _load_map_rerank_documents_chain(
|
||||||
@ -609,8 +608,7 @@ def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain:
|
|||||||
return LLMRequestsChain(
|
return LLMRequestsChain(
|
||||||
llm_chain=llm_chain, requests_wrapper=requests_wrapper, **config
|
llm_chain=llm_chain, requests_wrapper=requests_wrapper, **config
|
||||||
)
|
)
|
||||||
else:
|
return LLMRequestsChain(llm_chain=llm_chain, **config)
|
||||||
return LLMRequestsChain(llm_chain=llm_chain, **config)
|
|
||||||
|
|
||||||
|
|
||||||
type_to_loader_dict = {
|
type_to_loader_dict = {
|
||||||
|
@ -100,8 +100,7 @@ class OpenAIModerationChain(Chain):
|
|||||||
error_str = "Text was found that violates OpenAI's content policy."
|
error_str = "Text was found that violates OpenAI's content policy."
|
||||||
if self.error:
|
if self.error:
|
||||||
raise ValueError(error_str)
|
raise ValueError(error_str)
|
||||||
else:
|
return error_str
|
||||||
return error_str
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
|
@ -132,7 +132,7 @@ def create_openai_fn_chain(
|
|||||||
}
|
}
|
||||||
if len(openai_functions) == 1 and enforce_single_function_usage:
|
if len(openai_functions) == 1 and enforce_single_function_usage:
|
||||||
llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]}
|
llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]}
|
||||||
llm_chain = LLMChain(
|
return LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
@ -140,7 +140,6 @@ def create_openai_fn_chain(
|
|||||||
output_key=output_key,
|
output_key=output_key,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return llm_chain
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
|
@ -149,10 +149,9 @@ def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
|
|||||||
]
|
]
|
||||||
prompt = ChatPromptTemplate(messages=messages) # type: ignore[arg-type]
|
prompt = ChatPromptTemplate(messages=messages) # type: ignore[arg-type]
|
||||||
|
|
||||||
chain = LLMChain(
|
return LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
)
|
)
|
||||||
return chain
|
|
||||||
|
@ -103,7 +103,7 @@ def create_extraction_chain(
|
|||||||
extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||||
output_parser = JsonKeyOutputFunctionsParser(key_name="info")
|
output_parser = JsonKeyOutputFunctionsParser(key_name="info")
|
||||||
llm_kwargs = get_llm_kwargs(function)
|
llm_kwargs = get_llm_kwargs(function)
|
||||||
chain = LLMChain(
|
return LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=extraction_prompt,
|
prompt=extraction_prompt,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
@ -111,7 +111,6 @@ def create_extraction_chain(
|
|||||||
tags=tags,
|
tags=tags,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
return chain
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
@ -187,11 +186,10 @@ def create_extraction_chain_pydantic(
|
|||||||
pydantic_schema=PydanticSchema, attr_name="info"
|
pydantic_schema=PydanticSchema, attr_name="info"
|
||||||
)
|
)
|
||||||
llm_kwargs = get_llm_kwargs(function)
|
llm_kwargs = get_llm_kwargs(function)
|
||||||
chain = LLMChain(
|
return LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=extraction_prompt,
|
prompt=extraction_prompt,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
return chain
|
|
||||||
|
@ -100,14 +100,13 @@ def create_qa_with_structure_chain(
|
|||||||
]
|
]
|
||||||
prompt = prompt or ChatPromptTemplate(messages=messages) # type: ignore[arg-type]
|
prompt = prompt or ChatPromptTemplate(messages=messages) # type: ignore[arg-type]
|
||||||
|
|
||||||
chain = LLMChain(
|
return LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
output_parser=_output_parser,
|
output_parser=_output_parser,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
return chain
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
|
@ -91,14 +91,13 @@ def create_tagging_chain(
|
|||||||
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||||
output_parser = JsonOutputFunctionsParser()
|
output_parser = JsonOutputFunctionsParser()
|
||||||
llm_kwargs = get_llm_kwargs(function)
|
llm_kwargs = get_llm_kwargs(function)
|
||||||
chain = LLMChain(
|
return LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return chain
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
@ -164,11 +163,10 @@ def create_tagging_chain_pydantic(
|
|||||||
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||||
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
|
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
|
||||||
llm_kwargs = get_llm_kwargs(function)
|
llm_kwargs = get_llm_kwargs(function)
|
||||||
chain = LLMChain(
|
return LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return chain
|
|
||||||
|
@ -76,5 +76,4 @@ def create_extraction_chain_pydantic(
|
|||||||
functions = [convert_pydantic_to_openai_function(p) for p in pydantic_schemas]
|
functions = [convert_pydantic_to_openai_function(p) for p in pydantic_schemas]
|
||||||
tools = [{"type": "function", "function": d} for d in functions]
|
tools = [{"type": "function", "function": d} for d in functions]
|
||||||
model = llm.bind(tools=tools)
|
model = llm.bind(tools=tools)
|
||||||
chain = prompt | model | PydanticToolsParser(tools=pydantic_schemas)
|
return prompt | model | PydanticToolsParser(tools=pydantic_schemas)
|
||||||
return chain
|
|
||||||
|
@ -91,13 +91,12 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
|||||||
filter_directive = cast(
|
filter_directive = cast(
|
||||||
Optional[FilterDirective], get_parser().parse(raw_filter)
|
Optional[FilterDirective], get_parser().parse(raw_filter)
|
||||||
)
|
)
|
||||||
fixed = fix_filter_directive(
|
return fix_filter_directive(
|
||||||
filter_directive,
|
filter_directive,
|
||||||
allowed_comparators=allowed_comparators,
|
allowed_comparators=allowed_comparators,
|
||||||
allowed_operators=allowed_operators,
|
allowed_operators=allowed_operators,
|
||||||
allowed_attributes=allowed_attributes,
|
allowed_attributes=allowed_attributes,
|
||||||
)
|
)
|
||||||
return fixed
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
ast_parse = get_parser(
|
ast_parse = get_parser(
|
||||||
@ -131,13 +130,13 @@ def fix_filter_directive(
|
|||||||
) or not filter:
|
) or not filter:
|
||||||
return filter
|
return filter
|
||||||
|
|
||||||
elif isinstance(filter, Comparison):
|
if isinstance(filter, Comparison):
|
||||||
if allowed_comparators and filter.comparator not in allowed_comparators:
|
if allowed_comparators and filter.comparator not in allowed_comparators:
|
||||||
return None
|
return None
|
||||||
if allowed_attributes and filter.attribute not in allowed_attributes:
|
if allowed_attributes and filter.attribute not in allowed_attributes:
|
||||||
return None
|
return None
|
||||||
return filter
|
return filter
|
||||||
elif isinstance(filter, Operation):
|
if isinstance(filter, Operation):
|
||||||
if allowed_operators and filter.operator not in allowed_operators:
|
if allowed_operators and filter.operator not in allowed_operators:
|
||||||
return None
|
return None
|
||||||
args = [
|
args = [
|
||||||
@ -155,15 +154,13 @@ def fix_filter_directive(
|
|||||||
]
|
]
|
||||||
if not args:
|
if not args:
|
||||||
return None
|
return None
|
||||||
elif len(args) == 1 and filter.operator in (Operator.AND, Operator.OR):
|
if len(args) == 1 and filter.operator in (Operator.AND, Operator.OR):
|
||||||
return args[0]
|
return args[0]
|
||||||
else:
|
return Operation(
|
||||||
return Operation(
|
operator=filter.operator,
|
||||||
operator=filter.operator,
|
arguments=args,
|
||||||
arguments=args,
|
)
|
||||||
)
|
return filter
|
||||||
else:
|
|
||||||
return filter
|
|
||||||
|
|
||||||
|
|
||||||
def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str:
|
def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str:
|
||||||
|
@ -101,10 +101,9 @@ class QueryTransformer(Transformer):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return Comparison(comparator=func, attribute=args[0], value=args[1])
|
return Comparison(comparator=func, attribute=args[0], value=args[1])
|
||||||
elif len(args) == 1 and func in (Operator.AND, Operator.OR):
|
if len(args) == 1 and func in (Operator.AND, Operator.OR):
|
||||||
return args[0]
|
return args[0]
|
||||||
else:
|
return Operation(operator=func, arguments=args)
|
||||||
return Operation(operator=func, arguments=args)
|
|
||||||
|
|
||||||
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
|
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
|
||||||
if func_name in set(Comparator):
|
if func_name in set(Comparator):
|
||||||
@ -118,7 +117,7 @@ class QueryTransformer(Transformer):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return Comparator(func_name)
|
return Comparator(func_name)
|
||||||
elif func_name in set(Operator):
|
if func_name in set(Operator):
|
||||||
if (
|
if (
|
||||||
self.allowed_operators is not None
|
self.allowed_operators is not None
|
||||||
and func_name not in self.allowed_operators
|
and func_name not in self.allowed_operators
|
||||||
@ -129,12 +128,11 @@ class QueryTransformer(Transformer):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return Operator(func_name)
|
return Operator(func_name)
|
||||||
else:
|
msg = (
|
||||||
msg = (
|
f"Received unrecognized function {func_name}. Valid functions are "
|
||||||
f"Received unrecognized function {func_name}. Valid functions are "
|
f"{list(Operator) + list(Comparator)}"
|
||||||
f"{list(Operator) + list(Comparator)}"
|
)
|
||||||
)
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
def args(self, *items: Any) -> tuple:
|
def args(self, *items: Any) -> tuple:
|
||||||
return items
|
return items
|
||||||
|
@ -60,10 +60,8 @@ def create_retrieval_chain(
|
|||||||
else:
|
else:
|
||||||
retrieval_docs = (lambda x: x["input"]) | retriever
|
retrieval_docs = (lambda x: x["input"]) | retriever
|
||||||
|
|
||||||
retrieval_chain = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
context=retrieval_docs.with_config(run_name="retrieve_documents"),
|
context=retrieval_docs.with_config(run_name="retrieve_documents"),
|
||||||
).assign(answer=combine_docs_chain)
|
).assign(answer=combine_docs_chain)
|
||||||
).with_config(run_name="retrieval_chain")
|
).with_config(run_name="retrieval_chain")
|
||||||
|
|
||||||
return retrieval_chain
|
|
||||||
|
@ -157,8 +157,7 @@ class BaseRetrievalQA(Chain):
|
|||||||
|
|
||||||
if self.return_source_documents:
|
if self.return_source_documents:
|
||||||
return {self.output_key: answer, "source_documents": docs}
|
return {self.output_key: answer, "source_documents": docs}
|
||||||
else:
|
return {self.output_key: answer}
|
||||||
return {self.output_key: answer}
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def _aget_docs(
|
async def _aget_docs(
|
||||||
@ -200,8 +199,7 @@ class BaseRetrievalQA(Chain):
|
|||||||
|
|
||||||
if self.return_source_documents:
|
if self.return_source_documents:
|
||||||
return {self.output_key: answer, "source_documents": docs}
|
return {self.output_key: answer, "source_documents": docs}
|
||||||
else:
|
return {self.output_key: answer}
|
||||||
return {self.output_key: answer}
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
|
@ -97,15 +97,14 @@ class MultiRouteChain(Chain):
|
|||||||
)
|
)
|
||||||
if not route.destination:
|
if not route.destination:
|
||||||
return self.default_chain(route.next_inputs, callbacks=callbacks)
|
return self.default_chain(route.next_inputs, callbacks=callbacks)
|
||||||
elif route.destination in self.destination_chains:
|
if route.destination in self.destination_chains:
|
||||||
return self.destination_chains[route.destination](
|
return self.destination_chains[route.destination](
|
||||||
route.next_inputs, callbacks=callbacks
|
route.next_inputs, callbacks=callbacks
|
||||||
)
|
)
|
||||||
elif self.silent_errors:
|
if self.silent_errors:
|
||||||
return self.default_chain(route.next_inputs, callbacks=callbacks)
|
return self.default_chain(route.next_inputs, callbacks=callbacks)
|
||||||
else:
|
msg = f"Received invalid destination chain name '{route.destination}'"
|
||||||
msg = f"Received invalid destination chain name '{route.destination}'"
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
@ -123,14 +122,13 @@ class MultiRouteChain(Chain):
|
|||||||
return await self.default_chain.acall(
|
return await self.default_chain.acall(
|
||||||
route.next_inputs, callbacks=callbacks
|
route.next_inputs, callbacks=callbacks
|
||||||
)
|
)
|
||||||
elif route.destination in self.destination_chains:
|
if route.destination in self.destination_chains:
|
||||||
return await self.destination_chains[route.destination].acall(
|
return await self.destination_chains[route.destination].acall(
|
||||||
route.next_inputs, callbacks=callbacks
|
route.next_inputs, callbacks=callbacks
|
||||||
)
|
)
|
||||||
elif self.silent_errors:
|
if self.silent_errors:
|
||||||
return await self.default_chain.acall(
|
return await self.default_chain.acall(
|
||||||
route.next_inputs, callbacks=callbacks
|
route.next_inputs, callbacks=callbacks
|
||||||
)
|
)
|
||||||
else:
|
msg = f"Received invalid destination chain name '{route.destination}'"
|
||||||
msg = f"Received invalid destination chain name '{route.destination}'"
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
|
||||||
|
@ -136,11 +136,10 @@ class LLMRouterChain(RouterChain):
|
|||||||
callbacks = _run_manager.get_child()
|
callbacks = _run_manager.get_child()
|
||||||
|
|
||||||
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
|
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
|
||||||
output = cast(
|
return cast(
|
||||||
dict[str, Any],
|
dict[str, Any],
|
||||||
self.llm_chain.prompt.output_parser.parse(prediction),
|
self.llm_chain.prompt.output_parser.parse(prediction),
|
||||||
)
|
)
|
||||||
return output
|
|
||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
@ -149,11 +148,10 @@ class LLMRouterChain(RouterChain):
|
|||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
callbacks = _run_manager.get_child()
|
callbacks = _run_manager.get_child()
|
||||||
output = cast(
|
return cast(
|
||||||
dict[str, Any],
|
dict[str, Any],
|
||||||
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
|
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
|
||||||
)
|
)
|
||||||
return output
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
|
@ -141,8 +141,7 @@ def create_sql_query_chain(
|
|||||||
f"{db.dialect}"
|
f"{db.dialect}"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
table_info_kwargs["get_col_comments"] = True
|
||||||
table_info_kwargs["get_col_comments"] = True
|
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
"input": lambda x: x["question"] + "\nSQLQuery: ",
|
"input": lambda x: x["question"] + "\nSQLQuery: ",
|
||||||
|
@ -143,8 +143,7 @@ def create_openai_fn_runnable(
|
|||||||
output_parser = output_parser or get_openai_output_parser(functions)
|
output_parser = output_parser or get_openai_output_parser(functions)
|
||||||
if prompt:
|
if prompt:
|
||||||
return prompt | llm.bind(**llm_kwargs_) | output_parser
|
return prompt | llm.bind(**llm_kwargs_) | output_parser
|
||||||
else:
|
return llm.bind(**llm_kwargs_) | output_parser
|
||||||
return llm.bind(**llm_kwargs_) | output_parser
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
@ -413,7 +412,7 @@ def create_structured_output_runnable(
|
|||||||
first_tool_only=return_single,
|
first_tool_only=return_single,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif mode == "openai-functions":
|
if mode == "openai-functions":
|
||||||
return _create_openai_functions_structured_output_runnable(
|
return _create_openai_functions_structured_output_runnable(
|
||||||
output_schema,
|
output_schema,
|
||||||
llm,
|
llm,
|
||||||
@ -422,7 +421,7 @@ def create_structured_output_runnable(
|
|||||||
enforce_single_function_usage=force_function_usage,
|
enforce_single_function_usage=force_function_usage,
|
||||||
**kwargs, # llm-specific kwargs
|
**kwargs, # llm-specific kwargs
|
||||||
)
|
)
|
||||||
elif mode == "openai-json":
|
if mode == "openai-json":
|
||||||
if force_function_usage:
|
if force_function_usage:
|
||||||
msg = (
|
msg = (
|
||||||
"enforce_single_function_usage is not supported for mode='openai-json'."
|
"enforce_single_function_usage is not supported for mode='openai-json'."
|
||||||
@ -431,12 +430,11 @@ def create_structured_output_runnable(
|
|||||||
return _create_openai_json_runnable(
|
return _create_openai_json_runnable(
|
||||||
output_schema, llm, prompt=prompt, output_parser=output_parser, **kwargs
|
output_schema, llm, prompt=prompt, output_parser=output_parser, **kwargs
|
||||||
)
|
)
|
||||||
else:
|
msg = (
|
||||||
msg = (
|
f"Invalid mode {mode}. Expected one of 'openai-tools', 'openai-functions', "
|
||||||
f"Invalid mode {mode}. Expected one of 'openai-tools', 'openai-functions', "
|
f"'openai-json'."
|
||||||
f"'openai-json'."
|
)
|
||||||
)
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_openai_tools_runnable(
|
def _create_openai_tools_runnable(
|
||||||
@ -460,8 +458,7 @@ def _create_openai_tools_runnable(
|
|||||||
)
|
)
|
||||||
if prompt:
|
if prompt:
|
||||||
return prompt | llm.bind(**llm_kwargs) | output_parser
|
return prompt | llm.bind(**llm_kwargs) | output_parser
|
||||||
else:
|
return llm.bind(**llm_kwargs) | output_parser
|
||||||
return llm.bind(**llm_kwargs) | output_parser
|
|
||||||
|
|
||||||
|
|
||||||
def _get_openai_tool_output_parser(
|
def _get_openai_tool_output_parser(
|
||||||
@ -535,8 +532,7 @@ def _create_openai_json_runnable(
|
|||||||
prompt = prompt.partial(output_schema=json.dumps(schema_as_dict, indent=2))
|
prompt = prompt.partial(output_schema=json.dumps(schema_as_dict, indent=2))
|
||||||
|
|
||||||
return prompt | llm | output_parser
|
return prompt | llm | output_parser
|
||||||
else:
|
return llm | output_parser
|
||||||
return llm | output_parser
|
|
||||||
|
|
||||||
|
|
||||||
def _create_openai_functions_structured_output_runnable(
|
def _create_openai_functions_structured_output_runnable(
|
||||||
|
@ -77,9 +77,8 @@ class TransformChain(Chain):
|
|||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if self.atransform_cb is not None:
|
if self.atransform_cb is not None:
|
||||||
return await self.atransform_cb(inputs)
|
return await self.atransform_cb(inputs)
|
||||||
else:
|
self._log_once(
|
||||||
self._log_once(
|
"TransformChain's atransform is not provided, falling"
|
||||||
"TransformChain's atransform is not provided, falling"
|
" back to synchronous transform"
|
||||||
" back to synchronous transform"
|
)
|
||||||
)
|
return self.transform_cb(inputs)
|
||||||
return self.transform_cb(inputs)
|
|
||||||
|
@ -322,16 +322,15 @@ def init_chat_model(
|
|||||||
return _init_chat_model_helper(
|
return _init_chat_model_helper(
|
||||||
cast(str, model), model_provider=model_provider, **kwargs
|
cast(str, model), model_provider=model_provider, **kwargs
|
||||||
)
|
)
|
||||||
else:
|
if model:
|
||||||
if model:
|
kwargs["model"] = model
|
||||||
kwargs["model"] = model
|
if model_provider:
|
||||||
if model_provider:
|
kwargs["model_provider"] = model_provider
|
||||||
kwargs["model_provider"] = model_provider
|
return _ConfigurableModel(
|
||||||
return _ConfigurableModel(
|
default_config=kwargs,
|
||||||
default_config=kwargs,
|
config_prefix=config_prefix,
|
||||||
config_prefix=config_prefix,
|
configurable_fields=configurable_fields,
|
||||||
configurable_fields=configurable_fields,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _init_chat_model_helper(
|
def _init_chat_model_helper(
|
||||||
@ -343,42 +342,42 @@ def _init_chat_model_helper(
|
|||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
return ChatOpenAI(model=model, **kwargs)
|
return ChatOpenAI(model=model, **kwargs)
|
||||||
elif model_provider == "anthropic":
|
if model_provider == "anthropic":
|
||||||
_check_pkg("langchain_anthropic")
|
_check_pkg("langchain_anthropic")
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
|
||||||
return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
||||||
elif model_provider == "azure_openai":
|
if model_provider == "azure_openai":
|
||||||
_check_pkg("langchain_openai")
|
_check_pkg("langchain_openai")
|
||||||
from langchain_openai import AzureChatOpenAI
|
from langchain_openai import AzureChatOpenAI
|
||||||
|
|
||||||
return AzureChatOpenAI(model=model, **kwargs)
|
return AzureChatOpenAI(model=model, **kwargs)
|
||||||
elif model_provider == "azure_ai":
|
if model_provider == "azure_ai":
|
||||||
_check_pkg("langchain_azure_ai")
|
_check_pkg("langchain_azure_ai")
|
||||||
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
|
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
|
||||||
|
|
||||||
return AzureAIChatCompletionsModel(model=model, **kwargs)
|
return AzureAIChatCompletionsModel(model=model, **kwargs)
|
||||||
elif model_provider == "cohere":
|
if model_provider == "cohere":
|
||||||
_check_pkg("langchain_cohere")
|
_check_pkg("langchain_cohere")
|
||||||
from langchain_cohere import ChatCohere
|
from langchain_cohere import ChatCohere
|
||||||
|
|
||||||
return ChatCohere(model=model, **kwargs)
|
return ChatCohere(model=model, **kwargs)
|
||||||
elif model_provider == "google_vertexai":
|
if model_provider == "google_vertexai":
|
||||||
_check_pkg("langchain_google_vertexai")
|
_check_pkg("langchain_google_vertexai")
|
||||||
from langchain_google_vertexai import ChatVertexAI
|
from langchain_google_vertexai import ChatVertexAI
|
||||||
|
|
||||||
return ChatVertexAI(model=model, **kwargs)
|
return ChatVertexAI(model=model, **kwargs)
|
||||||
elif model_provider == "google_genai":
|
if model_provider == "google_genai":
|
||||||
_check_pkg("langchain_google_genai")
|
_check_pkg("langchain_google_genai")
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
|
||||||
return ChatGoogleGenerativeAI(model=model, **kwargs)
|
return ChatGoogleGenerativeAI(model=model, **kwargs)
|
||||||
elif model_provider == "fireworks":
|
if model_provider == "fireworks":
|
||||||
_check_pkg("langchain_fireworks")
|
_check_pkg("langchain_fireworks")
|
||||||
from langchain_fireworks import ChatFireworks
|
from langchain_fireworks import ChatFireworks
|
||||||
|
|
||||||
return ChatFireworks(model=model, **kwargs)
|
return ChatFireworks(model=model, **kwargs)
|
||||||
elif model_provider == "ollama":
|
if model_provider == "ollama":
|
||||||
try:
|
try:
|
||||||
_check_pkg("langchain_ollama")
|
_check_pkg("langchain_ollama")
|
||||||
from langchain_ollama import ChatOllama
|
from langchain_ollama import ChatOllama
|
||||||
@ -393,74 +392,72 @@ def _init_chat_model_helper(
|
|||||||
_check_pkg("langchain_ollama")
|
_check_pkg("langchain_ollama")
|
||||||
|
|
||||||
return ChatOllama(model=model, **kwargs)
|
return ChatOllama(model=model, **kwargs)
|
||||||
elif model_provider == "together":
|
if model_provider == "together":
|
||||||
_check_pkg("langchain_together")
|
_check_pkg("langchain_together")
|
||||||
from langchain_together import ChatTogether
|
from langchain_together import ChatTogether
|
||||||
|
|
||||||
return ChatTogether(model=model, **kwargs)
|
return ChatTogether(model=model, **kwargs)
|
||||||
elif model_provider == "mistralai":
|
if model_provider == "mistralai":
|
||||||
_check_pkg("langchain_mistralai")
|
_check_pkg("langchain_mistralai")
|
||||||
from langchain_mistralai import ChatMistralAI
|
from langchain_mistralai import ChatMistralAI
|
||||||
|
|
||||||
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
||||||
elif model_provider == "huggingface":
|
if model_provider == "huggingface":
|
||||||
_check_pkg("langchain_huggingface")
|
_check_pkg("langchain_huggingface")
|
||||||
from langchain_huggingface import ChatHuggingFace
|
from langchain_huggingface import ChatHuggingFace
|
||||||
|
|
||||||
return ChatHuggingFace(model_id=model, **kwargs)
|
return ChatHuggingFace(model_id=model, **kwargs)
|
||||||
elif model_provider == "groq":
|
if model_provider == "groq":
|
||||||
_check_pkg("langchain_groq")
|
_check_pkg("langchain_groq")
|
||||||
from langchain_groq import ChatGroq
|
from langchain_groq import ChatGroq
|
||||||
|
|
||||||
return ChatGroq(model=model, **kwargs)
|
return ChatGroq(model=model, **kwargs)
|
||||||
elif model_provider == "bedrock":
|
if model_provider == "bedrock":
|
||||||
_check_pkg("langchain_aws")
|
_check_pkg("langchain_aws")
|
||||||
from langchain_aws import ChatBedrock
|
from langchain_aws import ChatBedrock
|
||||||
|
|
||||||
# TODO: update to use model= once ChatBedrock supports
|
# TODO: update to use model= once ChatBedrock supports
|
||||||
return ChatBedrock(model_id=model, **kwargs)
|
return ChatBedrock(model_id=model, **kwargs)
|
||||||
elif model_provider == "bedrock_converse":
|
if model_provider == "bedrock_converse":
|
||||||
_check_pkg("langchain_aws")
|
_check_pkg("langchain_aws")
|
||||||
from langchain_aws import ChatBedrockConverse
|
from langchain_aws import ChatBedrockConverse
|
||||||
|
|
||||||
return ChatBedrockConverse(model=model, **kwargs)
|
return ChatBedrockConverse(model=model, **kwargs)
|
||||||
elif model_provider == "google_anthropic_vertex":
|
if model_provider == "google_anthropic_vertex":
|
||||||
_check_pkg("langchain_google_vertexai")
|
_check_pkg("langchain_google_vertexai")
|
||||||
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
|
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
|
||||||
|
|
||||||
return ChatAnthropicVertex(model=model, **kwargs)
|
return ChatAnthropicVertex(model=model, **kwargs)
|
||||||
elif model_provider == "deepseek":
|
if model_provider == "deepseek":
|
||||||
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
|
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
|
||||||
from langchain_deepseek import ChatDeepSeek
|
from langchain_deepseek import ChatDeepSeek
|
||||||
|
|
||||||
return ChatDeepSeek(model=model, **kwargs)
|
return ChatDeepSeek(model=model, **kwargs)
|
||||||
elif model_provider == "nvidia":
|
if model_provider == "nvidia":
|
||||||
_check_pkg("langchain_nvidia_ai_endpoints")
|
_check_pkg("langchain_nvidia_ai_endpoints")
|
||||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||||||
|
|
||||||
return ChatNVIDIA(model=model, **kwargs)
|
return ChatNVIDIA(model=model, **kwargs)
|
||||||
elif model_provider == "ibm":
|
if model_provider == "ibm":
|
||||||
_check_pkg("langchain_ibm")
|
_check_pkg("langchain_ibm")
|
||||||
from langchain_ibm import ChatWatsonx
|
from langchain_ibm import ChatWatsonx
|
||||||
|
|
||||||
return ChatWatsonx(model_id=model, **kwargs)
|
return ChatWatsonx(model_id=model, **kwargs)
|
||||||
elif model_provider == "xai":
|
if model_provider == "xai":
|
||||||
_check_pkg("langchain_xai")
|
_check_pkg("langchain_xai")
|
||||||
from langchain_xai import ChatXAI
|
from langchain_xai import ChatXAI
|
||||||
|
|
||||||
return ChatXAI(model=model, **kwargs)
|
return ChatXAI(model=model, **kwargs)
|
||||||
elif model_provider == "perplexity":
|
if model_provider == "perplexity":
|
||||||
_check_pkg("langchain_perplexity")
|
_check_pkg("langchain_perplexity")
|
||||||
from langchain_perplexity import ChatPerplexity
|
from langchain_perplexity import ChatPerplexity
|
||||||
|
|
||||||
return ChatPerplexity(model=model, **kwargs)
|
return ChatPerplexity(model=model, **kwargs)
|
||||||
else:
|
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
||||||
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
msg = (
|
||||||
msg = (
|
f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
|
||||||
f"Unsupported {model_provider=}.\n\nSupported model providers are: "
|
)
|
||||||
f"{supported}"
|
raise ValueError(msg)
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
_SUPPORTED_PROVIDERS = {
|
_SUPPORTED_PROVIDERS = {
|
||||||
@ -490,26 +487,25 @@ _SUPPORTED_PROVIDERS = {
|
|||||||
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
||||||
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
|
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
|
||||||
return "openai"
|
return "openai"
|
||||||
elif model_name.startswith("claude"):
|
if model_name.startswith("claude"):
|
||||||
return "anthropic"
|
return "anthropic"
|
||||||
elif model_name.startswith("command"):
|
if model_name.startswith("command"):
|
||||||
return "cohere"
|
return "cohere"
|
||||||
elif model_name.startswith("accounts/fireworks"):
|
if model_name.startswith("accounts/fireworks"):
|
||||||
return "fireworks"
|
return "fireworks"
|
||||||
elif model_name.startswith("gemini"):
|
if model_name.startswith("gemini"):
|
||||||
return "google_vertexai"
|
return "google_vertexai"
|
||||||
elif model_name.startswith("amazon."):
|
if model_name.startswith("amazon."):
|
||||||
return "bedrock"
|
return "bedrock"
|
||||||
elif model_name.startswith("mistral"):
|
if model_name.startswith("mistral"):
|
||||||
return "mistralai"
|
return "mistralai"
|
||||||
elif model_name.startswith("deepseek"):
|
if model_name.startswith("deepseek"):
|
||||||
return "deepseek"
|
return "deepseek"
|
||||||
elif model_name.startswith("grok"):
|
if model_name.startswith("grok"):
|
||||||
return "xai"
|
return "xai"
|
||||||
elif model_name.startswith("sonar"):
|
if model_name.startswith("sonar"):
|
||||||
return "perplexity"
|
return "perplexity"
|
||||||
else:
|
return None
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_model(model: str, model_provider: Optional[str]) -> tuple[str, str]:
|
def _parse_model(model: str, model_provider: Optional[str]) -> tuple[str, str]:
|
||||||
@ -595,14 +591,13 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return queue
|
return queue
|
||||||
elif self._default_config and (model := self._model()) and hasattr(model, name):
|
if self._default_config and (model := self._model()) and hasattr(model, name):
|
||||||
return getattr(model, name)
|
return getattr(model, name)
|
||||||
else:
|
msg = f"{name} is not a BaseChatModel attribute"
|
||||||
msg = f"{name} is not a BaseChatModel attribute"
|
if self._default_config:
|
||||||
if self._default_config:
|
msg += " and is not implemented on the default model"
|
||||||
msg += " and is not implemented on the default model"
|
msg += "."
|
||||||
msg += "."
|
raise AttributeError(msg)
|
||||||
raise AttributeError(msg)
|
|
||||||
|
|
||||||
def _model(self, config: Optional[RunnableConfig] = None) -> Runnable:
|
def _model(self, config: Optional[RunnableConfig] = None) -> Runnable:
|
||||||
params = {**self._default_config, **self._model_params(config)}
|
params = {**self._default_config, **self._model_params(config)}
|
||||||
@ -728,10 +723,9 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
)
|
)
|
||||||
# If multiple configs default to Runnable.batch which uses executor to invoke
|
# If multiple configs default to Runnable.batch which uses executor to invoke
|
||||||
# in parallel.
|
# in parallel.
|
||||||
else:
|
return super().batch(
|
||||||
return super().batch(
|
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
||||||
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
)
|
||||||
)
|
|
||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
self,
|
self,
|
||||||
@ -751,10 +745,9 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
)
|
)
|
||||||
# If multiple configs default to Runnable.batch which uses executor to invoke
|
# If multiple configs default to Runnable.batch which uses executor to invoke
|
||||||
# in parallel.
|
# in parallel.
|
||||||
else:
|
return await super().abatch(
|
||||||
return await super().abatch(
|
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
||||||
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
)
|
||||||
)
|
|
||||||
|
|
||||||
def batch_as_completed(
|
def batch_as_completed(
|
||||||
self,
|
self,
|
||||||
|
@ -192,41 +192,40 @@ def init_embeddings(
|
|||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
|
||||||
return OpenAIEmbeddings(model=model_name, **kwargs)
|
return OpenAIEmbeddings(model=model_name, **kwargs)
|
||||||
elif provider == "azure_openai":
|
if provider == "azure_openai":
|
||||||
from langchain_openai import AzureOpenAIEmbeddings
|
from langchain_openai import AzureOpenAIEmbeddings
|
||||||
|
|
||||||
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
|
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
|
||||||
elif provider == "google_vertexai":
|
if provider == "google_vertexai":
|
||||||
from langchain_google_vertexai import VertexAIEmbeddings
|
from langchain_google_vertexai import VertexAIEmbeddings
|
||||||
|
|
||||||
return VertexAIEmbeddings(model=model_name, **kwargs)
|
return VertexAIEmbeddings(model=model_name, **kwargs)
|
||||||
elif provider == "bedrock":
|
if provider == "bedrock":
|
||||||
from langchain_aws import BedrockEmbeddings
|
from langchain_aws import BedrockEmbeddings
|
||||||
|
|
||||||
return BedrockEmbeddings(model_id=model_name, **kwargs)
|
return BedrockEmbeddings(model_id=model_name, **kwargs)
|
||||||
elif provider == "cohere":
|
if provider == "cohere":
|
||||||
from langchain_cohere import CohereEmbeddings
|
from langchain_cohere import CohereEmbeddings
|
||||||
|
|
||||||
return CohereEmbeddings(model=model_name, **kwargs)
|
return CohereEmbeddings(model=model_name, **kwargs)
|
||||||
elif provider == "mistralai":
|
if provider == "mistralai":
|
||||||
from langchain_mistralai import MistralAIEmbeddings
|
from langchain_mistralai import MistralAIEmbeddings
|
||||||
|
|
||||||
return MistralAIEmbeddings(model=model_name, **kwargs)
|
return MistralAIEmbeddings(model=model_name, **kwargs)
|
||||||
elif provider == "huggingface":
|
if provider == "huggingface":
|
||||||
from langchain_huggingface import HuggingFaceEmbeddings
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
|
||||||
return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
||||||
elif provider == "ollama":
|
if provider == "ollama":
|
||||||
from langchain_ollama import OllamaEmbeddings
|
from langchain_ollama import OllamaEmbeddings
|
||||||
|
|
||||||
return OllamaEmbeddings(model=model_name, **kwargs)
|
return OllamaEmbeddings(model=model_name, **kwargs)
|
||||||
else:
|
msg = (
|
||||||
msg = (
|
f"Provider '{provider}' is not supported.\n"
|
||||||
f"Provider '{provider}' is not supported.\n"
|
f"Supported providers and their required packages:\n"
|
||||||
f"Supported providers and their required packages:\n"
|
f"{_get_provider_list()}"
|
||||||
f"{_get_provider_list()}"
|
)
|
||||||
)
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -70,7 +70,7 @@ def resolve_pairwise_criteria(
|
|||||||
Criteria.DEPTH,
|
Criteria.DEPTH,
|
||||||
]
|
]
|
||||||
return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria}
|
return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria}
|
||||||
elif isinstance(criteria, Criteria):
|
if isinstance(criteria, Criteria):
|
||||||
criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]}
|
criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]}
|
||||||
elif isinstance(criteria, str):
|
elif isinstance(criteria, str):
|
||||||
if criteria in _SUPPORTED_CRITERIA:
|
if criteria in _SUPPORTED_CRITERIA:
|
||||||
|
@ -186,9 +186,8 @@ class _EmbeddingDistanceChainMixin(Chain):
|
|||||||
}
|
}
|
||||||
if metric in metrics:
|
if metric in metrics:
|
||||||
return metrics[metric]
|
return metrics[metric]
|
||||||
else:
|
msg = f"Invalid metric: {metric}"
|
||||||
msg = f"Invalid metric: {metric}"
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _cosine_distance(a: Any, b: Any) -> Any:
|
def _cosine_distance(a: Any, b: Any) -> Any:
|
||||||
|
@ -162,8 +162,7 @@ def load_evaluator(
|
|||||||
)
|
)
|
||||||
raise ValueError(msg) from e
|
raise ValueError(msg) from e
|
||||||
return evaluator_cls.from_llm(llm=llm, **kwargs)
|
return evaluator_cls.from_llm(llm=llm, **kwargs)
|
||||||
else:
|
return evaluator_cls(**kwargs)
|
||||||
return evaluator_cls(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def load_evaluators(
|
def load_evaluators(
|
||||||
|
@ -70,7 +70,7 @@ class JsonSchemaEvaluator(StringEvaluator):
|
|||||||
def _parse_json(self, node: Any) -> Union[dict, list, None, float, bool, int, str]:
|
def _parse_json(self, node: Any) -> Union[dict, list, None, float, bool, int, str]:
|
||||||
if isinstance(node, str):
|
if isinstance(node, str):
|
||||||
return parse_json_markdown(node)
|
return parse_json_markdown(node)
|
||||||
elif hasattr(node, "schema") and callable(getattr(node, "schema")):
|
if hasattr(node, "schema") and callable(getattr(node, "schema")):
|
||||||
# Pydantic model
|
# Pydantic model
|
||||||
return getattr(node, "schema")()
|
return getattr(node, "schema")()
|
||||||
return node
|
return node
|
||||||
|
@ -24,7 +24,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]:
|
|||||||
if match:
|
if match:
|
||||||
if match.group(1).upper() == "CORRECT":
|
if match.group(1).upper() == "CORRECT":
|
||||||
return "CORRECT", 1
|
return "CORRECT", 1
|
||||||
elif match.group(1).upper() == "INCORRECT":
|
if match.group(1).upper() == "INCORRECT":
|
||||||
return "INCORRECT", 0
|
return "INCORRECT", 0
|
||||||
try:
|
try:
|
||||||
first_word = (
|
first_word = (
|
||||||
@ -32,7 +32,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]:
|
|||||||
)
|
)
|
||||||
if first_word.upper() == "CORRECT":
|
if first_word.upper() == "CORRECT":
|
||||||
return "CORRECT", 1
|
return "CORRECT", 1
|
||||||
elif first_word.upper() == "INCORRECT":
|
if first_word.upper() == "INCORRECT":
|
||||||
return "INCORRECT", 0
|
return "INCORRECT", 0
|
||||||
last_word = (
|
last_word = (
|
||||||
text.strip()
|
text.strip()
|
||||||
@ -41,7 +41,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]:
|
|||||||
)
|
)
|
||||||
if last_word.upper() == "CORRECT":
|
if last_word.upper() == "CORRECT":
|
||||||
return "CORRECT", 1
|
return "CORRECT", 1
|
||||||
elif last_word.upper() == "INCORRECT":
|
if last_word.upper() == "INCORRECT":
|
||||||
return "INCORRECT", 0
|
return "INCORRECT", 0
|
||||||
except IndexError:
|
except IndexError:
|
||||||
pass
|
pass
|
||||||
|
@ -123,12 +123,12 @@ class _EvalArgsMixin:
|
|||||||
if self.requires_input and input is None:
|
if self.requires_input and input is None:
|
||||||
msg = f"{self.__class__.__name__} requires an input string."
|
msg = f"{self.__class__.__name__} requires an input string."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif input is not None and not self.requires_input:
|
if input is not None and not self.requires_input:
|
||||||
warn(self._skip_input_warning)
|
warn(self._skip_input_warning)
|
||||||
if self.requires_reference and reference is None:
|
if self.requires_reference and reference is None:
|
||||||
msg = f"{self.__class__.__name__} requires a reference string."
|
msg = f"{self.__class__.__name__} requires a reference string."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif reference is not None and not self.requires_reference:
|
if reference is not None and not self.requires_reference:
|
||||||
warn(self._skip_reference_warning)
|
warn(self._skip_reference_warning)
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ def resolve_criteria(
|
|||||||
Criteria.DEPTH,
|
Criteria.DEPTH,
|
||||||
]
|
]
|
||||||
return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria}
|
return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria}
|
||||||
elif isinstance(criteria, Criteria):
|
if isinstance(criteria, Criteria):
|
||||||
criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]}
|
criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]}
|
||||||
elif isinstance(criteria, str):
|
elif isinstance(criteria, str):
|
||||||
if criteria in _SUPPORTED_CRITERIA:
|
if criteria in _SUPPORTED_CRITERIA:
|
||||||
|
@ -138,8 +138,7 @@ class _RapidFuzzChainMixin(Chain):
|
|||||||
module = module_map[distance]
|
module = module_map[distance]
|
||||||
if normalize_score:
|
if normalize_score:
|
||||||
return module.normalized_distance
|
return module.normalized_distance
|
||||||
else:
|
return module.distance
|
||||||
return module.distance
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metric(self) -> Callable:
|
def metric(self) -> Callable:
|
||||||
|
@ -21,10 +21,9 @@ def _get_client(
|
|||||||
ls_client = LangSmithClient(api_url, api_key=api_key)
|
ls_client = LangSmithClient(api_url, api_key=api_key)
|
||||||
if hasattr(ls_client, "push_prompt") and hasattr(ls_client, "pull_prompt"):
|
if hasattr(ls_client, "push_prompt") and hasattr(ls_client, "pull_prompt"):
|
||||||
return ls_client
|
return ls_client
|
||||||
else:
|
from langchainhub import Client as LangChainHubClient
|
||||||
from langchainhub import Client as LangChainHubClient
|
|
||||||
|
|
||||||
return LangChainHubClient(api_url, api_key=api_key)
|
return LangChainHubClient(api_url, api_key=api_key)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
try:
|
try:
|
||||||
from langchainhub import Client as LangChainHubClient
|
from langchainhub import Client as LangChainHubClient
|
||||||
@ -82,14 +81,13 @@ def push(
|
|||||||
|
|
||||||
# Then it's langchainhub
|
# Then it's langchainhub
|
||||||
manifest_json = dumps(object)
|
manifest_json = dumps(object)
|
||||||
message = client.push(
|
return client.push(
|
||||||
repo_full_name,
|
repo_full_name,
|
||||||
manifest_json,
|
manifest_json,
|
||||||
parent_commit_hash=parent_commit_hash,
|
parent_commit_hash=parent_commit_hash,
|
||||||
new_repo_is_public=new_repo_is_public,
|
new_repo_is_public=new_repo_is_public,
|
||||||
new_repo_description=new_repo_description,
|
new_repo_description=new_repo_description,
|
||||||
)
|
)
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
def pull(
|
def pull(
|
||||||
@ -113,8 +111,7 @@ def pull(
|
|||||||
|
|
||||||
# Then it's langsmith
|
# Then it's langsmith
|
||||||
if hasattr(client, "pull_prompt"):
|
if hasattr(client, "pull_prompt"):
|
||||||
response = client.pull_prompt(owner_repo_commit, include_model=include_model)
|
return client.pull_prompt(owner_repo_commit, include_model=include_model)
|
||||||
return response
|
|
||||||
|
|
||||||
# Then it's langchainhub
|
# Then it's langchainhub
|
||||||
if hasattr(client, "pull_repo"):
|
if hasattr(client, "pull_repo"):
|
||||||
|
@ -561,8 +561,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
k: v() for k, v in get_type_to_cls_dict().items()
|
k: v() for k, v in get_type_to_cls_dict().items()
|
||||||
}
|
}
|
||||||
return type_to_cls_dict
|
return type_to_cls_dict
|
||||||
else:
|
return getattr(llms, name)
|
||||||
return getattr(llms, name)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -154,6 +154,7 @@ class UpstashRedisEntityStore(BaseEntityStore):
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
||||||
)
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
def delete(self, key: str) -> None:
|
def delete(self, key: str) -> None:
|
||||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||||
@ -255,6 +256,7 @@ class RedisEntityStore(BaseEntityStore):
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
||||||
)
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
def delete(self, key: str) -> None:
|
def delete(self, key: str) -> None:
|
||||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||||
@ -362,6 +364,7 @@ class SQLiteEntityStore(BaseEntityStore):
|
|||||||
f'"{self.full_table_name}" (key, value) VALUES (?, ?)'
|
f'"{self.full_table_name}" (key, value) VALUES (?, ?)'
|
||||||
)
|
)
|
||||||
self._execute_query(query, (key, value))
|
self._execute_query(query, (key, value))
|
||||||
|
return None
|
||||||
|
|
||||||
def delete(self, key: str) -> None:
|
def delete(self, key: str) -> None:
|
||||||
"""Deletes a key-value pair, safely quoting the table name."""
|
"""Deletes a key-value pair, safely quoting the table name."""
|
||||||
|
@ -34,7 +34,7 @@ class BooleanOutputParser(BaseOutputParser[bool]):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return True
|
return True
|
||||||
elif self.false_val.upper() in truthy:
|
if self.false_val.upper() in truthy:
|
||||||
if self.true_val.upper() in truthy:
|
if self.true_val.upper() in truthy:
|
||||||
msg = (
|
msg = (
|
||||||
f"Ambiguous response. Both {self.true_val} and {self.false_val} "
|
f"Ambiguous response. Both {self.true_val} and {self.false_val} "
|
||||||
|
@ -71,31 +71,30 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise e
|
||||||
|
retries += 1
|
||||||
|
if self.legacy and hasattr(self.retry_chain, "run"):
|
||||||
|
completion = self.retry_chain.run(
|
||||||
|
instructions=self.parser.get_format_instructions(),
|
||||||
|
completion=completion,
|
||||||
|
error=repr(e),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
retries += 1
|
try:
|
||||||
if self.legacy and hasattr(self.retry_chain, "run"):
|
completion = self.retry_chain.invoke(
|
||||||
completion = self.retry_chain.run(
|
{
|
||||||
instructions=self.parser.get_format_instructions(),
|
"instructions": self.parser.get_format_instructions(), # noqa: E501
|
||||||
completion=completion,
|
"completion": completion,
|
||||||
error=repr(e),
|
"error": repr(e),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except (NotImplementedError, AttributeError):
|
||||||
|
# Case: self.parser does not have get_format_instructions
|
||||||
|
completion = self.retry_chain.invoke(
|
||||||
|
{
|
||||||
|
"completion": completion,
|
||||||
|
"error": repr(e),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
try:
|
|
||||||
completion = self.retry_chain.invoke(
|
|
||||||
{
|
|
||||||
"instructions": self.parser.get_format_instructions(), # noqa: E501
|
|
||||||
"completion": completion,
|
|
||||||
"error": repr(e),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except (NotImplementedError, AttributeError):
|
|
||||||
# Case: self.parser does not have get_format_instructions
|
|
||||||
completion = self.retry_chain.invoke(
|
|
||||||
{
|
|
||||||
"completion": completion,
|
|
||||||
"error": repr(e),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = "Failed to parse"
|
msg = "Failed to parse"
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
@ -109,31 +108,30 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise e
|
||||||
|
retries += 1
|
||||||
|
if self.legacy and hasattr(self.retry_chain, "arun"):
|
||||||
|
completion = await self.retry_chain.arun(
|
||||||
|
instructions=self.parser.get_format_instructions(),
|
||||||
|
completion=completion,
|
||||||
|
error=repr(e),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
retries += 1
|
try:
|
||||||
if self.legacy and hasattr(self.retry_chain, "arun"):
|
completion = await self.retry_chain.ainvoke(
|
||||||
completion = await self.retry_chain.arun(
|
{
|
||||||
instructions=self.parser.get_format_instructions(),
|
"instructions": self.parser.get_format_instructions(), # noqa: E501
|
||||||
completion=completion,
|
"completion": completion,
|
||||||
error=repr(e),
|
"error": repr(e),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except (NotImplementedError, AttributeError):
|
||||||
|
# Case: self.parser does not have get_format_instructions
|
||||||
|
completion = await self.retry_chain.ainvoke(
|
||||||
|
{
|
||||||
|
"completion": completion,
|
||||||
|
"error": repr(e),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
try:
|
|
||||||
completion = await self.retry_chain.ainvoke(
|
|
||||||
{
|
|
||||||
"instructions": self.parser.get_format_instructions(), # noqa: E501
|
|
||||||
"completion": completion,
|
|
||||||
"error": repr(e),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except (NotImplementedError, AttributeError):
|
|
||||||
# Case: self.parser does not have get_format_instructions
|
|
||||||
completion = await self.retry_chain.ainvoke(
|
|
||||||
{
|
|
||||||
"completion": completion,
|
|
||||||
"error": repr(e),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = "Failed to parse"
|
msg = "Failed to parse"
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
|
@ -64,7 +64,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
|
|||||||
msg = f"Invalid array format in '{original_request_params}'. \
|
msg = f"Invalid array format in '{original_request_params}'. \
|
||||||
Please check the format instructions."
|
Please check the format instructions."
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
elif (
|
if (
|
||||||
isinstance(parsed_array[0], int)
|
isinstance(parsed_array[0], int)
|
||||||
and parsed_array[-1] > self.dataframe.index.max()
|
and parsed_array[-1] > self.dataframe.index.max()
|
||||||
):
|
):
|
||||||
|
@ -30,12 +30,10 @@ class RegexParser(BaseOutputParser[dict[str, str]]):
|
|||||||
match = re.search(self.regex, text)
|
match = re.search(self.regex, text)
|
||||||
if match:
|
if match:
|
||||||
return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)}
|
return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)}
|
||||||
else:
|
if self.default_output_key is None:
|
||||||
if self.default_output_key is None:
|
msg = f"Could not parse output: {text}"
|
||||||
msg = f"Could not parse output: {text}"
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
return {
|
||||||
else:
|
key: text if key == self.default_output_key else ""
|
||||||
return {
|
for key in self.output_keys
|
||||||
key: text if key == self.default_output_key else ""
|
}
|
||||||
for key in self.output_keys
|
|
||||||
}
|
|
||||||
|
@ -33,14 +33,11 @@ class RegexDictParser(BaseOutputParser[dict[str, str]]):
|
|||||||
{expected_format} on text {text}"
|
{expected_format} on text {text}"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif len(matches) > 1:
|
if len(matches) > 1:
|
||||||
msg = f"Multiple matches found for output key: {output_key} with \
|
msg = f"Multiple matches found for output key: {output_key} with \
|
||||||
expected format {expected_format} on text {text}"
|
expected format {expected_format} on text {text}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif (
|
if self.no_update_value is not None and matches[0] == self.no_update_value:
|
||||||
self.no_update_value is not None and matches[0] == self.no_update_value
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
else:
|
result[output_key] = matches[0]
|
||||||
result[output_key] = matches[0]
|
|
||||||
return result
|
return result
|
||||||
|
@ -107,20 +107,19 @@ class RetryOutputParser(BaseOutputParser[T]):
|
|||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise e
|
||||||
|
retries += 1
|
||||||
|
if self.legacy and hasattr(self.retry_chain, "run"):
|
||||||
|
completion = self.retry_chain.run(
|
||||||
|
prompt=prompt_value.to_string(),
|
||||||
|
completion=completion,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
retries += 1
|
completion = self.retry_chain.invoke(
|
||||||
if self.legacy and hasattr(self.retry_chain, "run"):
|
{
|
||||||
completion = self.retry_chain.run(
|
"prompt": prompt_value.to_string(),
|
||||||
prompt=prompt_value.to_string(),
|
"completion": completion,
|
||||||
completion=completion,
|
}
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
completion = self.retry_chain.invoke(
|
|
||||||
{
|
|
||||||
"prompt": prompt_value.to_string(),
|
|
||||||
"completion": completion,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = "Failed to parse"
|
msg = "Failed to parse"
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
@ -143,21 +142,20 @@ class RetryOutputParser(BaseOutputParser[T]):
|
|||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise e
|
||||||
|
retries += 1
|
||||||
|
if self.legacy and hasattr(self.retry_chain, "arun"):
|
||||||
|
completion = await self.retry_chain.arun(
|
||||||
|
prompt=prompt_value.to_string(),
|
||||||
|
completion=completion,
|
||||||
|
error=repr(e),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
retries += 1
|
completion = await self.retry_chain.ainvoke(
|
||||||
if self.legacy and hasattr(self.retry_chain, "arun"):
|
{
|
||||||
completion = await self.retry_chain.arun(
|
"prompt": prompt_value.to_string(),
|
||||||
prompt=prompt_value.to_string(),
|
"completion": completion,
|
||||||
completion=completion,
|
}
|
||||||
error=repr(e),
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
completion = await self.retry_chain.ainvoke(
|
|
||||||
{
|
|
||||||
"prompt": prompt_value.to_string(),
|
|
||||||
"completion": completion,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = "Failed to parse"
|
msg = "Failed to parse"
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
@ -234,22 +232,21 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
|||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise e
|
||||||
|
retries += 1
|
||||||
|
if self.legacy and hasattr(self.retry_chain, "run"):
|
||||||
|
completion = self.retry_chain.run(
|
||||||
|
prompt=prompt_value.to_string(),
|
||||||
|
completion=completion,
|
||||||
|
error=repr(e),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
retries += 1
|
completion = self.retry_chain.invoke(
|
||||||
if self.legacy and hasattr(self.retry_chain, "run"):
|
{
|
||||||
completion = self.retry_chain.run(
|
"completion": completion,
|
||||||
prompt=prompt_value.to_string(),
|
"prompt": prompt_value.to_string(),
|
||||||
completion=completion,
|
"error": repr(e),
|
||||||
error=repr(e),
|
}
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
completion = self.retry_chain.invoke(
|
|
||||||
{
|
|
||||||
"completion": completion,
|
|
||||||
"prompt": prompt_value.to_string(),
|
|
||||||
"error": repr(e),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = "Failed to parse"
|
msg = "Failed to parse"
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
@ -263,22 +260,21 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
|||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise e
|
||||||
|
retries += 1
|
||||||
|
if self.legacy and hasattr(self.retry_chain, "arun"):
|
||||||
|
completion = await self.retry_chain.arun(
|
||||||
|
prompt=prompt_value.to_string(),
|
||||||
|
completion=completion,
|
||||||
|
error=repr(e),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
retries += 1
|
completion = await self.retry_chain.ainvoke(
|
||||||
if self.legacy and hasattr(self.retry_chain, "arun"):
|
{
|
||||||
completion = await self.retry_chain.arun(
|
"prompt": prompt_value.to_string(),
|
||||||
prompt=prompt_value.to_string(),
|
"completion": completion,
|
||||||
completion=completion,
|
"error": repr(e),
|
||||||
error=repr(e),
|
}
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
completion = await self.retry_chain.ainvoke(
|
|
||||||
{
|
|
||||||
"prompt": prompt_value.to_string(),
|
|
||||||
"completion": completion,
|
|
||||||
"error": repr(e),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = "Failed to parse"
|
msg = "Failed to parse"
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
|
@ -89,8 +89,7 @@ class StructuredOutputParser(BaseOutputParser[dict[str, Any]]):
|
|||||||
)
|
)
|
||||||
if only_json:
|
if only_json:
|
||||||
return STRUCTURED_FORMAT_SIMPLE_INSTRUCTIONS.format(format=schema_str)
|
return STRUCTURED_FORMAT_SIMPLE_INSTRUCTIONS.format(format=schema_str)
|
||||||
else:
|
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
|
||||||
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
|
|
||||||
|
|
||||||
def parse(self, text: str) -> dict[str, Any]:
|
def parse(self, text: str) -> dict[str, Any]:
|
||||||
expected_keys = [rs.name for rs in self.response_schemas]
|
expected_keys = [rs.name for rs in self.response_schemas]
|
||||||
|
@ -33,8 +33,7 @@ class YamlOutputParser(BaseOutputParser[T]):
|
|||||||
json_object = yaml.safe_load(yaml_str)
|
json_object = yaml.safe_load(yaml_str)
|
||||||
if hasattr(self.pydantic_object, "model_validate"):
|
if hasattr(self.pydantic_object, "model_validate"):
|
||||||
return self.pydantic_object.model_validate(json_object)
|
return self.pydantic_object.model_validate(json_object)
|
||||||
else:
|
return self.pydantic_object.parse_obj(json_object)
|
||||||
return self.pydantic_object.parse_obj(json_object)
|
|
||||||
|
|
||||||
except (yaml.YAMLError, ValidationError) as e:
|
except (yaml.YAMLError, ValidationError) as e:
|
||||||
name = self.pydantic_object.__name__
|
name = self.pydantic_object.__name__
|
||||||
|
@ -45,8 +45,7 @@ class ContextualCompressionRetriever(BaseRetriever):
|
|||||||
docs, query, callbacks=run_manager.get_child()
|
docs, query, callbacks=run_manager.get_child()
|
||||||
)
|
)
|
||||||
return list(compressed_docs)
|
return list(compressed_docs)
|
||||||
else:
|
return []
|
||||||
return []
|
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self,
|
||||||
@ -71,5 +70,4 @@ class ContextualCompressionRetriever(BaseRetriever):
|
|||||||
docs, query, callbacks=run_manager.get_child()
|
docs, query, callbacks=run_manager.get_child()
|
||||||
)
|
)
|
||||||
return list(compressed_docs)
|
return list(compressed_docs)
|
||||||
else:
|
return []
|
||||||
return []
|
|
||||||
|
@ -174,9 +174,7 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Get fused result of the retrievers.
|
# Get fused result of the retrievers.
|
||||||
fused_documents = self.rank_fusion(query, run_manager)
|
return self.rank_fusion(query, run_manager)
|
||||||
|
|
||||||
return fused_documents
|
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self,
|
||||||
@ -195,9 +193,7 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Get fused result of the retrievers.
|
# Get fused result of the retrievers.
|
||||||
fused_documents = await self.arank_fusion(query, run_manager)
|
return await self.arank_fusion(query, run_manager)
|
||||||
|
|
||||||
return fused_documents
|
|
||||||
|
|
||||||
def rank_fusion(
|
def rank_fusion(
|
||||||
self,
|
self,
|
||||||
@ -236,9 +232,7 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# apply rank fusion
|
# apply rank fusion
|
||||||
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
|
return self.weighted_reciprocal_rank(retriever_docs)
|
||||||
|
|
||||||
return fused_documents
|
|
||||||
|
|
||||||
async def arank_fusion(
|
async def arank_fusion(
|
||||||
self,
|
self,
|
||||||
@ -280,9 +274,7 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# apply rank fusion
|
# apply rank fusion
|
||||||
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
|
return self.weighted_reciprocal_rank(retriever_docs)
|
||||||
|
|
||||||
return fused_documents
|
|
||||||
|
|
||||||
def weighted_reciprocal_rank(
|
def weighted_reciprocal_rank(
|
||||||
self, doc_lists: list[list[Document]]
|
self, doc_lists: list[list[Document]]
|
||||||
@ -318,7 +310,7 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
|
|
||||||
# Docs are deduplicated by their contents then sorted by their scores
|
# Docs are deduplicated by their contents then sorted by their scores
|
||||||
all_docs = chain.from_iterable(doc_lists)
|
all_docs = chain.from_iterable(doc_lists)
|
||||||
sorted_docs = sorted(
|
return sorted(
|
||||||
unique_by_key(
|
unique_by_key(
|
||||||
all_docs,
|
all_docs,
|
||||||
lambda doc: (
|
lambda doc: (
|
||||||
@ -332,4 +324,3 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
doc.page_content if self.id_key is None else doc.metadata[self.id_key]
|
doc.page_content if self.id_key is None else doc.metadata[self.id_key]
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
return sorted_docs
|
|
||||||
|
@ -31,9 +31,7 @@ class MergerRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Merge the results of the retrievers.
|
# Merge the results of the retrievers.
|
||||||
merged_documents = self.merge_documents(query, run_manager)
|
return self.merge_documents(query, run_manager)
|
||||||
|
|
||||||
return merged_documents
|
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self,
|
||||||
@ -52,9 +50,7 @@ class MergerRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Merge the results of the retrievers.
|
# Merge the results of the retrievers.
|
||||||
merged_documents = await self.amerge_documents(query, run_manager)
|
return await self.amerge_documents(query, run_manager)
|
||||||
|
|
||||||
return merged_documents
|
|
||||||
|
|
||||||
def merge_documents(
|
def merge_documents(
|
||||||
self, query: str, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, run_manager: CallbackManagerForRetrieverRun
|
||||||
|
@ -74,10 +74,9 @@ class RePhraseQueryRetriever(BaseRetriever):
|
|||||||
query, {"callbacks": run_manager.get_child()}
|
query, {"callbacks": run_manager.get_child()}
|
||||||
)
|
)
|
||||||
logger.info(f"Re-phrased question: {re_phrased_question}")
|
logger.info(f"Re-phrased question: {re_phrased_question}")
|
||||||
docs = self.retriever.invoke(
|
return self.retriever.invoke(
|
||||||
re_phrased_question, config={"callbacks": run_manager.get_child()}
|
re_phrased_question, config={"callbacks": run_manager.get_child()}
|
||||||
)
|
)
|
||||||
return docs
|
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -119,117 +119,116 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
|
|||||||
}
|
}
|
||||||
if isinstance(vectorstore, DatabricksVectorSearch):
|
if isinstance(vectorstore, DatabricksVectorSearch):
|
||||||
return DatabricksVectorSearchTranslator()
|
return DatabricksVectorSearchTranslator()
|
||||||
elif isinstance(vectorstore, MyScale):
|
if isinstance(vectorstore, MyScale):
|
||||||
return MyScaleTranslator(metadata_key=vectorstore.metadata_column)
|
return MyScaleTranslator(metadata_key=vectorstore.metadata_column)
|
||||||
elif isinstance(vectorstore, Redis):
|
if isinstance(vectorstore, Redis):
|
||||||
return RedisTranslator.from_vectorstore(vectorstore)
|
return RedisTranslator.from_vectorstore(vectorstore)
|
||||||
elif isinstance(vectorstore, TencentVectorDB):
|
if isinstance(vectorstore, TencentVectorDB):
|
||||||
fields = [
|
fields = [
|
||||||
field.name for field in (vectorstore.meta_fields or []) if field.index
|
field.name for field in (vectorstore.meta_fields or []) if field.index
|
||||||
]
|
]
|
||||||
return TencentVectorDBTranslator(fields)
|
return TencentVectorDBTranslator(fields)
|
||||||
elif vectorstore.__class__ in BUILTIN_TRANSLATORS:
|
if vectorstore.__class__ in BUILTIN_TRANSLATORS:
|
||||||
return BUILTIN_TRANSLATORS[vectorstore.__class__]()
|
return BUILTIN_TRANSLATORS[vectorstore.__class__]()
|
||||||
|
try:
|
||||||
|
from langchain_astradb.vectorstores import AstraDBVectorStore
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
try:
|
if isinstance(vectorstore, AstraDBVectorStore):
|
||||||
from langchain_astradb.vectorstores import AstraDBVectorStore
|
return AstraDBTranslator()
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
if isinstance(vectorstore, AstraDBVectorStore):
|
|
||||||
return AstraDBTranslator()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from langchain_elasticsearch.vectorstores import ElasticsearchStore
|
from langchain_elasticsearch.vectorstores import ElasticsearchStore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if isinstance(vectorstore, ElasticsearchStore):
|
if isinstance(vectorstore, ElasticsearchStore):
|
||||||
return ElasticsearchTranslator()
|
return ElasticsearchTranslator()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from langchain_pinecone import PineconeVectorStore
|
from langchain_pinecone import PineconeVectorStore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if isinstance(vectorstore, PineconeVectorStore):
|
if isinstance(vectorstore, PineconeVectorStore):
|
||||||
return PineconeTranslator()
|
return PineconeTranslator()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from langchain_milvus import Milvus
|
from langchain_milvus import Milvus
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if isinstance(vectorstore, Milvus):
|
if isinstance(vectorstore, Milvus):
|
||||||
return MilvusTranslator()
|
return MilvusTranslator()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from langchain_mongodb import MongoDBAtlasVectorSearch
|
from langchain_mongodb import MongoDBAtlasVectorSearch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if isinstance(vectorstore, MongoDBAtlasVectorSearch):
|
if isinstance(vectorstore, MongoDBAtlasVectorSearch):
|
||||||
return MongoDBAtlasTranslator()
|
return MongoDBAtlasTranslator()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from langchain_neo4j import Neo4jVector
|
from langchain_neo4j import Neo4jVector
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if isinstance(vectorstore, Neo4jVector):
|
if isinstance(vectorstore, Neo4jVector):
|
||||||
return Neo4jTranslator()
|
return Neo4jTranslator()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Trying langchain_chroma import if exists
|
# Trying langchain_chroma import if exists
|
||||||
from langchain_chroma import Chroma
|
from langchain_chroma import Chroma
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if isinstance(vectorstore, Chroma):
|
if isinstance(vectorstore, Chroma):
|
||||||
return ChromaTranslator()
|
return ChromaTranslator()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from langchain_postgres import PGVector
|
from langchain_postgres import PGVector
|
||||||
from langchain_postgres import PGVectorTranslator as NewPGVectorTranslator
|
from langchain_postgres import PGVectorTranslator as NewPGVectorTranslator
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if isinstance(vectorstore, PGVector):
|
if isinstance(vectorstore, PGVector):
|
||||||
return NewPGVectorTranslator()
|
return NewPGVectorTranslator()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from langchain_qdrant import QdrantVectorStore
|
from langchain_qdrant import QdrantVectorStore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if isinstance(vectorstore, QdrantVectorStore):
|
if isinstance(vectorstore, QdrantVectorStore):
|
||||||
return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key)
|
return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Added in langchain-community==0.2.11
|
# Added in langchain-community==0.2.11
|
||||||
from langchain_community.query_constructors.hanavector import HanaTranslator
|
from langchain_community.query_constructors.hanavector import HanaTranslator
|
||||||
from langchain_community.vectorstores import HanaDB
|
from langchain_community.vectorstores import HanaDB
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if isinstance(vectorstore, HanaDB):
|
if isinstance(vectorstore, HanaDB):
|
||||||
return HanaTranslator()
|
return HanaTranslator()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Trying langchain_weaviate (weaviate v4) import if exists
|
# Trying langchain_weaviate (weaviate v4) import if exists
|
||||||
from langchain_weaviate.vectorstores import WeaviateVectorStore
|
from langchain_weaviate.vectorstores import WeaviateVectorStore
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if isinstance(vectorstore, WeaviateVectorStore):
|
if isinstance(vectorstore, WeaviateVectorStore):
|
||||||
return WeaviateTranslator()
|
return WeaviateTranslator()
|
||||||
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Self query retriever with Vector Store type {vectorstore.__class__}"
|
f"Self query retriever with Vector Store type {vectorstore.__class__}"
|
||||||
f" not supported."
|
f" not supported."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
class SelfQueryRetriever(BaseRetriever):
|
class SelfQueryRetriever(BaseRetriever):
|
||||||
@ -289,14 +288,12 @@ class SelfQueryRetriever(BaseRetriever):
|
|||||||
def _get_docs_with_query(
|
def _get_docs_with_query(
|
||||||
self, query: str, search_kwargs: dict[str, Any]
|
self, query: str, search_kwargs: dict[str, Any]
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
|
return self.vectorstore.search(query, self.search_type, **search_kwargs)
|
||||||
return docs
|
|
||||||
|
|
||||||
async def _aget_docs_with_query(
|
async def _aget_docs_with_query(
|
||||||
self, query: str, search_kwargs: dict[str, Any]
|
self, query: str, search_kwargs: dict[str, Any]
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
|
return await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
|
||||||
return docs
|
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
@ -315,8 +312,7 @@ class SelfQueryRetriever(BaseRetriever):
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.info(f"Generated Query: {structured_query}")
|
logger.info(f"Generated Query: {structured_query}")
|
||||||
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
||||||
docs = self._get_docs_with_query(new_query, search_kwargs)
|
return self._get_docs_with_query(new_query, search_kwargs)
|
||||||
return docs
|
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
@ -335,8 +331,7 @@ class SelfQueryRetriever(BaseRetriever):
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.info(f"Generated Query: {structured_query}")
|
logger.info(f"Generated Query: {structured_query}")
|
||||||
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
||||||
docs = await self._aget_docs_with_query(new_query, search_kwargs)
|
return await self._aget_docs_with_query(new_query, search_kwargs)
|
||||||
return docs
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
|
@ -191,13 +191,13 @@ def _wrap_in_chain_factory(
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return lambda: chain
|
return lambda: chain
|
||||||
elif isinstance(llm_or_chain_factory, BaseLanguageModel):
|
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||||
return llm_or_chain_factory
|
return llm_or_chain_factory
|
||||||
elif isinstance(llm_or_chain_factory, Runnable):
|
if isinstance(llm_or_chain_factory, Runnable):
|
||||||
# Memory may exist here, but it's not elegant to check all those cases.
|
# Memory may exist here, but it's not elegant to check all those cases.
|
||||||
lcf = llm_or_chain_factory
|
lcf = llm_or_chain_factory
|
||||||
return lambda: lcf
|
return lambda: lcf
|
||||||
elif callable(llm_or_chain_factory):
|
if callable(llm_or_chain_factory):
|
||||||
if is_traceable_function(llm_or_chain_factory):
|
if is_traceable_function(llm_or_chain_factory):
|
||||||
runnable_ = as_runnable(cast(Callable, llm_or_chain_factory))
|
runnable_ = as_runnable(cast(Callable, llm_or_chain_factory))
|
||||||
return lambda: runnable_
|
return lambda: runnable_
|
||||||
@ -215,15 +215,14 @@ def _wrap_in_chain_factory(
|
|||||||
# It's not uncommon to do an LLM constructor instead of raw LLM,
|
# It's not uncommon to do an LLM constructor instead of raw LLM,
|
||||||
# so we'll unpack it for the user.
|
# so we'll unpack it for the user.
|
||||||
return _model
|
return _model
|
||||||
elif is_traceable_function(cast(Callable, _model)):
|
if is_traceable_function(cast(Callable, _model)):
|
||||||
runnable_ = as_runnable(cast(Callable, _model))
|
runnable_ = as_runnable(cast(Callable, _model))
|
||||||
return lambda: runnable_
|
return lambda: runnable_
|
||||||
elif not isinstance(_model, Runnable):
|
if not isinstance(_model, Runnable):
|
||||||
# This is unlikely to happen - a constructor for a model function
|
# This is unlikely to happen - a constructor for a model function
|
||||||
return lambda: RunnableLambda(constructor)
|
return lambda: RunnableLambda(constructor)
|
||||||
else:
|
# Typical correct case
|
||||||
# Typical correct case
|
return constructor
|
||||||
return constructor
|
|
||||||
return llm_or_chain_factory
|
return llm_or_chain_factory
|
||||||
|
|
||||||
|
|
||||||
@ -272,9 +271,8 @@ def _get_prompt(inputs: dict[str, Any]) -> str:
|
|||||||
raise InputFormatError(msg)
|
raise InputFormatError(msg)
|
||||||
if len(prompts) == 1:
|
if len(prompts) == 1:
|
||||||
return prompts[0]
|
return prompts[0]
|
||||||
else:
|
msg = f"LLM Run expects single prompt input. Got {len(prompts)} prompts."
|
||||||
msg = f"LLM Run expects single prompt input. Got {len(prompts)} prompts."
|
raise InputFormatError(msg)
|
||||||
raise InputFormatError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatModelInput(TypedDict):
|
class ChatModelInput(TypedDict):
|
||||||
@ -321,12 +319,11 @@ def _get_messages(inputs: dict[str, Any]) -> dict:
|
|||||||
)
|
)
|
||||||
raise InputFormatError(msg)
|
raise InputFormatError(msg)
|
||||||
return input_copy
|
return input_copy
|
||||||
else:
|
msg = (
|
||||||
msg = (
|
f"Chat Run expects single List[dict] or List[List[dict]] 'messages'"
|
||||||
f"Chat Run expects single List[dict] or List[List[dict]] 'messages'"
|
f" input. Got {inputs}"
|
||||||
f" input. Got {inputs}"
|
)
|
||||||
)
|
raise InputFormatError(msg)
|
||||||
raise InputFormatError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
## Shared data validation utilities
|
## Shared data validation utilities
|
||||||
@ -707,31 +704,29 @@ async def _arun_llm(
|
|||||||
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
|
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
msg = (
|
||||||
msg = (
|
"Input mapper returned invalid format"
|
||||||
"Input mapper returned invalid format"
|
f" {prompt_or_messages}"
|
||||||
f" {prompt_or_messages}"
|
"\nExpected a single string or list of chat messages."
|
||||||
"\nExpected a single string or list of chat messages."
|
)
|
||||||
)
|
raise InputFormatError(msg)
|
||||||
raise InputFormatError(msg)
|
|
||||||
|
|
||||||
else:
|
try:
|
||||||
try:
|
prompt = _get_prompt(inputs)
|
||||||
prompt = _get_prompt(inputs)
|
llm_output: Union[str, BaseMessage] = await llm.ainvoke(
|
||||||
llm_output: Union[str, BaseMessage] = await llm.ainvoke(
|
prompt,
|
||||||
prompt,
|
config=RunnableConfig(
|
||||||
config=RunnableConfig(
|
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
|
||||||
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
|
),
|
||||||
),
|
)
|
||||||
)
|
except InputFormatError:
|
||||||
except InputFormatError:
|
llm_inputs = _get_messages(inputs)
|
||||||
llm_inputs = _get_messages(inputs)
|
llm_output = await llm.ainvoke(
|
||||||
llm_output = await llm.ainvoke(
|
**llm_inputs,
|
||||||
**llm_inputs,
|
config=RunnableConfig(
|
||||||
config=RunnableConfig(
|
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
|
||||||
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
return llm_output
|
return llm_output
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,8 +28,7 @@ def _get_messages_from_run_dict(messages: list[dict]) -> list[BaseMessage]:
|
|||||||
first_message = messages[0]
|
first_message = messages[0]
|
||||||
if "lc" in first_message:
|
if "lc" in first_message:
|
||||||
return [load(dumpd(message)) for message in messages]
|
return [load(dumpd(message)) for message in messages]
|
||||||
else:
|
return messages_from_dict(messages)
|
||||||
return messages_from_dict(messages)
|
|
||||||
|
|
||||||
|
|
||||||
class StringRunMapper(Serializable):
|
class StringRunMapper(Serializable):
|
||||||
@ -106,25 +105,23 @@ class LLMStringRunMapper(StringRunMapper):
|
|||||||
if run.run_type != "llm":
|
if run.run_type != "llm":
|
||||||
msg = "LLM RunMapper only supports LLM runs."
|
msg = "LLM RunMapper only supports LLM runs."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif not run.outputs:
|
if not run.outputs:
|
||||||
if run.error:
|
if run.error:
|
||||||
msg = f"Cannot evaluate errored LLM run {run.id}: {run.error}"
|
msg = f"Cannot evaluate errored LLM run {run.id}: {run.error}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
msg = f"Run {run.id} has no outputs. Cannot evaluate this run."
|
||||||
msg = f"Run {run.id} has no outputs. Cannot evaluate this run."
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
try:
|
||||||
else:
|
inputs = self.serialize_inputs(run.inputs)
|
||||||
try:
|
except Exception as e:
|
||||||
inputs = self.serialize_inputs(run.inputs)
|
msg = f"Could not parse LM input from run inputs {run.inputs}"
|
||||||
except Exception as e:
|
raise ValueError(msg) from e
|
||||||
msg = f"Could not parse LM input from run inputs {run.inputs}"
|
try:
|
||||||
raise ValueError(msg) from e
|
output_ = self.serialize_outputs(run.outputs)
|
||||||
try:
|
except Exception as e:
|
||||||
output_ = self.serialize_outputs(run.outputs)
|
msg = f"Could not parse LM prediction from run outputs {run.outputs}"
|
||||||
except Exception as e:
|
raise ValueError(msg) from e
|
||||||
msg = f"Could not parse LM prediction from run outputs {run.outputs}"
|
return {"input": inputs, "prediction": output_}
|
||||||
raise ValueError(msg) from e
|
|
||||||
return {"input": inputs, "prediction": output_}
|
|
||||||
|
|
||||||
|
|
||||||
class ChainStringRunMapper(StringRunMapper):
|
class ChainStringRunMapper(StringRunMapper):
|
||||||
@ -142,14 +139,13 @@ class ChainStringRunMapper(StringRunMapper):
|
|||||||
def _get_key(self, source: dict, key: Optional[str], which: str) -> str:
|
def _get_key(self, source: dict, key: Optional[str], which: str) -> str:
|
||||||
if key is not None:
|
if key is not None:
|
||||||
return source[key]
|
return source[key]
|
||||||
elif len(source) == 1:
|
if len(source) == 1:
|
||||||
return next(iter(source.values()))
|
return next(iter(source.values()))
|
||||||
else:
|
msg = (
|
||||||
msg = (
|
f"Could not map run {which} with multiple keys: "
|
||||||
f"Could not map run {which} with multiple keys: "
|
f"{source}\nPlease manually specify a {which}_key"
|
||||||
f"{source}\nPlease manually specify a {which}_key"
|
)
|
||||||
)
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
def map(self, run: Run) -> dict[str, str]:
|
def map(self, run: Run) -> dict[str, str]:
|
||||||
"""Maps the Run to a dictionary."""
|
"""Maps the Run to a dictionary."""
|
||||||
@ -168,7 +164,7 @@ class ChainStringRunMapper(StringRunMapper):
|
|||||||
f" '{self.input_key}'."
|
f" '{self.input_key}'."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif self.prediction_key is not None and self.prediction_key not in run.outputs:
|
if self.prediction_key is not None and self.prediction_key not in run.outputs:
|
||||||
available_keys = ", ".join(run.outputs.keys())
|
available_keys = ", ".join(run.outputs.keys())
|
||||||
msg = (
|
msg = (
|
||||||
f"Run with ID {run.id} doesn't have the expected prediction key"
|
f"Run with ID {run.id} doesn't have the expected prediction key"
|
||||||
@ -178,13 +174,12 @@ class ChainStringRunMapper(StringRunMapper):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
else:
|
input_ = self._get_key(run.inputs, self.input_key, "input")
|
||||||
input_ = self._get_key(run.inputs, self.input_key, "input")
|
prediction = self._get_key(run.outputs, self.prediction_key, "prediction")
|
||||||
prediction = self._get_key(run.outputs, self.prediction_key, "prediction")
|
return {
|
||||||
return {
|
"input": input_,
|
||||||
"input": input_,
|
"prediction": prediction,
|
||||||
"prediction": prediction,
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ToolStringRunMapper(StringRunMapper):
|
class ToolStringRunMapper(StringRunMapper):
|
||||||
@ -224,8 +219,7 @@ class StringExampleMapper(Serializable):
|
|||||||
" specify a reference_key."
|
" specify a reference_key."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
output = list(example.outputs.values())[0]
|
||||||
output = list(example.outputs.values())[0]
|
|
||||||
elif self.reference_key not in example.outputs:
|
elif self.reference_key not in example.outputs:
|
||||||
msg = (
|
msg = (
|
||||||
f"Example {example.id} does not have reference key"
|
f"Example {example.id} does not have reference key"
|
||||||
|
@ -65,24 +65,23 @@ def _import_python_tool_PythonREPLTool() -> Any:
|
|||||||
def __getattr__(name: str) -> Any:
|
def __getattr__(name: str) -> Any:
|
||||||
if name == "PythonAstREPLTool":
|
if name == "PythonAstREPLTool":
|
||||||
return _import_python_tool_PythonAstREPLTool()
|
return _import_python_tool_PythonAstREPLTool()
|
||||||
elif name == "PythonREPLTool":
|
if name == "PythonREPLTool":
|
||||||
return _import_python_tool_PythonREPLTool()
|
return _import_python_tool_PythonREPLTool()
|
||||||
else:
|
from langchain_community import tools
|
||||||
from langchain_community import tools
|
|
||||||
|
|
||||||
# If not in interactive env, raise warning.
|
# If not in interactive env, raise warning.
|
||||||
if not is_interactive_env():
|
if not is_interactive_env():
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Importing tools from langchain is deprecated. Importing from "
|
"Importing tools from langchain is deprecated. Importing from "
|
||||||
"langchain will no longer be supported as of langchain==0.2.0. "
|
"langchain will no longer be supported as of langchain==0.2.0. "
|
||||||
"Please import from langchain-community instead:\n\n"
|
"Please import from langchain-community instead:\n\n"
|
||||||
f"`from langchain_community.tools import {name}`.\n\n"
|
f"`from langchain_community.tools import {name}`.\n\n"
|
||||||
"To install langchain-community run "
|
"To install langchain-community run "
|
||||||
"`pip install -U langchain-community`.",
|
"`pip install -U langchain-community`.",
|
||||||
category=LangChainDeprecationWarning,
|
category=LangChainDeprecationWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
return getattr(tools, name)
|
return getattr(tools, name)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
|
|||||||
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
|
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["A", "C4", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"]
|
select = ["A", "C4", "D", "E", "EM", "F", "I", "PGH003", "PIE", "RET", "S", "SIM", "T201", "UP", "W"]
|
||||||
pydocstyle.convention = "google"
|
pydocstyle.convention = "google"
|
||||||
pyupgrade.keep-runtime-typing = true
|
pyupgrade.keep-runtime-typing = true
|
||||||
|
|
||||||
|
@ -139,8 +139,7 @@ async def get_state(
|
|||||||
async def ask_for_passphrase(said_please: bool) -> dict[str, Any]:
|
async def ask_for_passphrase(said_please: bool) -> dict[str, Any]:
|
||||||
if said_please:
|
if said_please:
|
||||||
return {"passphrase": f"The passphrase is {PASS_PHRASE}"}
|
return {"passphrase": f"The passphrase is {PASS_PHRASE}"}
|
||||||
else:
|
return {"passphrase": "I won't share the passphrase without saying 'please'."}
|
||||||
return {"passphrase": "I won't share the passphrase without saying 'please'."}
|
|
||||||
|
|
||||||
|
|
||||||
@app.delete(
|
@app.delete(
|
||||||
@ -153,12 +152,11 @@ async def recycle(password: SecretPassPhrase) -> dict[str, Any]:
|
|||||||
if password.pw == PASS_PHRASE:
|
if password.pw == PASS_PHRASE:
|
||||||
_ROBOT_STATE["destruct"] = True
|
_ROBOT_STATE["destruct"] = True
|
||||||
return {"status": "Self-destruct initiated", "state": _ROBOT_STATE}
|
return {"status": "Self-destruct initiated", "state": _ROBOT_STATE}
|
||||||
else:
|
_ROBOT_STATE["destruct"] = False
|
||||||
_ROBOT_STATE["destruct"] = False
|
raise HTTPException(
|
||||||
raise HTTPException(
|
status_code=400,
|
||||||
status_code=400,
|
detail="Pass phrase required. You should have thought to ask for it.",
|
||||||
detail="Pass phrase required. You should have thought to ask for it.",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
|
@ -100,14 +100,13 @@ def _get_agent(**kwargs: Any) -> AgentExecutor:
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
agent = initialize_agent(
|
return initialize_agent(
|
||||||
tools,
|
tools,
|
||||||
fake_llm,
|
fake_llm,
|
||||||
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def test_agent_bad_action() -> None:
|
def test_agent_bad_action() -> None:
|
||||||
|
@ -71,14 +71,13 @@ def _get_agent(**kwargs: Any) -> AgentExecutor:
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
agent = initialize_agent(
|
return initialize_agent(
|
||||||
tools,
|
tools,
|
||||||
fake_llm,
|
fake_llm,
|
||||||
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
async def test_agent_bad_action() -> None:
|
async def test_agent_bad_action() -> None:
|
||||||
|
@ -11,8 +11,7 @@ def get_action_and_input(text: str) -> tuple[str, str]:
|
|||||||
output = output_parser.parse(text)
|
output = output_parser.parse(text)
|
||||||
if isinstance(output, AgentAction):
|
if isinstance(output, AgentAction):
|
||||||
return output.tool, str(output.tool_input)
|
return output.tool, str(output.tool_input)
|
||||||
else:
|
return "Final Answer", output.return_values["output"]
|
||||||
return "Final Answer", output.return_values["output"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_with_language() -> None:
|
def test_parse_with_language() -> None:
|
||||||
|
@ -16,8 +16,7 @@ def get_action_and_input(text: str) -> tuple[str, str]:
|
|||||||
output = MRKLOutputParser().parse(text)
|
output = MRKLOutputParser().parse(text)
|
||||||
if isinstance(output, AgentAction):
|
if isinstance(output, AgentAction):
|
||||||
return output.tool, str(output.tool_input)
|
return output.tool, str(output.tool_input)
|
||||||
else:
|
return "Final Answer", output.return_values["output"]
|
||||||
return "Final Answer", output.return_values["output"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_action_and_input() -> None:
|
def test_get_action_and_input() -> None:
|
||||||
|
@ -21,11 +21,10 @@ def get_action_and_input(text: str) -> tuple[str, str]:
|
|||||||
output = output_parser.parse(text)
|
output = output_parser.parse(text)
|
||||||
if isinstance(output, AgentAction):
|
if isinstance(output, AgentAction):
|
||||||
return output.tool, str(output.tool_input)
|
return output.tool, str(output.tool_input)
|
||||||
elif isinstance(output, AgentFinish):
|
if isinstance(output, AgentFinish):
|
||||||
return output.return_values["output"], output.log
|
return output.return_values["output"], output.log
|
||||||
else:
|
msg = "Unexpected output type"
|
||||||
msg = "Unexpected output type"
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_with_language() -> None:
|
def test_parse_with_language() -> None:
|
||||||
|
@ -58,8 +58,7 @@ class FakeChain(Chain):
|
|||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
if self.be_correct:
|
if self.be_correct:
|
||||||
return {"bar": "baz"}
|
return {"bar": "baz"}
|
||||||
else:
|
return {"baz": "bar"}
|
||||||
return {"baz": "bar"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_bad_inputs() -> None:
|
def test_bad_inputs() -> None:
|
||||||
|
@ -54,9 +54,8 @@ class _FakeTrajectoryChatModel(FakeChatModel):
|
|||||||
response = self.queries[list(self.queries.keys())[self.response_index]]
|
response = self.queries[list(self.queries.keys())[self.response_index]]
|
||||||
self.response_index = self.response_index + 1
|
self.response_index = self.response_index + 1
|
||||||
return response
|
return response
|
||||||
else:
|
prompt = messages[0].content
|
||||||
prompt = messages[0].content
|
return self.queries[prompt]
|
||||||
return self.queries[prompt]
|
|
||||||
|
|
||||||
|
|
||||||
def test_trajectory_output_parser_parse() -> None:
|
def test_trajectory_output_parser_parse() -> None:
|
||||||
|
@ -45,8 +45,7 @@ class FakeLLM(LLM):
|
|||||||
return self.queries[prompt]
|
return self.queries[prompt]
|
||||||
if stop is None:
|
if stop is None:
|
||||||
return "foo"
|
return "foo"
|
||||||
else:
|
return "bar"
|
||||||
return "bar"
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> dict[str, Any]:
|
def _identifying_params(self) -> dict[str, Any]:
|
||||||
|
@ -14,9 +14,8 @@ class SequentialRetriever(BaseRetriever):
|
|||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
if self.response_index >= len(self.sequential_responses):
|
if self.response_index >= len(self.sequential_responses):
|
||||||
return []
|
return []
|
||||||
else:
|
self.response_index += 1
|
||||||
self.response_index += 1
|
return self.sequential_responses[self.response_index - 1]
|
||||||
return self.sequential_responses[self.response_index - 1]
|
|
||||||
|
|
||||||
async def _aget_relevant_documents( # type: ignore[override]
|
async def _aget_relevant_documents( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user