diff --git a/libs/langchain_v1/langchain/chains/summarization.py b/libs/langchain_v1/langchain/chains/summarization.py index 4d0c2022d3b..4d6e195f863 100644 --- a/libs/langchain_v1/langchain/chains/summarization.py +++ b/libs/langchain_v1/langchain/chains/summarization.py @@ -1,21 +1,50 @@ -"""Inline Summarization Chain""" +"""Inline Summarization Chain + +TODO(Eugene): see if we can add annotations / citations. +""" + +from __future__ import annotations from typing import NotRequired, cast -from typing import TypedDict - -from langgraph.graph import StateGraph, END -from langgraph.pregel import Pregel from langchain_core.documents import Document from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import MessageLikeRepresentation, AIMessage +from langchain_core.messages import AIMessage, MessageLikeRepresentation +from langgraph.graph import END, StateGraph +from langgraph.pregel import Pregel +from typing_extensions import TypedDict + +from langchain._internal.utils import RunnableCallable class InlineSummarizationState(TypedDict): """State for inline summarization.""" documents: list[Document] + """List of documents to summarize.""" summary: NotRequired[str] + """Summary of the documents, available after summarization.""" + + +class InputSchema(TypedDict): + """Input for the inline summarization chain.""" + + documents: list[Document] + """List of documents to summarize.""" + + +class OutputSchema(TypedDict): + """Output of the inline summarization chain.""" + + summary: str + """Summary of the documents.""" + + +class SummarizationNodeUpdate(TypedDict): + """Update for the summarization node.""" + + summary: str + """Summary of the documents.""" class InlineSummarizer: @@ -54,28 +83,40 @@ class InlineSummarizer: }, ] - def _summarize_node(self, state: InlineSummarizationState) -> TypedDict( - "Update", {"summary": str} - ): - """Builds a LangGraph for inline summarization.""" - prompt = self._get_prompt(state) - response = cast(AIMessage, self.model.invoke(prompt)) - return {"summary": response.text()} + def create_summarization_node( + self, + ) -> RunnableCallable[InlineSummarizationState, SummarizationNodeUpdate]: + """Creates a node for inline summarization.""" - async def _asummarize_node(self, state: InlineSummarizationState) -> TypedDict( - "Update", {"summary": str} - ): - """Asynchronous version of the summarize node.""" - prompt = self._get_prompt(state) - response = cast(AIMessage, await self.model.ainvoke(prompt)) - return {"summary": response.text()} + def _summarize_node(state: InlineSummarizationState) -> SummarizationNodeUpdate: + """Builds a LangGraph for inline summarization.""" + prompt = self._get_prompt(state) + response = cast("AIMessage", self.model.invoke(prompt)) + return {"summary": response.text()} + + async def _asummarize_node( + state: InlineSummarizationState, + ) -> SummarizationNodeUpdate: + """Asynchronous version of the summarize node.""" + prompt = self._get_prompt(state) + response = cast("AIMessage", await self.model.ainvoke(prompt)) + return {"summary": response.text()} + + return RunnableCallable[InlineSummarizationState, SummarizationNodeUpdate]( + _summarize_node, + _asummarize_node, + ) def build(self) -> Pregel: """Builds the LangGraph for inline summarization.""" - builder = StateGraph(InlineSummarizationState) - builder.add_node("summarize_inline", self._summarize_node) - builder.set_entry_point("summarize_inline") - builder.add_edge("summarize_inline", END) + builder = StateGraph( + InlineSummarizationState, + input_schema=InputSchema, + output_schema=OutputSchema, + ) + builder.add_node("summarize", self.create_summarization_node()) + builder.set_entry_point("summarize") + builder.add_edge("summarize", END) return builder