core[patch]: fix runnable history and add docs (#22283)

This commit is contained in:
Harrison Chase
2024-05-30 11:26:41 -07:00
committed by GitHub
parent dcec133b85
commit ee32369265
7 changed files with 630 additions and 492 deletions

File diff suppressed because it is too large Load Diff

BIN
docs/static/img/message_history.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

View File

@@ -372,6 +372,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
) -> List[BaseMessage]:
from langchain_core.messages import BaseMessage
# If dictionary, try to pluck the single key representing messages
if isinstance(input_val, dict):
if self.input_messages_key:
key = self.input_messages_key
@@ -381,13 +382,25 @@ class RunnableWithMessageHistory(RunnableBindingBase):
key = "input"
input_val = input_val[key]
# If value is a string, convert to a human message
if isinstance(input_val, str):
from langchain_core.messages import HumanMessage
return [HumanMessage(content=input_val)]
# If value is a single message, convert to a list
elif isinstance(input_val, BaseMessage):
return [input_val]
# If value is a list or tuple...
elif isinstance(input_val, (list, tuple)):
# Handle empty case
if len(input_val) == 0:
return list(input_val)
# If is a list of list, then return the first value
# This occurs for chat models - since we batch inputs
if isinstance(input_val[0], list):
if len(input_val) != 1:
raise ValueError()
return input_val[0]
return list(input_val)
else:
raise ValueError(
@@ -400,6 +413,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
) -> List[BaseMessage]:
from langchain_core.messages import BaseMessage
# If dictionary, try to pluck the single key representing messages
if isinstance(output_val, dict):
if self.output_messages_key:
key = self.output_messages_key
@@ -418,6 +432,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
from langchain_core.messages import AIMessage
return [AIMessage(content=output_val)]
# If value is a single message, convert to a list
elif isinstance(output_val, BaseMessage):
return [output_val]
elif isinstance(output_val, (list, tuple)):
@@ -431,7 +446,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
if not self.history_messages_key:
# return all messages
messages += self._get_input_messages(input)
input_val = (
input if not self.input_messages_key else input[self.input_messages_key]
)
messages += self._get_input_messages(input_val)
return messages
async def _aenter_history(
@@ -454,7 +472,6 @@ class RunnableWithMessageHistory(RunnableBindingBase):
# Get the input messages
inputs = load(run.inputs)
input_messages = self._get_input_messages(inputs)
# If historic messages were prepended to the input messages, remove them to
# avoid adding duplicate messages to history.
if not self.history_messages_key:

View File

@@ -48,7 +48,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
def __init__(
self,
*,
_schema_format: Literal["original", "streaming_events"] = "original",
_schema_format: Literal[
"original", "streaming_events", "original+chat"
] = "original",
**kwargs: Any,
) -> None:
"""Initialize the tracer.
@@ -63,6 +65,8 @@ class BaseTracer(BaseCallbackHandler, ABC):
for internal usage. It will likely change in the future, or
be deprecated entirely in favor of a dedicated async tracer
for streaming events.
- 'original+chat' is a format that is the same as 'original'
except it does NOT raise an attribute error on_chat_model_start
kwargs: Additional keyword arguments that will be passed to
the super class.
"""
@@ -163,7 +167,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> Run:
"""Start a trace for an LLM run."""
if self._schema_format != "streaming_events":
if self._schema_format not in ("streaming_events", "original+chat"):
# Please keep this un-implemented for backwards compatibility.
# When it's unimplemented old tracers that use the "original" format
# fallback on the on_llm_start method implementation if they
@@ -360,7 +364,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def _get_chain_inputs(self, inputs: Any) -> Any:
"""Get the inputs for a chain run."""
if self._schema_format == "original":
if self._schema_format in ("original", "original+chat"):
return inputs if isinstance(inputs, dict) else {"input": inputs}
elif self._schema_format == "streaming_events":
return {
@@ -371,7 +375,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def _get_chain_outputs(self, outputs: Any) -> Any:
"""Get the outputs for a chain run."""
if self._schema_format == "original":
if self._schema_format in ("original", "original+chat"):
return outputs if isinstance(outputs, dict) else {"output": outputs}
elif self._schema_format == "streaming_events":
return {
@@ -436,7 +440,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
if metadata:
kwargs.update({"metadata": metadata})
if self._schema_format == "original":
if self._schema_format in ("original", "original+chat"):
inputs = {"input": input_str}
elif self._schema_format == "streaming_events":
inputs = {"input": inputs}

View File

@@ -482,7 +482,7 @@ def _get_standardized_inputs(
def _get_standardized_outputs(
run: Run, schema_format: Literal["original", "streaming_events"]
run: Run, schema_format: Literal["original", "streaming_events", "original+chat"]
) -> Optional[Any]:
"""Extract standardized output from a run.

View File

@@ -22,7 +22,7 @@ class RootListenersTracer(BaseTracer):
on_end: Optional[Listener],
on_error: Optional[Listener],
) -> None:
super().__init__()
super().__init__(_schema_format="original+chat")
self.config = config
self._arg_on_start = on_start

View File

@@ -17,6 +17,8 @@ class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store"""
if not isinstance(message, BaseMessage):
raise ValueError
self.messages.append(message)
def clear(self) -> None: