mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-28 02:29:17 +00:00
add cm
This commit is contained in:
parent
f9a845b382
commit
4d7cd6db5f
@ -8,7 +8,6 @@ from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
import yaml
|
||||
|
||||
@ -69,7 +68,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_id=config.get("run_id"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
@ -92,7 +90,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_id=config.get("run_id"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
@ -240,7 +237,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
run_name: Optional[str] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
@ -283,7 +279,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
run_id=run_id,
|
||||
name=run_name,
|
||||
)
|
||||
try:
|
||||
@ -311,7 +306,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
run_name: Optional[str] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
@ -354,7 +348,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
run_id=run_id,
|
||||
name=run_name,
|
||||
)
|
||||
try:
|
||||
|
@ -4,7 +4,6 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
@ -165,7 +164,6 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Retrieve documents relevant to a query.
|
||||
@ -195,7 +193,6 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
run_manager = callback_manager.on_retriever_start(
|
||||
dumpd(self),
|
||||
query,
|
||||
run_id=run_id,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
@ -223,7 +220,6 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
@ -253,7 +249,6 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
run_manager = await callback_manager.on_retriever_start(
|
||||
dumpd(self),
|
||||
query,
|
||||
run_id=run_id,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
|
@ -266,7 +266,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
dumpd(self),
|
||||
input,
|
||||
run_type=run_type,
|
||||
run_id=config.get("run_id"),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
@ -309,7 +308,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
dumpd(self),
|
||||
input,
|
||||
run_type=run_type,
|
||||
run_id=config.get("run_id"),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
@ -368,7 +366,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
dumpd(self),
|
||||
{"input": ""},
|
||||
run_type=run_type,
|
||||
run_id=config.get("run_id"),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
@ -450,7 +447,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
dumpd(self),
|
||||
{"input": ""},
|
||||
run_type=run_type,
|
||||
run_id=config.get("run_id"),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
@ -528,7 +524,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
)
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
@ -562,7 +558,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
)
|
||||
|
||||
first_error = None
|
||||
@ -613,7 +609,6 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
cm.on_chain_start(
|
||||
dumpd(self),
|
||||
input if isinstance(input, dict) else {"input": input},
|
||||
run_id=config.get("run_id"),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
@ -675,7 +670,6 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
cm.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
run_id=config.get("run_id"),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
@ -784,7 +778,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
)
|
||||
|
||||
# invoke all steps in sequence
|
||||
@ -814,7 +808,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
)
|
||||
|
||||
# invoke all steps in sequence
|
||||
@ -860,7 +854,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
cm.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
run_id=config.get("run_id"),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
@ -917,7 +910,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
cm.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
run_id=config.get("run_id"),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
@ -957,7 +949,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
)
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
@ -1026,7 +1018,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
)
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
@ -1159,7 +1151,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
)
|
||||
|
||||
# gather results from all steps
|
||||
@ -1200,7 +1192,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
)
|
||||
|
||||
# gather results from all steps
|
||||
|
@ -4,7 +4,6 @@ from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@ -39,11 +38,6 @@ class RunnableConfig(TypedDict, total=False):
|
||||
Name for the tracer run for this call. Defaults to the name of the class.
|
||||
"""
|
||||
|
||||
run_id: UUID
|
||||
"""
|
||||
Unique ID for the tracer run for this call. Defaults to uuid4().
|
||||
"""
|
||||
|
||||
_locals: Dict[str, Any]
|
||||
"""
|
||||
Local variables
|
||||
|
@ -8,7 +8,6 @@ from abc import abstractmethod
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
@ -298,7 +297,6 @@ class ChildTool(BaseTool):
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool."""
|
||||
@ -322,7 +320,6 @@ class ChildTool(BaseTool):
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
color=start_color,
|
||||
run_id=run_id,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
@ -373,7 +370,6 @@ class ChildTool(BaseTool):
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool asynchronously."""
|
||||
@ -396,7 +392,6 @@ class ChildTool(BaseTool):
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
color=start_color,
|
||||
run_id=run_id,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user