mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 05:43:55 +00:00
Remove str() from RunnableConfigurableAlternatives (#11446)
This commit is contained in:
parent
656480feb6
commit
79011f835f
@ -10,6 +10,7 @@ from typing import (
|
|||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
@ -71,7 +72,7 @@ class BaseGenerationOutputParser(
|
|||||||
return Union[str, AnyMessage]
|
return Union[str, AnyMessage]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def OutputType(self) -> type[T]:
|
def OutputType(self) -> Type[T]:
|
||||||
# even though mypy complains this isn't valid,
|
# even though mypy complains this isn't valid,
|
||||||
# it is good enough for pydantic to build the schema from
|
# it is good enough for pydantic to build the schema from
|
||||||
return T # type: ignore[misc]
|
return T # type: ignore[misc]
|
||||||
@ -154,7 +155,7 @@ class BaseOutputParser(
|
|||||||
return Union[str, AnyMessage]
|
return Union[str, AnyMessage]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def OutputType(self) -> type[T]:
|
def OutputType(self) -> Type[T]:
|
||||||
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||||
type_args = get_args(cls)
|
type_args = get_args(cls)
|
||||||
if type_args and len(type_args) == 1:
|
if type_args and len(type_args) == 1:
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
|||||||
return Union[StringPromptValue, ChatPromptValueConcrete]
|
return Union[StringPromptValue, ChatPromptValueConcrete]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_schema(self) -> type[BaseModel]:
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
# This is correct, but pydantic typings/mypy don't think so.
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
"PromptInput",
|
"PromptInput",
|
||||||
|
@ -1459,7 +1459,7 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
return Any
|
return Any
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_schema(self) -> type[BaseModel]:
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
if all(
|
if all(
|
||||||
s.input_schema.schema().get("type", "object") == "object"
|
s.input_schema.schema().get("type", "object") == "object"
|
||||||
for s in self.steps.values()
|
for s in self.steps.values()
|
||||||
@ -1478,7 +1478,7 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
return super().input_schema
|
return super().input_schema
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_schema(self) -> type[BaseModel]:
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
# This is correct, but pydantic typings/mypy don't think so.
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
"RunnableMapOutput",
|
"RunnableMapOutput",
|
||||||
@ -2065,7 +2065,7 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
return List[self.bound.InputType] # type: ignore[name-defined]
|
return List[self.bound.InputType] # type: ignore[name-defined]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_schema(self) -> type[BaseModel]:
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
return create_model(
|
return create_model(
|
||||||
"RunnableEachInput",
|
"RunnableEachInput",
|
||||||
__root__=(
|
__root__=(
|
||||||
@ -2075,11 +2075,11 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def OutputType(self) -> type[List[Output]]:
|
def OutputType(self) -> Type[List[Output]]:
|
||||||
return List[self.bound.OutputType] # type: ignore[name-defined]
|
return List[self.bound.OutputType] # type: ignore[name-defined]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_schema(self) -> type[BaseModel]:
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
return create_model(
|
return create_model(
|
||||||
"RunnableEachOutput",
|
"RunnableEachOutput",
|
||||||
__root__=(
|
__root__=(
|
||||||
@ -2152,11 +2152,11 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def InputType(self) -> type[Input]:
|
def InputType(self) -> Type[Input]:
|
||||||
return self.bound.InputType
|
return self.bound.InputType
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def OutputType(self) -> type[Output]:
|
def OutputType(self) -> Type[Output]:
|
||||||
return self.bound.OutputType
|
return self.bound.OutputType
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -274,7 +274,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
|||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Runnable[Input, Output]:
|
) -> Runnable[Input, Output]:
|
||||||
config = config or {}
|
config = config or {}
|
||||||
which = str(config.get("configurable", {}).get(self.which.id, self.default_key))
|
which = config.get("configurable", {}).get(self.which.id, self.default_key)
|
||||||
if which == self.default_key:
|
if which == self.default_key:
|
||||||
return self.default
|
return self.default
|
||||||
elif which in self.alternatives:
|
elif which in self.alternatives:
|
||||||
|
@ -133,7 +133,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
return cls.__module__.split(".")[:-1]
|
return cls.__module__.split(".")[:-1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_schema(self) -> type[BaseModel]:
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
map_input_schema = self.mapper.input_schema
|
map_input_schema = self.mapper.input_schema
|
||||||
if not map_input_schema.__custom_root_type__:
|
if not map_input_schema.__custom_root_type__:
|
||||||
# ie. it's a dict
|
# ie. it's a dict
|
||||||
@ -142,7 +142,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
return super().input_schema
|
return super().input_schema
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_schema(self) -> type[BaseModel]:
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
map_input_schema = self.mapper.input_schema
|
map_input_schema = self.mapper.input_schema
|
||||||
map_output_schema = self.mapper.output_schema
|
map_output_schema = self.mapper.output_schema
|
||||||
if (
|
if (
|
||||||
|
Loading…
Reference in New Issue
Block a user