Compare commits

...

1 Commits

Author SHA1 Message Date
Eugene Yurtsev
f637aeede2 q 2023-06-04 08:24:03 -04:00
6 changed files with 46 additions and 22 deletions

View File

@@ -2,12 +2,11 @@
import inspect
import json
import warnings
import yaml
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import yaml
from pydantic import BaseModel, Field, root_validator, validator
from typing import Any, List, Optional, Union, Generic, TypeVar, Dict
import langchain
from langchain.callbacks.base import BaseCallbackManager
@@ -25,7 +24,10 @@ def _get_verbosity() -> bool:
return langchain.verbose
class Chain(BaseModel, ABC):
T = TypeVar("T", bound=Dict)
class Chain(BaseModel, ABC, Generic[T]):
"""Base interface that all chains should implement."""
memory: Optional[BaseMemory] = None
@@ -92,7 +94,7 @@ class Chain(BaseModel, ABC):
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> T:
"""Run the logic of this chain and return the output."""
async def _acall(

View File

@@ -1,7 +1,8 @@
"""Chain that just formats a prompt and calls an LLM."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, TypeVar, TypedDict
from typing import Generic
from pydantic import Extra
@@ -20,7 +21,14 @@ from langchain.prompts.prompt import PromptTemplate
from langchain.schema import LLMResult, PromptValue
class LLMChain(Chain):
T = TypeVar("T", bound=Dict[str, Any])
class StandardChain(TypedDict):
text: str
class LLMChain(Chain[T]):
"""Chain to run queries against LLMs.
Example:
@@ -65,7 +73,7 @@ class LLMChain(Chain):
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> T:
response = self.generate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]

View File

@@ -3,6 +3,7 @@ from typing import Any, Mapping, Optional, Protocol
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
@@ -204,7 +205,7 @@ def load_qa_chain(
Returns:
A chain to use for question answering.
"""
loader_mapping: Mapping[str, LoadingCallable] = {
loader_mapping: Mapping[str, Chain] = {
"stuff": _load_stuff_chain,
"map_reduce": _load_map_reduce_chain,
"refine": _load_refine_chain,
@@ -215,6 +216,7 @@ def load_qa_chain(
f"Got unsupported chain type: {chain_type}. "
f"Should be one of {loader_mapping.keys()}"
)
return loader_mapping[chain_type](
chain = load_qa_chain(chain_type)
return chain(
llm, verbose=verbose, callback_manager=callback_manager, **kwargs
)

View File

@@ -2,9 +2,8 @@
from __future__ import annotations
from abc import ABC
from typing import Any, Dict, List, Mapping, NamedTuple, Optional
from pydantic import Extra
from typing import Any, Dict, List, Mapping, NamedTuple, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
@@ -14,27 +13,38 @@ from langchain.callbacks.manager import (
from langchain.chains.base import Chain
class Route(NamedTuple):
from typing import TypedDict
class RoutingOutput(TypedDict):
"""Output of a router chain."""
destination: Optional[str]
next_inputs: Dict[str, Any]
class RouterChain(Chain, ABC):
class RouterChain(Chain[RoutingOutput], ABC):
"""Chain that outputs the name of a destination chain and the inputs to it."""
@property
def output_keys(self) -> List[str]:
return ["destination", "next_inputs"]
def route(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Route:
def route(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> RoutingOutput:
result = self(inputs, callbacks=callbacks)
return Route(result["destination"], result["next_inputs"])
return RoutingOutput(
destination=result["destination"], next_inputs=result["next_inputs"]
)
async def aroute(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Route:
) -> RoutingOutput:
result = await self.acall(inputs, callbacks=callbacks)
return Route(result["destination"], result["next_inputs"])
return RoutingOutput(
destination=result["destination"], next_inputs=result["next_inputs"]
)
class MultiRouteChain(Chain):
@@ -86,7 +96,7 @@ class MultiRouteChain(Chain):
)
if not route.destination:
return self.default_chain(route.next_inputs, callbacks=callbacks)
elif route.destination in self.destination_chains:
elif route.destination in self.destination_chainssummary:
return self.destination_chains[route.destination](
route.next_inputs, callbacks=callbacks
)

View File

@@ -11,6 +11,7 @@ from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParse
from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE
from langchain.prompts import PromptTemplate
from langchain.chains.llm import StandardChain
class MultiPromptChain(MultiRouteChain):
"""A multi-route chain that uses an LLM router chain to choose amongst prompts."""

View File

@@ -1,11 +1,11 @@
from __future__ import annotations
from typing import Any, Dict, List, Type
from typing import Any, Dict, List, Type, TypedDict
from pydantic import BaseModel, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.chains.llm import LLMChain, StandardChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts.base import BasePromptTemplate
@@ -33,7 +33,8 @@ class SummarizerMixin(BaseModel):
ai_prefix=self.ai_prefix,
)
chain = LLMChain(llm=self.llm, prompt=self.prompt)
chain = LLMChain[StandardChain](llm=self.llm, prompt=self.prompt)
return chain.predict(summary=existing_summary, new_lines=new_lines)