mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
core: allow artifact in create_retriever_tool (#28903)
Add option to return content and artifacts, to also be able to access the full info of the retrieved documents. They are returned as a list of dicts in the `artifacts` property if parameter `response_format` is set to `"content_and_artifact"`. Defaults to `"content"` to keep current behavior. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
3e618b16cd
commit
acddfc772e
@ -1,11 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
|
from langchain_core.documents import Document
|
||||||
from langchain_core.prompts import (
|
from langchain_core.prompts import (
|
||||||
BasePromptTemplate,
|
BasePromptTemplate,
|
||||||
PromptTemplate,
|
PromptTemplate,
|
||||||
@ -28,11 +29,16 @@ def _get_relevant_documents(
|
|||||||
document_prompt: BasePromptTemplate,
|
document_prompt: BasePromptTemplate,
|
||||||
document_separator: str,
|
document_separator: str,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
) -> str:
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
|
) -> Union[str, tuple[str, list[Document]]]:
|
||||||
docs = retriever.invoke(query, config={"callbacks": callbacks})
|
docs = retriever.invoke(query, config={"callbacks": callbacks})
|
||||||
return document_separator.join(
|
content = document_separator.join(
|
||||||
format_document(doc, document_prompt) for doc in docs
|
format_document(doc, document_prompt) for doc in docs
|
||||||
)
|
)
|
||||||
|
if response_format == "content_and_artifact":
|
||||||
|
return (content, docs)
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
@ -41,12 +47,18 @@ async def _aget_relevant_documents(
|
|||||||
document_prompt: BasePromptTemplate,
|
document_prompt: BasePromptTemplate,
|
||||||
document_separator: str,
|
document_separator: str,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
) -> str:
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
|
) -> Union[str, tuple[str, list[Document]]]:
|
||||||
docs = await retriever.ainvoke(query, config={"callbacks": callbacks})
|
docs = await retriever.ainvoke(query, config={"callbacks": callbacks})
|
||||||
return document_separator.join(
|
content = document_separator.join(
|
||||||
[await aformat_document(doc, document_prompt) for doc in docs]
|
[await aformat_document(doc, document_prompt) for doc in docs]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if response_format == "content_and_artifact":
|
||||||
|
return (content, docs)
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
def create_retriever_tool(
|
def create_retriever_tool(
|
||||||
retriever: BaseRetriever,
|
retriever: BaseRetriever,
|
||||||
@ -55,6 +67,7 @@ def create_retriever_tool(
|
|||||||
*,
|
*,
|
||||||
document_prompt: Optional[BasePromptTemplate] = None,
|
document_prompt: Optional[BasePromptTemplate] = None,
|
||||||
document_separator: str = "\n\n",
|
document_separator: str = "\n\n",
|
||||||
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
) -> Tool:
|
) -> Tool:
|
||||||
"""Create a tool to do retrieval of documents.
|
"""Create a tool to do retrieval of documents.
|
||||||
|
|
||||||
@ -66,6 +79,11 @@ def create_retriever_tool(
|
|||||||
model, so should be descriptive.
|
model, so should be descriptive.
|
||||||
document_prompt: The prompt to use for the document. Defaults to None.
|
document_prompt: The prompt to use for the document. Defaults to None.
|
||||||
document_separator: The separator to use between documents. Defaults to "\n\n".
|
document_separator: The separator to use between documents. Defaults to "\n\n".
|
||||||
|
response_format: The tool response format. If "content" then the output of
|
||||||
|
the tool is interpreted as the contents of a ToolMessage. If
|
||||||
|
"content_and_artifact" then the output is expected to be a two-tuple
|
||||||
|
corresponding to the (content, artifact) of a ToolMessage (artifact
|
||||||
|
being a list of documents in this case). Defaults to "content".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tool class to pass to an agent.
|
Tool class to pass to an agent.
|
||||||
@ -76,12 +94,14 @@ def create_retriever_tool(
|
|||||||
retriever=retriever,
|
retriever=retriever,
|
||||||
document_prompt=document_prompt,
|
document_prompt=document_prompt,
|
||||||
document_separator=document_separator,
|
document_separator=document_separator,
|
||||||
|
response_format=response_format,
|
||||||
)
|
)
|
||||||
afunc = partial(
|
afunc = partial(
|
||||||
_aget_relevant_documents,
|
_aget_relevant_documents,
|
||||||
retriever=retriever,
|
retriever=retriever,
|
||||||
document_prompt=document_prompt,
|
document_prompt=document_prompt,
|
||||||
document_separator=document_separator,
|
document_separator=document_separator,
|
||||||
|
response_format=response_format,
|
||||||
)
|
)
|
||||||
return Tool(
|
return Tool(
|
||||||
name=name,
|
name=name,
|
||||||
@ -89,4 +109,5 @@ def create_retriever_tool(
|
|||||||
func=func,
|
func=func,
|
||||||
coroutine=afunc,
|
coroutine=afunc,
|
||||||
args_schema=RetrieverInput,
|
args_schema=RetrieverInput,
|
||||||
|
response_format=response_format,
|
||||||
)
|
)
|
||||||
|
@ -30,8 +30,13 @@ from langchain_core.callbacks import (
|
|||||||
AsyncCallbackManagerForToolRun,
|
AsyncCallbackManagerForToolRun,
|
||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
)
|
)
|
||||||
from langchain_core.messages import ToolMessage
|
from langchain_core.callbacks.manager import (
|
||||||
|
CallbackManagerForRetrieverRun,
|
||||||
|
)
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.messages import ToolCall, ToolMessage
|
||||||
from langchain_core.messages.tool import ToolOutputMixin
|
from langchain_core.messages.tool import ToolOutputMixin
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
from langchain_core.runnables import (
|
from langchain_core.runnables import (
|
||||||
Runnable,
|
Runnable,
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
@ -2118,6 +2123,57 @@ def test_tool_annotations_preserved() -> None:
|
|||||||
assert schema.__annotations__ == expected_type_hints
|
assert schema.__annotations__ == expected_type_hints
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_retriever_tool() -> None:
|
||||||
|
class MyRetriever(BaseRetriever):
|
||||||
|
def _get_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
|
) -> list[Document]:
|
||||||
|
return [Document(page_content=f"foo {query}"), Document(page_content="bar")]
|
||||||
|
|
||||||
|
retriever = MyRetriever()
|
||||||
|
retriever_tool = tools.create_retriever_tool(
|
||||||
|
retriever, "retriever_tool_content", "Retriever Tool Content"
|
||||||
|
)
|
||||||
|
assert isinstance(retriever_tool, BaseTool)
|
||||||
|
assert retriever_tool.name == "retriever_tool_content"
|
||||||
|
assert retriever_tool.description == "Retriever Tool Content"
|
||||||
|
assert retriever_tool.invoke("bar") == "foo bar\n\nbar"
|
||||||
|
assert retriever_tool.invoke(
|
||||||
|
ToolCall(
|
||||||
|
name="retriever_tool_content",
|
||||||
|
args={"query": "bar"},
|
||||||
|
id="123",
|
||||||
|
type="tool_call",
|
||||||
|
)
|
||||||
|
) == ToolMessage(
|
||||||
|
"foo bar\n\nbar", tool_call_id="123", name="retriever_tool_content"
|
||||||
|
)
|
||||||
|
|
||||||
|
retriever_tool_artifact = tools.create_retriever_tool(
|
||||||
|
retriever,
|
||||||
|
"retriever_tool_artifact",
|
||||||
|
"Retriever Tool Artifact",
|
||||||
|
response_format="content_and_artifact",
|
||||||
|
)
|
||||||
|
assert isinstance(retriever_tool_artifact, BaseTool)
|
||||||
|
assert retriever_tool_artifact.name == "retriever_tool_artifact"
|
||||||
|
assert retriever_tool_artifact.description == "Retriever Tool Artifact"
|
||||||
|
assert retriever_tool_artifact.invoke("bar") == "foo bar\n\nbar"
|
||||||
|
assert retriever_tool_artifact.invoke(
|
||||||
|
ToolCall(
|
||||||
|
name="retriever_tool_artifact",
|
||||||
|
args={"query": "bar"},
|
||||||
|
id="123",
|
||||||
|
type="tool_call",
|
||||||
|
)
|
||||||
|
) == ToolMessage(
|
||||||
|
"foo bar\n\nbar",
|
||||||
|
artifact=[Document(page_content="foo bar"), Document(page_content="bar")],
|
||||||
|
tool_call_id="123",
|
||||||
|
name="retriever_tool_artifact",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
|
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
|
||||||
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
||||||
from pydantic import BaseModel as BaseModelV2
|
from pydantic import BaseModel as BaseModelV2
|
||||||
|
Loading…
Reference in New Issue
Block a user