This commit is contained in:
Eugene Yurtsev 2025-07-30 15:34:31 -04:00
parent 38f52bcf26
commit e812cc3e98

View File

@ -1,21 +1,50 @@
"""Inline Summarization Chain""" """Inline Summarization Chain
TODO(Eugene): see if we can add annotations / citations.
"""
from __future__ import annotations
from typing import NotRequired, cast from typing import NotRequired, cast
from typing import TypedDict
from langgraph.graph import StateGraph, END
from langgraph.pregel import Pregel
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import MessageLikeRepresentation, AIMessage from langchain_core.messages import AIMessage, MessageLikeRepresentation
from langgraph.graph import END, StateGraph
from langgraph.pregel import Pregel
from typing_extensions import TypedDict
from langchain._internal.utils import RunnableCallable
class InlineSummarizationState(TypedDict): class InlineSummarizationState(TypedDict):
"""State for inline summarization.""" """State for inline summarization."""
documents: list[Document] documents: list[Document]
"""List of documents to summarize."""
summary: NotRequired[str] summary: NotRequired[str]
"""Summary of the documents, available after summarization."""
class InputSchema(TypedDict):
"""Input for the inline summarization chain."""
documents: list[Document]
"""List of documents to summarize."""
class OutputSchema(TypedDict):
"""Output of the inline summarization chain."""
summary: str
"""Summary of the documents."""
class SummarizationNodeUpdate(TypedDict):
"""Update for the summarization node."""
summary: str
"""Summary of the documents."""
class InlineSummarizer: class InlineSummarizer:
@ -54,28 +83,40 @@ class InlineSummarizer:
}, },
] ]
def _summarize_node(self, state: InlineSummarizationState) -> TypedDict( def create_summarization_node(
"Update", {"summary": str} self,
): ) -> RunnableCallable[InlineSummarizationState, SummarizationNodeUpdate]:
"""Creates a node for inline summarization."""
def _summarize_node(state: InlineSummarizationState) -> SummarizationNodeUpdate:
"""Builds a LangGraph for inline summarization.""" """Builds a LangGraph for inline summarization."""
prompt = self._get_prompt(state) prompt = self._get_prompt(state)
response = cast(AIMessage, self.model.invoke(prompt)) response = cast("AIMessage", self.model.invoke(prompt))
return {"summary": response.text()} return {"summary": response.text()}
async def _asummarize_node(self, state: InlineSummarizationState) -> TypedDict( async def _asummarize_node(
"Update", {"summary": str} state: InlineSummarizationState,
): ) -> SummarizationNodeUpdate:
"""Asynchronous version of the summarize node.""" """Asynchronous version of the summarize node."""
prompt = self._get_prompt(state) prompt = self._get_prompt(state)
response = cast(AIMessage, await self.model.ainvoke(prompt)) response = cast("AIMessage", await self.model.ainvoke(prompt))
return {"summary": response.text()} return {"summary": response.text()}
return RunnableCallable[InlineSummarizationState, SummarizationNodeUpdate](
_summarize_node,
_asummarize_node,
)
def build(self) -> Pregel: def build(self) -> Pregel:
"""Builds the LangGraph for inline summarization.""" """Builds the LangGraph for inline summarization."""
builder = StateGraph(InlineSummarizationState) builder = StateGraph(
builder.add_node("summarize_inline", self._summarize_node) InlineSummarizationState,
builder.set_entry_point("summarize_inline") input_schema=InputSchema,
builder.add_edge("summarize_inline", END) output_schema=OutputSchema,
)
builder.add_node("summarize", self.create_summarization_node())
builder.set_entry_point("summarize")
builder.add_edge("summarize", END)
return builder return builder