fix(langchain): resolve race condition in ShellSession.execute() (#34535)

Addresses a flaky test

When executing `exit 1` as a startup command, the shell process
terminates immediately. The code then tries to write a marker command
(`printf '...'`) to stdin, but the pipe is already broken because the
shell has exited, causing `BrokenPipeError`.
This commit is contained in:
Mason Daugherty
2025-12-29 18:16:08 -06:00
committed by GitHub
parent dcfd9c0e04
commit 721bf15430

View File

@@ -211,9 +211,14 @@ class ShellSession:
with self._lock:
self._drain_queue()
payload = command if command.endswith("\n") else f"{command}\n"
self._stdin.write(payload)
self._stdin.write(f"printf '{marker} %s\\n' $?\n")
self._stdin.flush()
try:
self._stdin.write(payload)
self._stdin.write(f"printf '{marker} %s\\n' $?\n")
self._stdin.flush()
except (BrokenPipeError, OSError):
# The shell exited before we could write the marker command.
# This happens when commands like 'exit 1' terminate the shell.
return self._collect_output_after_exit(deadline)
return self._collect_output(marker, deadline, timeout)
@@ -304,6 +309,80 @@ class ShellSession:
total_bytes=total_bytes,
)
def _collect_output_after_exit(self, deadline: float) -> CommandExecutionResult:
"""Collect output after the shell exited unexpectedly.
Called when a `BrokenPipeError` occurs while writing to stdin, indicating the
shell process terminated (e.g., due to an 'exit' command).
Args:
deadline: Absolute time by which collection must complete.
Returns:
`CommandExecutionResult` with collected output and the process exit code.
"""
collected: list[str] = []
total_lines = 0
total_bytes = 0
truncated_by_lines = False
truncated_by_bytes = False
# Give reader threads a brief moment to enqueue any remaining output.
drain_timeout = 0.1
drain_deadline = min(time.monotonic() + drain_timeout, deadline)
while True:
remaining = drain_deadline - time.monotonic()
if remaining <= 0:
break
try:
source, data = self._queue.get(timeout=remaining)
except queue.Empty:
break
if data is None:
# EOF marker from a reader thread; continue draining.
continue
total_lines += 1
encoded = data.encode("utf-8", "replace")
total_bytes += len(encoded)
if total_lines > self._policy.max_output_lines:
truncated_by_lines = True
continue
if (
self._policy.max_output_bytes is not None
and total_bytes > self._policy.max_output_bytes
):
truncated_by_bytes = True
continue
if source == "stderr":
stripped = data.rstrip("\n")
collected.append(f"[stderr] {stripped}")
if data.endswith("\n"):
collected.append("\n")
else:
collected.append(data)
# Get exit code from the terminated process.
exit_code: int | None = None
if self._process:
exit_code = self._process.poll()
output = "".join(collected)
return CommandExecutionResult(
output=output,
exit_code=exit_code,
timed_out=False,
truncated_by_lines=truncated_by_lines,
truncated_by_bytes=truncated_by_bytes,
total_lines=total_lines,
total_bytes=total_bytes,
)
def _kill_process(self) -> None:
if not self._process:
return