diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index e46a45103da..e1316f0dd08 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -712,7 +712,7 @@ _SchemaConfig = ConfigDict(arbitrary_types_allowed=True, frozen=True) NO_DEFAULT = object() -def create_base_class( +def _create_root_model( name: str, type_: Any, default_: object = NO_DEFAULT ) -> Type[BaseModel]: """Create a base class.""" @@ -765,6 +765,15 @@ def create_base_class( return cast(Type[BaseModel], custom_root_type) +@lru_cache(maxsize=256) +def _create_root_model_cached( + __model_name: str, + type_: Any, + default_: object = NO_DEFAULT, +) -> Type[BaseModel]: + return _create_root_model(__model_name, type_, default_) + + def create_model( __model_name: str, **field_definitions: Any, @@ -789,9 +798,15 @@ def create_model( arg = field_definitions["__root__"] if isinstance(arg, tuple): - named_root_model = create_base_class(__model_name, arg[0], arg[1]) + kwargs = {"type_": arg[0], "default_": arg[1]} else: - named_root_model = create_base_class(__model_name, arg) + kwargs = {"type_": arg} + + try: + named_root_model = _create_root_model_cached(__model_name, **kwargs) + except TypeError: + # something in the arguments into _create_root_model_cached is not hashable + named_root_model = _create_root_model(__model_name, **kwargs) return named_root_model try: return _create_model_cached(__model_name, **field_definitions) diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_benchmark.py b/libs/core/tests/unit_tests/language_models/chat_models/test_benchmark.py new file mode 100644 index 00000000000..b455017325e --- /dev/null +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_benchmark.py @@ -0,0 +1,21 @@ +import time +from itertools import cycle + +from langchain_core.language_models import GenericFakeChatModel + + +def test_benchmark_model() -> None: + """Add rate limiter.""" + tic = time.time() + + model = GenericFakeChatModel( + messages=cycle(["hello", "world", "!"]), + ) + + for _ in range(1_000): + model.invoke("foo") + toc = time.time() + + # Verify that the time taken to run the loop is less than 1 seconds + + assert (toc - tic) < 1