mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 23:54:14 +00:00
Nc/small fixes 21aug (#9542)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. These live is docs/extras directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17, @rlancemartin. -->
This commit is contained in:
parent
a7eba8b006
commit
28e1ee4891
@ -715,11 +715,11 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
|||||||
class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||||
"""Callback manager for chain run."""
|
"""Callback manager for chain run."""
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
|
||||||
"""Run when chain ends running.
|
"""Run when chain ends running.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
outputs (Dict[str, Any]): The outputs of the chain.
|
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
|
||||||
"""
|
"""
|
||||||
_handle_event(
|
_handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
@ -797,11 +797,13 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
|||||||
class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||||
"""Async callback manager for chain run."""
|
"""Async callback manager for chain run."""
|
||||||
|
|
||||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
async def on_chain_end(
|
||||||
|
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
"""Run when chain ends running.
|
"""Run when chain ends running.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
outputs (Dict[str, Any]): The outputs of the chain.
|
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
|
||||||
"""
|
"""
|
||||||
await _ahandle_event(
|
await _ahandle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
@ -1144,7 +1146,7 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
inputs: Dict[str, Any],
|
inputs: Union[Dict[str, Any], Any],
|
||||||
run_id: Optional[UUID] = None,
|
run_id: Optional[UUID] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> CallbackManagerForChainRun:
|
) -> CallbackManagerForChainRun:
|
||||||
@ -1152,7 +1154,7 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
serialized (Dict[str, Any]): The serialized chain.
|
serialized (Dict[str, Any]): The serialized chain.
|
||||||
inputs (Dict[str, Any]): The inputs to the chain.
|
inputs (Union[Dict[str, Any], Any]): The inputs to the chain.
|
||||||
run_id (UUID, optional): The ID of the run. Defaults to None.
|
run_id (UUID, optional): The ID of the run. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -1433,7 +1435,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
async def on_chain_start(
|
async def on_chain_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
inputs: Dict[str, Any],
|
inputs: Union[Dict[str, Any], Any],
|
||||||
run_id: Optional[UUID] = None,
|
run_id: Optional[UUID] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncCallbackManagerForChainRun:
|
) -> AsyncCallbackManagerForChainRun:
|
||||||
@ -1441,7 +1443,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
serialized (Dict[str, Any]): The serialized chain.
|
serialized (Dict[str, Any]): The serialized chain.
|
||||||
inputs (Dict[str, Any]): The inputs to the chain.
|
inputs (Union[Dict[str, Any], Any]): The inputs to the chain.
|
||||||
run_id (UUID, optional): The ID of the run. Defaults to None.
|
run_id (UUID, optional): The ID of the run. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -231,6 +231,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
run_type: Optional[str] = None,
|
run_type: Optional[str] = None,
|
||||||
|
name: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a trace for a chain run."""
|
"""Start a trace for a chain run."""
|
||||||
@ -243,7 +244,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
id=run_id,
|
id=run_id,
|
||||||
parent_run_id=parent_run_id,
|
parent_run_id=parent_run_id,
|
||||||
serialized=serialized,
|
serialized=serialized,
|
||||||
inputs=inputs,
|
inputs=inputs if isinstance(inputs, dict) else {"input": inputs},
|
||||||
extra=kwargs,
|
extra=kwargs,
|
||||||
events=[{"name": "start", "time": start_time}],
|
events=[{"name": "start", "time": start_time}],
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
@ -251,6 +252,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
child_execution_order=execution_order,
|
child_execution_order=execution_order,
|
||||||
child_runs=[],
|
child_runs=[],
|
||||||
run_type=run_type or "chain",
|
run_type=run_type or "chain",
|
||||||
|
name=name,
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
)
|
)
|
||||||
self._start_trace(chain_run)
|
self._start_trace(chain_run)
|
||||||
@ -271,11 +273,13 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
if chain_run is None:
|
if chain_run is None:
|
||||||
raise TracerException(f"No chain Run found to be traced for {run_id}")
|
raise TracerException(f"No chain Run found to be traced for {run_id}")
|
||||||
|
|
||||||
chain_run.outputs = outputs
|
chain_run.outputs = (
|
||||||
|
outputs if isinstance(outputs, dict) else {"output": outputs}
|
||||||
|
)
|
||||||
chain_run.end_time = datetime.utcnow()
|
chain_run.end_time = datetime.utcnow()
|
||||||
chain_run.events.append({"name": "end", "time": chain_run.end_time})
|
chain_run.events.append({"name": "end", "time": chain_run.end_time})
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
chain_run.inputs = inputs
|
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
|
||||||
self._end_trace(chain_run)
|
self._end_trace(chain_run)
|
||||||
self._on_chain_end(chain_run)
|
self._on_chain_end(chain_run)
|
||||||
|
|
||||||
@ -298,7 +302,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
chain_run.end_time = datetime.utcnow()
|
chain_run.end_time = datetime.utcnow()
|
||||||
chain_run.events.append({"name": "error", "time": chain_run.end_time})
|
chain_run.events.append({"name": "error", "time": chain_run.end_time})
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
chain_run.inputs = inputs
|
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
|
||||||
self._end_trace(chain_run)
|
self._end_trace(chain_run)
|
||||||
self._on_chain_error(chain_run)
|
self._on_chain_error(chain_run)
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ from pathlib import Path
|
|||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
|
Dict,
|
||||||
List,
|
List,
|
||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
@ -298,6 +299,15 @@ class ChatPromptValue(PromptValue):
|
|||||||
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||||
"""Base class for chat prompt templates."""
|
"""Base class for chat prompt templates."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_attributes(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Return a list of attribute names that should be included in the
|
||||||
|
serialized kwargs. These attributes must be accepted by the
|
||||||
|
constructor.
|
||||||
|
"""
|
||||||
|
return {"input_variables": self.input_variables}
|
||||||
|
|
||||||
def format(self, **kwargs: Any) -> str:
|
def format(self, **kwargs: Any) -> str:
|
||||||
"""Format the chat template into a string.
|
"""Format the chat template into a string.
|
||||||
|
|
||||||
@ -419,7 +429,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
f"Got: {values['input_variables']}"
|
f"Got: {values['input_variables']}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
values["input_variables"] = list(input_vars)
|
values["input_variables"] = sorted(input_vars)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -266,7 +266,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
)
|
)
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
input if isinstance(input, dict) else {"input": input},
|
input,
|
||||||
run_type=run_type,
|
run_type=run_type,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
@ -284,12 +284,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
output_for_tracer = dumpd(output)
|
run_manager.on_chain_end(dumpd(output))
|
||||||
run_manager.on_chain_end(
|
|
||||||
output_for_tracer
|
|
||||||
if isinstance(output_for_tracer, dict)
|
|
||||||
else {"output": output_for_tracer}
|
|
||||||
)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
async def _acall_with_config(
|
async def _acall_with_config(
|
||||||
@ -318,7 +313,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
)
|
)
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
input if isinstance(input, dict) else {"input": input},
|
input,
|
||||||
run_type=run_type,
|
run_type=run_type,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
@ -339,12 +334,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
output_for_tracer = dumpd(output)
|
await run_manager.on_chain_end(dumpd(output))
|
||||||
await run_manager.on_chain_end(
|
|
||||||
output_for_tracer
|
|
||||||
if isinstance(output_for_tracer, dict)
|
|
||||||
else {"output": output_for_tracer}
|
|
||||||
)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _transform_stream_with_config(
|
def _transform_stream_with_config(
|
||||||
@ -425,22 +415,10 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
final_input = None
|
final_input = None
|
||||||
final_input_supported = False
|
final_input_supported = False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
run_manager.on_chain_error(
|
run_manager.on_chain_error(e, inputs=final_input)
|
||||||
e,
|
|
||||||
inputs=final_input
|
|
||||||
if isinstance(final_input, dict)
|
|
||||||
else {"input": final_input},
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
run_manager.on_chain_end(
|
run_manager.on_chain_end(final_output, inputs=final_input)
|
||||||
final_output
|
|
||||||
if isinstance(final_output, dict)
|
|
||||||
else {"output": final_output},
|
|
||||||
inputs=final_input
|
|
||||||
if isinstance(final_input, dict)
|
|
||||||
else {"input": final_input},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _atransform_stream_with_config(
|
async def _atransform_stream_with_config(
|
||||||
self,
|
self,
|
||||||
@ -525,22 +503,10 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
final_input = None
|
final_input = None
|
||||||
final_input_supported = False
|
final_input_supported = False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await run_manager.on_chain_error(
|
await run_manager.on_chain_error(e, inputs=final_input)
|
||||||
e,
|
|
||||||
inputs=final_input
|
|
||||||
if isinstance(final_input, dict)
|
|
||||||
else {"input": final_input},
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
await run_manager.on_chain_end(
|
await run_manager.on_chain_end(final_output, inputs=final_input)
|
||||||
final_output
|
|
||||||
if isinstance(final_output, dict)
|
|
||||||
else {"output": final_output},
|
|
||||||
inputs=final_input
|
|
||||||
if isinstance(final_input, dict)
|
|
||||||
else {"input": final_input},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||||
@ -583,9 +549,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
local_metadata=None,
|
local_metadata=None,
|
||||||
)
|
)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
|
||||||
)
|
|
||||||
first_error = None
|
first_error = None
|
||||||
for runnable in self.runnables:
|
for runnable in self.runnables:
|
||||||
try:
|
try:
|
||||||
@ -600,9 +564,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
run_manager.on_chain_end(
|
run_manager.on_chain_end(output)
|
||||||
output if isinstance(output, dict) else {"output": output}
|
|
||||||
)
|
|
||||||
return output
|
return output
|
||||||
if first_error is None:
|
if first_error is None:
|
||||||
raise ValueError("No error stored at end of fallbacks.")
|
raise ValueError("No error stored at end of fallbacks.")
|
||||||
@ -629,9 +591,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
local_metadata=None,
|
local_metadata=None,
|
||||||
)
|
)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
|
||||||
)
|
|
||||||
|
|
||||||
first_error = None
|
first_error = None
|
||||||
for runnable in self.runnables:
|
for runnable in self.runnables:
|
||||||
@ -647,9 +607,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
await run_manager.on_chain_end(
|
await run_manager.on_chain_end(output)
|
||||||
output if isinstance(output, dict) else {"output": output}
|
|
||||||
)
|
|
||||||
return output
|
return output
|
||||||
if first_error is None:
|
if first_error is None:
|
||||||
raise ValueError("No error stored at end of fallbacks.")
|
raise ValueError("No error stored at end of fallbacks.")
|
||||||
@ -709,9 +667,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
for rm, output in zip(run_managers, outputs):
|
for rm, output in zip(run_managers, outputs):
|
||||||
rm.on_chain_end(
|
rm.on_chain_end(output)
|
||||||
output if isinstance(output, dict) else {"output": output}
|
|
||||||
)
|
|
||||||
return outputs
|
return outputs
|
||||||
if first_error is None:
|
if first_error is None:
|
||||||
raise ValueError("No error stored at end of fallbacks.")
|
raise ValueError("No error stored at end of fallbacks.")
|
||||||
@ -749,9 +705,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
# start the root runs, one per input
|
# start the root runs, one per input
|
||||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||||
*(
|
*(
|
||||||
cm.on_chain_start(
|
cm.on_chain_start(dumpd(self), input)
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
|
||||||
)
|
|
||||||
for cm, input in zip(callback_managers, inputs)
|
for cm, input in zip(callback_managers, inputs)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -776,9 +730,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
else:
|
else:
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*(
|
*(
|
||||||
rm.on_chain_end(
|
rm.on_chain_end(output)
|
||||||
output if isinstance(output, dict) else {"output": output}
|
|
||||||
)
|
|
||||||
for rm, output in zip(run_managers, outputs)
|
for rm, output in zip(run_managers, outputs)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -870,9 +822,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
local_metadata=None,
|
local_metadata=None,
|
||||||
)
|
)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
|
||||||
)
|
|
||||||
|
|
||||||
# invoke all steps in sequence
|
# invoke all steps in sequence
|
||||||
try:
|
try:
|
||||||
@ -887,9 +837,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
run_manager.on_chain_end(
|
run_manager.on_chain_end(input)
|
||||||
input if isinstance(input, dict) else {"output": input}
|
|
||||||
)
|
|
||||||
return cast(Output, input)
|
return cast(Output, input)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
@ -912,9 +860,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
local_metadata=None,
|
local_metadata=None,
|
||||||
)
|
)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
|
||||||
)
|
|
||||||
|
|
||||||
# invoke all steps in sequence
|
# invoke all steps in sequence
|
||||||
try:
|
try:
|
||||||
@ -929,9 +875,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
await run_manager.on_chain_end(
|
await run_manager.on_chain_end(input)
|
||||||
input if isinstance(input, dict) else {"output": input}
|
|
||||||
)
|
|
||||||
return cast(Output, input)
|
return cast(Output, input)
|
||||||
|
|
||||||
def batch(
|
def batch(
|
||||||
@ -960,9 +904,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
]
|
]
|
||||||
# start the root runs, one per input
|
# start the root runs, one per input
|
||||||
run_managers = [
|
run_managers = [
|
||||||
cm.on_chain_start(
|
cm.on_chain_start(dumpd(self), input)
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
|
||||||
)
|
|
||||||
for cm, input in zip(callback_managers, inputs)
|
for cm, input in zip(callback_managers, inputs)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -985,7 +927,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
for rm, input in zip(run_managers, inputs):
|
for rm, input in zip(run_managers, inputs):
|
||||||
rm.on_chain_end(input if isinstance(input, dict) else {"output": input})
|
rm.on_chain_end(input)
|
||||||
return cast(List[Output], inputs)
|
return cast(List[Output], inputs)
|
||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
@ -1017,9 +959,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
# start the root runs, one per input
|
# start the root runs, one per input
|
||||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||||
*(
|
*(
|
||||||
cm.on_chain_start(
|
cm.on_chain_start(dumpd(self), input)
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
|
||||||
)
|
|
||||||
for cm, input in zip(callback_managers, inputs)
|
for cm, input in zip(callback_managers, inputs)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -1043,12 +983,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*(
|
*(rm.on_chain_end(input) for rm, input in zip(run_managers, inputs))
|
||||||
rm.on_chain_end(
|
|
||||||
input if isinstance(input, dict) else {"output": input}
|
|
||||||
)
|
|
||||||
for rm, input in zip(run_managers, inputs)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return cast(List[Output], inputs)
|
return cast(List[Output], inputs)
|
||||||
|
|
||||||
@ -1072,9 +1007,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
local_metadata=None,
|
local_metadata=None,
|
||||||
)
|
)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
|
||||||
)
|
|
||||||
|
|
||||||
steps = [self.first] + self.middle + [self.last]
|
steps = [self.first] + self.middle + [self.last]
|
||||||
streaming_start_index = 0
|
streaming_start_index = 0
|
||||||
@ -1128,9 +1061,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
run_manager.on_chain_end(
|
run_manager.on_chain_end(final)
|
||||||
final if isinstance(final, dict) else {"output": final}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def astream(
|
async def astream(
|
||||||
self,
|
self,
|
||||||
@ -1152,9 +1083,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
local_metadata=None,
|
local_metadata=None,
|
||||||
)
|
)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
|
||||||
)
|
|
||||||
|
|
||||||
steps = [self.first] + self.middle + [self.last]
|
steps = [self.first] + self.middle + [self.last]
|
||||||
streaming_start_index = len(steps) - 1
|
streaming_start_index = len(steps) - 1
|
||||||
@ -1208,9 +1137,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
await run_manager.on_chain_end(
|
await run_manager.on_chain_end(final)
|
||||||
final if isinstance(final, dict) else {"output": final}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RunnableMapChunk(Dict[str, Any]):
|
class RunnableMapChunk(Dict[str, Any]):
|
||||||
@ -1277,9 +1204,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
local_metadata=None,
|
local_metadata=None,
|
||||||
)
|
)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
|
||||||
)
|
|
||||||
|
|
||||||
# gather results from all steps
|
# gather results from all steps
|
||||||
try:
|
try:
|
||||||
@ -1324,9 +1249,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
local_metadata=None,
|
local_metadata=None,
|
||||||
)
|
)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
|
||||||
)
|
|
||||||
|
|
||||||
# gather results from all steps
|
# gather results from all steps
|
||||||
try:
|
try:
|
||||||
|
File diff suppressed because one or more lines are too long
@ -21,7 +21,12 @@ from langchain.prompts.chat import (
|
|||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema.document import Document
|
from langchain.schema.document import Document
|
||||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
from langchain.schema.messages import (
|
||||||
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
from langchain.schema.output_parser import StrOutputParser
|
from langchain.schema.output_parser import StrOutputParser
|
||||||
from langchain.schema.retriever import BaseRetriever
|
from langchain.schema.retriever import BaseRetriever
|
||||||
from langchain.schema.runnable import (
|
from langchain.schema.runnable import (
|
||||||
@ -794,7 +799,7 @@ def test_map_stream() -> None:
|
|||||||
assert streamed_chunks[0] in [
|
assert streamed_chunks[0] in [
|
||||||
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
||||||
{"llm": "i"},
|
{"llm": "i"},
|
||||||
{"chat": "i"},
|
{"chat": AIMessageChunk(content="i")},
|
||||||
]
|
]
|
||||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
||||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||||
@ -841,7 +846,11 @@ def test_map_stream_iterator_input() -> None:
|
|||||||
else:
|
else:
|
||||||
final_value += chunk
|
final_value += chunk
|
||||||
|
|
||||||
assert streamed_chunks[0] in [{"passthrough": "i"}, {"llm": "i"}, {"chat": "i"}]
|
assert streamed_chunks[0] in [
|
||||||
|
{"passthrough": "i"},
|
||||||
|
{"llm": "i"},
|
||||||
|
{"chat": AIMessageChunk(content="i")},
|
||||||
|
]
|
||||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
|
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
|
||||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||||
assert final_value is not None
|
assert final_value is not None
|
||||||
@ -885,7 +894,7 @@ async def test_map_astream() -> None:
|
|||||||
assert streamed_chunks[0] in [
|
assert streamed_chunks[0] in [
|
||||||
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
||||||
{"llm": "i"},
|
{"llm": "i"},
|
||||||
{"chat": "i"},
|
{"chat": AIMessageChunk(content="i")},
|
||||||
]
|
]
|
||||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
||||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||||
@ -933,7 +942,11 @@ async def test_map_astream_iterator_input() -> None:
|
|||||||
else:
|
else:
|
||||||
final_value += chunk
|
final_value += chunk
|
||||||
|
|
||||||
assert streamed_chunks[0] in [{"passthrough": "i"}, {"llm": "i"}, {"chat": "i"}]
|
assert streamed_chunks[0] in [
|
||||||
|
{"passthrough": "i"},
|
||||||
|
{"llm": "i"},
|
||||||
|
{"chat": AIMessageChunk(content="i")},
|
||||||
|
]
|
||||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
|
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
|
||||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||||
assert final_value is not None
|
assert final_value is not None
|
||||||
|
Loading…
Reference in New Issue
Block a user