mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 21:12:48 +00:00
core: allow artifact in create_retriever_tool (#28903)
Add option to return content and artifacts, to also be able to access the full info of the retrieved documents. They are returned as a list of dicts in the `artifacts` property if parameter `response_format` is set to `"content_and_artifact"`. Defaults to `"content"` to keep current behavior. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -30,8 +30,13 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.callbacks.manager import (
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import ToolCall, ToolMessage
|
||||
from langchain_core.messages.tool import ToolOutputMixin
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
@@ -2118,6 +2123,57 @@ def test_tool_annotations_preserved() -> None:
|
||||
assert schema.__annotations__ == expected_type_hints
|
||||
|
||||
|
||||
def test_create_retriever_tool() -> None:
|
||||
class MyRetriever(BaseRetriever):
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> list[Document]:
|
||||
return [Document(page_content=f"foo {query}"), Document(page_content="bar")]
|
||||
|
||||
retriever = MyRetriever()
|
||||
retriever_tool = tools.create_retriever_tool(
|
||||
retriever, "retriever_tool_content", "Retriever Tool Content"
|
||||
)
|
||||
assert isinstance(retriever_tool, BaseTool)
|
||||
assert retriever_tool.name == "retriever_tool_content"
|
||||
assert retriever_tool.description == "Retriever Tool Content"
|
||||
assert retriever_tool.invoke("bar") == "foo bar\n\nbar"
|
||||
assert retriever_tool.invoke(
|
||||
ToolCall(
|
||||
name="retriever_tool_content",
|
||||
args={"query": "bar"},
|
||||
id="123",
|
||||
type="tool_call",
|
||||
)
|
||||
) == ToolMessage(
|
||||
"foo bar\n\nbar", tool_call_id="123", name="retriever_tool_content"
|
||||
)
|
||||
|
||||
retriever_tool_artifact = tools.create_retriever_tool(
|
||||
retriever,
|
||||
"retriever_tool_artifact",
|
||||
"Retriever Tool Artifact",
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
assert isinstance(retriever_tool_artifact, BaseTool)
|
||||
assert retriever_tool_artifact.name == "retriever_tool_artifact"
|
||||
assert retriever_tool_artifact.description == "Retriever Tool Artifact"
|
||||
assert retriever_tool_artifact.invoke("bar") == "foo bar\n\nbar"
|
||||
assert retriever_tool_artifact.invoke(
|
||||
ToolCall(
|
||||
name="retriever_tool_artifact",
|
||||
args={"query": "bar"},
|
||||
id="123",
|
||||
type="tool_call",
|
||||
)
|
||||
) == ToolMessage(
|
||||
"foo bar\n\nbar",
|
||||
artifact=[Document(page_content="foo bar"), Document(page_content="bar")],
|
||||
tool_call_id="123",
|
||||
name="retriever_tool_artifact",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
|
||||
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
|
Reference in New Issue
Block a user