diff --git a/libs/core/langchain_core/language_models/__init__.py b/libs/core/langchain_core/language_models/__init__.py index 4c2e597dddf..b86de1cdb13 100644 --- a/libs/core/langchain_core/language_models/__init__.py +++ b/libs/core/langchain_core/language_models/__init__.py @@ -1,6 +1,7 @@ from langchain_core.language_models.base import ( BaseLanguageModel, LanguageModelInput, + LanguageModelLike, LanguageModelOutput, get_tokenizer, ) @@ -16,4 +17,5 @@ __all__ = [ "LanguageModelInput", "get_tokenizer", "LanguageModelOutput", + "LanguageModelLike", ] diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index f48f83b37f4..b590e0df84f 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -17,7 +17,7 @@ from typing_extensions import TypeAlias from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string from langchain_core.prompt_values import PromptValue -from langchain_core.runnables import RunnableSerializable +from langchain_core.runnables import Runnable, RunnableSerializable from langchain_core.utils import get_pydantic_field_names if TYPE_CHECKING: @@ -49,11 +49,13 @@ def _get_token_ids_default_method(text: str) -> List[int]: LanguageModelInput = Union[PromptValue, str, List[BaseMessage]] -LanguageModelOutput = TypeVar("LanguageModelOutput") +LanguageModelOutput = Union[BaseMessage, str] +LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput] +LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str) class BaseLanguageModel( - RunnableSerializable[LanguageModelInput, LanguageModelOutput], ABC + RunnableSerializable[LanguageModelInput, LanguageModelOutputVar], ABC ): """Abstract base class for interfacing with language models. diff --git a/libs/core/tests/unit_tests/language_models/test_imports.py b/libs/core/tests/unit_tests/language_models/test_imports.py index 348111f8511..65d627c56fd 100644 --- a/libs/core/tests/unit_tests/language_models/test_imports.py +++ b/libs/core/tests/unit_tests/language_models/test_imports.py @@ -9,6 +9,7 @@ EXPECTED_ALL = [ "LanguageModelInput", "LanguageModelOutput", "get_tokenizer", + "LanguageModelLike", ]