mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-27 22:37:46 +00:00
core[patch]: Add missing cache for create_model (#26173)
It makes a big difference for performance.
This commit is contained in:
@@ -712,7 +712,7 @@ _SchemaConfig = ConfigDict(arbitrary_types_allowed=True, frozen=True)
|
|||||||
NO_DEFAULT = object()
|
NO_DEFAULT = object()
|
||||||
|
|
||||||
|
|
||||||
def create_base_class(
|
def _create_root_model(
|
||||||
name: str, type_: Any, default_: object = NO_DEFAULT
|
name: str, type_: Any, default_: object = NO_DEFAULT
|
||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
"""Create a base class."""
|
"""Create a base class."""
|
||||||
@@ -765,6 +765,15 @@ def create_base_class(
|
|||||||
return cast(Type[BaseModel], custom_root_type)
|
return cast(Type[BaseModel], custom_root_type)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=256)
|
||||||
|
def _create_root_model_cached(
|
||||||
|
__model_name: str,
|
||||||
|
type_: Any,
|
||||||
|
default_: object = NO_DEFAULT,
|
||||||
|
) -> Type[BaseModel]:
|
||||||
|
return _create_root_model(__model_name, type_, default_)
|
||||||
|
|
||||||
|
|
||||||
def create_model(
|
def create_model(
|
||||||
__model_name: str,
|
__model_name: str,
|
||||||
**field_definitions: Any,
|
**field_definitions: Any,
|
||||||
@@ -789,9 +798,15 @@ def create_model(
|
|||||||
|
|
||||||
arg = field_definitions["__root__"]
|
arg = field_definitions["__root__"]
|
||||||
if isinstance(arg, tuple):
|
if isinstance(arg, tuple):
|
||||||
named_root_model = create_base_class(__model_name, arg[0], arg[1])
|
kwargs = {"type_": arg[0], "default_": arg[1]}
|
||||||
else:
|
else:
|
||||||
named_root_model = create_base_class(__model_name, arg)
|
kwargs = {"type_": arg}
|
||||||
|
|
||||||
|
try:
|
||||||
|
named_root_model = _create_root_model_cached(__model_name, **kwargs)
|
||||||
|
except TypeError:
|
||||||
|
# something in the arguments into _create_root_model_cached is not hashable
|
||||||
|
named_root_model = _create_root_model(__model_name, **kwargs)
|
||||||
return named_root_model
|
return named_root_model
|
||||||
try:
|
try:
|
||||||
return _create_model_cached(__model_name, **field_definitions)
|
return _create_model_cached(__model_name, **field_definitions)
|
||||||
|
@@ -0,0 +1,21 @@
|
|||||||
|
import time
|
||||||
|
from itertools import cycle
|
||||||
|
|
||||||
|
from langchain_core.language_models import GenericFakeChatModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_benchmark_model() -> None:
|
||||||
|
"""Add rate limiter."""
|
||||||
|
tic = time.time()
|
||||||
|
|
||||||
|
model = GenericFakeChatModel(
|
||||||
|
messages=cycle(["hello", "world", "!"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(1_000):
|
||||||
|
model.invoke("foo")
|
||||||
|
toc = time.time()
|
||||||
|
|
||||||
|
# Verify that the time taken to run the loop is less than 1 seconds
|
||||||
|
|
||||||
|
assert (toc - tic) < 1
|
Reference in New Issue
Block a user