diff --git a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py index 15f561127c1..a6cb1c2fee7 100644 --- a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py +++ b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py @@ -211,9 +211,14 @@ class ShellSession: with self._lock: self._drain_queue() payload = command if command.endswith("\n") else f"{command}\n" - self._stdin.write(payload) - self._stdin.write(f"printf '{marker} %s\\n' $?\n") - self._stdin.flush() + try: + self._stdin.write(payload) + self._stdin.write(f"printf '{marker} %s\\n' $?\n") + self._stdin.flush() + except (BrokenPipeError, OSError): + # The shell exited before we could write the marker command. + # This happens when commands like 'exit 1' terminate the shell. + return self._collect_output_after_exit(deadline) return self._collect_output(marker, deadline, timeout) @@ -304,6 +309,80 @@ class ShellSession: total_bytes=total_bytes, ) + def _collect_output_after_exit(self, deadline: float) -> CommandExecutionResult: + """Collect output after the shell exited unexpectedly. + + Called when a `BrokenPipeError` occurs while writing to stdin, indicating the + shell process terminated (e.g., due to an 'exit' command). + + Args: + deadline: Absolute time by which collection must complete. + + Returns: + `CommandExecutionResult` with collected output and the process exit code. + """ + collected: list[str] = [] + total_lines = 0 + total_bytes = 0 + truncated_by_lines = False + truncated_by_bytes = False + + # Give reader threads a brief moment to enqueue any remaining output. + drain_timeout = 0.1 + drain_deadline = min(time.monotonic() + drain_timeout, deadline) + + while True: + remaining = drain_deadline - time.monotonic() + if remaining <= 0: + break + try: + source, data = self._queue.get(timeout=remaining) + except queue.Empty: + break + + if data is None: + # EOF marker from a reader thread; continue draining. + continue + + total_lines += 1 + encoded = data.encode("utf-8", "replace") + total_bytes += len(encoded) + + if total_lines > self._policy.max_output_lines: + truncated_by_lines = True + continue + + if ( + self._policy.max_output_bytes is not None + and total_bytes > self._policy.max_output_bytes + ): + truncated_by_bytes = True + continue + + if source == "stderr": + stripped = data.rstrip("\n") + collected.append(f"[stderr] {stripped}") + if data.endswith("\n"): + collected.append("\n") + else: + collected.append(data) + + # Get exit code from the terminated process. + exit_code: int | None = None + if self._process: + exit_code = self._process.poll() + + output = "".join(collected) + return CommandExecutionResult( + output=output, + exit_code=exit_code, + timed_out=False, + truncated_by_lines=truncated_by_lines, + truncated_by_bytes=truncated_by_bytes, + total_lines=total_lines, + total_bytes=total_bytes, + ) + def _kill_process(self) -> None: if not self._process: return