mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 23:29:21 +00:00
Nc/small fixes 21aug (#9542)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. These live is docs/extras directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17, @rlancemartin. -->
This commit is contained in:
parent
a7eba8b006
commit
28e1ee4891
@ -715,11 +715,11 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||
"""Callback manager for chain run."""
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running.
|
||||
|
||||
Args:
|
||||
outputs (Dict[str, Any]): The outputs of the chain.
|
||||
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
|
||||
"""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
@ -797,11 +797,13 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||
class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
"""Async callback manager for chain run."""
|
||||
|
||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
async def on_chain_end(
|
||||
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain ends running.
|
||||
|
||||
Args:
|
||||
outputs (Dict[str, Any]): The outputs of the chain.
|
||||
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
|
||||
"""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
@ -1144,7 +1146,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForChainRun:
|
||||
@ -1152,7 +1154,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
Args:
|
||||
serialized (Dict[str, Any]): The serialized chain.
|
||||
inputs (Dict[str, Any]): The inputs to the chain.
|
||||
inputs (Union[Dict[str, Any], Any]): The inputs to the chain.
|
||||
run_id (UUID, optional): The ID of the run. Defaults to None.
|
||||
|
||||
Returns:
|
||||
@ -1433,7 +1435,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCallbackManagerForChainRun:
|
||||
@ -1441,7 +1443,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
Args:
|
||||
serialized (Dict[str, Any]): The serialized chain.
|
||||
inputs (Dict[str, Any]): The inputs to the chain.
|
||||
inputs (Union[Dict[str, Any], Any]): The inputs to the chain.
|
||||
run_id (UUID, optional): The ID of the run. Defaults to None.
|
||||
|
||||
Returns:
|
||||
|
@ -231,6 +231,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_type: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a trace for a chain run."""
|
||||
@ -243,7 +244,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs=inputs,
|
||||
inputs=inputs if isinstance(inputs, dict) else {"input": inputs},
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
@ -251,6 +252,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
child_execution_order=execution_order,
|
||||
child_runs=[],
|
||||
run_type=run_type or "chain",
|
||||
name=name,
|
||||
tags=tags or [],
|
||||
)
|
||||
self._start_trace(chain_run)
|
||||
@ -271,11 +273,13 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
if chain_run is None:
|
||||
raise TracerException(f"No chain Run found to be traced for {run_id}")
|
||||
|
||||
chain_run.outputs = outputs
|
||||
chain_run.outputs = (
|
||||
outputs if isinstance(outputs, dict) else {"output": outputs}
|
||||
)
|
||||
chain_run.end_time = datetime.utcnow()
|
||||
chain_run.events.append({"name": "end", "time": chain_run.end_time})
|
||||
if inputs is not None:
|
||||
chain_run.inputs = inputs
|
||||
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
|
||||
self._end_trace(chain_run)
|
||||
self._on_chain_end(chain_run)
|
||||
|
||||
@ -298,7 +302,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
chain_run.end_time = datetime.utcnow()
|
||||
chain_run.events.append({"name": "error", "time": chain_run.end_time})
|
||||
if inputs is not None:
|
||||
chain_run.inputs = inputs
|
||||
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
|
||||
self._end_trace(chain_run)
|
||||
self._on_chain_error(chain_run)
|
||||
|
||||
|
@ -6,6 +6,7 @@ from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Sequence,
|
||||
Set,
|
||||
@ -298,6 +299,15 @@ class ChatPromptValue(PromptValue):
|
||||
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""Base class for chat prompt templates."""
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict:
|
||||
"""
|
||||
Return a list of attribute names that should be included in the
|
||||
serialized kwargs. These attributes must be accepted by the
|
||||
constructor.
|
||||
"""
|
||||
return {"input_variables": self.input_variables}
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the chat template into a string.
|
||||
|
||||
@ -419,7 +429,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
f"Got: {values['input_variables']}"
|
||||
)
|
||||
else:
|
||||
values["input_variables"] = list(input_vars)
|
||||
values["input_variables"] = sorted(input_vars)
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
|
@ -266,7 +266,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input if isinstance(input, dict) else {"input": input},
|
||||
input,
|
||||
run_type=run_type,
|
||||
)
|
||||
try:
|
||||
@ -284,12 +284,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
output_for_tracer = dumpd(output)
|
||||
run_manager.on_chain_end(
|
||||
output_for_tracer
|
||||
if isinstance(output_for_tracer, dict)
|
||||
else {"output": output_for_tracer}
|
||||
)
|
||||
run_manager.on_chain_end(dumpd(output))
|
||||
return output
|
||||
|
||||
async def _acall_with_config(
|
||||
@ -318,7 +313,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input if isinstance(input, dict) else {"input": input},
|
||||
input,
|
||||
run_type=run_type,
|
||||
)
|
||||
try:
|
||||
@ -339,12 +334,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
output_for_tracer = dumpd(output)
|
||||
await run_manager.on_chain_end(
|
||||
output_for_tracer
|
||||
if isinstance(output_for_tracer, dict)
|
||||
else {"output": output_for_tracer}
|
||||
)
|
||||
await run_manager.on_chain_end(dumpd(output))
|
||||
return output
|
||||
|
||||
def _transform_stream_with_config(
|
||||
@ -425,22 +415,10 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
final_input = None
|
||||
final_input_supported = False
|
||||
except Exception as e:
|
||||
run_manager.on_chain_error(
|
||||
e,
|
||||
inputs=final_input
|
||||
if isinstance(final_input, dict)
|
||||
else {"input": final_input},
|
||||
)
|
||||
run_manager.on_chain_error(e, inputs=final_input)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(
|
||||
final_output
|
||||
if isinstance(final_output, dict)
|
||||
else {"output": final_output},
|
||||
inputs=final_input
|
||||
if isinstance(final_input, dict)
|
||||
else {"input": final_input},
|
||||
)
|
||||
run_manager.on_chain_end(final_output, inputs=final_input)
|
||||
|
||||
async def _atransform_stream_with_config(
|
||||
self,
|
||||
@ -525,22 +503,10 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
final_input = None
|
||||
final_input_supported = False
|
||||
except Exception as e:
|
||||
await run_manager.on_chain_error(
|
||||
e,
|
||||
inputs=final_input
|
||||
if isinstance(final_input, dict)
|
||||
else {"input": final_input},
|
||||
)
|
||||
await run_manager.on_chain_error(e, inputs=final_input)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(
|
||||
final_output
|
||||
if isinstance(final_output, dict)
|
||||
else {"output": final_output},
|
||||
inputs=final_input
|
||||
if isinstance(final_input, dict)
|
||||
else {"input": final_input},
|
||||
)
|
||||
await run_manager.on_chain_end(final_output, inputs=final_input)
|
||||
|
||||
|
||||
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
@ -583,9 +549,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
@ -600,9 +564,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_chain_end(
|
||||
output if isinstance(output, dict) else {"output": output}
|
||||
)
|
||||
run_manager.on_chain_end(output)
|
||||
return output
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
@ -629,9 +591,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
||||
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
@ -647,9 +607,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_chain_end(
|
||||
output if isinstance(output, dict) else {"output": output}
|
||||
)
|
||||
await run_manager.on_chain_end(output)
|
||||
return output
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
@ -709,9 +667,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
raise e
|
||||
else:
|
||||
for rm, output in zip(run_managers, outputs):
|
||||
rm.on_chain_end(
|
||||
output if isinstance(output, dict) else {"output": output}
|
||||
)
|
||||
rm.on_chain_end(output)
|
||||
return outputs
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
@ -749,9 +705,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
# start the root runs, one per input
|
||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
*(
|
||||
cm.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
cm.on_chain_start(dumpd(self), input)
|
||||
for cm, input in zip(callback_managers, inputs)
|
||||
)
|
||||
)
|
||||
@ -776,9 +730,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
else:
|
||||
await asyncio.gather(
|
||||
*(
|
||||
rm.on_chain_end(
|
||||
output if isinstance(output, dict) else {"output": output}
|
||||
)
|
||||
rm.on_chain_end(output)
|
||||
for rm, output in zip(run_managers, outputs)
|
||||
)
|
||||
)
|
||||
@ -870,9 +822,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
||||
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
@ -887,9 +837,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(
|
||||
input if isinstance(input, dict) else {"output": input}
|
||||
)
|
||||
run_manager.on_chain_end(input)
|
||||
return cast(Output, input)
|
||||
|
||||
async def ainvoke(
|
||||
@ -912,9 +860,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
||||
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
@ -929,9 +875,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(
|
||||
input if isinstance(input, dict) else {"output": input}
|
||||
)
|
||||
await run_manager.on_chain_end(input)
|
||||
return cast(Output, input)
|
||||
|
||||
def batch(
|
||||
@ -960,9 +904,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
]
|
||||
# start the root runs, one per input
|
||||
run_managers = [
|
||||
cm.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
cm.on_chain_start(dumpd(self), input)
|
||||
for cm, input in zip(callback_managers, inputs)
|
||||
]
|
||||
|
||||
@ -985,7 +927,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
raise
|
||||
else:
|
||||
for rm, input in zip(run_managers, inputs):
|
||||
rm.on_chain_end(input if isinstance(input, dict) else {"output": input})
|
||||
rm.on_chain_end(input)
|
||||
return cast(List[Output], inputs)
|
||||
|
||||
async def abatch(
|
||||
@ -1017,9 +959,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
# start the root runs, one per input
|
||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
*(
|
||||
cm.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
cm.on_chain_start(dumpd(self), input)
|
||||
for cm, input in zip(callback_managers, inputs)
|
||||
)
|
||||
)
|
||||
@ -1043,12 +983,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
raise
|
||||
else:
|
||||
await asyncio.gather(
|
||||
*(
|
||||
rm.on_chain_end(
|
||||
input if isinstance(input, dict) else {"output": input}
|
||||
)
|
||||
for rm, input in zip(run_managers, inputs)
|
||||
)
|
||||
*(rm.on_chain_end(input) for rm, input in zip(run_managers, inputs))
|
||||
)
|
||||
return cast(List[Output], inputs)
|
||||
|
||||
@ -1072,9 +1007,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
streaming_start_index = 0
|
||||
@ -1128,9 +1061,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(
|
||||
final if isinstance(final, dict) else {"output": final}
|
||||
)
|
||||
run_manager.on_chain_end(final)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
@ -1152,9 +1083,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
streaming_start_index = len(steps) - 1
|
||||
@ -1208,9 +1137,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(
|
||||
final if isinstance(final, dict) else {"output": final}
|
||||
)
|
||||
await run_manager.on_chain_end(final)
|
||||
|
||||
|
||||
class RunnableMapChunk(Dict[str, Any]):
|
||||
@ -1277,9 +1204,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
||||
|
||||
# gather results from all steps
|
||||
try:
|
||||
@ -1324,9 +1249,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
||||
|
||||
# gather results from all steps
|
||||
try:
|
||||
|
File diff suppressed because one or more lines are too long
@ -21,7 +21,12 @@ from langchain.prompts.chat import (
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.schema.runnable import (
|
||||
@ -794,7 +799,7 @@ def test_map_stream() -> None:
|
||||
assert streamed_chunks[0] in [
|
||||
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
||||
{"llm": "i"},
|
||||
{"chat": "i"},
|
||||
{"chat": AIMessageChunk(content="i")},
|
||||
]
|
||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||
@ -841,7 +846,11 @@ def test_map_stream_iterator_input() -> None:
|
||||
else:
|
||||
final_value += chunk
|
||||
|
||||
assert streamed_chunks[0] in [{"passthrough": "i"}, {"llm": "i"}, {"chat": "i"}]
|
||||
assert streamed_chunks[0] in [
|
||||
{"passthrough": "i"},
|
||||
{"llm": "i"},
|
||||
{"chat": AIMessageChunk(content="i")},
|
||||
]
|
||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
|
||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||
assert final_value is not None
|
||||
@ -885,7 +894,7 @@ async def test_map_astream() -> None:
|
||||
assert streamed_chunks[0] in [
|
||||
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
||||
{"llm": "i"},
|
||||
{"chat": "i"},
|
||||
{"chat": AIMessageChunk(content="i")},
|
||||
]
|
||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||
@ -933,7 +942,11 @@ async def test_map_astream_iterator_input() -> None:
|
||||
else:
|
||||
final_value += chunk
|
||||
|
||||
assert streamed_chunks[0] in [{"passthrough": "i"}, {"llm": "i"}, {"chat": "i"}]
|
||||
assert streamed_chunks[0] in [
|
||||
{"passthrough": "i"},
|
||||
{"llm": "i"},
|
||||
{"chat": AIMessageChunk(content="i")},
|
||||
]
|
||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
|
||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||
assert final_value is not None
|
||||
|
Loading…
Reference in New Issue
Block a user