Compare commits

...

8 Commits

Author SHA1 Message Date
Dev 2049
2b854914a8 fmt 2023-05-01 22:06:14 -07:00
Dev 2049
d1122b103c wip 2023-05-01 22:06:04 -07:00
Dev 2049
e84873c6ef Merge branch 'master' into fork-chains 2023-05-01 21:31:01 -07:00
Shreya Rajpal
89f27dc1e9 Fix flake8 issues 2022-12-28 20:28:52 +05:30
Shreya Rajpal
7081a0c1d9 Fix formatting issues 2022-12-28 20:18:43 +05:30
Shreya Rajpal
a2c5680d9e Fix lint issues 2022-12-28 20:05:51 +05:30
Shreya Rajpal
bb5427d2cd made output conform 2022-12-23 21:05:59 +05:30
Shreya Rajpal
3fcbae90ed Add support for fork chains 2022-12-23 20:13:32 +05:30
2 changed files with 60 additions and 5 deletions

View File

@@ -83,11 +83,9 @@ class Chain(BaseModel, ABC):
raise ValueError(f"Missing some input keys: {missing_keys}")
def _validate_outputs(self, outputs: Dict[str, str]) -> None:
if set(outputs) != set(self.output_keys):
raise ValueError(
f"Did not get output keys that were expected. "
f"Got: {set(outputs)}. Expected: {set(self.output_keys)}."
)
missing_keys = set(self.output_keys).difference(outputs)
if missing_keys:
raise ValueError(f"Missing some output keys: {missing_keys}")
@abstractmethod
def _call(

View File

@@ -0,0 +1,57 @@
"""Use a single chain to route an input to one of multiple candidate chains."""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, validator
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
class RouterChain(Chain):
"""Use a single chain to route an input to one of multiple candidate chains."""
router_chain: Chain
"""Chain that routes inputs to destination chains."""
destination_chains: Dict[str, Chain]
"""Chains that return final answer to inputs."""
default_chain: Chain
"""Default chain to use when routing fails."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return self.router_chain.input_keys
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return []
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
router_output = self.router_chain(**inputs, callbacks=callbacks)
destination = router_output["destination"]
next_inputs = router_output["next_inputs"]
if destination in self.destination_chains:
return self.destination_chains[destination](
next_inputs, callbacks=callbacks
)
else:
return self.default_chain(next_inputs, callbacks=callbacks)