mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 04:29:09 +00:00
Remove str() from RunnableConfigurableAlternatives (#11446)
This commit is contained in:
parent
656480feb6
commit
79011f835f
@ -10,6 +10,7 @@ from typing import (
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
@ -71,7 +72,7 @@ class BaseGenerationOutputParser(
|
||||
return Union[str, AnyMessage]
|
||||
|
||||
@property
|
||||
def OutputType(self) -> type[T]:
|
||||
def OutputType(self) -> Type[T]:
|
||||
# even though mypy complains this isn't valid,
|
||||
# it is good enough for pydantic to build the schema from
|
||||
return T # type: ignore[misc]
|
||||
@ -154,7 +155,7 @@ class BaseOutputParser(
|
||||
return Union[str, AnyMessage]
|
||||
|
||||
@property
|
||||
def OutputType(self) -> type[T]:
|
||||
def OutputType(self) -> Type[T]:
|
||||
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||
type_args = get_args(cls)
|
||||
if type_args and len(type_args) == 1:
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union
|
||||
|
||||
import yaml
|
||||
|
||||
@ -46,7 +46,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||
return Union[StringPromptValue, ChatPromptValueConcrete]
|
||||
|
||||
@property
|
||||
def input_schema(self) -> type[BaseModel]:
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"PromptInput",
|
||||
|
@ -1459,7 +1459,7 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
return Any
|
||||
|
||||
@property
|
||||
def input_schema(self) -> type[BaseModel]:
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
if all(
|
||||
s.input_schema.schema().get("type", "object") == "object"
|
||||
for s in self.steps.values()
|
||||
@ -1478,7 +1478,7 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
return super().input_schema
|
||||
|
||||
@property
|
||||
def output_schema(self) -> type[BaseModel]:
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableMapOutput",
|
||||
@ -2065,7 +2065,7 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
||||
return List[self.bound.InputType] # type: ignore[name-defined]
|
||||
|
||||
@property
|
||||
def input_schema(self) -> type[BaseModel]:
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
return create_model(
|
||||
"RunnableEachInput",
|
||||
__root__=(
|
||||
@ -2075,11 +2075,11 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
||||
)
|
||||
|
||||
@property
|
||||
def OutputType(self) -> type[List[Output]]:
|
||||
def OutputType(self) -> Type[List[Output]]:
|
||||
return List[self.bound.OutputType] # type: ignore[name-defined]
|
||||
|
||||
@property
|
||||
def output_schema(self) -> type[BaseModel]:
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return create_model(
|
||||
"RunnableEachOutput",
|
||||
__root__=(
|
||||
@ -2152,11 +2152,11 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def InputType(self) -> type[Input]:
|
||||
def InputType(self) -> Type[Input]:
|
||||
return self.bound.InputType
|
||||
|
||||
@property
|
||||
def OutputType(self) -> type[Output]:
|
||||
def OutputType(self) -> Type[Output]:
|
||||
return self.bound.OutputType
|
||||
|
||||
@property
|
||||
|
@ -274,7 +274,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Runnable[Input, Output]:
|
||||
config = config or {}
|
||||
which = str(config.get("configurable", {}).get(self.which.id, self.default_key))
|
||||
which = config.get("configurable", {}).get(self.which.id, self.default_key)
|
||||
if which == self.default_key:
|
||||
return self.default
|
||||
elif which in self.alternatives:
|
||||
|
@ -133,7 +133,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
@property
|
||||
def input_schema(self) -> type[BaseModel]:
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
map_input_schema = self.mapper.input_schema
|
||||
if not map_input_schema.__custom_root_type__:
|
||||
# ie. it's a dict
|
||||
@ -142,7 +142,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
return super().input_schema
|
||||
|
||||
@property
|
||||
def output_schema(self) -> type[BaseModel]:
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
map_input_schema = self.mapper.input_schema
|
||||
map_output_schema = self.mapper.output_schema
|
||||
if (
|
||||
|
Loading…
Reference in New Issue
Block a user