mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 23:13:31 +00:00
core[patch]: fix runnable history and add docs (#22283)
This commit is contained in:
File diff suppressed because it is too large
Load Diff
BIN
docs/static/img/message_history.png
vendored
Normal file
BIN
docs/static/img/message_history.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 39 KiB |
@@ -372,6 +372,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
) -> List[BaseMessage]:
|
) -> List[BaseMessage]:
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
|
# If dictionary, try to pluck the single key representing messages
|
||||||
if isinstance(input_val, dict):
|
if isinstance(input_val, dict):
|
||||||
if self.input_messages_key:
|
if self.input_messages_key:
|
||||||
key = self.input_messages_key
|
key = self.input_messages_key
|
||||||
@@ -381,13 +382,25 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
key = "input"
|
key = "input"
|
||||||
input_val = input_val[key]
|
input_val = input_val[key]
|
||||||
|
|
||||||
|
# If value is a string, convert to a human message
|
||||||
if isinstance(input_val, str):
|
if isinstance(input_val, str):
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
return [HumanMessage(content=input_val)]
|
return [HumanMessage(content=input_val)]
|
||||||
|
# If value is a single message, convert to a list
|
||||||
elif isinstance(input_val, BaseMessage):
|
elif isinstance(input_val, BaseMessage):
|
||||||
return [input_val]
|
return [input_val]
|
||||||
|
# If value is a list or tuple...
|
||||||
elif isinstance(input_val, (list, tuple)):
|
elif isinstance(input_val, (list, tuple)):
|
||||||
|
# Handle empty case
|
||||||
|
if len(input_val) == 0:
|
||||||
|
return list(input_val)
|
||||||
|
# If is a list of list, then return the first value
|
||||||
|
# This occurs for chat models - since we batch inputs
|
||||||
|
if isinstance(input_val[0], list):
|
||||||
|
if len(input_val) != 1:
|
||||||
|
raise ValueError()
|
||||||
|
return input_val[0]
|
||||||
return list(input_val)
|
return list(input_val)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -400,6 +413,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
) -> List[BaseMessage]:
|
) -> List[BaseMessage]:
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
|
# If dictionary, try to pluck the single key representing messages
|
||||||
if isinstance(output_val, dict):
|
if isinstance(output_val, dict):
|
||||||
if self.output_messages_key:
|
if self.output_messages_key:
|
||||||
key = self.output_messages_key
|
key = self.output_messages_key
|
||||||
@@ -418,6 +432,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
return [AIMessage(content=output_val)]
|
return [AIMessage(content=output_val)]
|
||||||
|
# If value is a single message, convert to a list
|
||||||
elif isinstance(output_val, BaseMessage):
|
elif isinstance(output_val, BaseMessage):
|
||||||
return [output_val]
|
return [output_val]
|
||||||
elif isinstance(output_val, (list, tuple)):
|
elif isinstance(output_val, (list, tuple)):
|
||||||
@@ -431,7 +446,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
|
|
||||||
if not self.history_messages_key:
|
if not self.history_messages_key:
|
||||||
# return all messages
|
# return all messages
|
||||||
messages += self._get_input_messages(input)
|
input_val = (
|
||||||
|
input if not self.input_messages_key else input[self.input_messages_key]
|
||||||
|
)
|
||||||
|
messages += self._get_input_messages(input_val)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def _aenter_history(
|
async def _aenter_history(
|
||||||
@@ -454,7 +472,6 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
# Get the input messages
|
# Get the input messages
|
||||||
inputs = load(run.inputs)
|
inputs = load(run.inputs)
|
||||||
input_messages = self._get_input_messages(inputs)
|
input_messages = self._get_input_messages(inputs)
|
||||||
|
|
||||||
# If historic messages were prepended to the input messages, remove them to
|
# If historic messages were prepended to the input messages, remove them to
|
||||||
# avoid adding duplicate messages to history.
|
# avoid adding duplicate messages to history.
|
||||||
if not self.history_messages_key:
|
if not self.history_messages_key:
|
||||||
|
@@ -48,7 +48,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
_schema_format: Literal["original", "streaming_events"] = "original",
|
_schema_format: Literal[
|
||||||
|
"original", "streaming_events", "original+chat"
|
||||||
|
] = "original",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the tracer.
|
"""Initialize the tracer.
|
||||||
@@ -63,6 +65,8 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
for internal usage. It will likely change in the future, or
|
for internal usage. It will likely change in the future, or
|
||||||
be deprecated entirely in favor of a dedicated async tracer
|
be deprecated entirely in favor of a dedicated async tracer
|
||||||
for streaming events.
|
for streaming events.
|
||||||
|
- 'original+chat' is a format that is the same as 'original'
|
||||||
|
except it does NOT raise an attribute error on_chat_model_start
|
||||||
kwargs: Additional keyword arguments that will be passed to
|
kwargs: Additional keyword arguments that will be passed to
|
||||||
the super class.
|
the super class.
|
||||||
"""
|
"""
|
||||||
@@ -163,7 +167,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Start a trace for an LLM run."""
|
"""Start a trace for an LLM run."""
|
||||||
if self._schema_format != "streaming_events":
|
if self._schema_format not in ("streaming_events", "original+chat"):
|
||||||
# Please keep this un-implemented for backwards compatibility.
|
# Please keep this un-implemented for backwards compatibility.
|
||||||
# When it's unimplemented old tracers that use the "original" format
|
# When it's unimplemented old tracers that use the "original" format
|
||||||
# fallback on the on_llm_start method implementation if they
|
# fallback on the on_llm_start method implementation if they
|
||||||
@@ -360,7 +364,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
|
|
||||||
def _get_chain_inputs(self, inputs: Any) -> Any:
|
def _get_chain_inputs(self, inputs: Any) -> Any:
|
||||||
"""Get the inputs for a chain run."""
|
"""Get the inputs for a chain run."""
|
||||||
if self._schema_format == "original":
|
if self._schema_format in ("original", "original+chat"):
|
||||||
return inputs if isinstance(inputs, dict) else {"input": inputs}
|
return inputs if isinstance(inputs, dict) else {"input": inputs}
|
||||||
elif self._schema_format == "streaming_events":
|
elif self._schema_format == "streaming_events":
|
||||||
return {
|
return {
|
||||||
@@ -371,7 +375,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
|
|
||||||
def _get_chain_outputs(self, outputs: Any) -> Any:
|
def _get_chain_outputs(self, outputs: Any) -> Any:
|
||||||
"""Get the outputs for a chain run."""
|
"""Get the outputs for a chain run."""
|
||||||
if self._schema_format == "original":
|
if self._schema_format in ("original", "original+chat"):
|
||||||
return outputs if isinstance(outputs, dict) else {"output": outputs}
|
return outputs if isinstance(outputs, dict) else {"output": outputs}
|
||||||
elif self._schema_format == "streaming_events":
|
elif self._schema_format == "streaming_events":
|
||||||
return {
|
return {
|
||||||
@@ -436,7 +440,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
if metadata:
|
if metadata:
|
||||||
kwargs.update({"metadata": metadata})
|
kwargs.update({"metadata": metadata})
|
||||||
|
|
||||||
if self._schema_format == "original":
|
if self._schema_format in ("original", "original+chat"):
|
||||||
inputs = {"input": input_str}
|
inputs = {"input": input_str}
|
||||||
elif self._schema_format == "streaming_events":
|
elif self._schema_format == "streaming_events":
|
||||||
inputs = {"input": inputs}
|
inputs = {"input": inputs}
|
||||||
|
@@ -482,7 +482,7 @@ def _get_standardized_inputs(
|
|||||||
|
|
||||||
|
|
||||||
def _get_standardized_outputs(
|
def _get_standardized_outputs(
|
||||||
run: Run, schema_format: Literal["original", "streaming_events"]
|
run: Run, schema_format: Literal["original", "streaming_events", "original+chat"]
|
||||||
) -> Optional[Any]:
|
) -> Optional[Any]:
|
||||||
"""Extract standardized output from a run.
|
"""Extract standardized output from a run.
|
||||||
|
|
||||||
|
@@ -22,7 +22,7 @@ class RootListenersTracer(BaseTracer):
|
|||||||
on_end: Optional[Listener],
|
on_end: Optional[Listener],
|
||||||
on_error: Optional[Listener],
|
on_error: Optional[Listener],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__(_schema_format="original+chat")
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self._arg_on_start = on_start
|
self._arg_on_start = on_start
|
||||||
|
@@ -17,6 +17,8 @@ class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
|
|||||||
|
|
||||||
def add_message(self, message: BaseMessage) -> None:
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
"""Add a self-created message to the store"""
|
"""Add a self-created message to the store"""
|
||||||
|
if not isinstance(message, BaseMessage):
|
||||||
|
raise ValueError
|
||||||
self.messages.append(message)
|
self.messages.append(message)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
|
Reference in New Issue
Block a user