Nc/small fixes 21aug (#9542)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. These live is docs/extras
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17, @rlancemartin.
 -->
This commit is contained in:
Nuno Campos 2023-08-21 18:01:20 +01:00 committed by GitHub
parent a7eba8b006
commit 28e1ee4891
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 113 additions and 130 deletions

View File

@ -715,11 +715,11 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
"""Callback manager for chain run."""
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
"""Run when chain ends running.
Args:
outputs (Dict[str, Any]): The outputs of the chain.
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
"""
_handle_event(
self.handlers,
@ -797,11 +797,13 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
"""Async callback manager for chain run."""
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
async def on_chain_end(
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
) -> None:
"""Run when chain ends running.
Args:
outputs (Dict[str, Any]): The outputs of the chain.
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
"""
await _ahandle_event(
self.handlers,
@ -1144,7 +1146,7 @@ class CallbackManager(BaseCallbackManager):
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
inputs: Union[Dict[str, Any], Any],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForChainRun:
@ -1152,7 +1154,7 @@ class CallbackManager(BaseCallbackManager):
Args:
serialized (Dict[str, Any]): The serialized chain.
inputs (Dict[str, Any]): The inputs to the chain.
inputs (Union[Dict[str, Any], Any]): The inputs to the chain.
run_id (UUID, optional): The ID of the run. Defaults to None.
Returns:
@ -1433,7 +1435,7 @@ class AsyncCallbackManager(BaseCallbackManager):
async def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
inputs: Union[Dict[str, Any], Any],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForChainRun:
@ -1441,7 +1443,7 @@ class AsyncCallbackManager(BaseCallbackManager):
Args:
serialized (Dict[str, Any]): The serialized chain.
inputs (Dict[str, Any]): The inputs to the chain.
inputs (Union[Dict[str, Any], Any]): The inputs to the chain.
run_id (UUID, optional): The ID of the run. Defaults to None.
Returns:

View File

@ -231,6 +231,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
run_type: Optional[str] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Start a trace for a chain run."""
@ -243,7 +244,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
id=run_id,
parent_run_id=parent_run_id,
serialized=serialized,
inputs=inputs,
inputs=inputs if isinstance(inputs, dict) else {"input": inputs},
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
@ -251,6 +252,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
child_execution_order=execution_order,
child_runs=[],
run_type=run_type or "chain",
name=name,
tags=tags or [],
)
self._start_trace(chain_run)
@ -271,11 +273,13 @@ class BaseTracer(BaseCallbackHandler, ABC):
if chain_run is None:
raise TracerException(f"No chain Run found to be traced for {run_id}")
chain_run.outputs = outputs
chain_run.outputs = (
outputs if isinstance(outputs, dict) else {"output": outputs}
)
chain_run.end_time = datetime.utcnow()
chain_run.events.append({"name": "end", "time": chain_run.end_time})
if inputs is not None:
chain_run.inputs = inputs
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
self._end_trace(chain_run)
self._on_chain_end(chain_run)
@ -298,7 +302,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
chain_run.end_time = datetime.utcnow()
chain_run.events.append({"name": "error", "time": chain_run.end_time})
if inputs is not None:
chain_run.inputs = inputs
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
self._end_trace(chain_run)
self._on_chain_error(chain_run)

View File

@ -6,6 +6,7 @@ from pathlib import Path
from typing import (
Any,
Callable,
Dict,
List,
Sequence,
Set,
@ -298,6 +299,15 @@ class ChatPromptValue(PromptValue):
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
"""Base class for chat prompt templates."""
@property
def lc_attributes(self) -> Dict:
"""
Return a list of attribute names that should be included in the
serialized kwargs. These attributes must be accepted by the
constructor.
"""
return {"input_variables": self.input_variables}
def format(self, **kwargs: Any) -> str:
"""Format the chat template into a string.
@ -419,7 +429,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
f"Got: {values['input_variables']}"
)
else:
values["input_variables"] = list(input_vars)
values["input_variables"] = sorted(input_vars)
return values
@classmethod

View File

@ -266,7 +266,7 @@ class Runnable(Generic[Input, Output], ABC):
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input if isinstance(input, dict) else {"input": input},
input,
run_type=run_type,
)
try:
@ -284,12 +284,7 @@ class Runnable(Generic[Input, Output], ABC):
run_manager.on_chain_error(e)
raise
else:
output_for_tracer = dumpd(output)
run_manager.on_chain_end(
output_for_tracer
if isinstance(output_for_tracer, dict)
else {"output": output_for_tracer}
)
run_manager.on_chain_end(dumpd(output))
return output
async def _acall_with_config(
@ -318,7 +313,7 @@ class Runnable(Generic[Input, Output], ABC):
)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
input if isinstance(input, dict) else {"input": input},
input,
run_type=run_type,
)
try:
@ -339,12 +334,7 @@ class Runnable(Generic[Input, Output], ABC):
await run_manager.on_chain_error(e)
raise
else:
output_for_tracer = dumpd(output)
await run_manager.on_chain_end(
output_for_tracer
if isinstance(output_for_tracer, dict)
else {"output": output_for_tracer}
)
await run_manager.on_chain_end(dumpd(output))
return output
def _transform_stream_with_config(
@ -425,22 +415,10 @@ class Runnable(Generic[Input, Output], ABC):
final_input = None
final_input_supported = False
except Exception as e:
run_manager.on_chain_error(
e,
inputs=final_input
if isinstance(final_input, dict)
else {"input": final_input},
)
run_manager.on_chain_error(e, inputs=final_input)
raise
else:
run_manager.on_chain_end(
final_output
if isinstance(final_output, dict)
else {"output": final_output},
inputs=final_input
if isinstance(final_input, dict)
else {"input": final_input},
)
run_manager.on_chain_end(final_output, inputs=final_input)
async def _atransform_stream_with_config(
self,
@ -525,22 +503,10 @@ class Runnable(Generic[Input, Output], ABC):
final_input = None
final_input_supported = False
except Exception as e:
await run_manager.on_chain_error(
e,
inputs=final_input
if isinstance(final_input, dict)
else {"input": final_input},
)
await run_manager.on_chain_error(e, inputs=final_input)
raise
else:
await run_manager.on_chain_end(
final_output
if isinstance(final_output, dict)
else {"output": final_output},
inputs=final_input
if isinstance(final_input, dict)
else {"input": final_input},
)
await run_manager.on_chain_end(final_output, inputs=final_input)
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
@ -583,9 +549,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
run_manager = callback_manager.on_chain_start(dumpd(self), input)
first_error = None
for runnable in self.runnables:
try:
@ -600,9 +564,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
run_manager.on_chain_error(e)
raise e
else:
run_manager.on_chain_end(
output if isinstance(output, dict) else {"output": output}
)
run_manager.on_chain_end(output)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
@ -629,9 +591,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
first_error = None
for runnable in self.runnables:
@ -647,9 +607,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
await run_manager.on_chain_error(e)
raise e
else:
await run_manager.on_chain_end(
output if isinstance(output, dict) else {"output": output}
)
await run_manager.on_chain_end(output)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
@ -709,9 +667,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
raise e
else:
for rm, output in zip(run_managers, outputs):
rm.on_chain_end(
output if isinstance(output, dict) else {"output": output}
)
rm.on_chain_end(output)
return outputs
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
@ -749,9 +705,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
cm.on_chain_start(dumpd(self), input)
for cm, input in zip(callback_managers, inputs)
)
)
@ -776,9 +730,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
else:
await asyncio.gather(
*(
rm.on_chain_end(
output if isinstance(output, dict) else {"output": output}
)
rm.on_chain_end(output)
for rm, output in zip(run_managers, outputs)
)
)
@ -870,9 +822,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
run_manager = callback_manager.on_chain_start(dumpd(self), input)
# invoke all steps in sequence
try:
@ -887,9 +837,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(
input if isinstance(input, dict) else {"output": input}
)
run_manager.on_chain_end(input)
return cast(Output, input)
async def ainvoke(
@ -912,9 +860,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
# invoke all steps in sequence
try:
@ -929,9 +875,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(
input if isinstance(input, dict) else {"output": input}
)
await run_manager.on_chain_end(input)
return cast(Output, input)
def batch(
@ -960,9 +904,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
]
# start the root runs, one per input
run_managers = [
cm.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
cm.on_chain_start(dumpd(self), input)
for cm, input in zip(callback_managers, inputs)
]
@ -985,7 +927,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
raise
else:
for rm, input in zip(run_managers, inputs):
rm.on_chain_end(input if isinstance(input, dict) else {"output": input})
rm.on_chain_end(input)
return cast(List[Output], inputs)
async def abatch(
@ -1017,9 +959,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
cm.on_chain_start(dumpd(self), input)
for cm, input in zip(callback_managers, inputs)
)
)
@ -1043,12 +983,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
raise
else:
await asyncio.gather(
*(
rm.on_chain_end(
input if isinstance(input, dict) else {"output": input}
)
for rm, input in zip(run_managers, inputs)
)
*(rm.on_chain_end(input) for rm, input in zip(run_managers, inputs))
)
return cast(List[Output], inputs)
@ -1072,9 +1007,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
run_manager = callback_manager.on_chain_start(dumpd(self), input)
steps = [self.first] + self.middle + [self.last]
streaming_start_index = 0
@ -1128,9 +1061,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(
final if isinstance(final, dict) else {"output": final}
)
run_manager.on_chain_end(final)
async def astream(
self,
@ -1152,9 +1083,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
steps = [self.first] + self.middle + [self.last]
streaming_start_index = len(steps) - 1
@ -1208,9 +1137,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(
final if isinstance(final, dict) else {"output": final}
)
await run_manager.on_chain_end(final)
class RunnableMapChunk(Dict[str, Any]):
@ -1277,9 +1204,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
run_manager = callback_manager.on_chain_start(dumpd(self), input)
# gather results from all steps
try:
@ -1324,9 +1249,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
# gather results from all steps
try:

File diff suppressed because one or more lines are too long

View File

@ -21,7 +21,12 @@ from langchain.prompts.chat import (
SystemMessagePromptTemplate,
)
from langchain.schema.document import Document
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
HumanMessage,
SystemMessage,
)
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import (
@ -794,7 +799,7 @@ def test_map_stream() -> None:
assert streamed_chunks[0] in [
{"passthrough": prompt.invoke({"question": "What is your name?"})},
{"llm": "i"},
{"chat": "i"},
{"chat": AIMessageChunk(content="i")},
]
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
assert all(len(c.keys()) == 1 for c in streamed_chunks)
@ -841,7 +846,11 @@ def test_map_stream_iterator_input() -> None:
else:
final_value += chunk
assert streamed_chunks[0] in [{"passthrough": "i"}, {"llm": "i"}, {"chat": "i"}]
assert streamed_chunks[0] in [
{"passthrough": "i"},
{"llm": "i"},
{"chat": AIMessageChunk(content="i")},
]
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
assert all(len(c.keys()) == 1 for c in streamed_chunks)
assert final_value is not None
@ -885,7 +894,7 @@ async def test_map_astream() -> None:
assert streamed_chunks[0] in [
{"passthrough": prompt.invoke({"question": "What is your name?"})},
{"llm": "i"},
{"chat": "i"},
{"chat": AIMessageChunk(content="i")},
]
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
assert all(len(c.keys()) == 1 for c in streamed_chunks)
@ -933,7 +942,11 @@ async def test_map_astream_iterator_input() -> None:
else:
final_value += chunk
assert streamed_chunks[0] in [{"passthrough": "i"}, {"llm": "i"}, {"chat": "i"}]
assert streamed_chunks[0] in [
{"passthrough": "i"},
{"llm": "i"},
{"chat": AIMessageChunk(content="i")},
]
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
assert all(len(c.keys()) == 1 for c in streamed_chunks)
assert final_value is not None