Enable run mutation (#31090)

This lets you more easily modify a run in-flight
This commit is contained in:
William FH 2025-05-01 17:00:51 -07:00 committed by GitHub
parent 0b79fc1733
commit 167afa5102
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 12 deletions

View File

@ -278,7 +278,12 @@ class _TracerCore(ABC):
def _complete_llm_run(self, response: LLMResult, run_id: UUID) -> Run:
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
llm_run.outputs = response.model_dump()
if getattr(llm_run, "outputs", None) is None:
llm_run.outputs = {}
else:
llm_run.outputs = cast("dict[str, Any]", llm_run.outputs)
if not llm_run.extra.get("__omit_auto_outputs", False):
llm_run.outputs.update(response.model_dump())
for i, generations in enumerate(response.generations):
for j, generation in enumerate(generations):
output_generation = llm_run.outputs["generations"][i][j]
@ -297,7 +302,12 @@ class _TracerCore(ABC):
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
llm_run.error = self._get_stacktrace(error)
if response:
llm_run.outputs = response.model_dump()
if getattr(llm_run, "outputs", None) is None:
llm_run.outputs = {}
else:
llm_run.outputs = cast("dict[str, Any]", llm_run.outputs)
if not llm_run.extra.get("__omit_auto_outputs", False):
llm_run.outputs.update(response.model_dump())
for i, generations in enumerate(response.generations):
for j, generation in enumerate(generations):
output_generation = llm_run.outputs["generations"][i][j]
@ -370,7 +380,12 @@ class _TracerCore(ABC):
) -> Run:
"""Update a chain run with outputs and end time."""
chain_run = self._get_run(run_id)
chain_run.outputs = self._get_chain_outputs(outputs)
if getattr(chain_run, "outputs", None) is None:
chain_run.outputs = {}
if not chain_run.extra.get("__omit_auto_outputs", False):
cast("dict[str, Any]", chain_run.outputs).update(
self._get_chain_outputs(outputs)
)
chain_run.end_time = datetime.now(timezone.utc)
chain_run.events.append({"name": "end", "time": chain_run.end_time})
if inputs is not None:
@ -438,7 +453,10 @@ class _TracerCore(ABC):
) -> Run:
"""Update a tool run with outputs and end time."""
tool_run = self._get_run(run_id, run_type="tool")
tool_run.outputs = {"output": output}
if getattr(tool_run, "outputs", None) is None:
tool_run.outputs = {}
if not tool_run.extra.get("__omit_auto_outputs", False):
cast("dict[str, Any]", tool_run.outputs).update({"output": output})
tool_run.end_time = datetime.now(timezone.utc)
tool_run.events.append({"name": "end", "time": tool_run.end_time})
return tool_run
@ -491,7 +509,12 @@ class _TracerCore(ABC):
) -> Run:
"""Update a retrieval run with outputs and end time."""
retrieval_run = self._get_run(run_id, run_type="retriever")
retrieval_run.outputs = {"documents": documents}
if getattr(retrieval_run, "outputs", None) is None:
retrieval_run.outputs = {}
if not retrieval_run.extra.get("__omit_auto_outputs", False):
cast("dict[str, Any]", retrieval_run.outputs).update(
{"documents": documents}
)
retrieval_run.end_time = datetime.now(timezone.utc)
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})
return retrieval_run

View File

@ -1,10 +1,9 @@
version = 1
requires-python = ">=3.9, <4.0"
requires-python = ">=3.9"
resolution-markers = [
"python_full_version >= '3.13'",
"python_full_version >= '3.12.4' and python_full_version < '3.13'",
"python_full_version >= '3.12' and python_full_version < '3.12.4'",
"python_full_version >= '3.10' and python_full_version < '3.12'",
"python_full_version >= '3.10' and python_full_version < '3.12.4'",
"python_full_version < '3.10'",
]
@ -587,8 +586,7 @@ source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version >= '3.13'",
"python_full_version >= '3.12.4' and python_full_version < '3.13'",
"python_full_version >= '3.12' and python_full_version < '3.12.4'",
"python_full_version >= '3.10' and python_full_version < '3.12'",
"python_full_version >= '3.10' and python_full_version < '3.12.4'",
]
dependencies = [
{ name = "colorama", marker = "python_full_version >= '3.10' and sys_platform == 'win32'" },
@ -1447,8 +1445,7 @@ source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version >= '3.13'",
"python_full_version >= '3.12.4' and python_full_version < '3.13'",
"python_full_version >= '3.12' and python_full_version < '3.12.4'",
"python_full_version >= '3.10' and python_full_version < '3.12'",
"python_full_version >= '3.10' and python_full_version < '3.12.4'",
]
sdist = { url = "https://files.pythonhosted.org/packages/e1/78/31103410a57bc2c2b93a3597340a8119588571f6a4539067546cb9a0bfac/numpy-2.2.4.tar.gz", hash = "sha256:9ba03692a45d3eef66559efe1d1096c4b9b75c0986b5dff5530c378fb8331d4f", size = 20270701 }
wheels = [