Enable run mutation (#31090)

This lets you more easily modify a run in-flight
This commit is contained in:
William FH 2025-05-01 17:00:51 -07:00 committed by GitHub
parent 0b79fc1733
commit 167afa5102
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 12 deletions

View File

@ -278,7 +278,12 @@ class _TracerCore(ABC):
def _complete_llm_run(self, response: LLMResult, run_id: UUID) -> Run: def _complete_llm_run(self, response: LLMResult, run_id: UUID) -> Run:
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"}) llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
llm_run.outputs = response.model_dump() if getattr(llm_run, "outputs", None) is None:
llm_run.outputs = {}
else:
llm_run.outputs = cast("dict[str, Any]", llm_run.outputs)
if not llm_run.extra.get("__omit_auto_outputs", False):
llm_run.outputs.update(response.model_dump())
for i, generations in enumerate(response.generations): for i, generations in enumerate(response.generations):
for j, generation in enumerate(generations): for j, generation in enumerate(generations):
output_generation = llm_run.outputs["generations"][i][j] output_generation = llm_run.outputs["generations"][i][j]
@ -297,7 +302,12 @@ class _TracerCore(ABC):
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"}) llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
llm_run.error = self._get_stacktrace(error) llm_run.error = self._get_stacktrace(error)
if response: if response:
llm_run.outputs = response.model_dump() if getattr(llm_run, "outputs", None) is None:
llm_run.outputs = {}
else:
llm_run.outputs = cast("dict[str, Any]", llm_run.outputs)
if not llm_run.extra.get("__omit_auto_outputs", False):
llm_run.outputs.update(response.model_dump())
for i, generations in enumerate(response.generations): for i, generations in enumerate(response.generations):
for j, generation in enumerate(generations): for j, generation in enumerate(generations):
output_generation = llm_run.outputs["generations"][i][j] output_generation = llm_run.outputs["generations"][i][j]
@ -370,7 +380,12 @@ class _TracerCore(ABC):
) -> Run: ) -> Run:
"""Update a chain run with outputs and end time.""" """Update a chain run with outputs and end time."""
chain_run = self._get_run(run_id) chain_run = self._get_run(run_id)
chain_run.outputs = self._get_chain_outputs(outputs) if getattr(chain_run, "outputs", None) is None:
chain_run.outputs = {}
if not chain_run.extra.get("__omit_auto_outputs", False):
cast("dict[str, Any]", chain_run.outputs).update(
self._get_chain_outputs(outputs)
)
chain_run.end_time = datetime.now(timezone.utc) chain_run.end_time = datetime.now(timezone.utc)
chain_run.events.append({"name": "end", "time": chain_run.end_time}) chain_run.events.append({"name": "end", "time": chain_run.end_time})
if inputs is not None: if inputs is not None:
@ -438,7 +453,10 @@ class _TracerCore(ABC):
) -> Run: ) -> Run:
"""Update a tool run with outputs and end time.""" """Update a tool run with outputs and end time."""
tool_run = self._get_run(run_id, run_type="tool") tool_run = self._get_run(run_id, run_type="tool")
tool_run.outputs = {"output": output} if getattr(tool_run, "outputs", None) is None:
tool_run.outputs = {}
if not tool_run.extra.get("__omit_auto_outputs", False):
cast("dict[str, Any]", tool_run.outputs).update({"output": output})
tool_run.end_time = datetime.now(timezone.utc) tool_run.end_time = datetime.now(timezone.utc)
tool_run.events.append({"name": "end", "time": tool_run.end_time}) tool_run.events.append({"name": "end", "time": tool_run.end_time})
return tool_run return tool_run
@ -491,7 +509,12 @@ class _TracerCore(ABC):
) -> Run: ) -> Run:
"""Update a retrieval run with outputs and end time.""" """Update a retrieval run with outputs and end time."""
retrieval_run = self._get_run(run_id, run_type="retriever") retrieval_run = self._get_run(run_id, run_type="retriever")
retrieval_run.outputs = {"documents": documents} if getattr(retrieval_run, "outputs", None) is None:
retrieval_run.outputs = {}
if not retrieval_run.extra.get("__omit_auto_outputs", False):
cast("dict[str, Any]", retrieval_run.outputs).update(
{"documents": documents}
)
retrieval_run.end_time = datetime.now(timezone.utc) retrieval_run.end_time = datetime.now(timezone.utc)
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time}) retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})
return retrieval_run return retrieval_run

View File

@ -1,10 +1,9 @@
version = 1 version = 1
requires-python = ">=3.9, <4.0" requires-python = ">=3.9"
resolution-markers = [ resolution-markers = [
"python_full_version >= '3.13'", "python_full_version >= '3.13'",
"python_full_version >= '3.12.4' and python_full_version < '3.13'", "python_full_version >= '3.12.4' and python_full_version < '3.13'",
"python_full_version >= '3.12' and python_full_version < '3.12.4'", "python_full_version >= '3.10' and python_full_version < '3.12.4'",
"python_full_version >= '3.10' and python_full_version < '3.12'",
"python_full_version < '3.10'", "python_full_version < '3.10'",
] ]
@ -587,8 +586,7 @@ source = { registry = "https://pypi.org/simple" }
resolution-markers = [ resolution-markers = [
"python_full_version >= '3.13'", "python_full_version >= '3.13'",
"python_full_version >= '3.12.4' and python_full_version < '3.13'", "python_full_version >= '3.12.4' and python_full_version < '3.13'",
"python_full_version >= '3.12' and python_full_version < '3.12.4'", "python_full_version >= '3.10' and python_full_version < '3.12.4'",
"python_full_version >= '3.10' and python_full_version < '3.12'",
] ]
dependencies = [ dependencies = [
{ name = "colorama", marker = "python_full_version >= '3.10' and sys_platform == 'win32'" }, { name = "colorama", marker = "python_full_version >= '3.10' and sys_platform == 'win32'" },
@ -1447,8 +1445,7 @@ source = { registry = "https://pypi.org/simple" }
resolution-markers = [ resolution-markers = [
"python_full_version >= '3.13'", "python_full_version >= '3.13'",
"python_full_version >= '3.12.4' and python_full_version < '3.13'", "python_full_version >= '3.12.4' and python_full_version < '3.13'",
"python_full_version >= '3.12' and python_full_version < '3.12.4'", "python_full_version >= '3.10' and python_full_version < '3.12.4'",
"python_full_version >= '3.10' and python_full_version < '3.12'",
] ]
sdist = { url = "https://files.pythonhosted.org/packages/e1/78/31103410a57bc2c2b93a3597340a8119588571f6a4539067546cb9a0bfac/numpy-2.2.4.tar.gz", hash = "sha256:9ba03692a45d3eef66559efe1d1096c4b9b75c0986b5dff5530c378fb8331d4f", size = 20270701 } sdist = { url = "https://files.pythonhosted.org/packages/e1/78/31103410a57bc2c2b93a3597340a8119588571f6a4539067546cb9a0bfac/numpy-2.2.4.tar.gz", hash = "sha256:9ba03692a45d3eef66559efe1d1096c4b9b75c0986b5dff5530c378fb8331d4f", size = 20270701 }
wheels = [ wheels = [