mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-26 22:05:29 +00:00
core[patch]: Add missing cache for create_model (#26173)
It makes a big difference for performance.
This commit is contained in:
@@ -712,7 +712,7 @@ _SchemaConfig = ConfigDict(arbitrary_types_allowed=True, frozen=True)
|
||||
NO_DEFAULT = object()
|
||||
|
||||
|
||||
def create_base_class(
|
||||
def _create_root_model(
|
||||
name: str, type_: Any, default_: object = NO_DEFAULT
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a base class."""
|
||||
@@ -765,6 +765,15 @@ def create_base_class(
|
||||
return cast(Type[BaseModel], custom_root_type)
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def _create_root_model_cached(
|
||||
__model_name: str,
|
||||
type_: Any,
|
||||
default_: object = NO_DEFAULT,
|
||||
) -> Type[BaseModel]:
|
||||
return _create_root_model(__model_name, type_, default_)
|
||||
|
||||
|
||||
def create_model(
|
||||
__model_name: str,
|
||||
**field_definitions: Any,
|
||||
@@ -789,9 +798,15 @@ def create_model(
|
||||
|
||||
arg = field_definitions["__root__"]
|
||||
if isinstance(arg, tuple):
|
||||
named_root_model = create_base_class(__model_name, arg[0], arg[1])
|
||||
kwargs = {"type_": arg[0], "default_": arg[1]}
|
||||
else:
|
||||
named_root_model = create_base_class(__model_name, arg)
|
||||
kwargs = {"type_": arg}
|
||||
|
||||
try:
|
||||
named_root_model = _create_root_model_cached(__model_name, **kwargs)
|
||||
except TypeError:
|
||||
# something in the arguments into _create_root_model_cached is not hashable
|
||||
named_root_model = _create_root_model(__model_name, **kwargs)
|
||||
return named_root_model
|
||||
try:
|
||||
return _create_model_cached(__model_name, **field_definitions)
|
||||
|
@@ -0,0 +1,21 @@
|
||||
import time
|
||||
from itertools import cycle
|
||||
|
||||
from langchain_core.language_models import GenericFakeChatModel
|
||||
|
||||
|
||||
def test_benchmark_model() -> None:
|
||||
"""Add rate limiter."""
|
||||
tic = time.time()
|
||||
|
||||
model = GenericFakeChatModel(
|
||||
messages=cycle(["hello", "world", "!"]),
|
||||
)
|
||||
|
||||
for _ in range(1_000):
|
||||
model.invoke("foo")
|
||||
toc = time.time()
|
||||
|
||||
# Verify that the time taken to run the loop is less than 1 seconds
|
||||
|
||||
assert (toc - tic) < 1
|
Reference in New Issue
Block a user