This commit is contained in:
Eugene Yurtsev 2025-07-30 15:34:31 -04:00
parent 38f52bcf26
commit e812cc3e98

View File

@ -1,21 +1,50 @@
"""Inline Summarization Chain"""
"""Inline Summarization Chain
TODO(Eugene): see if we can add annotations / citations.
"""
from __future__ import annotations
from typing import NotRequired, cast
from typing import TypedDict
from langgraph.graph import StateGraph, END
from langgraph.pregel import Pregel
from langchain_core.documents import Document
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import MessageLikeRepresentation, AIMessage
from langchain_core.messages import AIMessage, MessageLikeRepresentation
from langgraph.graph import END, StateGraph
from langgraph.pregel import Pregel
from typing_extensions import TypedDict
from langchain._internal.utils import RunnableCallable
class InlineSummarizationState(TypedDict):
"""State for inline summarization."""
documents: list[Document]
"""List of documents to summarize."""
summary: NotRequired[str]
"""Summary of the documents, available after summarization."""
class InputSchema(TypedDict):
"""Input for the inline summarization chain."""
documents: list[Document]
"""List of documents to summarize."""
class OutputSchema(TypedDict):
"""Output of the inline summarization chain."""
summary: str
"""Summary of the documents."""
class SummarizationNodeUpdate(TypedDict):
"""Update for the summarization node."""
summary: str
"""Summary of the documents."""
class InlineSummarizer:
@ -54,28 +83,40 @@ class InlineSummarizer:
},
]
def _summarize_node(self, state: InlineSummarizationState) -> TypedDict(
"Update", {"summary": str}
):
def create_summarization_node(
self,
) -> RunnableCallable[InlineSummarizationState, SummarizationNodeUpdate]:
"""Creates a node for inline summarization."""
def _summarize_node(state: InlineSummarizationState) -> SummarizationNodeUpdate:
"""Builds a LangGraph for inline summarization."""
prompt = self._get_prompt(state)
response = cast(AIMessage, self.model.invoke(prompt))
response = cast("AIMessage", self.model.invoke(prompt))
return {"summary": response.text()}
async def _asummarize_node(self, state: InlineSummarizationState) -> TypedDict(
"Update", {"summary": str}
):
async def _asummarize_node(
state: InlineSummarizationState,
) -> SummarizationNodeUpdate:
"""Asynchronous version of the summarize node."""
prompt = self._get_prompt(state)
response = cast(AIMessage, await self.model.ainvoke(prompt))
response = cast("AIMessage", await self.model.ainvoke(prompt))
return {"summary": response.text()}
return RunnableCallable[InlineSummarizationState, SummarizationNodeUpdate](
_summarize_node,
_asummarize_node,
)
def build(self) -> Pregel:
"""Builds the LangGraph for inline summarization."""
builder = StateGraph(InlineSummarizationState)
builder.add_node("summarize_inline", self._summarize_node)
builder.set_entry_point("summarize_inline")
builder.add_edge("summarize_inline", END)
builder = StateGraph(
InlineSummarizationState,
input_schema=InputSchema,
output_schema=OutputSchema,
)
builder.add_node("summarize", self.create_summarization_node())
builder.set_entry_point("summarize")
builder.add_edge("summarize", END)
return builder