mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 08:56:27 +00:00
langchain: Add ruff rule RET (#31875)
All auto-fixes See https://docs.astral.sh/ruff/rules/#flake8-return-ret --------- Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
parent
fceebbb387
commit
56bbfd9723
@ -48,25 +48,25 @@ def __getattr__(name: str) -> Any:
|
|||||||
_warn_on_import(name, replacement="langchain.agents.MRKLChain")
|
_warn_on_import(name, replacement="langchain.agents.MRKLChain")
|
||||||
|
|
||||||
return MRKLChain
|
return MRKLChain
|
||||||
elif name == "ReActChain":
|
if name == "ReActChain":
|
||||||
from langchain.agents import ReActChain
|
from langchain.agents import ReActChain
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.agents.ReActChain")
|
_warn_on_import(name, replacement="langchain.agents.ReActChain")
|
||||||
|
|
||||||
return ReActChain
|
return ReActChain
|
||||||
elif name == "SelfAskWithSearchChain":
|
if name == "SelfAskWithSearchChain":
|
||||||
from langchain.agents import SelfAskWithSearchChain
|
from langchain.agents import SelfAskWithSearchChain
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.agents.SelfAskWithSearchChain")
|
_warn_on_import(name, replacement="langchain.agents.SelfAskWithSearchChain")
|
||||||
|
|
||||||
return SelfAskWithSearchChain
|
return SelfAskWithSearchChain
|
||||||
elif name == "ConversationChain":
|
if name == "ConversationChain":
|
||||||
from langchain.chains import ConversationChain
|
from langchain.chains import ConversationChain
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.chains.ConversationChain")
|
_warn_on_import(name, replacement="langchain.chains.ConversationChain")
|
||||||
|
|
||||||
return ConversationChain
|
return ConversationChain
|
||||||
elif name == "LLMBashChain":
|
if name == "LLMBashChain":
|
||||||
msg = (
|
msg = (
|
||||||
"This module has been moved to langchain-experimental. "
|
"This module has been moved to langchain-experimental. "
|
||||||
"For more details: "
|
"For more details: "
|
||||||
@ -77,97 +77,97 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
raise ImportError(msg)
|
raise ImportError(msg)
|
||||||
|
|
||||||
elif name == "LLMChain":
|
if name == "LLMChain":
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.chains.LLMChain")
|
_warn_on_import(name, replacement="langchain.chains.LLMChain")
|
||||||
|
|
||||||
return LLMChain
|
return LLMChain
|
||||||
elif name == "LLMCheckerChain":
|
if name == "LLMCheckerChain":
|
||||||
from langchain.chains import LLMCheckerChain
|
from langchain.chains import LLMCheckerChain
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.chains.LLMCheckerChain")
|
_warn_on_import(name, replacement="langchain.chains.LLMCheckerChain")
|
||||||
|
|
||||||
return LLMCheckerChain
|
return LLMCheckerChain
|
||||||
elif name == "LLMMathChain":
|
if name == "LLMMathChain":
|
||||||
from langchain.chains import LLMMathChain
|
from langchain.chains import LLMMathChain
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.chains.LLMMathChain")
|
_warn_on_import(name, replacement="langchain.chains.LLMMathChain")
|
||||||
|
|
||||||
return LLMMathChain
|
return LLMMathChain
|
||||||
elif name == "QAWithSourcesChain":
|
if name == "QAWithSourcesChain":
|
||||||
from langchain.chains import QAWithSourcesChain
|
from langchain.chains import QAWithSourcesChain
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.chains.QAWithSourcesChain")
|
_warn_on_import(name, replacement="langchain.chains.QAWithSourcesChain")
|
||||||
|
|
||||||
return QAWithSourcesChain
|
return QAWithSourcesChain
|
||||||
elif name == "VectorDBQA":
|
if name == "VectorDBQA":
|
||||||
from langchain.chains import VectorDBQA
|
from langchain.chains import VectorDBQA
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.chains.VectorDBQA")
|
_warn_on_import(name, replacement="langchain.chains.VectorDBQA")
|
||||||
|
|
||||||
return VectorDBQA
|
return VectorDBQA
|
||||||
elif name == "VectorDBQAWithSourcesChain":
|
if name == "VectorDBQAWithSourcesChain":
|
||||||
from langchain.chains import VectorDBQAWithSourcesChain
|
from langchain.chains import VectorDBQAWithSourcesChain
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.chains.VectorDBQAWithSourcesChain")
|
_warn_on_import(name, replacement="langchain.chains.VectorDBQAWithSourcesChain")
|
||||||
|
|
||||||
return VectorDBQAWithSourcesChain
|
return VectorDBQAWithSourcesChain
|
||||||
elif name == "InMemoryDocstore":
|
if name == "InMemoryDocstore":
|
||||||
from langchain_community.docstore import InMemoryDocstore
|
from langchain_community.docstore import InMemoryDocstore
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.docstore.InMemoryDocstore")
|
_warn_on_import(name, replacement="langchain.docstore.InMemoryDocstore")
|
||||||
|
|
||||||
return InMemoryDocstore
|
return InMemoryDocstore
|
||||||
elif name == "Wikipedia":
|
if name == "Wikipedia":
|
||||||
from langchain_community.docstore import Wikipedia
|
from langchain_community.docstore import Wikipedia
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.docstore.Wikipedia")
|
_warn_on_import(name, replacement="langchain.docstore.Wikipedia")
|
||||||
|
|
||||||
return Wikipedia
|
return Wikipedia
|
||||||
elif name == "Anthropic":
|
if name == "Anthropic":
|
||||||
from langchain_community.llms import Anthropic
|
from langchain_community.llms import Anthropic
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.Anthropic")
|
_warn_on_import(name, replacement="langchain_community.llms.Anthropic")
|
||||||
|
|
||||||
return Anthropic
|
return Anthropic
|
||||||
elif name == "Banana":
|
if name == "Banana":
|
||||||
from langchain_community.llms import Banana
|
from langchain_community.llms import Banana
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.Banana")
|
_warn_on_import(name, replacement="langchain_community.llms.Banana")
|
||||||
|
|
||||||
return Banana
|
return Banana
|
||||||
elif name == "CerebriumAI":
|
if name == "CerebriumAI":
|
||||||
from langchain_community.llms import CerebriumAI
|
from langchain_community.llms import CerebriumAI
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.CerebriumAI")
|
_warn_on_import(name, replacement="langchain_community.llms.CerebriumAI")
|
||||||
|
|
||||||
return CerebriumAI
|
return CerebriumAI
|
||||||
elif name == "Cohere":
|
if name == "Cohere":
|
||||||
from langchain_community.llms import Cohere
|
from langchain_community.llms import Cohere
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.Cohere")
|
_warn_on_import(name, replacement="langchain_community.llms.Cohere")
|
||||||
|
|
||||||
return Cohere
|
return Cohere
|
||||||
elif name == "ForefrontAI":
|
if name == "ForefrontAI":
|
||||||
from langchain_community.llms import ForefrontAI
|
from langchain_community.llms import ForefrontAI
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.ForefrontAI")
|
_warn_on_import(name, replacement="langchain_community.llms.ForefrontAI")
|
||||||
|
|
||||||
return ForefrontAI
|
return ForefrontAI
|
||||||
elif name == "GooseAI":
|
if name == "GooseAI":
|
||||||
from langchain_community.llms import GooseAI
|
from langchain_community.llms import GooseAI
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.GooseAI")
|
_warn_on_import(name, replacement="langchain_community.llms.GooseAI")
|
||||||
|
|
||||||
return GooseAI
|
return GooseAI
|
||||||
elif name == "HuggingFaceHub":
|
if name == "HuggingFaceHub":
|
||||||
from langchain_community.llms import HuggingFaceHub
|
from langchain_community.llms import HuggingFaceHub
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.HuggingFaceHub")
|
_warn_on_import(name, replacement="langchain_community.llms.HuggingFaceHub")
|
||||||
|
|
||||||
return HuggingFaceHub
|
return HuggingFaceHub
|
||||||
elif name == "HuggingFaceTextGenInference":
|
if name == "HuggingFaceTextGenInference":
|
||||||
from langchain_community.llms import HuggingFaceTextGenInference
|
from langchain_community.llms import HuggingFaceTextGenInference
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -175,55 +175,55 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return HuggingFaceTextGenInference
|
return HuggingFaceTextGenInference
|
||||||
elif name == "LlamaCpp":
|
if name == "LlamaCpp":
|
||||||
from langchain_community.llms import LlamaCpp
|
from langchain_community.llms import LlamaCpp
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.LlamaCpp")
|
_warn_on_import(name, replacement="langchain_community.llms.LlamaCpp")
|
||||||
|
|
||||||
return LlamaCpp
|
return LlamaCpp
|
||||||
elif name == "Modal":
|
if name == "Modal":
|
||||||
from langchain_community.llms import Modal
|
from langchain_community.llms import Modal
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.Modal")
|
_warn_on_import(name, replacement="langchain_community.llms.Modal")
|
||||||
|
|
||||||
return Modal
|
return Modal
|
||||||
elif name == "OpenAI":
|
if name == "OpenAI":
|
||||||
from langchain_community.llms import OpenAI
|
from langchain_community.llms import OpenAI
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.OpenAI")
|
_warn_on_import(name, replacement="langchain_community.llms.OpenAI")
|
||||||
|
|
||||||
return OpenAI
|
return OpenAI
|
||||||
elif name == "Petals":
|
if name == "Petals":
|
||||||
from langchain_community.llms import Petals
|
from langchain_community.llms import Petals
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.Petals")
|
_warn_on_import(name, replacement="langchain_community.llms.Petals")
|
||||||
|
|
||||||
return Petals
|
return Petals
|
||||||
elif name == "PipelineAI":
|
if name == "PipelineAI":
|
||||||
from langchain_community.llms import PipelineAI
|
from langchain_community.llms import PipelineAI
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.PipelineAI")
|
_warn_on_import(name, replacement="langchain_community.llms.PipelineAI")
|
||||||
|
|
||||||
return PipelineAI
|
return PipelineAI
|
||||||
elif name == "SagemakerEndpoint":
|
if name == "SagemakerEndpoint":
|
||||||
from langchain_community.llms import SagemakerEndpoint
|
from langchain_community.llms import SagemakerEndpoint
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.SagemakerEndpoint")
|
_warn_on_import(name, replacement="langchain_community.llms.SagemakerEndpoint")
|
||||||
|
|
||||||
return SagemakerEndpoint
|
return SagemakerEndpoint
|
||||||
elif name == "StochasticAI":
|
if name == "StochasticAI":
|
||||||
from langchain_community.llms import StochasticAI
|
from langchain_community.llms import StochasticAI
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.StochasticAI")
|
_warn_on_import(name, replacement="langchain_community.llms.StochasticAI")
|
||||||
|
|
||||||
return StochasticAI
|
return StochasticAI
|
||||||
elif name == "Writer":
|
if name == "Writer":
|
||||||
from langchain_community.llms import Writer
|
from langchain_community.llms import Writer
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.llms.Writer")
|
_warn_on_import(name, replacement="langchain_community.llms.Writer")
|
||||||
|
|
||||||
return Writer
|
return Writer
|
||||||
elif name == "HuggingFacePipeline":
|
if name == "HuggingFacePipeline":
|
||||||
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -232,7 +232,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return HuggingFacePipeline
|
return HuggingFacePipeline
|
||||||
elif name == "FewShotPromptTemplate":
|
if name == "FewShotPromptTemplate":
|
||||||
from langchain_core.prompts import FewShotPromptTemplate
|
from langchain_core.prompts import FewShotPromptTemplate
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -240,7 +240,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return FewShotPromptTemplate
|
return FewShotPromptTemplate
|
||||||
elif name == "Prompt":
|
if name == "Prompt":
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate")
|
_warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate")
|
||||||
@ -248,19 +248,19 @@ def __getattr__(name: str) -> Any:
|
|||||||
# it's renamed as prompt template anyways
|
# it's renamed as prompt template anyways
|
||||||
# this is just for backwards compat
|
# this is just for backwards compat
|
||||||
return PromptTemplate
|
return PromptTemplate
|
||||||
elif name == "PromptTemplate":
|
if name == "PromptTemplate":
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate")
|
_warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate")
|
||||||
|
|
||||||
return PromptTemplate
|
return PromptTemplate
|
||||||
elif name == "BasePromptTemplate":
|
if name == "BasePromptTemplate":
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_core.prompts.BasePromptTemplate")
|
_warn_on_import(name, replacement="langchain_core.prompts.BasePromptTemplate")
|
||||||
|
|
||||||
return BasePromptTemplate
|
return BasePromptTemplate
|
||||||
elif name == "ArxivAPIWrapper":
|
if name == "ArxivAPIWrapper":
|
||||||
from langchain_community.utilities import ArxivAPIWrapper
|
from langchain_community.utilities import ArxivAPIWrapper
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -268,7 +268,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ArxivAPIWrapper
|
return ArxivAPIWrapper
|
||||||
elif name == "GoldenQueryAPIWrapper":
|
if name == "GoldenQueryAPIWrapper":
|
||||||
from langchain_community.utilities import GoldenQueryAPIWrapper
|
from langchain_community.utilities import GoldenQueryAPIWrapper
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -276,7 +276,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return GoldenQueryAPIWrapper
|
return GoldenQueryAPIWrapper
|
||||||
elif name == "GoogleSearchAPIWrapper":
|
if name == "GoogleSearchAPIWrapper":
|
||||||
from langchain_community.utilities import GoogleSearchAPIWrapper
|
from langchain_community.utilities import GoogleSearchAPIWrapper
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -284,7 +284,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return GoogleSearchAPIWrapper
|
return GoogleSearchAPIWrapper
|
||||||
elif name == "GoogleSerperAPIWrapper":
|
if name == "GoogleSerperAPIWrapper":
|
||||||
from langchain_community.utilities import GoogleSerperAPIWrapper
|
from langchain_community.utilities import GoogleSerperAPIWrapper
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -292,7 +292,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return GoogleSerperAPIWrapper
|
return GoogleSerperAPIWrapper
|
||||||
elif name == "PowerBIDataset":
|
if name == "PowerBIDataset":
|
||||||
from langchain_community.utilities import PowerBIDataset
|
from langchain_community.utilities import PowerBIDataset
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -300,7 +300,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return PowerBIDataset
|
return PowerBIDataset
|
||||||
elif name == "SearxSearchWrapper":
|
if name == "SearxSearchWrapper":
|
||||||
from langchain_community.utilities import SearxSearchWrapper
|
from langchain_community.utilities import SearxSearchWrapper
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -308,7 +308,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return SearxSearchWrapper
|
return SearxSearchWrapper
|
||||||
elif name == "WikipediaAPIWrapper":
|
if name == "WikipediaAPIWrapper":
|
||||||
from langchain_community.utilities import WikipediaAPIWrapper
|
from langchain_community.utilities import WikipediaAPIWrapper
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -316,7 +316,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return WikipediaAPIWrapper
|
return WikipediaAPIWrapper
|
||||||
elif name == "WolframAlphaAPIWrapper":
|
if name == "WolframAlphaAPIWrapper":
|
||||||
from langchain_community.utilities import WolframAlphaAPIWrapper
|
from langchain_community.utilities import WolframAlphaAPIWrapper
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -324,19 +324,19 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return WolframAlphaAPIWrapper
|
return WolframAlphaAPIWrapper
|
||||||
elif name == "SQLDatabase":
|
if name == "SQLDatabase":
|
||||||
from langchain_community.utilities import SQLDatabase
|
from langchain_community.utilities import SQLDatabase
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.utilities.SQLDatabase")
|
_warn_on_import(name, replacement="langchain_community.utilities.SQLDatabase")
|
||||||
|
|
||||||
return SQLDatabase
|
return SQLDatabase
|
||||||
elif name == "FAISS":
|
if name == "FAISS":
|
||||||
from langchain_community.vectorstores import FAISS
|
from langchain_community.vectorstores import FAISS
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain_community.vectorstores.FAISS")
|
_warn_on_import(name, replacement="langchain_community.vectorstores.FAISS")
|
||||||
|
|
||||||
return FAISS
|
return FAISS
|
||||||
elif name == "ElasticVectorSearch":
|
if name == "ElasticVectorSearch":
|
||||||
from langchain_community.vectorstores import ElasticVectorSearch
|
from langchain_community.vectorstores import ElasticVectorSearch
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -345,7 +345,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
|
|
||||||
return ElasticVectorSearch
|
return ElasticVectorSearch
|
||||||
# For backwards compatibility
|
# For backwards compatibility
|
||||||
elif name == "SerpAPIChain" or name == "SerpAPIWrapper":
|
if name == "SerpAPIChain" or name == "SerpAPIWrapper":
|
||||||
from langchain_community.utilities import SerpAPIWrapper
|
from langchain_community.utilities import SerpAPIWrapper
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -353,7 +353,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return SerpAPIWrapper
|
return SerpAPIWrapper
|
||||||
elif name == "verbose":
|
if name == "verbose":
|
||||||
from langchain.globals import _verbose
|
from langchain.globals import _verbose
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -364,7 +364,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return _verbose
|
return _verbose
|
||||||
elif name == "debug":
|
if name == "debug":
|
||||||
from langchain.globals import _debug
|
from langchain.globals import _debug
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -375,7 +375,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return _debug
|
return _debug
|
||||||
elif name == "llm_cache":
|
if name == "llm_cache":
|
||||||
from langchain.globals import _llm_cache
|
from langchain.globals import _llm_cache
|
||||||
|
|
||||||
_warn_on_import(
|
_warn_on_import(
|
||||||
@ -386,7 +386,6 @@ def __getattr__(name: str) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return _llm_cache
|
return _llm_cache
|
||||||
else:
|
|
||||||
msg = f"Could not find: {name}"
|
msg = f"Could not find: {name}"
|
||||||
raise AttributeError(msg)
|
raise AttributeError(msg)
|
||||||
|
|
||||||
|
@ -137,7 +137,6 @@ class BaseSingleActionAgent(BaseModel):
|
|||||||
return AgentFinish(
|
return AgentFinish(
|
||||||
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
msg = f"Got unsupported early_stopping_method `{early_stopping_method}`"
|
msg = f"Got unsupported early_stopping_method `{early_stopping_method}`"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@ -308,7 +307,6 @@ class BaseMultiActionAgent(BaseModel):
|
|||||||
if early_stopping_method == "force":
|
if early_stopping_method == "force":
|
||||||
# `force` just returns a constant string
|
# `force` just returns a constant string
|
||||||
return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
|
return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
|
||||||
else:
|
|
||||||
msg = f"Got unsupported early_stopping_method `{early_stopping_method}`"
|
msg = f"Got unsupported early_stopping_method `{early_stopping_method}`"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@ -815,8 +813,7 @@ class Agent(BaseSingleActionAgent):
|
|||||||
"""
|
"""
|
||||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||||
full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs)
|
full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs)
|
||||||
agent_output = await self.output_parser.aparse(full_output)
|
return await self.output_parser.aparse(full_output)
|
||||||
return agent_output
|
|
||||||
|
|
||||||
def get_full_inputs(
|
def get_full_inputs(
|
||||||
self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs: Any
|
self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs: Any
|
||||||
@ -833,8 +830,7 @@ class Agent(BaseSingleActionAgent):
|
|||||||
"""
|
"""
|
||||||
thoughts = self._construct_scratchpad(intermediate_steps)
|
thoughts = self._construct_scratchpad(intermediate_steps)
|
||||||
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
|
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
|
||||||
full_inputs = {**kwargs, **new_inputs}
|
return {**kwargs, **new_inputs}
|
||||||
return full_inputs
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> list[str]:
|
def input_keys(self) -> list[str]:
|
||||||
@ -970,7 +966,7 @@ class Agent(BaseSingleActionAgent):
|
|||||||
return AgentFinish(
|
return AgentFinish(
|
||||||
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
||||||
)
|
)
|
||||||
elif early_stopping_method == "generate":
|
if early_stopping_method == "generate":
|
||||||
# Generate does one final forward pass
|
# Generate does one final forward pass
|
||||||
thoughts = ""
|
thoughts = ""
|
||||||
for action, observation in intermediate_steps:
|
for action, observation in intermediate_steps:
|
||||||
@ -990,11 +986,9 @@ class Agent(BaseSingleActionAgent):
|
|||||||
if isinstance(parsed_output, AgentFinish):
|
if isinstance(parsed_output, AgentFinish):
|
||||||
# If we can extract, we send the correct stuff
|
# If we can extract, we send the correct stuff
|
||||||
return parsed_output
|
return parsed_output
|
||||||
else:
|
|
||||||
# If we can extract, but the tool is not the final tool,
|
# If we can extract, but the tool is not the final tool,
|
||||||
# we just return the full output
|
# we just return the full output
|
||||||
return AgentFinish({"output": full_output}, full_output)
|
return AgentFinish({"output": full_output}, full_output)
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
"early_stopping_method should be one of `force` or `generate`, "
|
"early_stopping_method should be one of `force` or `generate`, "
|
||||||
f"got {early_stopping_method}"
|
f"got {early_stopping_method}"
|
||||||
@ -1179,7 +1173,6 @@ class AgentExecutor(Chain):
|
|||||||
"""
|
"""
|
||||||
if isinstance(self.agent, Runnable):
|
if isinstance(self.agent, Runnable):
|
||||||
return cast(RunnableAgentType, self.agent)
|
return cast(RunnableAgentType, self.agent)
|
||||||
else:
|
|
||||||
return self.agent
|
return self.agent
|
||||||
|
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
@ -1249,7 +1242,6 @@ class AgentExecutor(Chain):
|
|||||||
"""
|
"""
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
return self._action_agent.return_values + ["intermediate_steps"]
|
return self._action_agent.return_values + ["intermediate_steps"]
|
||||||
else:
|
|
||||||
return self._action_agent.return_values
|
return self._action_agent.return_values
|
||||||
|
|
||||||
def lookup_tool(self, name: str) -> BaseTool:
|
def lookup_tool(self, name: str) -> BaseTool:
|
||||||
@ -1304,10 +1296,7 @@ class AgentExecutor(Chain):
|
|||||||
msg = "Expected a single AgentFinish output, but got multiple values."
|
msg = "Expected a single AgentFinish output, but got multiple values."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return values[-1]
|
return values[-1]
|
||||||
else:
|
return [(a.action, a.observation) for a in values if isinstance(a, AgentStep)]
|
||||||
return [
|
|
||||||
(a.action, a.observation) for a in values if isinstance(a, AgentStep)
|
|
||||||
]
|
|
||||||
|
|
||||||
def _take_next_step(
|
def _take_next_step(
|
||||||
self,
|
self,
|
||||||
@ -1727,9 +1716,8 @@ class AgentExecutor(Chain):
|
|||||||
and self.trim_intermediate_steps > 0
|
and self.trim_intermediate_steps > 0
|
||||||
):
|
):
|
||||||
return intermediate_steps[-self.trim_intermediate_steps :]
|
return intermediate_steps[-self.trim_intermediate_steps :]
|
||||||
elif callable(self.trim_intermediate_steps):
|
if callable(self.trim_intermediate_steps):
|
||||||
return self.trim_intermediate_steps(intermediate_steps)
|
return self.trim_intermediate_steps(intermediate_steps)
|
||||||
else:
|
|
||||||
return intermediate_steps
|
return intermediate_steps
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
@ -61,7 +61,6 @@ class ChatAgent(Agent):
|
|||||||
f"(but I haven't seen any of it! I only see what "
|
f"(but I haven't seen any of it! I only see what "
|
||||||
f"you return as final answer):\n{agent_scratchpad}"
|
f"you return as final answer):\n{agent_scratchpad}"
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return agent_scratchpad
|
return agent_scratchpad
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -39,11 +39,9 @@ class ConvoOutputParser(AgentOutputParser):
|
|||||||
# If the action indicates a final answer, return an AgentFinish
|
# If the action indicates a final answer, return an AgentFinish
|
||||||
if action == "Final Answer":
|
if action == "Final Answer":
|
||||||
return AgentFinish({"output": action_input}, text)
|
return AgentFinish({"output": action_input}, text)
|
||||||
else:
|
|
||||||
# Otherwise, return an AgentAction with the specified action and
|
# Otherwise, return an AgentAction with the specified action and
|
||||||
# input
|
# input
|
||||||
return AgentAction(action, action_input, text)
|
return AgentAction(action, action_input, text)
|
||||||
else:
|
|
||||||
# If the necessary keys aren't present in the response, raise an
|
# If the necessary keys aren't present in the response, raise an
|
||||||
# exception
|
# exception
|
||||||
msg = f"Missing 'action' or 'action_input' in LLM output: {text}"
|
msg = f"Missing 'action' or 'action_input' in LLM output: {text}"
|
||||||
|
@ -23,7 +23,6 @@ def _convert_agent_action_to_messages(
|
|||||||
return list(agent_action.message_log) + [
|
return list(agent_action.message_log) + [
|
||||||
_create_function_message(agent_action, observation)
|
_create_function_message(agent_action, observation)
|
||||||
]
|
]
|
||||||
else:
|
|
||||||
return [AIMessage(content=agent_action.log)]
|
return [AIMessage(content=agent_action.log)]
|
||||||
|
|
||||||
|
|
||||||
|
@ -182,7 +182,7 @@ def create_json_chat_agent(
|
|||||||
else:
|
else:
|
||||||
llm_to_use = llm
|
llm_to_use = llm
|
||||||
|
|
||||||
agent = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
agent_scratchpad=lambda x: format_log_to_messages(
|
agent_scratchpad=lambda x: format_log_to_messages(
|
||||||
x["intermediate_steps"], template_tool_response=template_tool_response
|
x["intermediate_steps"], template_tool_response=template_tool_response
|
||||||
@ -192,4 +192,3 @@ def create_json_chat_agent(
|
|||||||
| llm_to_use
|
| llm_to_use
|
||||||
| JSONAgentOutputParser()
|
| JSONAgentOutputParser()
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
@ -55,7 +55,6 @@ class MRKLOutputParser(AgentOutputParser):
|
|||||||
return AgentFinish(
|
return AgentFinish(
|
||||||
{"output": text[start_index:end_index].strip()}, text[:end_index]
|
{"output": text[start_index:end_index].strip()}, text[:end_index]
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
msg = f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
|
msg = f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
|
|
||||||
@ -69,7 +68,7 @@ class MRKLOutputParser(AgentOutputParser):
|
|||||||
|
|
||||||
return AgentAction(action, tool_input, text)
|
return AgentAction(action, tool_input, text)
|
||||||
|
|
||||||
elif includes_answer:
|
if includes_answer:
|
||||||
return AgentFinish(
|
return AgentFinish(
|
||||||
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
||||||
)
|
)
|
||||||
@ -82,7 +81,7 @@ class MRKLOutputParser(AgentOutputParser):
|
|||||||
llm_output=text,
|
llm_output=text,
|
||||||
send_to_llm=True,
|
send_to_llm=True,
|
||||||
)
|
)
|
||||||
elif not re.search(
|
if not re.search(
|
||||||
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
|
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
|
||||||
):
|
):
|
||||||
msg = f"Could not parse LLM output: `{text}`"
|
msg = f"Could not parse LLM output: `{text}`"
|
||||||
@ -92,7 +91,6 @@ class MRKLOutputParser(AgentOutputParser):
|
|||||||
llm_output=text,
|
llm_output=text,
|
||||||
send_to_llm=True,
|
send_to_llm=True,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
msg = f"Could not parse LLM output: `{text}`"
|
msg = f"Could not parse LLM output: `{text}`"
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
|
|
||||||
|
@ -128,7 +128,6 @@ def _get_assistants_tool(
|
|||||||
"""
|
"""
|
||||||
if _is_assistants_builtin_tool(tool):
|
if _is_assistants_builtin_tool(tool):
|
||||||
return tool # type: ignore[return-value]
|
return tool # type: ignore[return-value]
|
||||||
else:
|
|
||||||
return convert_to_openai_tool(tool)
|
return convert_to_openai_tool(tool)
|
||||||
|
|
||||||
|
|
||||||
@ -510,12 +509,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
for action, output in intermediate_steps
|
for action, output in intermediate_steps
|
||||||
if action.tool_call_id in required_tool_call_ids
|
if action.tool_call_id in required_tool_call_ids
|
||||||
]
|
]
|
||||||
submit_tool_outputs = {
|
return {
|
||||||
"tool_outputs": tool_outputs,
|
"tool_outputs": tool_outputs,
|
||||||
"run_id": last_action.run_id,
|
"run_id": last_action.run_id,
|
||||||
"thread_id": last_action.thread_id,
|
"thread_id": last_action.thread_id,
|
||||||
}
|
}
|
||||||
return submit_tool_outputs
|
|
||||||
|
|
||||||
def _create_run(self, input_dict: dict) -> Any:
|
def _create_run(self, input_dict: dict) -> Any:
|
||||||
params = {
|
params = {
|
||||||
@ -558,12 +556,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
"run_metadata",
|
"run_metadata",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
run = self.client.beta.threads.create_and_run(
|
return self.client.beta.threads.create_and_run(
|
||||||
assistant_id=self.assistant_id,
|
assistant_id=self.assistant_id,
|
||||||
thread=thread,
|
thread=thread,
|
||||||
**params,
|
**params,
|
||||||
)
|
)
|
||||||
return run
|
|
||||||
|
|
||||||
def _get_response(self, run: Any) -> Any:
|
def _get_response(self, run: Any) -> Any:
|
||||||
# TODO: Pagination
|
# TODO: Pagination
|
||||||
@ -612,7 +609,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
run_id=run.id,
|
run_id=run.id,
|
||||||
thread_id=run.thread_id,
|
thread_id=run.thread_id,
|
||||||
)
|
)
|
||||||
elif run.status == "requires_action":
|
if run.status == "requires_action":
|
||||||
if not self.as_agent:
|
if not self.as_agent:
|
||||||
return run.required_action.submit_tool_outputs.tool_calls
|
return run.required_action.submit_tool_outputs.tool_calls
|
||||||
actions = []
|
actions = []
|
||||||
@ -639,7 +636,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
return actions
|
return actions
|
||||||
else:
|
|
||||||
run_info = json.dumps(run.dict(), indent=2)
|
run_info = json.dumps(run.dict(), indent=2)
|
||||||
msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})"
|
msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
@ -668,12 +664,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
for action, output in intermediate_steps
|
for action, output in intermediate_steps
|
||||||
if action.tool_call_id in required_tool_call_ids
|
if action.tool_call_id in required_tool_call_ids
|
||||||
]
|
]
|
||||||
submit_tool_outputs = {
|
return {
|
||||||
"tool_outputs": tool_outputs,
|
"tool_outputs": tool_outputs,
|
||||||
"run_id": last_action.run_id,
|
"run_id": last_action.run_id,
|
||||||
"thread_id": last_action.thread_id,
|
"thread_id": last_action.thread_id,
|
||||||
}
|
}
|
||||||
return submit_tool_outputs
|
|
||||||
|
|
||||||
async def _acreate_run(self, input_dict: dict) -> Any:
|
async def _acreate_run(self, input_dict: dict) -> Any:
|
||||||
params = {
|
params = {
|
||||||
@ -716,12 +711,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
"run_metadata",
|
"run_metadata",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
run = await self.async_client.beta.threads.create_and_run(
|
return await self.async_client.beta.threads.create_and_run(
|
||||||
assistant_id=self.assistant_id,
|
assistant_id=self.assistant_id,
|
||||||
thread=thread,
|
thread=thread,
|
||||||
**params,
|
**params,
|
||||||
)
|
)
|
||||||
return run
|
|
||||||
|
|
||||||
async def _aget_response(self, run: Any) -> Any:
|
async def _aget_response(self, run: Any) -> Any:
|
||||||
# TODO: Pagination
|
# TODO: Pagination
|
||||||
@ -766,7 +760,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
run_id=run.id,
|
run_id=run.id,
|
||||||
thread_id=run.thread_id,
|
thread_id=run.thread_id,
|
||||||
)
|
)
|
||||||
elif run.status == "requires_action":
|
if run.status == "requires_action":
|
||||||
if not self.as_agent:
|
if not self.as_agent:
|
||||||
return run.required_action.submit_tool_outputs.tool_calls
|
return run.required_action.submit_tool_outputs.tool_calls
|
||||||
actions = []
|
actions = []
|
||||||
@ -793,7 +787,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
return actions
|
return actions
|
||||||
else:
|
|
||||||
run_info = json.dumps(run.dict(), indent=2)
|
run_info = json.dumps(run.dict(), indent=2)
|
||||||
msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})"
|
msg = f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
@ -132,8 +132,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
messages,
|
messages,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
agent_decision = self.output_parser._parse_ai_message(predicted_message)
|
return self.output_parser._parse_ai_message(predicted_message)
|
||||||
return agent_decision
|
|
||||||
|
|
||||||
async def aplan(
|
async def aplan(
|
||||||
self,
|
self,
|
||||||
@ -164,8 +163,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
predicted_message = await self.llm.apredict_messages(
|
predicted_message = await self.llm.apredict_messages(
|
||||||
messages, functions=self.functions, callbacks=callbacks
|
messages, functions=self.functions, callbacks=callbacks
|
||||||
)
|
)
|
||||||
agent_decision = self.output_parser._parse_ai_message(predicted_message)
|
return self.output_parser._parse_ai_message(predicted_message)
|
||||||
return agent_decision
|
|
||||||
|
|
||||||
def return_stopped_response(
|
def return_stopped_response(
|
||||||
self,
|
self,
|
||||||
@ -192,17 +190,15 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
return AgentFinish(
|
return AgentFinish(
|
||||||
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
||||||
)
|
)
|
||||||
elif early_stopping_method == "generate":
|
if early_stopping_method == "generate":
|
||||||
# Generate does one final forward pass
|
# Generate does one final forward pass
|
||||||
agent_decision = self.plan(
|
agent_decision = self.plan(
|
||||||
intermediate_steps, with_functions=False, **kwargs
|
intermediate_steps, with_functions=False, **kwargs
|
||||||
)
|
)
|
||||||
if isinstance(agent_decision, AgentFinish):
|
if isinstance(agent_decision, AgentFinish):
|
||||||
return agent_decision
|
return agent_decision
|
||||||
else:
|
|
||||||
msg = f"got AgentAction with no functions provided: {agent_decision}"
|
msg = f"got AgentAction with no functions provided: {agent_decision}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
"early_stopping_method should be one of `force` or `generate`, "
|
"early_stopping_method should be one of `force` or `generate`, "
|
||||||
f"got {early_stopping_method}"
|
f"got {early_stopping_method}"
|
||||||
@ -358,7 +354,7 @@ def create_openai_functions_agent(
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
llm_with_tools = llm.bind(functions=[convert_to_openai_function(t) for t in tools])
|
llm_with_tools = llm.bind(functions=[convert_to_openai_function(t) for t in tools])
|
||||||
agent = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
agent_scratchpad=lambda x: format_to_openai_function_messages(
|
agent_scratchpad=lambda x: format_to_openai_function_messages(
|
||||||
x["intermediate_steps"]
|
x["intermediate_steps"]
|
||||||
@ -368,4 +364,3 @@ def create_openai_functions_agent(
|
|||||||
| llm_with_tools
|
| llm_with_tools
|
||||||
| OpenAIFunctionsAgentOutputParser()
|
| OpenAIFunctionsAgentOutputParser()
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
@ -224,8 +224,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
|||||||
predicted_message = self.llm.predict_messages(
|
predicted_message = self.llm.predict_messages(
|
||||||
messages, functions=self.functions, callbacks=callbacks
|
messages, functions=self.functions, callbacks=callbacks
|
||||||
)
|
)
|
||||||
agent_decision = _parse_ai_message(predicted_message)
|
return _parse_ai_message(predicted_message)
|
||||||
return agent_decision
|
|
||||||
|
|
||||||
async def aplan(
|
async def aplan(
|
||||||
self,
|
self,
|
||||||
@ -254,8 +253,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
|||||||
predicted_message = await self.llm.apredict_messages(
|
predicted_message = await self.llm.apredict_messages(
|
||||||
messages, functions=self.functions, callbacks=callbacks
|
messages, functions=self.functions, callbacks=callbacks
|
||||||
)
|
)
|
||||||
agent_decision = _parse_ai_message(predicted_message)
|
return _parse_ai_message(predicted_message)
|
||||||
return agent_decision
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_prompt(
|
def create_prompt(
|
||||||
|
@ -96,7 +96,7 @@ def create_openai_tools_agent(
|
|||||||
tools=[convert_to_openai_tool(tool, strict=strict) for tool in tools]
|
tools=[convert_to_openai_tool(tool, strict=strict) for tool in tools]
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
agent_scratchpad=lambda x: format_to_openai_tool_messages(
|
agent_scratchpad=lambda x: format_to_openai_tool_messages(
|
||||||
x["intermediate_steps"]
|
x["intermediate_steps"]
|
||||||
@ -106,4 +106,3 @@ def create_openai_tools_agent(
|
|||||||
| llm_with_tools
|
| llm_with_tools
|
||||||
| OpenAIToolsAgentOutputParser()
|
| OpenAIToolsAgentOutputParser()
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
@ -49,7 +49,6 @@ class JSONAgentOutputParser(AgentOutputParser):
|
|||||||
response = response[0]
|
response = response[0]
|
||||||
if response["action"] == "Final Answer":
|
if response["action"] == "Final Answer":
|
||||||
return AgentFinish({"output": response["action_input"]}, text)
|
return AgentFinish({"output": response["action_input"]}, text)
|
||||||
else:
|
|
||||||
action_input = response.get("action_input", {})
|
action_input = response.get("action_input", {})
|
||||||
if action_input is None:
|
if action_input is None:
|
||||||
action_input = {}
|
action_input = {}
|
||||||
|
@ -65,7 +65,7 @@ class ReActSingleInputOutputParser(AgentOutputParser):
|
|||||||
|
|
||||||
return AgentAction(action, tool_input, text)
|
return AgentAction(action, tool_input, text)
|
||||||
|
|
||||||
elif includes_answer:
|
if includes_answer:
|
||||||
return AgentFinish(
|
return AgentFinish(
|
||||||
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
||||||
)
|
)
|
||||||
@ -78,7 +78,7 @@ class ReActSingleInputOutputParser(AgentOutputParser):
|
|||||||
llm_output=text,
|
llm_output=text,
|
||||||
send_to_llm=True,
|
send_to_llm=True,
|
||||||
)
|
)
|
||||||
elif not re.search(
|
if not re.search(
|
||||||
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
|
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
|
||||||
):
|
):
|
||||||
msg = f"Could not parse LLM output: `{text}`"
|
msg = f"Could not parse LLM output: `{text}`"
|
||||||
@ -88,7 +88,6 @@ class ReActSingleInputOutputParser(AgentOutputParser):
|
|||||||
llm_output=text,
|
llm_output=text,
|
||||||
send_to_llm=True,
|
send_to_llm=True,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
msg = f"Could not parse LLM output: `{text}`"
|
msg = f"Could not parse LLM output: `{text}`"
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
|
|
||||||
|
@ -36,12 +36,11 @@ class XMLAgentOutputParser(AgentOutputParser):
|
|||||||
if "</tool_input>" in _tool_input:
|
if "</tool_input>" in _tool_input:
|
||||||
_tool_input = _tool_input.split("</tool_input>")[0]
|
_tool_input = _tool_input.split("</tool_input>")[0]
|
||||||
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
|
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
|
||||||
elif "<final_answer>" in text:
|
if "<final_answer>" in text:
|
||||||
_, answer = text.split("<final_answer>")
|
_, answer = text.split("<final_answer>")
|
||||||
if "</final_answer>" in answer:
|
if "</final_answer>" in answer:
|
||||||
answer = answer.split("</final_answer>")[0]
|
answer = answer.split("</final_answer>")[0]
|
||||||
return AgentFinish(return_values={"output": answer}, log=text)
|
return AgentFinish(return_values={"output": answer}, log=text)
|
||||||
else:
|
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
def get_format_instructions(self) -> str:
|
def get_format_instructions(self) -> str:
|
||||||
|
@ -134,7 +134,7 @@ def create_react_agent(
|
|||||||
else:
|
else:
|
||||||
llm_with_stop = llm
|
llm_with_stop = llm
|
||||||
output_parser = output_parser or ReActSingleInputOutputParser()
|
output_parser = output_parser or ReActSingleInputOutputParser()
|
||||||
agent = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
|
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
|
||||||
)
|
)
|
||||||
@ -142,4 +142,3 @@ def create_react_agent(
|
|||||||
| llm_with_stop
|
| llm_with_stop
|
||||||
| output_parser
|
| output_parser
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
@ -96,7 +96,6 @@ class DocstoreExplorer:
|
|||||||
if isinstance(result, Document):
|
if isinstance(result, Document):
|
||||||
self.document = result
|
self.document = result
|
||||||
return self._summary
|
return self._summary
|
||||||
else:
|
|
||||||
self.document = None
|
self.document = None
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -113,9 +112,8 @@ class DocstoreExplorer:
|
|||||||
lookups = [p for p in self._paragraphs if self.lookup_str in p.lower()]
|
lookups = [p for p in self._paragraphs if self.lookup_str in p.lower()]
|
||||||
if len(lookups) == 0:
|
if len(lookups) == 0:
|
||||||
return "No Results"
|
return "No Results"
|
||||||
elif self.lookup_index >= len(lookups):
|
if self.lookup_index >= len(lookups):
|
||||||
return "No More Results"
|
return "No More Results"
|
||||||
else:
|
|
||||||
result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})"
|
result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})"
|
||||||
return f"{result_prefix} {lookups[self.lookup_index]}"
|
return f"{result_prefix} {lookups[self.lookup_index]}"
|
||||||
|
|
||||||
|
@ -26,7 +26,6 @@ class ReActOutputParser(AgentOutputParser):
|
|||||||
action, action_input = re_matches.group(1), re_matches.group(2)
|
action, action_input = re_matches.group(1), re_matches.group(2)
|
||||||
if action == "Finish":
|
if action == "Finish":
|
||||||
return AgentFinish({"output": action_input}, text)
|
return AgentFinish({"output": action_input}, text)
|
||||||
else:
|
|
||||||
return AgentAction(action, action_input, text)
|
return AgentAction(action, action_input, text)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -195,7 +195,7 @@ def create_self_ask_with_search_agent(
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
llm_with_stop = llm.bind(stop=["\nIntermediate answer:"])
|
llm_with_stop = llm.bind(stop=["\nIntermediate answer:"])
|
||||||
agent = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
agent_scratchpad=lambda x: format_log_to_str(
|
agent_scratchpad=lambda x: format_log_to_str(
|
||||||
x["intermediate_steps"],
|
x["intermediate_steps"],
|
||||||
@ -209,4 +209,3 @@ def create_self_ask_with_search_agent(
|
|||||||
| llm_with_stop
|
| llm_with_stop
|
||||||
| SelfAskOutputParser()
|
| SelfAskOutputParser()
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
@ -62,7 +62,6 @@ class StructuredChatAgent(Agent):
|
|||||||
f"(but I haven't seen any of it! I only see what "
|
f"(but I haven't seen any of it! I only see what "
|
||||||
f"you return as final answer):\n{agent_scratchpad}"
|
f"you return as final answer):\n{agent_scratchpad}"
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return agent_scratchpad
|
return agent_scratchpad
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -292,7 +291,7 @@ def create_structured_chat_agent(
|
|||||||
else:
|
else:
|
||||||
llm_with_stop = llm
|
llm_with_stop = llm
|
||||||
|
|
||||||
agent = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
|
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
|
||||||
)
|
)
|
||||||
@ -300,4 +299,3 @@ def create_structured_chat_agent(
|
|||||||
| llm_with_stop
|
| llm_with_stop
|
||||||
| JSONAgentOutputParser()
|
| JSONAgentOutputParser()
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
@ -42,11 +42,9 @@ class StructuredChatOutputParser(AgentOutputParser):
|
|||||||
response = response[0]
|
response = response[0]
|
||||||
if response["action"] == "Final Answer":
|
if response["action"] == "Final Answer":
|
||||||
return AgentFinish({"output": response["action_input"]}, text)
|
return AgentFinish({"output": response["action_input"]}, text)
|
||||||
else:
|
|
||||||
return AgentAction(
|
return AgentAction(
|
||||||
response["action"], response.get("action_input", {}), text
|
response["action"], response.get("action_input", {}), text
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return AgentFinish({"output": text}, text)
|
return AgentFinish({"output": text}, text)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"Could not parse LLM output: {text}"
|
msg = f"Could not parse LLM output: {text}"
|
||||||
@ -93,9 +91,8 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
|||||||
llm=llm, parser=base_parser
|
llm=llm, parser=base_parser
|
||||||
)
|
)
|
||||||
return cls(output_fixing_parser=output_fixing_parser)
|
return cls(output_fixing_parser=output_fixing_parser)
|
||||||
elif base_parser is not None:
|
if base_parser is not None:
|
||||||
return cls(base_parser=base_parser)
|
return cls(base_parser=base_parser)
|
||||||
else:
|
|
||||||
return cls()
|
return cls()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -100,7 +100,7 @@ def create_tool_calling_agent(
|
|||||||
)
|
)
|
||||||
llm_with_tools = llm.bind_tools(tools)
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
|
||||||
agent = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
agent_scratchpad=lambda x: message_formatter(x["intermediate_steps"])
|
agent_scratchpad=lambda x: message_formatter(x["intermediate_steps"])
|
||||||
)
|
)
|
||||||
@ -108,4 +108,3 @@ def create_tool_calling_agent(
|
|||||||
| llm_with_tools
|
| llm_with_tools
|
||||||
| ToolsAgentOutputParser()
|
| ToolsAgentOutputParser()
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
@ -221,7 +221,7 @@ def create_xml_agent(
|
|||||||
else:
|
else:
|
||||||
llm_with_stop = llm
|
llm_with_stop = llm
|
||||||
|
|
||||||
agent = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
agent_scratchpad=lambda x: format_xml(x["intermediate_steps"]),
|
agent_scratchpad=lambda x: format_xml(x["intermediate_steps"]),
|
||||||
)
|
)
|
||||||
@ -229,4 +229,3 @@ def create_xml_agent(
|
|||||||
| llm_with_stop
|
| llm_with_stop
|
||||||
| XMLAgentOutputParser()
|
| XMLAgentOutputParser()
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
@ -24,7 +24,6 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
def check_if_answer_reached(self) -> bool:
|
def check_if_answer_reached(self) -> bool:
|
||||||
if self.strip_tokens:
|
if self.strip_tokens:
|
||||||
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
|
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
|
||||||
else:
|
|
||||||
return self.last_tokens == self.answer_prefix_tokens
|
return self.last_tokens == self.answer_prefix_tokens
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -25,7 +25,6 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
|||||||
def check_if_answer_reached(self) -> bool:
|
def check_if_answer_reached(self) -> bool:
|
||||||
if self.strip_tokens:
|
if self.strip_tokens:
|
||||||
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
|
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
|
||||||
else:
|
|
||||||
return self.last_tokens == self.answer_prefix_tokens
|
return self.last_tokens == self.answer_prefix_tokens
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -261,7 +261,6 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|||||||
"""
|
"""
|
||||||
if verbose is None:
|
if verbose is None:
|
||||||
return _get_verbosity()
|
return _get_verbosity()
|
||||||
else:
|
|
||||||
return verbose
|
return verbose
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -474,7 +473,6 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|||||||
self.memory.save_context(inputs, outputs)
|
self.memory.save_context(inputs, outputs)
|
||||||
if return_only_outputs:
|
if return_only_outputs:
|
||||||
return outputs
|
return outputs
|
||||||
else:
|
|
||||||
return {**inputs, **outputs}
|
return {**inputs, **outputs}
|
||||||
|
|
||||||
async def aprep_outputs(
|
async def aprep_outputs(
|
||||||
@ -500,7 +498,6 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|||||||
await self.memory.asave_context(inputs, outputs)
|
await self.memory.asave_context(inputs, outputs)
|
||||||
if return_only_outputs:
|
if return_only_outputs:
|
||||||
return outputs
|
return outputs
|
||||||
else:
|
|
||||||
return {**inputs, **outputs}
|
return {**inputs, **outputs}
|
||||||
|
|
||||||
def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
|
def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
|
||||||
@ -628,7 +625,6 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|||||||
" but none were provided."
|
" but none were provided."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"`run` supported with either positional arguments or keyword arguments"
|
f"`run` supported with either positional arguments or keyword arguments"
|
||||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||||
@ -687,7 +683,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|||||||
f"one output key. Got {self.output_keys}."
|
f"one output key. Got {self.output_keys}."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif args and not kwargs:
|
if args and not kwargs:
|
||||||
if len(args) != 1:
|
if len(args) != 1:
|
||||||
msg = "`run` supports only one positional argument."
|
msg = "`run` supports only one positional argument."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
@ -208,9 +208,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
|
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
|
||||||
if self.reduce_documents_chain.collapse_documents_chain:
|
if self.reduce_documents_chain.collapse_documents_chain:
|
||||||
return self.reduce_documents_chain.collapse_documents_chain
|
return self.reduce_documents_chain.collapse_documents_chain
|
||||||
else:
|
|
||||||
return self.reduce_documents_chain.combine_documents_chain
|
return self.reduce_documents_chain.combine_documents_chain
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"`reduce_documents_chain` is of type "
|
f"`reduce_documents_chain` is of type "
|
||||||
f"{type(self.reduce_documents_chain)} so it does not have "
|
f"{type(self.reduce_documents_chain)} so it does not have "
|
||||||
@ -223,7 +221,6 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
"""Kept for backward compatibility."""
|
"""Kept for backward compatibility."""
|
||||||
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
|
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
|
||||||
return self.reduce_documents_chain.combine_documents_chain
|
return self.reduce_documents_chain.combine_documents_chain
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"`reduce_documents_chain` is of type "
|
f"`reduce_documents_chain` is of type "
|
||||||
f"{type(self.reduce_documents_chain)} so it does not have "
|
f"{type(self.reduce_documents_chain)} so it does not have "
|
||||||
|
@ -225,7 +225,6 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
def _collapse_chain(self) -> BaseCombineDocumentsChain:
|
def _collapse_chain(self) -> BaseCombineDocumentsChain:
|
||||||
if self.collapse_documents_chain is not None:
|
if self.collapse_documents_chain is not None:
|
||||||
return self.collapse_documents_chain
|
return self.collapse_documents_chain
|
||||||
else:
|
|
||||||
return self.combine_documents_chain
|
return self.combine_documents_chain
|
||||||
|
|
||||||
def combine_docs(
|
def combine_docs(
|
||||||
|
@ -222,8 +222,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
base_inputs: dict = {
|
base_inputs: dict = {
|
||||||
self.document_variable_name: self.document_prompt.format(**document_info)
|
self.document_variable_name: self.document_prompt.format(**document_info)
|
||||||
}
|
}
|
||||||
inputs = {**base_inputs, **kwargs}
|
return {**base_inputs, **kwargs}
|
||||||
return inputs
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _chain_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
|
@ -201,7 +201,6 @@ class ConstitutionalChain(Chain):
|
|||||||
) -> list[ConstitutionalPrinciple]:
|
) -> list[ConstitutionalPrinciple]:
|
||||||
if names is None:
|
if names is None:
|
||||||
return list(PRINCIPLES.values())
|
return list(PRINCIPLES.values())
|
||||||
else:
|
|
||||||
return [PRINCIPLES[name] for name in names]
|
return [PRINCIPLES[name] for name in names]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -80,7 +80,6 @@ class ElasticsearchDatabaseChain(Chain):
|
|||||||
"""
|
"""
|
||||||
if not self.return_intermediate_steps:
|
if not self.return_intermediate_steps:
|
||||||
return [self.output_key]
|
return [self.output_key]
|
||||||
else:
|
|
||||||
return [self.output_key, INTERMEDIATE_STEPS_KEY]
|
return [self.output_key, INTERMEDIATE_STEPS_KEY]
|
||||||
|
|
||||||
def _list_indices(self) -> list[str]:
|
def _list_indices(self) -> list[str]:
|
||||||
|
@ -47,7 +47,6 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|||||||
"""Output keys for Hyde's LLM chain."""
|
"""Output keys for Hyde's LLM chain."""
|
||||||
if isinstance(self.llm_chain, LLMChain):
|
if isinstance(self.llm_chain, LLMChain):
|
||||||
return self.llm_chain.output_keys
|
return self.llm_chain.output_keys
|
||||||
else:
|
|
||||||
return ["text"]
|
return ["text"]
|
||||||
|
|
||||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
@ -116,7 +116,6 @@ class LLMChain(Chain):
|
|||||||
"""
|
"""
|
||||||
if self.return_final_only:
|
if self.return_final_only:
|
||||||
return [self.output_key]
|
return [self.output_key]
|
||||||
else:
|
|
||||||
return [self.output_key, "full_generation"]
|
return [self.output_key, "full_generation"]
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
@ -142,7 +141,6 @@ class LLMChain(Chain):
|
|||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**self.llm_kwargs,
|
**self.llm_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
|
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
|
||||||
cast(list, prompts), {"callbacks": callbacks}
|
cast(list, prompts), {"callbacks": callbacks}
|
||||||
)
|
)
|
||||||
@ -169,7 +167,6 @@ class LLMChain(Chain):
|
|||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**self.llm_kwargs,
|
**self.llm_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
|
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
|
||||||
cast(list, prompts), {"callbacks": callbacks}
|
cast(list, prompts), {"callbacks": callbacks}
|
||||||
)
|
)
|
||||||
@ -344,7 +341,6 @@ class LLMChain(Chain):
|
|||||||
result = self.predict(callbacks=callbacks, **kwargs)
|
result = self.predict(callbacks=callbacks, **kwargs)
|
||||||
if self.prompt.output_parser is not None:
|
if self.prompt.output_parser is not None:
|
||||||
return self.prompt.output_parser.parse(result)
|
return self.prompt.output_parser.parse(result)
|
||||||
else:
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def apredict_and_parse(
|
async def apredict_and_parse(
|
||||||
@ -358,7 +354,6 @@ class LLMChain(Chain):
|
|||||||
result = await self.apredict(callbacks=callbacks, **kwargs)
|
result = await self.apredict(callbacks=callbacks, **kwargs)
|
||||||
if self.prompt.output_parser is not None:
|
if self.prompt.output_parser is not None:
|
||||||
return self.prompt.output_parser.parse(result)
|
return self.prompt.output_parser.parse(result)
|
||||||
else:
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def apply_and_parse(
|
def apply_and_parse(
|
||||||
@ -380,7 +375,6 @@ class LLMChain(Chain):
|
|||||||
self.prompt.output_parser.parse(res[self.output_key])
|
self.prompt.output_parser.parse(res[self.output_key])
|
||||||
for res in generation
|
for res in generation
|
||||||
]
|
]
|
||||||
else:
|
|
||||||
return generation
|
return generation
|
||||||
|
|
||||||
async def aapply_and_parse(
|
async def aapply_and_parse(
|
||||||
@ -411,13 +405,12 @@ class LLMChain(Chain):
|
|||||||
def _get_language_model(llm_like: Runnable) -> BaseLanguageModel:
|
def _get_language_model(llm_like: Runnable) -> BaseLanguageModel:
|
||||||
if isinstance(llm_like, BaseLanguageModel):
|
if isinstance(llm_like, BaseLanguageModel):
|
||||||
return llm_like
|
return llm_like
|
||||||
elif isinstance(llm_like, RunnableBinding):
|
if isinstance(llm_like, RunnableBinding):
|
||||||
return _get_language_model(llm_like.bound)
|
return _get_language_model(llm_like.bound)
|
||||||
elif isinstance(llm_like, RunnableWithFallbacks):
|
if isinstance(llm_like, RunnableWithFallbacks):
|
||||||
return _get_language_model(llm_like.runnable)
|
return _get_language_model(llm_like.runnable)
|
||||||
elif isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
|
if isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
|
||||||
return _get_language_model(llm_like.default)
|
return _get_language_model(llm_like.default)
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Unable to extract BaseLanguageModel from llm_like object of type "
|
f"Unable to extract BaseLanguageModel from llm_like object of type "
|
||||||
f"{type(llm_like)}"
|
f"{type(llm_like)}"
|
||||||
|
@ -55,13 +55,12 @@ def _load_question_to_checked_assertions_chain(
|
|||||||
check_assertions_chain,
|
check_assertions_chain,
|
||||||
revised_answer_chain,
|
revised_answer_chain,
|
||||||
]
|
]
|
||||||
question_to_checked_assertions_chain = SequentialChain(
|
return SequentialChain(
|
||||||
chains=chains, # type: ignore[arg-type]
|
chains=chains, # type: ignore[arg-type]
|
||||||
input_variables=["question"],
|
input_variables=["question"],
|
||||||
output_variables=["revised_statement"],
|
output_variables=["revised_statement"],
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
return question_to_checked_assertions_chain
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
|
@ -32,7 +32,7 @@ def _load_sequential_chain(
|
|||||||
are_all_true_prompt: PromptTemplate,
|
are_all_true_prompt: PromptTemplate,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> SequentialChain:
|
) -> SequentialChain:
|
||||||
chain = SequentialChain(
|
return SequentialChain(
|
||||||
chains=[
|
chains=[
|
||||||
LLMChain(
|
LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
@ -63,7 +63,6 @@ def _load_sequential_chain(
|
|||||||
output_variables=["all_true", "revised_summary"],
|
output_variables=["all_true", "revised_summary"],
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
return chain
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
|
@ -311,7 +311,6 @@ def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain:
|
|||||||
prompt = load_prompt(config.pop("prompt_path"))
|
prompt = load_prompt(config.pop("prompt_path"))
|
||||||
if llm_chain:
|
if llm_chain:
|
||||||
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config) # type: ignore[arg-type]
|
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config) # type: ignore[arg-type]
|
||||||
else:
|
|
||||||
return LLMMathChain(llm=llm, prompt=prompt, **config)
|
return LLMMathChain(llm=llm, prompt=prompt, **config)
|
||||||
|
|
||||||
|
|
||||||
@ -609,7 +608,6 @@ def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain:
|
|||||||
return LLMRequestsChain(
|
return LLMRequestsChain(
|
||||||
llm_chain=llm_chain, requests_wrapper=requests_wrapper, **config
|
llm_chain=llm_chain, requests_wrapper=requests_wrapper, **config
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return LLMRequestsChain(llm_chain=llm_chain, **config)
|
return LLMRequestsChain(llm_chain=llm_chain, **config)
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,7 +100,6 @@ class OpenAIModerationChain(Chain):
|
|||||||
error_str = "Text was found that violates OpenAI's content policy."
|
error_str = "Text was found that violates OpenAI's content policy."
|
||||||
if self.error:
|
if self.error:
|
||||||
raise ValueError(error_str)
|
raise ValueError(error_str)
|
||||||
else:
|
|
||||||
return error_str
|
return error_str
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
@ -132,7 +132,7 @@ def create_openai_fn_chain(
|
|||||||
}
|
}
|
||||||
if len(openai_functions) == 1 and enforce_single_function_usage:
|
if len(openai_functions) == 1 and enforce_single_function_usage:
|
||||||
llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]}
|
llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]}
|
||||||
llm_chain = LLMChain(
|
return LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
@ -140,7 +140,6 @@ def create_openai_fn_chain(
|
|||||||
output_key=output_key,
|
output_key=output_key,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return llm_chain
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
|
@ -149,10 +149,9 @@ def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
|
|||||||
]
|
]
|
||||||
prompt = ChatPromptTemplate(messages=messages) # type: ignore[arg-type]
|
prompt = ChatPromptTemplate(messages=messages) # type: ignore[arg-type]
|
||||||
|
|
||||||
chain = LLMChain(
|
return LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
)
|
)
|
||||||
return chain
|
|
||||||
|
@ -103,7 +103,7 @@ def create_extraction_chain(
|
|||||||
extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||||
output_parser = JsonKeyOutputFunctionsParser(key_name="info")
|
output_parser = JsonKeyOutputFunctionsParser(key_name="info")
|
||||||
llm_kwargs = get_llm_kwargs(function)
|
llm_kwargs = get_llm_kwargs(function)
|
||||||
chain = LLMChain(
|
return LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=extraction_prompt,
|
prompt=extraction_prompt,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
@ -111,7 +111,6 @@ def create_extraction_chain(
|
|||||||
tags=tags,
|
tags=tags,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
return chain
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
@ -187,11 +186,10 @@ def create_extraction_chain_pydantic(
|
|||||||
pydantic_schema=PydanticSchema, attr_name="info"
|
pydantic_schema=PydanticSchema, attr_name="info"
|
||||||
)
|
)
|
||||||
llm_kwargs = get_llm_kwargs(function)
|
llm_kwargs = get_llm_kwargs(function)
|
||||||
chain = LLMChain(
|
return LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=extraction_prompt,
|
prompt=extraction_prompt,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
return chain
|
|
||||||
|
@ -100,14 +100,13 @@ def create_qa_with_structure_chain(
|
|||||||
]
|
]
|
||||||
prompt = prompt or ChatPromptTemplate(messages=messages) # type: ignore[arg-type]
|
prompt = prompt or ChatPromptTemplate(messages=messages) # type: ignore[arg-type]
|
||||||
|
|
||||||
chain = LLMChain(
|
return LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
output_parser=_output_parser,
|
output_parser=_output_parser,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
return chain
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
|
@ -91,14 +91,13 @@ def create_tagging_chain(
|
|||||||
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||||
output_parser = JsonOutputFunctionsParser()
|
output_parser = JsonOutputFunctionsParser()
|
||||||
llm_kwargs = get_llm_kwargs(function)
|
llm_kwargs = get_llm_kwargs(function)
|
||||||
chain = LLMChain(
|
return LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return chain
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
@ -164,11 +163,10 @@ def create_tagging_chain_pydantic(
|
|||||||
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||||
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
|
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
|
||||||
llm_kwargs = get_llm_kwargs(function)
|
llm_kwargs = get_llm_kwargs(function)
|
||||||
chain = LLMChain(
|
return LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return chain
|
|
||||||
|
@ -76,5 +76,4 @@ def create_extraction_chain_pydantic(
|
|||||||
functions = [convert_pydantic_to_openai_function(p) for p in pydantic_schemas]
|
functions = [convert_pydantic_to_openai_function(p) for p in pydantic_schemas]
|
||||||
tools = [{"type": "function", "function": d} for d in functions]
|
tools = [{"type": "function", "function": d} for d in functions]
|
||||||
model = llm.bind(tools=tools)
|
model = llm.bind(tools=tools)
|
||||||
chain = prompt | model | PydanticToolsParser(tools=pydantic_schemas)
|
return prompt | model | PydanticToolsParser(tools=pydantic_schemas)
|
||||||
return chain
|
|
||||||
|
@ -91,13 +91,12 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
|||||||
filter_directive = cast(
|
filter_directive = cast(
|
||||||
Optional[FilterDirective], get_parser().parse(raw_filter)
|
Optional[FilterDirective], get_parser().parse(raw_filter)
|
||||||
)
|
)
|
||||||
fixed = fix_filter_directive(
|
return fix_filter_directive(
|
||||||
filter_directive,
|
filter_directive,
|
||||||
allowed_comparators=allowed_comparators,
|
allowed_comparators=allowed_comparators,
|
||||||
allowed_operators=allowed_operators,
|
allowed_operators=allowed_operators,
|
||||||
allowed_attributes=allowed_attributes,
|
allowed_attributes=allowed_attributes,
|
||||||
)
|
)
|
||||||
return fixed
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
ast_parse = get_parser(
|
ast_parse = get_parser(
|
||||||
@ -131,13 +130,13 @@ def fix_filter_directive(
|
|||||||
) or not filter:
|
) or not filter:
|
||||||
return filter
|
return filter
|
||||||
|
|
||||||
elif isinstance(filter, Comparison):
|
if isinstance(filter, Comparison):
|
||||||
if allowed_comparators and filter.comparator not in allowed_comparators:
|
if allowed_comparators and filter.comparator not in allowed_comparators:
|
||||||
return None
|
return None
|
||||||
if allowed_attributes and filter.attribute not in allowed_attributes:
|
if allowed_attributes and filter.attribute not in allowed_attributes:
|
||||||
return None
|
return None
|
||||||
return filter
|
return filter
|
||||||
elif isinstance(filter, Operation):
|
if isinstance(filter, Operation):
|
||||||
if allowed_operators and filter.operator not in allowed_operators:
|
if allowed_operators and filter.operator not in allowed_operators:
|
||||||
return None
|
return None
|
||||||
args = [
|
args = [
|
||||||
@ -155,14 +154,12 @@ def fix_filter_directive(
|
|||||||
]
|
]
|
||||||
if not args:
|
if not args:
|
||||||
return None
|
return None
|
||||||
elif len(args) == 1 and filter.operator in (Operator.AND, Operator.OR):
|
if len(args) == 1 and filter.operator in (Operator.AND, Operator.OR):
|
||||||
return args[0]
|
return args[0]
|
||||||
else:
|
|
||||||
return Operation(
|
return Operation(
|
||||||
operator=filter.operator,
|
operator=filter.operator,
|
||||||
arguments=args,
|
arguments=args,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return filter
|
return filter
|
||||||
|
|
||||||
|
|
||||||
|
@ -101,9 +101,8 @@ class QueryTransformer(Transformer):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return Comparison(comparator=func, attribute=args[0], value=args[1])
|
return Comparison(comparator=func, attribute=args[0], value=args[1])
|
||||||
elif len(args) == 1 and func in (Operator.AND, Operator.OR):
|
if len(args) == 1 and func in (Operator.AND, Operator.OR):
|
||||||
return args[0]
|
return args[0]
|
||||||
else:
|
|
||||||
return Operation(operator=func, arguments=args)
|
return Operation(operator=func, arguments=args)
|
||||||
|
|
||||||
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
|
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
|
||||||
@ -118,7 +117,7 @@ class QueryTransformer(Transformer):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return Comparator(func_name)
|
return Comparator(func_name)
|
||||||
elif func_name in set(Operator):
|
if func_name in set(Operator):
|
||||||
if (
|
if (
|
||||||
self.allowed_operators is not None
|
self.allowed_operators is not None
|
||||||
and func_name not in self.allowed_operators
|
and func_name not in self.allowed_operators
|
||||||
@ -129,7 +128,6 @@ class QueryTransformer(Transformer):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return Operator(func_name)
|
return Operator(func_name)
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Received unrecognized function {func_name}. Valid functions are "
|
f"Received unrecognized function {func_name}. Valid functions are "
|
||||||
f"{list(Operator) + list(Comparator)}"
|
f"{list(Operator) + list(Comparator)}"
|
||||||
|
@ -60,10 +60,8 @@ def create_retrieval_chain(
|
|||||||
else:
|
else:
|
||||||
retrieval_docs = (lambda x: x["input"]) | retriever
|
retrieval_docs = (lambda x: x["input"]) | retriever
|
||||||
|
|
||||||
retrieval_chain = (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
context=retrieval_docs.with_config(run_name="retrieve_documents"),
|
context=retrieval_docs.with_config(run_name="retrieve_documents"),
|
||||||
).assign(answer=combine_docs_chain)
|
).assign(answer=combine_docs_chain)
|
||||||
).with_config(run_name="retrieval_chain")
|
).with_config(run_name="retrieval_chain")
|
||||||
|
|
||||||
return retrieval_chain
|
|
||||||
|
@ -157,7 +157,6 @@ class BaseRetrievalQA(Chain):
|
|||||||
|
|
||||||
if self.return_source_documents:
|
if self.return_source_documents:
|
||||||
return {self.output_key: answer, "source_documents": docs}
|
return {self.output_key: answer, "source_documents": docs}
|
||||||
else:
|
|
||||||
return {self.output_key: answer}
|
return {self.output_key: answer}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -200,7 +199,6 @@ class BaseRetrievalQA(Chain):
|
|||||||
|
|
||||||
if self.return_source_documents:
|
if self.return_source_documents:
|
||||||
return {self.output_key: answer, "source_documents": docs}
|
return {self.output_key: answer, "source_documents": docs}
|
||||||
else:
|
|
||||||
return {self.output_key: answer}
|
return {self.output_key: answer}
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,13 +97,12 @@ class MultiRouteChain(Chain):
|
|||||||
)
|
)
|
||||||
if not route.destination:
|
if not route.destination:
|
||||||
return self.default_chain(route.next_inputs, callbacks=callbacks)
|
return self.default_chain(route.next_inputs, callbacks=callbacks)
|
||||||
elif route.destination in self.destination_chains:
|
if route.destination in self.destination_chains:
|
||||||
return self.destination_chains[route.destination](
|
return self.destination_chains[route.destination](
|
||||||
route.next_inputs, callbacks=callbacks
|
route.next_inputs, callbacks=callbacks
|
||||||
)
|
)
|
||||||
elif self.silent_errors:
|
if self.silent_errors:
|
||||||
return self.default_chain(route.next_inputs, callbacks=callbacks)
|
return self.default_chain(route.next_inputs, callbacks=callbacks)
|
||||||
else:
|
|
||||||
msg = f"Received invalid destination chain name '{route.destination}'"
|
msg = f"Received invalid destination chain name '{route.destination}'"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@ -123,14 +122,13 @@ class MultiRouteChain(Chain):
|
|||||||
return await self.default_chain.acall(
|
return await self.default_chain.acall(
|
||||||
route.next_inputs, callbacks=callbacks
|
route.next_inputs, callbacks=callbacks
|
||||||
)
|
)
|
||||||
elif route.destination in self.destination_chains:
|
if route.destination in self.destination_chains:
|
||||||
return await self.destination_chains[route.destination].acall(
|
return await self.destination_chains[route.destination].acall(
|
||||||
route.next_inputs, callbacks=callbacks
|
route.next_inputs, callbacks=callbacks
|
||||||
)
|
)
|
||||||
elif self.silent_errors:
|
if self.silent_errors:
|
||||||
return await self.default_chain.acall(
|
return await self.default_chain.acall(
|
||||||
route.next_inputs, callbacks=callbacks
|
route.next_inputs, callbacks=callbacks
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
msg = f"Received invalid destination chain name '{route.destination}'"
|
msg = f"Received invalid destination chain name '{route.destination}'"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
@ -136,11 +136,10 @@ class LLMRouterChain(RouterChain):
|
|||||||
callbacks = _run_manager.get_child()
|
callbacks = _run_manager.get_child()
|
||||||
|
|
||||||
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
|
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
|
||||||
output = cast(
|
return cast(
|
||||||
dict[str, Any],
|
dict[str, Any],
|
||||||
self.llm_chain.prompt.output_parser.parse(prediction),
|
self.llm_chain.prompt.output_parser.parse(prediction),
|
||||||
)
|
)
|
||||||
return output
|
|
||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
@ -149,11 +148,10 @@ class LLMRouterChain(RouterChain):
|
|||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
callbacks = _run_manager.get_child()
|
callbacks = _run_manager.get_child()
|
||||||
output = cast(
|
return cast(
|
||||||
dict[str, Any],
|
dict[str, Any],
|
||||||
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
|
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
|
||||||
)
|
)
|
||||||
return output
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
|
@ -141,7 +141,6 @@ def create_sql_query_chain(
|
|||||||
f"{db.dialect}"
|
f"{db.dialect}"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
|
||||||
table_info_kwargs["get_col_comments"] = True
|
table_info_kwargs["get_col_comments"] = True
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
|
@ -143,7 +143,6 @@ def create_openai_fn_runnable(
|
|||||||
output_parser = output_parser or get_openai_output_parser(functions)
|
output_parser = output_parser or get_openai_output_parser(functions)
|
||||||
if prompt:
|
if prompt:
|
||||||
return prompt | llm.bind(**llm_kwargs_) | output_parser
|
return prompt | llm.bind(**llm_kwargs_) | output_parser
|
||||||
else:
|
|
||||||
return llm.bind(**llm_kwargs_) | output_parser
|
return llm.bind(**llm_kwargs_) | output_parser
|
||||||
|
|
||||||
|
|
||||||
@ -413,7 +412,7 @@ def create_structured_output_runnable(
|
|||||||
first_tool_only=return_single,
|
first_tool_only=return_single,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif mode == "openai-functions":
|
if mode == "openai-functions":
|
||||||
return _create_openai_functions_structured_output_runnable(
|
return _create_openai_functions_structured_output_runnable(
|
||||||
output_schema,
|
output_schema,
|
||||||
llm,
|
llm,
|
||||||
@ -422,7 +421,7 @@ def create_structured_output_runnable(
|
|||||||
enforce_single_function_usage=force_function_usage,
|
enforce_single_function_usage=force_function_usage,
|
||||||
**kwargs, # llm-specific kwargs
|
**kwargs, # llm-specific kwargs
|
||||||
)
|
)
|
||||||
elif mode == "openai-json":
|
if mode == "openai-json":
|
||||||
if force_function_usage:
|
if force_function_usage:
|
||||||
msg = (
|
msg = (
|
||||||
"enforce_single_function_usage is not supported for mode='openai-json'."
|
"enforce_single_function_usage is not supported for mode='openai-json'."
|
||||||
@ -431,7 +430,6 @@ def create_structured_output_runnable(
|
|||||||
return _create_openai_json_runnable(
|
return _create_openai_json_runnable(
|
||||||
output_schema, llm, prompt=prompt, output_parser=output_parser, **kwargs
|
output_schema, llm, prompt=prompt, output_parser=output_parser, **kwargs
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Invalid mode {mode}. Expected one of 'openai-tools', 'openai-functions', "
|
f"Invalid mode {mode}. Expected one of 'openai-tools', 'openai-functions', "
|
||||||
f"'openai-json'."
|
f"'openai-json'."
|
||||||
@ -460,7 +458,6 @@ def _create_openai_tools_runnable(
|
|||||||
)
|
)
|
||||||
if prompt:
|
if prompt:
|
||||||
return prompt | llm.bind(**llm_kwargs) | output_parser
|
return prompt | llm.bind(**llm_kwargs) | output_parser
|
||||||
else:
|
|
||||||
return llm.bind(**llm_kwargs) | output_parser
|
return llm.bind(**llm_kwargs) | output_parser
|
||||||
|
|
||||||
|
|
||||||
@ -535,7 +532,6 @@ def _create_openai_json_runnable(
|
|||||||
prompt = prompt.partial(output_schema=json.dumps(schema_as_dict, indent=2))
|
prompt = prompt.partial(output_schema=json.dumps(schema_as_dict, indent=2))
|
||||||
|
|
||||||
return prompt | llm | output_parser
|
return prompt | llm | output_parser
|
||||||
else:
|
|
||||||
return llm | output_parser
|
return llm | output_parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -77,7 +77,6 @@ class TransformChain(Chain):
|
|||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if self.atransform_cb is not None:
|
if self.atransform_cb is not None:
|
||||||
return await self.atransform_cb(inputs)
|
return await self.atransform_cb(inputs)
|
||||||
else:
|
|
||||||
self._log_once(
|
self._log_once(
|
||||||
"TransformChain's atransform is not provided, falling"
|
"TransformChain's atransform is not provided, falling"
|
||||||
" back to synchronous transform"
|
" back to synchronous transform"
|
||||||
|
@ -322,7 +322,6 @@ def init_chat_model(
|
|||||||
return _init_chat_model_helper(
|
return _init_chat_model_helper(
|
||||||
cast(str, model), model_provider=model_provider, **kwargs
|
cast(str, model), model_provider=model_provider, **kwargs
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
if model:
|
if model:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
if model_provider:
|
if model_provider:
|
||||||
@ -343,42 +342,42 @@ def _init_chat_model_helper(
|
|||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
return ChatOpenAI(model=model, **kwargs)
|
return ChatOpenAI(model=model, **kwargs)
|
||||||
elif model_provider == "anthropic":
|
if model_provider == "anthropic":
|
||||||
_check_pkg("langchain_anthropic")
|
_check_pkg("langchain_anthropic")
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
|
||||||
return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
||||||
elif model_provider == "azure_openai":
|
if model_provider == "azure_openai":
|
||||||
_check_pkg("langchain_openai")
|
_check_pkg("langchain_openai")
|
||||||
from langchain_openai import AzureChatOpenAI
|
from langchain_openai import AzureChatOpenAI
|
||||||
|
|
||||||
return AzureChatOpenAI(model=model, **kwargs)
|
return AzureChatOpenAI(model=model, **kwargs)
|
||||||
elif model_provider == "azure_ai":
|
if model_provider == "azure_ai":
|
||||||
_check_pkg("langchain_azure_ai")
|
_check_pkg("langchain_azure_ai")
|
||||||
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
|
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
|
||||||
|
|
||||||
return AzureAIChatCompletionsModel(model=model, **kwargs)
|
return AzureAIChatCompletionsModel(model=model, **kwargs)
|
||||||
elif model_provider == "cohere":
|
if model_provider == "cohere":
|
||||||
_check_pkg("langchain_cohere")
|
_check_pkg("langchain_cohere")
|
||||||
from langchain_cohere import ChatCohere
|
from langchain_cohere import ChatCohere
|
||||||
|
|
||||||
return ChatCohere(model=model, **kwargs)
|
return ChatCohere(model=model, **kwargs)
|
||||||
elif model_provider == "google_vertexai":
|
if model_provider == "google_vertexai":
|
||||||
_check_pkg("langchain_google_vertexai")
|
_check_pkg("langchain_google_vertexai")
|
||||||
from langchain_google_vertexai import ChatVertexAI
|
from langchain_google_vertexai import ChatVertexAI
|
||||||
|
|
||||||
return ChatVertexAI(model=model, **kwargs)
|
return ChatVertexAI(model=model, **kwargs)
|
||||||
elif model_provider == "google_genai":
|
if model_provider == "google_genai":
|
||||||
_check_pkg("langchain_google_genai")
|
_check_pkg("langchain_google_genai")
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
|
||||||
return ChatGoogleGenerativeAI(model=model, **kwargs)
|
return ChatGoogleGenerativeAI(model=model, **kwargs)
|
||||||
elif model_provider == "fireworks":
|
if model_provider == "fireworks":
|
||||||
_check_pkg("langchain_fireworks")
|
_check_pkg("langchain_fireworks")
|
||||||
from langchain_fireworks import ChatFireworks
|
from langchain_fireworks import ChatFireworks
|
||||||
|
|
||||||
return ChatFireworks(model=model, **kwargs)
|
return ChatFireworks(model=model, **kwargs)
|
||||||
elif model_provider == "ollama":
|
if model_provider == "ollama":
|
||||||
try:
|
try:
|
||||||
_check_pkg("langchain_ollama")
|
_check_pkg("langchain_ollama")
|
||||||
from langchain_ollama import ChatOllama
|
from langchain_ollama import ChatOllama
|
||||||
@ -393,72 +392,70 @@ def _init_chat_model_helper(
|
|||||||
_check_pkg("langchain_ollama")
|
_check_pkg("langchain_ollama")
|
||||||
|
|
||||||
return ChatOllama(model=model, **kwargs)
|
return ChatOllama(model=model, **kwargs)
|
||||||
elif model_provider == "together":
|
if model_provider == "together":
|
||||||
_check_pkg("langchain_together")
|
_check_pkg("langchain_together")
|
||||||
from langchain_together import ChatTogether
|
from langchain_together import ChatTogether
|
||||||
|
|
||||||
return ChatTogether(model=model, **kwargs)
|
return ChatTogether(model=model, **kwargs)
|
||||||
elif model_provider == "mistralai":
|
if model_provider == "mistralai":
|
||||||
_check_pkg("langchain_mistralai")
|
_check_pkg("langchain_mistralai")
|
||||||
from langchain_mistralai import ChatMistralAI
|
from langchain_mistralai import ChatMistralAI
|
||||||
|
|
||||||
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
||||||
elif model_provider == "huggingface":
|
if model_provider == "huggingface":
|
||||||
_check_pkg("langchain_huggingface")
|
_check_pkg("langchain_huggingface")
|
||||||
from langchain_huggingface import ChatHuggingFace
|
from langchain_huggingface import ChatHuggingFace
|
||||||
|
|
||||||
return ChatHuggingFace(model_id=model, **kwargs)
|
return ChatHuggingFace(model_id=model, **kwargs)
|
||||||
elif model_provider == "groq":
|
if model_provider == "groq":
|
||||||
_check_pkg("langchain_groq")
|
_check_pkg("langchain_groq")
|
||||||
from langchain_groq import ChatGroq
|
from langchain_groq import ChatGroq
|
||||||
|
|
||||||
return ChatGroq(model=model, **kwargs)
|
return ChatGroq(model=model, **kwargs)
|
||||||
elif model_provider == "bedrock":
|
if model_provider == "bedrock":
|
||||||
_check_pkg("langchain_aws")
|
_check_pkg("langchain_aws")
|
||||||
from langchain_aws import ChatBedrock
|
from langchain_aws import ChatBedrock
|
||||||
|
|
||||||
# TODO: update to use model= once ChatBedrock supports
|
# TODO: update to use model= once ChatBedrock supports
|
||||||
return ChatBedrock(model_id=model, **kwargs)
|
return ChatBedrock(model_id=model, **kwargs)
|
||||||
elif model_provider == "bedrock_converse":
|
if model_provider == "bedrock_converse":
|
||||||
_check_pkg("langchain_aws")
|
_check_pkg("langchain_aws")
|
||||||
from langchain_aws import ChatBedrockConverse
|
from langchain_aws import ChatBedrockConverse
|
||||||
|
|
||||||
return ChatBedrockConverse(model=model, **kwargs)
|
return ChatBedrockConverse(model=model, **kwargs)
|
||||||
elif model_provider == "google_anthropic_vertex":
|
if model_provider == "google_anthropic_vertex":
|
||||||
_check_pkg("langchain_google_vertexai")
|
_check_pkg("langchain_google_vertexai")
|
||||||
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
|
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
|
||||||
|
|
||||||
return ChatAnthropicVertex(model=model, **kwargs)
|
return ChatAnthropicVertex(model=model, **kwargs)
|
||||||
elif model_provider == "deepseek":
|
if model_provider == "deepseek":
|
||||||
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
|
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
|
||||||
from langchain_deepseek import ChatDeepSeek
|
from langchain_deepseek import ChatDeepSeek
|
||||||
|
|
||||||
return ChatDeepSeek(model=model, **kwargs)
|
return ChatDeepSeek(model=model, **kwargs)
|
||||||
elif model_provider == "nvidia":
|
if model_provider == "nvidia":
|
||||||
_check_pkg("langchain_nvidia_ai_endpoints")
|
_check_pkg("langchain_nvidia_ai_endpoints")
|
||||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||||||
|
|
||||||
return ChatNVIDIA(model=model, **kwargs)
|
return ChatNVIDIA(model=model, **kwargs)
|
||||||
elif model_provider == "ibm":
|
if model_provider == "ibm":
|
||||||
_check_pkg("langchain_ibm")
|
_check_pkg("langchain_ibm")
|
||||||
from langchain_ibm import ChatWatsonx
|
from langchain_ibm import ChatWatsonx
|
||||||
|
|
||||||
return ChatWatsonx(model_id=model, **kwargs)
|
return ChatWatsonx(model_id=model, **kwargs)
|
||||||
elif model_provider == "xai":
|
if model_provider == "xai":
|
||||||
_check_pkg("langchain_xai")
|
_check_pkg("langchain_xai")
|
||||||
from langchain_xai import ChatXAI
|
from langchain_xai import ChatXAI
|
||||||
|
|
||||||
return ChatXAI(model=model, **kwargs)
|
return ChatXAI(model=model, **kwargs)
|
||||||
elif model_provider == "perplexity":
|
if model_provider == "perplexity":
|
||||||
_check_pkg("langchain_perplexity")
|
_check_pkg("langchain_perplexity")
|
||||||
from langchain_perplexity import ChatPerplexity
|
from langchain_perplexity import ChatPerplexity
|
||||||
|
|
||||||
return ChatPerplexity(model=model, **kwargs)
|
return ChatPerplexity(model=model, **kwargs)
|
||||||
else:
|
|
||||||
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
||||||
msg = (
|
msg = (
|
||||||
f"Unsupported {model_provider=}.\n\nSupported model providers are: "
|
f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
|
||||||
f"{supported}"
|
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@ -490,25 +487,24 @@ _SUPPORTED_PROVIDERS = {
|
|||||||
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
||||||
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
|
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
|
||||||
return "openai"
|
return "openai"
|
||||||
elif model_name.startswith("claude"):
|
if model_name.startswith("claude"):
|
||||||
return "anthropic"
|
return "anthropic"
|
||||||
elif model_name.startswith("command"):
|
if model_name.startswith("command"):
|
||||||
return "cohere"
|
return "cohere"
|
||||||
elif model_name.startswith("accounts/fireworks"):
|
if model_name.startswith("accounts/fireworks"):
|
||||||
return "fireworks"
|
return "fireworks"
|
||||||
elif model_name.startswith("gemini"):
|
if model_name.startswith("gemini"):
|
||||||
return "google_vertexai"
|
return "google_vertexai"
|
||||||
elif model_name.startswith("amazon."):
|
if model_name.startswith("amazon."):
|
||||||
return "bedrock"
|
return "bedrock"
|
||||||
elif model_name.startswith("mistral"):
|
if model_name.startswith("mistral"):
|
||||||
return "mistralai"
|
return "mistralai"
|
||||||
elif model_name.startswith("deepseek"):
|
if model_name.startswith("deepseek"):
|
||||||
return "deepseek"
|
return "deepseek"
|
||||||
elif model_name.startswith("grok"):
|
if model_name.startswith("grok"):
|
||||||
return "xai"
|
return "xai"
|
||||||
elif model_name.startswith("sonar"):
|
if model_name.startswith("sonar"):
|
||||||
return "perplexity"
|
return "perplexity"
|
||||||
else:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -595,9 +591,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return queue
|
return queue
|
||||||
elif self._default_config and (model := self._model()) and hasattr(model, name):
|
if self._default_config and (model := self._model()) and hasattr(model, name):
|
||||||
return getattr(model, name)
|
return getattr(model, name)
|
||||||
else:
|
|
||||||
msg = f"{name} is not a BaseChatModel attribute"
|
msg = f"{name} is not a BaseChatModel attribute"
|
||||||
if self._default_config:
|
if self._default_config:
|
||||||
msg += " and is not implemented on the default model"
|
msg += " and is not implemented on the default model"
|
||||||
@ -728,7 +723,6 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
)
|
)
|
||||||
# If multiple configs default to Runnable.batch which uses executor to invoke
|
# If multiple configs default to Runnable.batch which uses executor to invoke
|
||||||
# in parallel.
|
# in parallel.
|
||||||
else:
|
|
||||||
return super().batch(
|
return super().batch(
|
||||||
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
||||||
)
|
)
|
||||||
@ -751,7 +745,6 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
)
|
)
|
||||||
# If multiple configs default to Runnable.batch which uses executor to invoke
|
# If multiple configs default to Runnable.batch which uses executor to invoke
|
||||||
# in parallel.
|
# in parallel.
|
||||||
else:
|
|
||||||
return await super().abatch(
|
return await super().abatch(
|
||||||
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
||||||
)
|
)
|
||||||
|
@ -192,35 +192,34 @@ def init_embeddings(
|
|||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
|
||||||
return OpenAIEmbeddings(model=model_name, **kwargs)
|
return OpenAIEmbeddings(model=model_name, **kwargs)
|
||||||
elif provider == "azure_openai":
|
if provider == "azure_openai":
|
||||||
from langchain_openai import AzureOpenAIEmbeddings
|
from langchain_openai import AzureOpenAIEmbeddings
|
||||||
|
|
||||||
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
|
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
|
||||||
elif provider == "google_vertexai":
|
if provider == "google_vertexai":
|
||||||
from langchain_google_vertexai import VertexAIEmbeddings
|
from langchain_google_vertexai import VertexAIEmbeddings
|
||||||
|
|
||||||
return VertexAIEmbeddings(model=model_name, **kwargs)
|
return VertexAIEmbeddings(model=model_name, **kwargs)
|
||||||
elif provider == "bedrock":
|
if provider == "bedrock":
|
||||||
from langchain_aws import BedrockEmbeddings
|
from langchain_aws import BedrockEmbeddings
|
||||||
|
|
||||||
return BedrockEmbeddings(model_id=model_name, **kwargs)
|
return BedrockEmbeddings(model_id=model_name, **kwargs)
|
||||||
elif provider == "cohere":
|
if provider == "cohere":
|
||||||
from langchain_cohere import CohereEmbeddings
|
from langchain_cohere import CohereEmbeddings
|
||||||
|
|
||||||
return CohereEmbeddings(model=model_name, **kwargs)
|
return CohereEmbeddings(model=model_name, **kwargs)
|
||||||
elif provider == "mistralai":
|
if provider == "mistralai":
|
||||||
from langchain_mistralai import MistralAIEmbeddings
|
from langchain_mistralai import MistralAIEmbeddings
|
||||||
|
|
||||||
return MistralAIEmbeddings(model=model_name, **kwargs)
|
return MistralAIEmbeddings(model=model_name, **kwargs)
|
||||||
elif provider == "huggingface":
|
if provider == "huggingface":
|
||||||
from langchain_huggingface import HuggingFaceEmbeddings
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
|
||||||
return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
||||||
elif provider == "ollama":
|
if provider == "ollama":
|
||||||
from langchain_ollama import OllamaEmbeddings
|
from langchain_ollama import OllamaEmbeddings
|
||||||
|
|
||||||
return OllamaEmbeddings(model=model_name, **kwargs)
|
return OllamaEmbeddings(model=model_name, **kwargs)
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Provider '{provider}' is not supported.\n"
|
f"Provider '{provider}' is not supported.\n"
|
||||||
f"Supported providers and their required packages:\n"
|
f"Supported providers and their required packages:\n"
|
||||||
|
@ -70,7 +70,7 @@ def resolve_pairwise_criteria(
|
|||||||
Criteria.DEPTH,
|
Criteria.DEPTH,
|
||||||
]
|
]
|
||||||
return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria}
|
return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria}
|
||||||
elif isinstance(criteria, Criteria):
|
if isinstance(criteria, Criteria):
|
||||||
criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]}
|
criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]}
|
||||||
elif isinstance(criteria, str):
|
elif isinstance(criteria, str):
|
||||||
if criteria in _SUPPORTED_CRITERIA:
|
if criteria in _SUPPORTED_CRITERIA:
|
||||||
|
@ -186,7 +186,6 @@ class _EmbeddingDistanceChainMixin(Chain):
|
|||||||
}
|
}
|
||||||
if metric in metrics:
|
if metric in metrics:
|
||||||
return metrics[metric]
|
return metrics[metric]
|
||||||
else:
|
|
||||||
msg = f"Invalid metric: {metric}"
|
msg = f"Invalid metric: {metric}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
@ -162,7 +162,6 @@ def load_evaluator(
|
|||||||
)
|
)
|
||||||
raise ValueError(msg) from e
|
raise ValueError(msg) from e
|
||||||
return evaluator_cls.from_llm(llm=llm, **kwargs)
|
return evaluator_cls.from_llm(llm=llm, **kwargs)
|
||||||
else:
|
|
||||||
return evaluator_cls(**kwargs)
|
return evaluator_cls(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ class JsonSchemaEvaluator(StringEvaluator):
|
|||||||
def _parse_json(self, node: Any) -> Union[dict, list, None, float, bool, int, str]:
|
def _parse_json(self, node: Any) -> Union[dict, list, None, float, bool, int, str]:
|
||||||
if isinstance(node, str):
|
if isinstance(node, str):
|
||||||
return parse_json_markdown(node)
|
return parse_json_markdown(node)
|
||||||
elif hasattr(node, "schema") and callable(getattr(node, "schema")):
|
if hasattr(node, "schema") and callable(getattr(node, "schema")):
|
||||||
# Pydantic model
|
# Pydantic model
|
||||||
return getattr(node, "schema")()
|
return getattr(node, "schema")()
|
||||||
return node
|
return node
|
||||||
|
@ -24,7 +24,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]:
|
|||||||
if match:
|
if match:
|
||||||
if match.group(1).upper() == "CORRECT":
|
if match.group(1).upper() == "CORRECT":
|
||||||
return "CORRECT", 1
|
return "CORRECT", 1
|
||||||
elif match.group(1).upper() == "INCORRECT":
|
if match.group(1).upper() == "INCORRECT":
|
||||||
return "INCORRECT", 0
|
return "INCORRECT", 0
|
||||||
try:
|
try:
|
||||||
first_word = (
|
first_word = (
|
||||||
@ -32,7 +32,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]:
|
|||||||
)
|
)
|
||||||
if first_word.upper() == "CORRECT":
|
if first_word.upper() == "CORRECT":
|
||||||
return "CORRECT", 1
|
return "CORRECT", 1
|
||||||
elif first_word.upper() == "INCORRECT":
|
if first_word.upper() == "INCORRECT":
|
||||||
return "INCORRECT", 0
|
return "INCORRECT", 0
|
||||||
last_word = (
|
last_word = (
|
||||||
text.strip()
|
text.strip()
|
||||||
@ -41,7 +41,7 @@ def _get_score(text: str) -> Optional[tuple[str, int]]:
|
|||||||
)
|
)
|
||||||
if last_word.upper() == "CORRECT":
|
if last_word.upper() == "CORRECT":
|
||||||
return "CORRECT", 1
|
return "CORRECT", 1
|
||||||
elif last_word.upper() == "INCORRECT":
|
if last_word.upper() == "INCORRECT":
|
||||||
return "INCORRECT", 0
|
return "INCORRECT", 0
|
||||||
except IndexError:
|
except IndexError:
|
||||||
pass
|
pass
|
||||||
|
@ -123,12 +123,12 @@ class _EvalArgsMixin:
|
|||||||
if self.requires_input and input is None:
|
if self.requires_input and input is None:
|
||||||
msg = f"{self.__class__.__name__} requires an input string."
|
msg = f"{self.__class__.__name__} requires an input string."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif input is not None and not self.requires_input:
|
if input is not None and not self.requires_input:
|
||||||
warn(self._skip_input_warning)
|
warn(self._skip_input_warning)
|
||||||
if self.requires_reference and reference is None:
|
if self.requires_reference and reference is None:
|
||||||
msg = f"{self.__class__.__name__} requires a reference string."
|
msg = f"{self.__class__.__name__} requires a reference string."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif reference is not None and not self.requires_reference:
|
if reference is not None and not self.requires_reference:
|
||||||
warn(self._skip_reference_warning)
|
warn(self._skip_reference_warning)
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ def resolve_criteria(
|
|||||||
Criteria.DEPTH,
|
Criteria.DEPTH,
|
||||||
]
|
]
|
||||||
return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria}
|
return {k.value: _SUPPORTED_CRITERIA[k] for k in _default_criteria}
|
||||||
elif isinstance(criteria, Criteria):
|
if isinstance(criteria, Criteria):
|
||||||
criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]}
|
criteria_ = {criteria.value: _SUPPORTED_CRITERIA[criteria]}
|
||||||
elif isinstance(criteria, str):
|
elif isinstance(criteria, str):
|
||||||
if criteria in _SUPPORTED_CRITERIA:
|
if criteria in _SUPPORTED_CRITERIA:
|
||||||
|
@ -138,7 +138,6 @@ class _RapidFuzzChainMixin(Chain):
|
|||||||
module = module_map[distance]
|
module = module_map[distance]
|
||||||
if normalize_score:
|
if normalize_score:
|
||||||
return module.normalized_distance
|
return module.normalized_distance
|
||||||
else:
|
|
||||||
return module.distance
|
return module.distance
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -21,7 +21,6 @@ def _get_client(
|
|||||||
ls_client = LangSmithClient(api_url, api_key=api_key)
|
ls_client = LangSmithClient(api_url, api_key=api_key)
|
||||||
if hasattr(ls_client, "push_prompt") and hasattr(ls_client, "pull_prompt"):
|
if hasattr(ls_client, "push_prompt") and hasattr(ls_client, "pull_prompt"):
|
||||||
return ls_client
|
return ls_client
|
||||||
else:
|
|
||||||
from langchainhub import Client as LangChainHubClient
|
from langchainhub import Client as LangChainHubClient
|
||||||
|
|
||||||
return LangChainHubClient(api_url, api_key=api_key)
|
return LangChainHubClient(api_url, api_key=api_key)
|
||||||
@ -82,14 +81,13 @@ def push(
|
|||||||
|
|
||||||
# Then it's langchainhub
|
# Then it's langchainhub
|
||||||
manifest_json = dumps(object)
|
manifest_json = dumps(object)
|
||||||
message = client.push(
|
return client.push(
|
||||||
repo_full_name,
|
repo_full_name,
|
||||||
manifest_json,
|
manifest_json,
|
||||||
parent_commit_hash=parent_commit_hash,
|
parent_commit_hash=parent_commit_hash,
|
||||||
new_repo_is_public=new_repo_is_public,
|
new_repo_is_public=new_repo_is_public,
|
||||||
new_repo_description=new_repo_description,
|
new_repo_description=new_repo_description,
|
||||||
)
|
)
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
def pull(
|
def pull(
|
||||||
@ -113,8 +111,7 @@ def pull(
|
|||||||
|
|
||||||
# Then it's langsmith
|
# Then it's langsmith
|
||||||
if hasattr(client, "pull_prompt"):
|
if hasattr(client, "pull_prompt"):
|
||||||
response = client.pull_prompt(owner_repo_commit, include_model=include_model)
|
return client.pull_prompt(owner_repo_commit, include_model=include_model)
|
||||||
return response
|
|
||||||
|
|
||||||
# Then it's langchainhub
|
# Then it's langchainhub
|
||||||
if hasattr(client, "pull_repo"):
|
if hasattr(client, "pull_repo"):
|
||||||
|
@ -561,7 +561,6 @@ def __getattr__(name: str) -> Any:
|
|||||||
k: v() for k, v in get_type_to_cls_dict().items()
|
k: v() for k, v in get_type_to_cls_dict().items()
|
||||||
}
|
}
|
||||||
return type_to_cls_dict
|
return type_to_cls_dict
|
||||||
else:
|
|
||||||
return getattr(llms, name)
|
return getattr(llms, name)
|
||||||
|
|
||||||
|
|
||||||
|
@ -154,6 +154,7 @@ class UpstashRedisEntityStore(BaseEntityStore):
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
||||||
)
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
def delete(self, key: str) -> None:
|
def delete(self, key: str) -> None:
|
||||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||||
@ -255,6 +256,7 @@ class RedisEntityStore(BaseEntityStore):
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
||||||
)
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
def delete(self, key: str) -> None:
|
def delete(self, key: str) -> None:
|
||||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||||
@ -362,6 +364,7 @@ class SQLiteEntityStore(BaseEntityStore):
|
|||||||
f'"{self.full_table_name}" (key, value) VALUES (?, ?)'
|
f'"{self.full_table_name}" (key, value) VALUES (?, ?)'
|
||||||
)
|
)
|
||||||
self._execute_query(query, (key, value))
|
self._execute_query(query, (key, value))
|
||||||
|
return None
|
||||||
|
|
||||||
def delete(self, key: str) -> None:
|
def delete(self, key: str) -> None:
|
||||||
"""Deletes a key-value pair, safely quoting the table name."""
|
"""Deletes a key-value pair, safely quoting the table name."""
|
||||||
|
@ -34,7 +34,7 @@ class BooleanOutputParser(BaseOutputParser[bool]):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return True
|
return True
|
||||||
elif self.false_val.upper() in truthy:
|
if self.false_val.upper() in truthy:
|
||||||
if self.true_val.upper() in truthy:
|
if self.true_val.upper() in truthy:
|
||||||
msg = (
|
msg = (
|
||||||
f"Ambiguous response. Both {self.true_val} and {self.false_val} "
|
f"Ambiguous response. Both {self.true_val} and {self.false_val} "
|
||||||
|
@ -71,7 +71,6 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise e
|
||||||
else:
|
|
||||||
retries += 1
|
retries += 1
|
||||||
if self.legacy and hasattr(self.retry_chain, "run"):
|
if self.legacy and hasattr(self.retry_chain, "run"):
|
||||||
completion = self.retry_chain.run(
|
completion = self.retry_chain.run(
|
||||||
@ -109,7 +108,6 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise e
|
||||||
else:
|
|
||||||
retries += 1
|
retries += 1
|
||||||
if self.legacy and hasattr(self.retry_chain, "arun"):
|
if self.legacy and hasattr(self.retry_chain, "arun"):
|
||||||
completion = await self.retry_chain.arun(
|
completion = await self.retry_chain.arun(
|
||||||
|
@ -64,7 +64,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
|
|||||||
msg = f"Invalid array format in '{original_request_params}'. \
|
msg = f"Invalid array format in '{original_request_params}'. \
|
||||||
Please check the format instructions."
|
Please check the format instructions."
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
elif (
|
if (
|
||||||
isinstance(parsed_array[0], int)
|
isinstance(parsed_array[0], int)
|
||||||
and parsed_array[-1] > self.dataframe.index.max()
|
and parsed_array[-1] > self.dataframe.index.max()
|
||||||
):
|
):
|
||||||
|
@ -30,11 +30,9 @@ class RegexParser(BaseOutputParser[dict[str, str]]):
|
|||||||
match = re.search(self.regex, text)
|
match = re.search(self.regex, text)
|
||||||
if match:
|
if match:
|
||||||
return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)}
|
return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)}
|
||||||
else:
|
|
||||||
if self.default_output_key is None:
|
if self.default_output_key is None:
|
||||||
msg = f"Could not parse output: {text}"
|
msg = f"Could not parse output: {text}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
key: text if key == self.default_output_key else ""
|
key: text if key == self.default_output_key else ""
|
||||||
for key in self.output_keys
|
for key in self.output_keys
|
||||||
|
@ -33,14 +33,11 @@ class RegexDictParser(BaseOutputParser[dict[str, str]]):
|
|||||||
{expected_format} on text {text}"
|
{expected_format} on text {text}"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif len(matches) > 1:
|
if len(matches) > 1:
|
||||||
msg = f"Multiple matches found for output key: {output_key} with \
|
msg = f"Multiple matches found for output key: {output_key} with \
|
||||||
expected format {expected_format} on text {text}"
|
expected format {expected_format} on text {text}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif (
|
if self.no_update_value is not None and matches[0] == self.no_update_value:
|
||||||
self.no_update_value is not None and matches[0] == self.no_update_value
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
else:
|
|
||||||
result[output_key] = matches[0]
|
result[output_key] = matches[0]
|
||||||
return result
|
return result
|
||||||
|
@ -107,7 +107,6 @@ class RetryOutputParser(BaseOutputParser[T]):
|
|||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise e
|
||||||
else:
|
|
||||||
retries += 1
|
retries += 1
|
||||||
if self.legacy and hasattr(self.retry_chain, "run"):
|
if self.legacy and hasattr(self.retry_chain, "run"):
|
||||||
completion = self.retry_chain.run(
|
completion = self.retry_chain.run(
|
||||||
@ -143,7 +142,6 @@ class RetryOutputParser(BaseOutputParser[T]):
|
|||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise e
|
||||||
else:
|
|
||||||
retries += 1
|
retries += 1
|
||||||
if self.legacy and hasattr(self.retry_chain, "arun"):
|
if self.legacy and hasattr(self.retry_chain, "arun"):
|
||||||
completion = await self.retry_chain.arun(
|
completion = await self.retry_chain.arun(
|
||||||
@ -234,7 +232,6 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
|||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise e
|
||||||
else:
|
|
||||||
retries += 1
|
retries += 1
|
||||||
if self.legacy and hasattr(self.retry_chain, "run"):
|
if self.legacy and hasattr(self.retry_chain, "run"):
|
||||||
completion = self.retry_chain.run(
|
completion = self.retry_chain.run(
|
||||||
@ -263,7 +260,6 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
|||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if retries == self.max_retries:
|
if retries == self.max_retries:
|
||||||
raise e
|
raise e
|
||||||
else:
|
|
||||||
retries += 1
|
retries += 1
|
||||||
if self.legacy and hasattr(self.retry_chain, "arun"):
|
if self.legacy and hasattr(self.retry_chain, "arun"):
|
||||||
completion = await self.retry_chain.arun(
|
completion = await self.retry_chain.arun(
|
||||||
|
@ -89,7 +89,6 @@ class StructuredOutputParser(BaseOutputParser[dict[str, Any]]):
|
|||||||
)
|
)
|
||||||
if only_json:
|
if only_json:
|
||||||
return STRUCTURED_FORMAT_SIMPLE_INSTRUCTIONS.format(format=schema_str)
|
return STRUCTURED_FORMAT_SIMPLE_INSTRUCTIONS.format(format=schema_str)
|
||||||
else:
|
|
||||||
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
|
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
|
||||||
|
|
||||||
def parse(self, text: str) -> dict[str, Any]:
|
def parse(self, text: str) -> dict[str, Any]:
|
||||||
|
@ -33,7 +33,6 @@ class YamlOutputParser(BaseOutputParser[T]):
|
|||||||
json_object = yaml.safe_load(yaml_str)
|
json_object = yaml.safe_load(yaml_str)
|
||||||
if hasattr(self.pydantic_object, "model_validate"):
|
if hasattr(self.pydantic_object, "model_validate"):
|
||||||
return self.pydantic_object.model_validate(json_object)
|
return self.pydantic_object.model_validate(json_object)
|
||||||
else:
|
|
||||||
return self.pydantic_object.parse_obj(json_object)
|
return self.pydantic_object.parse_obj(json_object)
|
||||||
|
|
||||||
except (yaml.YAMLError, ValidationError) as e:
|
except (yaml.YAMLError, ValidationError) as e:
|
||||||
|
@ -45,7 +45,6 @@ class ContextualCompressionRetriever(BaseRetriever):
|
|||||||
docs, query, callbacks=run_manager.get_child()
|
docs, query, callbacks=run_manager.get_child()
|
||||||
)
|
)
|
||||||
return list(compressed_docs)
|
return list(compressed_docs)
|
||||||
else:
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
@ -71,5 +70,4 @@ class ContextualCompressionRetriever(BaseRetriever):
|
|||||||
docs, query, callbacks=run_manager.get_child()
|
docs, query, callbacks=run_manager.get_child()
|
||||||
)
|
)
|
||||||
return list(compressed_docs)
|
return list(compressed_docs)
|
||||||
else:
|
|
||||||
return []
|
return []
|
||||||
|
@ -174,9 +174,7 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Get fused result of the retrievers.
|
# Get fused result of the retrievers.
|
||||||
fused_documents = self.rank_fusion(query, run_manager)
|
return self.rank_fusion(query, run_manager)
|
||||||
|
|
||||||
return fused_documents
|
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self,
|
||||||
@ -195,9 +193,7 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Get fused result of the retrievers.
|
# Get fused result of the retrievers.
|
||||||
fused_documents = await self.arank_fusion(query, run_manager)
|
return await self.arank_fusion(query, run_manager)
|
||||||
|
|
||||||
return fused_documents
|
|
||||||
|
|
||||||
def rank_fusion(
|
def rank_fusion(
|
||||||
self,
|
self,
|
||||||
@ -236,9 +232,7 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# apply rank fusion
|
# apply rank fusion
|
||||||
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
|
return self.weighted_reciprocal_rank(retriever_docs)
|
||||||
|
|
||||||
return fused_documents
|
|
||||||
|
|
||||||
async def arank_fusion(
|
async def arank_fusion(
|
||||||
self,
|
self,
|
||||||
@ -280,9 +274,7 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# apply rank fusion
|
# apply rank fusion
|
||||||
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
|
return self.weighted_reciprocal_rank(retriever_docs)
|
||||||
|
|
||||||
return fused_documents
|
|
||||||
|
|
||||||
def weighted_reciprocal_rank(
|
def weighted_reciprocal_rank(
|
||||||
self, doc_lists: list[list[Document]]
|
self, doc_lists: list[list[Document]]
|
||||||
@ -318,7 +310,7 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
|
|
||||||
# Docs are deduplicated by their contents then sorted by their scores
|
# Docs are deduplicated by their contents then sorted by their scores
|
||||||
all_docs = chain.from_iterable(doc_lists)
|
all_docs = chain.from_iterable(doc_lists)
|
||||||
sorted_docs = sorted(
|
return sorted(
|
||||||
unique_by_key(
|
unique_by_key(
|
||||||
all_docs,
|
all_docs,
|
||||||
lambda doc: (
|
lambda doc: (
|
||||||
@ -332,4 +324,3 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
doc.page_content if self.id_key is None else doc.metadata[self.id_key]
|
doc.page_content if self.id_key is None else doc.metadata[self.id_key]
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
return sorted_docs
|
|
||||||
|
@ -31,9 +31,7 @@ class MergerRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Merge the results of the retrievers.
|
# Merge the results of the retrievers.
|
||||||
merged_documents = self.merge_documents(query, run_manager)
|
return self.merge_documents(query, run_manager)
|
||||||
|
|
||||||
return merged_documents
|
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self,
|
||||||
@ -52,9 +50,7 @@ class MergerRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Merge the results of the retrievers.
|
# Merge the results of the retrievers.
|
||||||
merged_documents = await self.amerge_documents(query, run_manager)
|
return await self.amerge_documents(query, run_manager)
|
||||||
|
|
||||||
return merged_documents
|
|
||||||
|
|
||||||
def merge_documents(
|
def merge_documents(
|
||||||
self, query: str, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, run_manager: CallbackManagerForRetrieverRun
|
||||||
|
@ -74,10 +74,9 @@ class RePhraseQueryRetriever(BaseRetriever):
|
|||||||
query, {"callbacks": run_manager.get_child()}
|
query, {"callbacks": run_manager.get_child()}
|
||||||
)
|
)
|
||||||
logger.info(f"Re-phrased question: {re_phrased_question}")
|
logger.info(f"Re-phrased question: {re_phrased_question}")
|
||||||
docs = self.retriever.invoke(
|
return self.retriever.invoke(
|
||||||
re_phrased_question, config={"callbacks": run_manager.get_child()}
|
re_phrased_question, config={"callbacks": run_manager.get_child()}
|
||||||
)
|
)
|
||||||
return docs
|
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -119,18 +119,17 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
|
|||||||
}
|
}
|
||||||
if isinstance(vectorstore, DatabricksVectorSearch):
|
if isinstance(vectorstore, DatabricksVectorSearch):
|
||||||
return DatabricksVectorSearchTranslator()
|
return DatabricksVectorSearchTranslator()
|
||||||
elif isinstance(vectorstore, MyScale):
|
if isinstance(vectorstore, MyScale):
|
||||||
return MyScaleTranslator(metadata_key=vectorstore.metadata_column)
|
return MyScaleTranslator(metadata_key=vectorstore.metadata_column)
|
||||||
elif isinstance(vectorstore, Redis):
|
if isinstance(vectorstore, Redis):
|
||||||
return RedisTranslator.from_vectorstore(vectorstore)
|
return RedisTranslator.from_vectorstore(vectorstore)
|
||||||
elif isinstance(vectorstore, TencentVectorDB):
|
if isinstance(vectorstore, TencentVectorDB):
|
||||||
fields = [
|
fields = [
|
||||||
field.name for field in (vectorstore.meta_fields or []) if field.index
|
field.name for field in (vectorstore.meta_fields or []) if field.index
|
||||||
]
|
]
|
||||||
return TencentVectorDBTranslator(fields)
|
return TencentVectorDBTranslator(fields)
|
||||||
elif vectorstore.__class__ in BUILTIN_TRANSLATORS:
|
if vectorstore.__class__ in BUILTIN_TRANSLATORS:
|
||||||
return BUILTIN_TRANSLATORS[vectorstore.__class__]()
|
return BUILTIN_TRANSLATORS[vectorstore.__class__]()
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
from langchain_astradb.vectorstores import AstraDBVectorStore
|
from langchain_astradb.vectorstores import AstraDBVectorStore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -289,14 +288,12 @@ class SelfQueryRetriever(BaseRetriever):
|
|||||||
def _get_docs_with_query(
|
def _get_docs_with_query(
|
||||||
self, query: str, search_kwargs: dict[str, Any]
|
self, query: str, search_kwargs: dict[str, Any]
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
|
return self.vectorstore.search(query, self.search_type, **search_kwargs)
|
||||||
return docs
|
|
||||||
|
|
||||||
async def _aget_docs_with_query(
|
async def _aget_docs_with_query(
|
||||||
self, query: str, search_kwargs: dict[str, Any]
|
self, query: str, search_kwargs: dict[str, Any]
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
|
return await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
|
||||||
return docs
|
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
@ -315,8 +312,7 @@ class SelfQueryRetriever(BaseRetriever):
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.info(f"Generated Query: {structured_query}")
|
logger.info(f"Generated Query: {structured_query}")
|
||||||
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
||||||
docs = self._get_docs_with_query(new_query, search_kwargs)
|
return self._get_docs_with_query(new_query, search_kwargs)
|
||||||
return docs
|
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
@ -335,8 +331,7 @@ class SelfQueryRetriever(BaseRetriever):
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.info(f"Generated Query: {structured_query}")
|
logger.info(f"Generated Query: {structured_query}")
|
||||||
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
||||||
docs = await self._aget_docs_with_query(new_query, search_kwargs)
|
return await self._aget_docs_with_query(new_query, search_kwargs)
|
||||||
return docs
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
|
@ -191,13 +191,13 @@ def _wrap_in_chain_factory(
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return lambda: chain
|
return lambda: chain
|
||||||
elif isinstance(llm_or_chain_factory, BaseLanguageModel):
|
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||||
return llm_or_chain_factory
|
return llm_or_chain_factory
|
||||||
elif isinstance(llm_or_chain_factory, Runnable):
|
if isinstance(llm_or_chain_factory, Runnable):
|
||||||
# Memory may exist here, but it's not elegant to check all those cases.
|
# Memory may exist here, but it's not elegant to check all those cases.
|
||||||
lcf = llm_or_chain_factory
|
lcf = llm_or_chain_factory
|
||||||
return lambda: lcf
|
return lambda: lcf
|
||||||
elif callable(llm_or_chain_factory):
|
if callable(llm_or_chain_factory):
|
||||||
if is_traceable_function(llm_or_chain_factory):
|
if is_traceable_function(llm_or_chain_factory):
|
||||||
runnable_ = as_runnable(cast(Callable, llm_or_chain_factory))
|
runnable_ = as_runnable(cast(Callable, llm_or_chain_factory))
|
||||||
return lambda: runnable_
|
return lambda: runnable_
|
||||||
@ -215,13 +215,12 @@ def _wrap_in_chain_factory(
|
|||||||
# It's not uncommon to do an LLM constructor instead of raw LLM,
|
# It's not uncommon to do an LLM constructor instead of raw LLM,
|
||||||
# so we'll unpack it for the user.
|
# so we'll unpack it for the user.
|
||||||
return _model
|
return _model
|
||||||
elif is_traceable_function(cast(Callable, _model)):
|
if is_traceable_function(cast(Callable, _model)):
|
||||||
runnable_ = as_runnable(cast(Callable, _model))
|
runnable_ = as_runnable(cast(Callable, _model))
|
||||||
return lambda: runnable_
|
return lambda: runnable_
|
||||||
elif not isinstance(_model, Runnable):
|
if not isinstance(_model, Runnable):
|
||||||
# This is unlikely to happen - a constructor for a model function
|
# This is unlikely to happen - a constructor for a model function
|
||||||
return lambda: RunnableLambda(constructor)
|
return lambda: RunnableLambda(constructor)
|
||||||
else:
|
|
||||||
# Typical correct case
|
# Typical correct case
|
||||||
return constructor
|
return constructor
|
||||||
return llm_or_chain_factory
|
return llm_or_chain_factory
|
||||||
@ -272,7 +271,6 @@ def _get_prompt(inputs: dict[str, Any]) -> str:
|
|||||||
raise InputFormatError(msg)
|
raise InputFormatError(msg)
|
||||||
if len(prompts) == 1:
|
if len(prompts) == 1:
|
||||||
return prompts[0]
|
return prompts[0]
|
||||||
else:
|
|
||||||
msg = f"LLM Run expects single prompt input. Got {len(prompts)} prompts."
|
msg = f"LLM Run expects single prompt input. Got {len(prompts)} prompts."
|
||||||
raise InputFormatError(msg)
|
raise InputFormatError(msg)
|
||||||
|
|
||||||
@ -321,7 +319,6 @@ def _get_messages(inputs: dict[str, Any]) -> dict:
|
|||||||
)
|
)
|
||||||
raise InputFormatError(msg)
|
raise InputFormatError(msg)
|
||||||
return input_copy
|
return input_copy
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Chat Run expects single List[dict] or List[List[dict]] 'messages'"
|
f"Chat Run expects single List[dict] or List[List[dict]] 'messages'"
|
||||||
f" input. Got {inputs}"
|
f" input. Got {inputs}"
|
||||||
@ -707,7 +704,6 @@ async def _arun_llm(
|
|||||||
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
|
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
"Input mapper returned invalid format"
|
"Input mapper returned invalid format"
|
||||||
f" {prompt_or_messages}"
|
f" {prompt_or_messages}"
|
||||||
@ -715,7 +711,6 @@ async def _arun_llm(
|
|||||||
)
|
)
|
||||||
raise InputFormatError(msg)
|
raise InputFormatError(msg)
|
||||||
|
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
prompt = _get_prompt(inputs)
|
prompt = _get_prompt(inputs)
|
||||||
llm_output: Union[str, BaseMessage] = await llm.ainvoke(
|
llm_output: Union[str, BaseMessage] = await llm.ainvoke(
|
||||||
|
@ -28,7 +28,6 @@ def _get_messages_from_run_dict(messages: list[dict]) -> list[BaseMessage]:
|
|||||||
first_message = messages[0]
|
first_message = messages[0]
|
||||||
if "lc" in first_message:
|
if "lc" in first_message:
|
||||||
return [load(dumpd(message)) for message in messages]
|
return [load(dumpd(message)) for message in messages]
|
||||||
else:
|
|
||||||
return messages_from_dict(messages)
|
return messages_from_dict(messages)
|
||||||
|
|
||||||
|
|
||||||
@ -106,14 +105,12 @@ class LLMStringRunMapper(StringRunMapper):
|
|||||||
if run.run_type != "llm":
|
if run.run_type != "llm":
|
||||||
msg = "LLM RunMapper only supports LLM runs."
|
msg = "LLM RunMapper only supports LLM runs."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif not run.outputs:
|
if not run.outputs:
|
||||||
if run.error:
|
if run.error:
|
||||||
msg = f"Cannot evaluate errored LLM run {run.id}: {run.error}"
|
msg = f"Cannot evaluate errored LLM run {run.id}: {run.error}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
|
||||||
msg = f"Run {run.id} has no outputs. Cannot evaluate this run."
|
msg = f"Run {run.id} has no outputs. Cannot evaluate this run."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
inputs = self.serialize_inputs(run.inputs)
|
inputs = self.serialize_inputs(run.inputs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -142,9 +139,8 @@ class ChainStringRunMapper(StringRunMapper):
|
|||||||
def _get_key(self, source: dict, key: Optional[str], which: str) -> str:
|
def _get_key(self, source: dict, key: Optional[str], which: str) -> str:
|
||||||
if key is not None:
|
if key is not None:
|
||||||
return source[key]
|
return source[key]
|
||||||
elif len(source) == 1:
|
if len(source) == 1:
|
||||||
return next(iter(source.values()))
|
return next(iter(source.values()))
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Could not map run {which} with multiple keys: "
|
f"Could not map run {which} with multiple keys: "
|
||||||
f"{source}\nPlease manually specify a {which}_key"
|
f"{source}\nPlease manually specify a {which}_key"
|
||||||
@ -168,7 +164,7 @@ class ChainStringRunMapper(StringRunMapper):
|
|||||||
f" '{self.input_key}'."
|
f" '{self.input_key}'."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif self.prediction_key is not None and self.prediction_key not in run.outputs:
|
if self.prediction_key is not None and self.prediction_key not in run.outputs:
|
||||||
available_keys = ", ".join(run.outputs.keys())
|
available_keys = ", ".join(run.outputs.keys())
|
||||||
msg = (
|
msg = (
|
||||||
f"Run with ID {run.id} doesn't have the expected prediction key"
|
f"Run with ID {run.id} doesn't have the expected prediction key"
|
||||||
@ -178,7 +174,6 @@ class ChainStringRunMapper(StringRunMapper):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
else:
|
|
||||||
input_ = self._get_key(run.inputs, self.input_key, "input")
|
input_ = self._get_key(run.inputs, self.input_key, "input")
|
||||||
prediction = self._get_key(run.outputs, self.prediction_key, "prediction")
|
prediction = self._get_key(run.outputs, self.prediction_key, "prediction")
|
||||||
return {
|
return {
|
||||||
@ -224,7 +219,6 @@ class StringExampleMapper(Serializable):
|
|||||||
" specify a reference_key."
|
" specify a reference_key."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
|
||||||
output = list(example.outputs.values())[0]
|
output = list(example.outputs.values())[0]
|
||||||
elif self.reference_key not in example.outputs:
|
elif self.reference_key not in example.outputs:
|
||||||
msg = (
|
msg = (
|
||||||
|
@ -65,9 +65,8 @@ def _import_python_tool_PythonREPLTool() -> Any:
|
|||||||
def __getattr__(name: str) -> Any:
|
def __getattr__(name: str) -> Any:
|
||||||
if name == "PythonAstREPLTool":
|
if name == "PythonAstREPLTool":
|
||||||
return _import_python_tool_PythonAstREPLTool()
|
return _import_python_tool_PythonAstREPLTool()
|
||||||
elif name == "PythonREPLTool":
|
if name == "PythonREPLTool":
|
||||||
return _import_python_tool_PythonREPLTool()
|
return _import_python_tool_PythonREPLTool()
|
||||||
else:
|
|
||||||
from langchain_community import tools
|
from langchain_community import tools
|
||||||
|
|
||||||
# If not in interactive env, raise warning.
|
# If not in interactive env, raise warning.
|
||||||
|
@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
|
|||||||
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
|
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["A", "C4", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"]
|
select = ["A", "C4", "D", "E", "EM", "F", "I", "PGH003", "PIE", "RET", "S", "SIM", "T201", "UP", "W"]
|
||||||
pydocstyle.convention = "google"
|
pydocstyle.convention = "google"
|
||||||
pyupgrade.keep-runtime-typing = true
|
pyupgrade.keep-runtime-typing = true
|
||||||
|
|
||||||
|
@ -139,7 +139,6 @@ async def get_state(
|
|||||||
async def ask_for_passphrase(said_please: bool) -> dict[str, Any]:
|
async def ask_for_passphrase(said_please: bool) -> dict[str, Any]:
|
||||||
if said_please:
|
if said_please:
|
||||||
return {"passphrase": f"The passphrase is {PASS_PHRASE}"}
|
return {"passphrase": f"The passphrase is {PASS_PHRASE}"}
|
||||||
else:
|
|
||||||
return {"passphrase": "I won't share the passphrase without saying 'please'."}
|
return {"passphrase": "I won't share the passphrase without saying 'please'."}
|
||||||
|
|
||||||
|
|
||||||
@ -153,7 +152,6 @@ async def recycle(password: SecretPassPhrase) -> dict[str, Any]:
|
|||||||
if password.pw == PASS_PHRASE:
|
if password.pw == PASS_PHRASE:
|
||||||
_ROBOT_STATE["destruct"] = True
|
_ROBOT_STATE["destruct"] = True
|
||||||
return {"status": "Self-destruct initiated", "state": _ROBOT_STATE}
|
return {"status": "Self-destruct initiated", "state": _ROBOT_STATE}
|
||||||
else:
|
|
||||||
_ROBOT_STATE["destruct"] = False
|
_ROBOT_STATE["destruct"] = False
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
|
@ -100,14 +100,13 @@ def _get_agent(**kwargs: Any) -> AgentExecutor:
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
agent = initialize_agent(
|
return initialize_agent(
|
||||||
tools,
|
tools,
|
||||||
fake_llm,
|
fake_llm,
|
||||||
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def test_agent_bad_action() -> None:
|
def test_agent_bad_action() -> None:
|
||||||
|
@ -71,14 +71,13 @@ def _get_agent(**kwargs: Any) -> AgentExecutor:
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
agent = initialize_agent(
|
return initialize_agent(
|
||||||
tools,
|
tools,
|
||||||
fake_llm,
|
fake_llm,
|
||||||
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
async def test_agent_bad_action() -> None:
|
async def test_agent_bad_action() -> None:
|
||||||
|
@ -11,7 +11,6 @@ def get_action_and_input(text: str) -> tuple[str, str]:
|
|||||||
output = output_parser.parse(text)
|
output = output_parser.parse(text)
|
||||||
if isinstance(output, AgentAction):
|
if isinstance(output, AgentAction):
|
||||||
return output.tool, str(output.tool_input)
|
return output.tool, str(output.tool_input)
|
||||||
else:
|
|
||||||
return "Final Answer", output.return_values["output"]
|
return "Final Answer", output.return_values["output"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,7 +16,6 @@ def get_action_and_input(text: str) -> tuple[str, str]:
|
|||||||
output = MRKLOutputParser().parse(text)
|
output = MRKLOutputParser().parse(text)
|
||||||
if isinstance(output, AgentAction):
|
if isinstance(output, AgentAction):
|
||||||
return output.tool, str(output.tool_input)
|
return output.tool, str(output.tool_input)
|
||||||
else:
|
|
||||||
return "Final Answer", output.return_values["output"]
|
return "Final Answer", output.return_values["output"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,9 +21,8 @@ def get_action_and_input(text: str) -> tuple[str, str]:
|
|||||||
output = output_parser.parse(text)
|
output = output_parser.parse(text)
|
||||||
if isinstance(output, AgentAction):
|
if isinstance(output, AgentAction):
|
||||||
return output.tool, str(output.tool_input)
|
return output.tool, str(output.tool_input)
|
||||||
elif isinstance(output, AgentFinish):
|
if isinstance(output, AgentFinish):
|
||||||
return output.return_values["output"], output.log
|
return output.return_values["output"], output.log
|
||||||
else:
|
|
||||||
msg = "Unexpected output type"
|
msg = "Unexpected output type"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
@ -58,7 +58,6 @@ class FakeChain(Chain):
|
|||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
if self.be_correct:
|
if self.be_correct:
|
||||||
return {"bar": "baz"}
|
return {"bar": "baz"}
|
||||||
else:
|
|
||||||
return {"baz": "bar"}
|
return {"baz": "bar"}
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,7 +54,6 @@ class _FakeTrajectoryChatModel(FakeChatModel):
|
|||||||
response = self.queries[list(self.queries.keys())[self.response_index]]
|
response = self.queries[list(self.queries.keys())[self.response_index]]
|
||||||
self.response_index = self.response_index + 1
|
self.response_index = self.response_index + 1
|
||||||
return response
|
return response
|
||||||
else:
|
|
||||||
prompt = messages[0].content
|
prompt = messages[0].content
|
||||||
return self.queries[prompt]
|
return self.queries[prompt]
|
||||||
|
|
||||||
|
@ -45,7 +45,6 @@ class FakeLLM(LLM):
|
|||||||
return self.queries[prompt]
|
return self.queries[prompt]
|
||||||
if stop is None:
|
if stop is None:
|
||||||
return "foo"
|
return "foo"
|
||||||
else:
|
|
||||||
return "bar"
|
return "bar"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -14,7 +14,6 @@ class SequentialRetriever(BaseRetriever):
|
|||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
if self.response_index >= len(self.sequential_responses):
|
if self.response_index >= len(self.sequential_responses):
|
||||||
return []
|
return []
|
||||||
else:
|
|
||||||
self.response_index += 1
|
self.response_index += 1
|
||||||
return self.sequential_responses[self.response_index - 1]
|
return self.sequential_responses[self.response_index - 1]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user