core: allow artifact in create_retriever_tool (#28903)

Add option to return content and artifacts, to also be able to access
the full info of the retrieved documents.

They are returned as a list of dicts in the `artifacts` property if
parameter `response_format` is set to `"content_and_artifact"`.

Defaults to `"content"` to keep current behavior.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Adrián Panella 2025-01-03 17:10:31 -05:00 committed by GitHub
parent 3e618b16cd
commit acddfc772e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 83 additions and 6 deletions

View File

@ -1,11 +1,12 @@
from __future__ import annotations from __future__ import annotations
from functools import partial from functools import partial
from typing import Optional from typing import Literal, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document
from langchain_core.prompts import ( from langchain_core.prompts import (
BasePromptTemplate, BasePromptTemplate,
PromptTemplate, PromptTemplate,
@ -28,11 +29,16 @@ def _get_relevant_documents(
document_prompt: BasePromptTemplate, document_prompt: BasePromptTemplate,
document_separator: str, document_separator: str,
callbacks: Callbacks = None, callbacks: Callbacks = None,
) -> str: response_format: Literal["content", "content_and_artifact"] = "content",
) -> Union[str, tuple[str, list[Document]]]:
docs = retriever.invoke(query, config={"callbacks": callbacks}) docs = retriever.invoke(query, config={"callbacks": callbacks})
return document_separator.join( content = document_separator.join(
format_document(doc, document_prompt) for doc in docs format_document(doc, document_prompt) for doc in docs
) )
if response_format == "content_and_artifact":
return (content, docs)
return content
async def _aget_relevant_documents( async def _aget_relevant_documents(
@ -41,12 +47,18 @@ async def _aget_relevant_documents(
document_prompt: BasePromptTemplate, document_prompt: BasePromptTemplate,
document_separator: str, document_separator: str,
callbacks: Callbacks = None, callbacks: Callbacks = None,
) -> str: response_format: Literal["content", "content_and_artifact"] = "content",
) -> Union[str, tuple[str, list[Document]]]:
docs = await retriever.ainvoke(query, config={"callbacks": callbacks}) docs = await retriever.ainvoke(query, config={"callbacks": callbacks})
return document_separator.join( content = document_separator.join(
[await aformat_document(doc, document_prompt) for doc in docs] [await aformat_document(doc, document_prompt) for doc in docs]
) )
if response_format == "content_and_artifact":
return (content, docs)
return content
def create_retriever_tool( def create_retriever_tool(
retriever: BaseRetriever, retriever: BaseRetriever,
@ -55,6 +67,7 @@ def create_retriever_tool(
*, *,
document_prompt: Optional[BasePromptTemplate] = None, document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = "\n\n", document_separator: str = "\n\n",
response_format: Literal["content", "content_and_artifact"] = "content",
) -> Tool: ) -> Tool:
"""Create a tool to do retrieval of documents. """Create a tool to do retrieval of documents.
@ -66,6 +79,11 @@ def create_retriever_tool(
model, so should be descriptive. model, so should be descriptive.
document_prompt: The prompt to use for the document. Defaults to None. document_prompt: The prompt to use for the document. Defaults to None.
document_separator: The separator to use between documents. Defaults to "\n\n". document_separator: The separator to use between documents. Defaults to "\n\n".
response_format: The tool response format. If "content" then the output of
the tool is interpreted as the contents of a ToolMessage. If
"content_and_artifact" then the output is expected to be a two-tuple
corresponding to the (content, artifact) of a ToolMessage (artifact
being a list of documents in this case). Defaults to "content".
Returns: Returns:
Tool class to pass to an agent. Tool class to pass to an agent.
@ -76,12 +94,14 @@ def create_retriever_tool(
retriever=retriever, retriever=retriever,
document_prompt=document_prompt, document_prompt=document_prompt,
document_separator=document_separator, document_separator=document_separator,
response_format=response_format,
) )
afunc = partial( afunc = partial(
_aget_relevant_documents, _aget_relevant_documents,
retriever=retriever, retriever=retriever,
document_prompt=document_prompt, document_prompt=document_prompt,
document_separator=document_separator, document_separator=document_separator,
response_format=response_format,
) )
return Tool( return Tool(
name=name, name=name,
@ -89,4 +109,5 @@ def create_retriever_tool(
func=func, func=func,
coroutine=afunc, coroutine=afunc,
args_schema=RetrieverInput, args_schema=RetrieverInput,
response_format=response_format,
) )

View File

@ -30,8 +30,13 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun, AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun, CallbackManagerForToolRun,
) )
from langchain_core.messages import ToolMessage from langchain_core.callbacks.manager import (
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.messages import ToolCall, ToolMessage
from langchain_core.messages.tool import ToolOutputMixin from langchain_core.messages.tool import ToolOutputMixin
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import ( from langchain_core.runnables import (
Runnable, Runnable,
RunnableConfig, RunnableConfig,
@ -2118,6 +2123,57 @@ def test_tool_annotations_preserved() -> None:
assert schema.__annotations__ == expected_type_hints assert schema.__annotations__ == expected_type_hints
def test_create_retriever_tool() -> None:
class MyRetriever(BaseRetriever):
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
return [Document(page_content=f"foo {query}"), Document(page_content="bar")]
retriever = MyRetriever()
retriever_tool = tools.create_retriever_tool(
retriever, "retriever_tool_content", "Retriever Tool Content"
)
assert isinstance(retriever_tool, BaseTool)
assert retriever_tool.name == "retriever_tool_content"
assert retriever_tool.description == "Retriever Tool Content"
assert retriever_tool.invoke("bar") == "foo bar\n\nbar"
assert retriever_tool.invoke(
ToolCall(
name="retriever_tool_content",
args={"query": "bar"},
id="123",
type="tool_call",
)
) == ToolMessage(
"foo bar\n\nbar", tool_call_id="123", name="retriever_tool_content"
)
retriever_tool_artifact = tools.create_retriever_tool(
retriever,
"retriever_tool_artifact",
"Retriever Tool Artifact",
response_format="content_and_artifact",
)
assert isinstance(retriever_tool_artifact, BaseTool)
assert retriever_tool_artifact.name == "retriever_tool_artifact"
assert retriever_tool_artifact.description == "Retriever Tool Artifact"
assert retriever_tool_artifact.invoke("bar") == "foo bar\n\nbar"
assert retriever_tool_artifact.invoke(
ToolCall(
name="retriever_tool_artifact",
args={"query": "bar"},
id="123",
type="tool_call",
)
) == ToolMessage(
"foo bar\n\nbar",
artifact=[Document(page_content="foo bar"), Document(page_content="bar")],
tool_call_id="123",
name="retriever_tool_artifact",
)
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.") @pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
def test_tool_args_schema_pydantic_v2_with_metadata() -> None: def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
from pydantic import BaseModel as BaseModelV2 from pydantic import BaseModel as BaseModelV2