mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 15:35:14 +00:00
core[patch]: fix runnable history and add docs (#22283)
This commit is contained in:
File diff suppressed because it is too large
Load Diff
BIN
docs/static/img/message_history.png
vendored
Normal file
BIN
docs/static/img/message_history.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 39 KiB |
@@ -372,6 +372,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
) -> List[BaseMessage]:
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
# If dictionary, try to pluck the single key representing messages
|
||||
if isinstance(input_val, dict):
|
||||
if self.input_messages_key:
|
||||
key = self.input_messages_key
|
||||
@@ -381,13 +382,25 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
key = "input"
|
||||
input_val = input_val[key]
|
||||
|
||||
# If value is a string, convert to a human message
|
||||
if isinstance(input_val, str):
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
return [HumanMessage(content=input_val)]
|
||||
# If value is a single message, convert to a list
|
||||
elif isinstance(input_val, BaseMessage):
|
||||
return [input_val]
|
||||
# If value is a list or tuple...
|
||||
elif isinstance(input_val, (list, tuple)):
|
||||
# Handle empty case
|
||||
if len(input_val) == 0:
|
||||
return list(input_val)
|
||||
# If is a list of list, then return the first value
|
||||
# This occurs for chat models - since we batch inputs
|
||||
if isinstance(input_val[0], list):
|
||||
if len(input_val) != 1:
|
||||
raise ValueError()
|
||||
return input_val[0]
|
||||
return list(input_val)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -400,6 +413,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
) -> List[BaseMessage]:
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
# If dictionary, try to pluck the single key representing messages
|
||||
if isinstance(output_val, dict):
|
||||
if self.output_messages_key:
|
||||
key = self.output_messages_key
|
||||
@@ -418,6 +432,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
return [AIMessage(content=output_val)]
|
||||
# If value is a single message, convert to a list
|
||||
elif isinstance(output_val, BaseMessage):
|
||||
return [output_val]
|
||||
elif isinstance(output_val, (list, tuple)):
|
||||
@@ -431,7 +446,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
|
||||
if not self.history_messages_key:
|
||||
# return all messages
|
||||
messages += self._get_input_messages(input)
|
||||
input_val = (
|
||||
input if not self.input_messages_key else input[self.input_messages_key]
|
||||
)
|
||||
messages += self._get_input_messages(input_val)
|
||||
return messages
|
||||
|
||||
async def _aenter_history(
|
||||
@@ -454,7 +472,6 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
# Get the input messages
|
||||
inputs = load(run.inputs)
|
||||
input_messages = self._get_input_messages(inputs)
|
||||
|
||||
# If historic messages were prepended to the input messages, remove them to
|
||||
# avoid adding duplicate messages to history.
|
||||
if not self.history_messages_key:
|
||||
|
@@ -48,7 +48,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
_schema_format: Literal["original", "streaming_events"] = "original",
|
||||
_schema_format: Literal[
|
||||
"original", "streaming_events", "original+chat"
|
||||
] = "original",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the tracer.
|
||||
@@ -63,6 +65,8 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
for internal usage. It will likely change in the future, or
|
||||
be deprecated entirely in favor of a dedicated async tracer
|
||||
for streaming events.
|
||||
- 'original+chat' is a format that is the same as 'original'
|
||||
except it does NOT raise an attribute error on_chat_model_start
|
||||
kwargs: Additional keyword arguments that will be passed to
|
||||
the super class.
|
||||
"""
|
||||
@@ -163,7 +167,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for an LLM run."""
|
||||
if self._schema_format != "streaming_events":
|
||||
if self._schema_format not in ("streaming_events", "original+chat"):
|
||||
# Please keep this un-implemented for backwards compatibility.
|
||||
# When it's unimplemented old tracers that use the "original" format
|
||||
# fallback on the on_llm_start method implementation if they
|
||||
@@ -360,7 +364,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
def _get_chain_inputs(self, inputs: Any) -> Any:
|
||||
"""Get the inputs for a chain run."""
|
||||
if self._schema_format == "original":
|
||||
if self._schema_format in ("original", "original+chat"):
|
||||
return inputs if isinstance(inputs, dict) else {"input": inputs}
|
||||
elif self._schema_format == "streaming_events":
|
||||
return {
|
||||
@@ -371,7 +375,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
def _get_chain_outputs(self, outputs: Any) -> Any:
|
||||
"""Get the outputs for a chain run."""
|
||||
if self._schema_format == "original":
|
||||
if self._schema_format in ("original", "original+chat"):
|
||||
return outputs if isinstance(outputs, dict) else {"output": outputs}
|
||||
elif self._schema_format == "streaming_events":
|
||||
return {
|
||||
@@ -436,7 +440,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
|
||||
if self._schema_format == "original":
|
||||
if self._schema_format in ("original", "original+chat"):
|
||||
inputs = {"input": input_str}
|
||||
elif self._schema_format == "streaming_events":
|
||||
inputs = {"input": inputs}
|
||||
|
@@ -482,7 +482,7 @@ def _get_standardized_inputs(
|
||||
|
||||
|
||||
def _get_standardized_outputs(
|
||||
run: Run, schema_format: Literal["original", "streaming_events"]
|
||||
run: Run, schema_format: Literal["original", "streaming_events", "original+chat"]
|
||||
) -> Optional[Any]:
|
||||
"""Extract standardized output from a run.
|
||||
|
||||
|
@@ -22,7 +22,7 @@ class RootListenersTracer(BaseTracer):
|
||||
on_end: Optional[Listener],
|
||||
on_error: Optional[Listener],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
super().__init__(_schema_format="original+chat")
|
||||
|
||||
self.config = config
|
||||
self._arg_on_start = on_start
|
||||
|
@@ -17,6 +17,8 @@ class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a self-created message to the store"""
|
||||
if not isinstance(message, BaseMessage):
|
||||
raise ValueError
|
||||
self.messages.append(message)
|
||||
|
||||
def clear(self) -> None:
|
||||
|
Reference in New Issue
Block a user