WIP Runnable Chain

This commit is contained in:
Nuno Campos
2023-12-20 12:55:54 -08:00
parent 42822484ef
commit ad1ab2b566
2 changed files with 112 additions and 4 deletions

View File

@@ -6,7 +6,7 @@ import logging
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Type, Union
import yaml
from langchain_core.load.dump import dumpd
@@ -19,7 +19,8 @@ from langchain_core.pydantic_v1 import (
root_validator,
validator,
)
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
from langchain_core.runnables.configurable import ConfigurableFieldSpec
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
@@ -666,3 +667,103 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs, callbacks=callbacks) for inputs in input_list]
class RunnableChain(Chain, ABC):
@abstractmethod
def as_runnable(self) -> Runnable:
...
@property
def InputType(self) -> Type[Dict[str, Any]]:
return self.as_runnable().InputType
@property
def OutputType(self) -> Type[Dict[str, Any]]:
return self.as_runnable().OutputType
def get_input_schema(self, config: RunnableConfig | None = None) -> Type[BaseModel]:
return self.as_runnable().get_input_schema(config)
def get_output_schema(
self, config: RunnableConfig | None = None
) -> Type[BaseModel]:
return self.as_runnable().get_output_schema(config)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.as_runnable().config_specs
def invoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return self.as_runnable().invoke(input, config, **kwargs)
async def ainvoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return await self.as_runnable().ainvoke(input, config, **kwargs)
def stream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
yield from self.as_runnable().stream(input, config, **kwargs)
async def astream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
async for item in self.as_runnable().astream(input, config, **kwargs):
yield item
def batch(
self,
inputs: List[Dict[str, Any]],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[Dict[str, Any]]:
return self.as_runnable().batch(
inputs, config, **kwargs, return_exceptions=return_exceptions
)
async def abatch(
self,
inputs: List[Dict[str, Any]],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[Dict[str, Any]]:
return await self.as_runnable().abatch(
inputs, config, **kwargs, return_exceptions=return_exceptions
)
def transform(
self,
input: Iterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
) -> Iterator[Dict[str, Any]]:
yield from self.as_runnable().transform(input, config, **kwargs)
async def atransform(
self,
input: AsyncIterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
) -> AsyncIterator[Dict[str, Any]]:
async for chunk in super().atransform(input, config, **kwargs):
yield chunk

View File

@@ -31,10 +31,10 @@ from langchain.callbacks.manager import (
CallbackManagerForChainRun,
Callbacks,
)
from langchain.chains.base import Chain
from langchain.chains.base import RunnableChain
class LLMChain(Chain):
class LLMChain(RunnableChain):
"""Chain to run queries against LLMs.
Example:
@@ -76,6 +76,13 @@ class LLMChain(Chain):
extra = Extra.forbid
arbitrary_types_allowed = True
def as_runnable(self) -> Runnable:
return (
self.prompt
| self.llm.bind(**self.llm_kwargs)
| {self.output_key: self.output_parser}
)
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.