diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index f02a76eaa5a..ca3548a752d 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -6,7 +6,7 @@ import logging import warnings from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Type, Union import yaml from langchain_core.load.dump import dumpd @@ -19,7 +19,8 @@ from langchain_core.pydantic_v1 import ( root_validator, validator, ) -from langchain_core.runnables import RunnableConfig, RunnableSerializable +from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable +from langchain_core.runnables.configurable import ConfigurableFieldSpec from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import ( @@ -666,3 +667,103 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): ) -> List[Dict[str, str]]: """Call the chain on all inputs in the list.""" return [self(inputs, callbacks=callbacks) for inputs in input_list] + + +class RunnableChain(Chain, ABC): + @abstractmethod + def as_runnable(self) -> Runnable: + ... + + @property + def InputType(self) -> Type[Dict[str, Any]]: + return self.as_runnable().InputType + + @property + def OutputType(self) -> Type[Dict[str, Any]]: + return self.as_runnable().OutputType + + def get_input_schema(self, config: RunnableConfig | None = None) -> Type[BaseModel]: + return self.as_runnable().get_input_schema(config) + + def get_output_schema( + self, config: RunnableConfig | None = None + ) -> Type[BaseModel]: + return self.as_runnable().get_output_schema(config) + + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + return self.as_runnable().config_specs + + def invoke( + self, + input: Dict[str, Any], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + return self.as_runnable().invoke(input, config, **kwargs) + + async def ainvoke( + self, + input: Dict[str, Any], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + return await self.as_runnable().ainvoke(input, config, **kwargs) + + def stream( + self, + input: Dict[str, Any], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[Dict[str, Any]]: + yield from self.as_runnable().stream(input, config, **kwargs) + + async def astream( + self, + input: Dict[str, Any], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + async for item in self.as_runnable().astream(input, config, **kwargs): + yield item + + def batch( + self, + inputs: List[Dict[str, Any]], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Any, + ) -> List[Dict[str, Any]]: + return self.as_runnable().batch( + inputs, config, **kwargs, return_exceptions=return_exceptions + ) + + async def abatch( + self, + inputs: List[Dict[str, Any]], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Any, + ) -> List[Dict[str, Any]]: + return await self.as_runnable().abatch( + inputs, config, **kwargs, return_exceptions=return_exceptions + ) + + def transform( + self, + input: Iterator[Dict[str, Any]], + config: Optional[RunnableConfig] = None, + **kwargs: Any | None, + ) -> Iterator[Dict[str, Any]]: + yield from self.as_runnable().transform(input, config, **kwargs) + + async def atransform( + self, + input: AsyncIterator[Dict[str, Any]], + config: Optional[RunnableConfig] = None, + **kwargs: Any | None, + ) -> AsyncIterator[Dict[str, Any]]: + async for chunk in super().atransform(input, config, **kwargs): + yield chunk diff --git a/libs/langchain/langchain/chains/llm.py b/libs/langchain/langchain/chains/llm.py index b7023dbc6a1..3834c79c41d 100644 --- a/libs/langchain/langchain/chains/llm.py +++ b/libs/langchain/langchain/chains/llm.py @@ -31,10 +31,10 @@ from langchain.callbacks.manager import ( CallbackManagerForChainRun, Callbacks, ) -from langchain.chains.base import Chain +from langchain.chains.base import RunnableChain -class LLMChain(Chain): +class LLMChain(RunnableChain): """Chain to run queries against LLMs. Example: @@ -76,6 +76,13 @@ class LLMChain(Chain): extra = Extra.forbid arbitrary_types_allowed = True + def as_runnable(self) -> Runnable: + return ( + self.prompt + | self.llm.bind(**self.llm_kwargs) + | {self.output_key: self.output_parser} + ) + @property def input_keys(self) -> List[str]: """Will be whatever keys the prompt expects.