mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
WIP Runnable Chain
This commit is contained in:
@@ -6,7 +6,7 @@ import logging
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Type, Union
|
||||
|
||||
import yaml
|
||||
from langchain_core.load.dump import dumpd
|
||||
@@ -19,7 +19,8 @@ from langchain_core.pydantic_v1 import (
|
||||
root_validator,
|
||||
validator,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables.configurable import ConfigurableFieldSpec
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
@@ -666,3 +667,103 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
|
||||
|
||||
class RunnableChain(Chain, ABC):
|
||||
@abstractmethod
|
||||
def as_runnable(self) -> Runnable:
|
||||
...
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Dict[str, Any]]:
|
||||
return self.as_runnable().InputType
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Dict[str, Any]]:
|
||||
return self.as_runnable().OutputType
|
||||
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> Type[BaseModel]:
|
||||
return self.as_runnable().get_input_schema(config)
|
||||
|
||||
def get_output_schema(
|
||||
self, config: RunnableConfig | None = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.as_runnable().get_output_schema(config)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return self.as_runnable().config_specs
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
return self.as_runnable().invoke(input, config, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
return await self.as_runnable().ainvoke(input, config, **kwargs)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
yield from self.as_runnable().stream(input, config, **kwargs)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
async for item in self.as_runnable().astream(input, config, **kwargs):
|
||||
yield item
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Dict[str, Any]],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return self.as_runnable().batch(
|
||||
inputs, config, **kwargs, return_exceptions=return_exceptions
|
||||
)
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Dict[str, Any]],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return await self.as_runnable().abatch(
|
||||
inputs, config, **kwargs, return_exceptions=return_exceptions
|
||||
)
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
yield from self.as_runnable().transform(input, config, **kwargs)
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any | None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
async for chunk in super().atransform(input, config, **kwargs):
|
||||
yield chunk
|
||||
|
||||
@@ -31,10 +31,10 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.base import RunnableChain
|
||||
|
||||
|
||||
class LLMChain(Chain):
|
||||
class LLMChain(RunnableChain):
|
||||
"""Chain to run queries against LLMs.
|
||||
|
||||
Example:
|
||||
@@ -76,6 +76,13 @@ class LLMChain(Chain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def as_runnable(self) -> Runnable:
|
||||
return (
|
||||
self.prompt
|
||||
| self.llm.bind(**self.llm_kwargs)
|
||||
| {self.output_key: self.output_parser}
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
Reference in New Issue
Block a user