mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-10 15:06:18 +00:00
Accept run name arg for non-chain runs (#10935)
This commit is contained in:
parent
aac2d4dcef
commit
3d5e92e3ef
@ -102,6 +102,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for an LLM run."""
|
||||
@ -122,6 +123,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
child_execution_order=execution_order,
|
||||
run_type="llm",
|
||||
tags=tags or [],
|
||||
name=name,
|
||||
)
|
||||
self._start_trace(llm_run)
|
||||
self._on_llm_start(llm_run)
|
||||
@ -335,6 +337,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for a tool run."""
|
||||
@ -356,6 +359,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
child_runs=[],
|
||||
run_type="tool",
|
||||
tags=tags or [],
|
||||
name=name,
|
||||
)
|
||||
self._start_trace(tool_run)
|
||||
self._on_tool_start(tool_run)
|
||||
@ -406,6 +410,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Run when Retriever starts running."""
|
||||
@ -416,7 +421,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
kwargs.update({"metadata": metadata})
|
||||
retrieval_run = Run(
|
||||
id=run_id,
|
||||
name="Retriever",
|
||||
name=name or "Retriever",
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs={"query": query},
|
||||
|
@ -98,6 +98,7 @@ class LangChainTracer(BaseTracer):
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a trace for an LLM run."""
|
||||
@ -118,6 +119,7 @@ class LangChainTracer(BaseTracer):
|
||||
child_execution_order=execution_order,
|
||||
run_type="llm",
|
||||
tags=tags,
|
||||
name=name,
|
||||
)
|
||||
self._start_trace(chat_model_run)
|
||||
self._on_chat_model_start(chat_model_run)
|
||||
|
@ -139,6 +139,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
).generations[0][0],
|
||||
).message,
|
||||
@ -165,6 +166,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
return cast(
|
||||
@ -197,7 +199,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = callback_manager.on_chat_model_start(
|
||||
dumpd(self), [messages], invocation_params=params, options=options
|
||||
dumpd(self),
|
||||
[messages],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
@ -244,7 +250,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||
dumpd(self), [messages], invocation_params=params, options=options
|
||||
dumpd(self),
|
||||
[messages],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
@ -298,6 +308,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
@ -314,7 +325,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
self.metadata,
|
||||
)
|
||||
run_managers = callback_manager.on_chat_model_start(
|
||||
dumpd(self), messages, invocation_params=params, options=options
|
||||
dumpd(self),
|
||||
messages,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name,
|
||||
)
|
||||
results = []
|
||||
for i, m in enumerate(messages):
|
||||
@ -354,6 +369,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
@ -371,7 +387,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
)
|
||||
|
||||
run_managers = await callback_manager.on_chat_model_start(
|
||||
dumpd(self), messages, invocation_params=params, options=options
|
||||
dumpd(self),
|
||||
messages,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name,
|
||||
)
|
||||
|
||||
results = await asyncio.gather(
|
||||
|
@ -228,6 +228,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
.generations[0][0]
|
||||
@ -255,6 +256,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
return llm_result.generations[0][0].text
|
||||
@ -280,6 +282,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
callbacks=[c.get("callbacks") for c in config],
|
||||
tags=[c.get("tags") for c in config],
|
||||
metadata=[c.get("metadata") for c in config],
|
||||
run_name=[c.get("run_name") for c in config],
|
||||
**kwargs,
|
||||
)
|
||||
return [g[0].text for g in llm_result.generations]
|
||||
@ -328,6 +331,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
callbacks=[c.get("callbacks") for c in config],
|
||||
tags=[c.get("tags") for c in config],
|
||||
metadata=[c.get("metadata") for c in config],
|
||||
run_name=[c.get("run_name") for c in config],
|
||||
**kwargs,
|
||||
)
|
||||
return [g[0].text for g in llm_result.generations]
|
||||
@ -375,7 +379,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = callback_manager.on_llm_start(
|
||||
dumpd(self), [prompt], invocation_params=params, options=options
|
||||
dumpd(self),
|
||||
[prompt],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
generation: Optional[GenerationChunk] = None
|
||||
@ -422,7 +430,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_llm_start(
|
||||
dumpd(self), [prompt], invocation_params=params, options=options
|
||||
dumpd(self),
|
||||
[prompt],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
generation: Optional[GenerationChunk] = None
|
||||
@ -544,6 +556,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
*,
|
||||
tags: Optional[Union[List[str], List[List[str]]]] = None,
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
run_name: Optional[Union[str, List[str]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
@ -569,11 +582,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
assert metadata is None or (
|
||||
isinstance(metadata, list) and len(metadata) == len(prompts)
|
||||
)
|
||||
assert run_name is None or (
|
||||
isinstance(run_name, list) and len(run_name) == len(prompts)
|
||||
)
|
||||
callbacks = cast(List[Callbacks], callbacks)
|
||||
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
|
||||
metadata_list = cast(
|
||||
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
|
||||
)
|
||||
run_name_list = run_name or cast(
|
||||
List[Optional[str]], ([None] * len(prompts))
|
||||
)
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
callback,
|
||||
@ -599,6 +618,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
self.metadata,
|
||||
)
|
||||
] * len(prompts)
|
||||
run_name_list = [cast(Optional[str], run_name)] * len(prompts)
|
||||
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
@ -620,9 +640,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
)
|
||||
run_managers = [
|
||||
callback_manager.on_llm_start(
|
||||
dumpd(self), [prompt], invocation_params=params, options=options
|
||||
dumpd(self),
|
||||
[prompt],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name,
|
||||
)[0]
|
||||
for callback_manager, prompt in zip(callback_managers, prompts)
|
||||
for callback_manager, prompt, run_name in zip(
|
||||
callback_managers, prompts, run_name_list
|
||||
)
|
||||
]
|
||||
output = self._generate_helper(
|
||||
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
||||
@ -635,6 +661,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
[prompts[idx]],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name_list[idx],
|
||||
)[0]
|
||||
for idx in missing_prompt_idxs
|
||||
]
|
||||
@ -702,6 +729,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
*,
|
||||
tags: Optional[Union[List[str], List[List[str]]]] = None,
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
run_name: Optional[Union[str, List[str]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
@ -718,11 +746,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
assert metadata is None or (
|
||||
isinstance(metadata, list) and len(metadata) == len(prompts)
|
||||
)
|
||||
assert run_name is None or (
|
||||
isinstance(run_name, list) and len(run_name) == len(prompts)
|
||||
)
|
||||
callbacks = cast(List[Callbacks], callbacks)
|
||||
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
|
||||
metadata_list = cast(
|
||||
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
|
||||
)
|
||||
run_name_list = run_name or cast(
|
||||
List[Optional[str]], ([None] * len(prompts))
|
||||
)
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
callback,
|
||||
@ -748,6 +782,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
self.metadata,
|
||||
)
|
||||
] * len(prompts)
|
||||
run_name_list = [cast(Optional[str], run_name)] * len(prompts)
|
||||
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
@ -770,9 +805,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
run_managers = await asyncio.gather(
|
||||
*[
|
||||
callback_manager.on_llm_start(
|
||||
dumpd(self), [prompt], invocation_params=params, options=options
|
||||
dumpd(self),
|
||||
[prompt],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name,
|
||||
)
|
||||
for callback_manager, prompt, run_name in zip(
|
||||
callback_managers, prompts, run_name_list
|
||||
)
|
||||
for callback_manager, prompt in zip(callback_managers, prompts)
|
||||
]
|
||||
)
|
||||
run_managers = [r[0] for r in run_managers]
|
||||
@ -788,6 +829,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
[prompts[idx]],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name_list[idx],
|
||||
)
|
||||
for idx in missing_prompt_idxs
|
||||
]
|
||||
|
@ -113,6 +113,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
@ -131,6 +132,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@ -164,6 +166,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Retrieve documents relevant to a query.
|
||||
@ -193,6 +196,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
run_manager = callback_manager.on_retriever_start(
|
||||
dumpd(self),
|
||||
query,
|
||||
name=run_name,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
@ -220,6 +224,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
@ -249,6 +254,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
run_manager = await callback_manager.on_retriever_start(
|
||||
dumpd(self),
|
||||
query,
|
||||
name=run_name,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
|
@ -199,6 +199,7 @@ class ChildTool(BaseTool):
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -218,6 +219,7 @@ class ChildTool(BaseTool):
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -297,6 +299,7 @@ class ChildTool(BaseTool):
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool."""
|
||||
@ -320,6 +323,7 @@ class ChildTool(BaseTool):
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
color=start_color,
|
||||
name=run_name,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
@ -370,6 +374,7 @@ class ChildTool(BaseTool):
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool asynchronously."""
|
||||
@ -392,6 +397,7 @@ class ChildTool(BaseTool):
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
color=start_color,
|
||||
name=run_name,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user