core: allow artifact in create_retriever_tool (#28903)

Add option to return content and artifacts, to also be able to access
the full info of the retrieved documents.

They are returned as a list of dicts in the `artifacts` property if
parameter `response_format` is set to `"content_and_artifact"`.

Defaults to `"content"` to keep current behavior.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Adrián Panella
2025-01-03 17:10:31 -05:00
committed by GitHub
parent 3e618b16cd
commit acddfc772e
2 changed files with 83 additions and 6 deletions

View File

@@ -30,8 +30,13 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.messages import ToolMessage
from langchain_core.callbacks.manager import (
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.messages import ToolCall, ToolMessage
from langchain_core.messages.tool import ToolOutputMixin
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import (
Runnable,
RunnableConfig,
@@ -2118,6 +2123,57 @@ def test_tool_annotations_preserved() -> None:
assert schema.__annotations__ == expected_type_hints
def test_create_retriever_tool() -> None:
class MyRetriever(BaseRetriever):
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
return [Document(page_content=f"foo {query}"), Document(page_content="bar")]
retriever = MyRetriever()
retriever_tool = tools.create_retriever_tool(
retriever, "retriever_tool_content", "Retriever Tool Content"
)
assert isinstance(retriever_tool, BaseTool)
assert retriever_tool.name == "retriever_tool_content"
assert retriever_tool.description == "Retriever Tool Content"
assert retriever_tool.invoke("bar") == "foo bar\n\nbar"
assert retriever_tool.invoke(
ToolCall(
name="retriever_tool_content",
args={"query": "bar"},
id="123",
type="tool_call",
)
) == ToolMessage(
"foo bar\n\nbar", tool_call_id="123", name="retriever_tool_content"
)
retriever_tool_artifact = tools.create_retriever_tool(
retriever,
"retriever_tool_artifact",
"Retriever Tool Artifact",
response_format="content_and_artifact",
)
assert isinstance(retriever_tool_artifact, BaseTool)
assert retriever_tool_artifact.name == "retriever_tool_artifact"
assert retriever_tool_artifact.description == "Retriever Tool Artifact"
assert retriever_tool_artifact.invoke("bar") == "foo bar\n\nbar"
assert retriever_tool_artifact.invoke(
ToolCall(
name="retriever_tool_artifact",
args={"query": "bar"},
id="123",
type="tool_call",
)
) == ToolMessage(
"foo bar\n\nbar",
artifact=[Document(page_content="foo bar"), Document(page_content="bar")],
tool_call_id="123",
name="retriever_tool_artifact",
)
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
from pydantic import BaseModel as BaseModelV2