From a612800ef0ac8ea851cd96f98611d1d668d0e1b6 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 26 Jul 2023 20:16:46 +0100 Subject: [PATCH] Runnable single protocol (#7800) Objects implementing Runnable: BasePromptTemplate, LLM, ChatModel, Chain, Retriever, OutputParser - [x] Implement Runnable in base Retriever - [x] Raise TypeError in operator methods for unsupported things - [x] Implement dict which calls values in parallel and outputs dict with results - [x] Merge in `+` for prompts - [x] Confirm precedence order for operators, ideal would be `+` `|`, https://docs.python.org/3/reference/expressions.html#operator-precedence - [x] Add support for openai functions, ie. Chat Models must return messages - [x] Implement BaseMessageChunk return type for BaseChatModel, a subclass of BaseMessage which implements __add__ to return BaseMessageChunk, concatenating all str args - [x] Update implementation of stream/astream for llm and chat models to use new `_stream`, `_astream` optional methods, with default implementation in base class `raise NotImplementedError` use https://stackoverflow.com/a/59762827 to see if it is implemented in base class - [x] Delete the IteratorCallbackHandler (leave the async one because people using) - [x] Make BaseLLMOutputParser implement Runnable, accepting either str or BaseMessage --------- Co-authored-by: Eugene Yurtsev --- .../langchain_experimental/llms/llamaapi.py | 14 +- libs/langchain/langchain/callbacks/base.py | 14 +- libs/langchain/langchain/callbacks/manager.py | 29 +- .../langchain/callbacks/streaming_aiter.py | 2 +- libs/langchain/langchain/chains/base.py | 17 +- .../langchain/chat_models/anthropic.py | 90 ++- libs/langchain/langchain/chat_models/base.py | 196 ++++- libs/langchain/langchain/chat_models/fake.py | 5 +- .../langchain/chat_models/jinachat.py | 137 ++-- .../langchain/langchain/chat_models/openai.py | 172 +++-- .../langchain/chat_models/vertexai.py | 16 +- .../document_transformers/openai_functions.py | 3 +- libs/langchain/langchain/llms/anthropic.py | 112 +-- libs/langchain/langchain/llms/base.py | 404 +++++++++- libs/langchain/langchain/llms/fake.py | 10 +- libs/langchain/langchain/llms/google_palm.py | 14 +- .../llms/huggingface_text_gen_inference.py | 171 +++-- libs/langchain/langchain/llms/llamacpp.py | 32 +- libs/langchain/langchain/llms/openai.py | 294 +++++--- libs/langchain/langchain/llms/tongyi.py | 14 +- .../langchain/output_parsers/list.py | 2 +- .../output_parsers/openai_functions.py | 4 +- libs/langchain/langchain/prompts/base.py | 4 +- libs/langchain/langchain/prompts/chat.py | 2 +- libs/langchain/langchain/retrievers/arxiv.py | 10 +- libs/langchain/langchain/retrievers/bm25.py | 10 +- .../langchain/retrievers/docarray.py | 13 +- .../retrievers/elastic_search_bm25.py | 10 +- .../google_cloud_enterprise_search.py | 10 +- libs/langchain/langchain/retrievers/kendra.py | 13 +- libs/langchain/langchain/retrievers/knn.py | 10 +- .../langchain/retrievers/llama_index.py | 15 +- libs/langchain/langchain/retrievers/metal.py | 10 +- libs/langchain/langchain/retrievers/milvus.py | 14 +- .../langchain/retrievers/multi_query.py | 13 +- .../retrievers/pinecone_hybrid_search.py | 10 +- libs/langchain/langchain/retrievers/pubmed.py | 10 +- .../langchain/retrievers/self_query/base.py | 10 +- libs/langchain/langchain/retrievers/svm.py | 10 +- libs/langchain/langchain/retrievers/tfidf.py | 10 +- .../retrievers/time_weighted_retriever.py | 11 +- .../langchain/retrievers/vespa_retriever.py | 10 +- .../retrievers/weaviate_hybrid_search.py | 10 +- .../langchain/retrievers/wikipedia.py | 10 +- libs/langchain/langchain/retrievers/zilliz.py | 14 +- libs/langchain/langchain/schema/__init__.py | 2 - .../langchain/schema/language_model.py | 20 +- libs/langchain/langchain/schema/messages.py | 65 +- libs/langchain/langchain/schema/output.py | 40 +- .../langchain/schema/output_parser.py | 30 +- .../langchain/schema/prompt_template.py | 8 +- libs/langchain/langchain/schema/retriever.py | 22 +- libs/langchain/langchain/schema/runnable.py | 705 ++++++++++++++++++ .../chat_models/test_openai.py | 65 ++ .../integration_tests/llms/test_openai.py | 66 +- .../llms/test_promptlayer_openai.py | 7 - .../schema/__snapshots__/test_runnable.ambr | 668 +++++++++++++++++ .../tests/unit_tests/schema/test_messages.py | 38 + .../tests/unit_tests/schema/test_output.py | 52 ++ .../tests/unit_tests/schema/test_runnable.py | 547 ++++++++++++++ 60 files changed, 3564 insertions(+), 762 deletions(-) create mode 100644 libs/langchain/langchain/schema/runnable.py create mode 100644 libs/langchain/tests/unit_tests/schema/__snapshots__/test_runnable.ambr create mode 100644 libs/langchain/tests/unit_tests/schema/test_messages.py create mode 100644 libs/langchain/tests/unit_tests/schema/test_output.py create mode 100644 libs/langchain/tests/unit_tests/schema/test_runnable.py diff --git a/libs/experimental/langchain_experimental/llms/llamaapi.py b/libs/experimental/langchain_experimental/llms/llamaapi.py index e5fdff812f2..9d2e79d2fe0 100644 --- a/libs/experimental/langchain_experimental/llms/llamaapi.py +++ b/libs/experimental/langchain_experimental/llms/llamaapi.py @@ -9,10 +9,7 @@ from typing import ( Tuple, ) -from langchain.callbacks.manager import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.chat_models.base import BaseChatModel from langchain.schema import ( ChatGeneration, @@ -116,15 +113,6 @@ class ChatLlamaAPI(BaseChatModel): generations.append(gen) return ChatResult(generations=generations) - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - raise NotImplementedError - @property def _client_params(self) -> Mapping[str, Any]: """Get the parameters used for the client.""" diff --git a/libs/langchain/langchain/callbacks/base.py b/libs/langchain/langchain/callbacks/base.py index 2e15185f898..85ea6c96ae5 100644 --- a/libs/langchain/langchain/callbacks/base.py +++ b/libs/langchain/langchain/callbacks/base.py @@ -1,13 +1,14 @@ """Base callback handler that can be used to handle callbacks in langchain.""" from __future__ import annotations -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from uuid import UUID -from langchain.schema.agent import AgentAction, AgentFinish -from langchain.schema.document import Document -from langchain.schema.messages import BaseMessage -from langchain.schema.output import LLMResult +if TYPE_CHECKING: + from langchain.schema.agent import AgentAction, AgentFinish + from langchain.schema.document import Document + from langchain.schema.messages import BaseMessage + from langchain.schema.output import LLMResult class RetrieverManagerMixin: @@ -543,3 +544,6 @@ class BaseCallbackManager(CallbackManagerMixin): for key in keys: self.metadata.pop(key) self.inheritable_metadata.pop(key) + + +Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] diff --git a/libs/langchain/langchain/callbacks/manager.py b/libs/langchain/langchain/callbacks/manager.py index 94ca4d58029..bc26656f099 100644 --- a/libs/langchain/langchain/callbacks/manager.py +++ b/libs/langchain/langchain/callbacks/manager.py @@ -4,6 +4,7 @@ import asyncio import functools import logging import os +import uuid from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar from typing import ( @@ -20,12 +21,13 @@ from typing import ( Union, cast, ) -from uuid import UUID, uuid4 +from uuid import UUID import langchain from langchain.callbacks.base import ( BaseCallbackHandler, BaseCallbackManager, + Callbacks, ChainManagerMixin, LLMManagerMixin, RetrieverManagerMixin, @@ -50,7 +52,6 @@ if TYPE_CHECKING: from langsmith import Client as LangSmithClient logger = logging.getLogger(__name__) -Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar( "openai_callback", default=None @@ -437,7 +438,7 @@ class BaseRunManager(RunManagerMixin): BaseRunManager: The noop manager. """ return cls( - run_id=uuid4(), + run_id=uuid.uuid4(), handlers=[], inheritable_handlers=[], tags=[], @@ -1024,7 +1025,7 @@ class CallbackManager(BaseCallbackManager): """ managers = [] for prompt in prompts: - run_id_ = uuid4() + run_id_ = uuid.uuid4() _handle_event( self.handlers, "on_llm_start", @@ -1073,7 +1074,7 @@ class CallbackManager(BaseCallbackManager): managers = [] for message_list in messages: - run_id_ = uuid4() + run_id_ = uuid.uuid4() _handle_event( self.handlers, "on_chat_model_start", @@ -1120,7 +1121,7 @@ class CallbackManager(BaseCallbackManager): CallbackManagerForChainRun: The callback manager for the chain run. """ if run_id is None: - run_id = uuid4() + run_id = uuid.uuid4() _handle_event( self.handlers, @@ -1166,7 +1167,7 @@ class CallbackManager(BaseCallbackManager): CallbackManagerForToolRun: The callback manager for the tool run. """ if run_id is None: - run_id = uuid4() + run_id = uuid.uuid4() _handle_event( self.handlers, @@ -1202,7 +1203,7 @@ class CallbackManager(BaseCallbackManager): ) -> CallbackManagerForRetrieverRun: """Run when retriever starts running.""" if run_id is None: - run_id = uuid4() + run_id = uuid.uuid4() _handle_event( self.handlers, @@ -1302,7 +1303,7 @@ class AsyncCallbackManager(BaseCallbackManager): managers = [] for prompt in prompts: - run_id_ = uuid4() + run_id_ = uuid.uuid4() tasks.append( _ahandle_event( @@ -1341,7 +1342,7 @@ class AsyncCallbackManager(BaseCallbackManager): serialized: Dict[str, Any], messages: List[List[BaseMessage]], **kwargs: Any, - ) -> Any: + ) -> List[AsyncCallbackManagerForLLMRun]: """Run when LLM starts running. Args: @@ -1358,7 +1359,7 @@ class AsyncCallbackManager(BaseCallbackManager): managers = [] for message_list in messages: - run_id_ = uuid4() + run_id_ = uuid.uuid4() tasks.append( _ahandle_event( @@ -1410,7 +1411,7 @@ class AsyncCallbackManager(BaseCallbackManager): for the chain run. """ if run_id is None: - run_id = uuid4() + run_id = uuid.uuid4() await _ahandle_event( self.handlers, @@ -1458,7 +1459,7 @@ class AsyncCallbackManager(BaseCallbackManager): for the tool run. """ if run_id is None: - run_id = uuid4() + run_id = uuid.uuid4() await _ahandle_event( self.handlers, @@ -1494,7 +1495,7 @@ class AsyncCallbackManager(BaseCallbackManager): ) -> AsyncCallbackManagerForRetrieverRun: """Run when retriever starts running.""" if run_id is None: - run_id = uuid4() + run_id = uuid.uuid4() await _ahandle_event( self.handlers, diff --git a/libs/langchain/langchain/callbacks/streaming_aiter.py b/libs/langchain/langchain/callbacks/streaming_aiter.py index 6e791a64b52..17e962ac872 100644 --- a/libs/langchain/langchain/callbacks/streaming_aiter.py +++ b/libs/langchain/langchain/callbacks/streaming_aiter.py @@ -4,7 +4,7 @@ import asyncio from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast from langchain.callbacks.base import AsyncCallbackHandler -from langchain.schema import LLMResult +from langchain.schema.output import LLMResult # TODO If used by two LLM runs in parallel this won't work as expected diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 7b25c36fb02..7f79c4ffd97 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -22,6 +22,7 @@ from langchain.callbacks.manager import ( from langchain.load.dump import dumpd from langchain.load.serializable import Serializable from langchain.schema import RUN_KEY, BaseMemory, RunInfo +from langchain.schema.runnable import Runnable, RunnableConfig logger = logging.getLogger(__name__) @@ -30,7 +31,7 @@ def _get_verbosity() -> bool: return langchain.verbose -class Chain(Serializable, ABC): +class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): """Abstract base class for creating structured sequences of calls to components. Chains should be used to encode a sequence of calls to components like @@ -53,6 +54,20 @@ class Chain(Serializable, ABC): chains and cannot return as rich of an output as `__call__`. """ + def invoke( + self, input: Dict[str, Any], config: Optional[RunnableConfig] = None + ) -> Dict[str, Any]: + return self(input, **(config or {})) + + async def ainvoke( + self, input: Dict[str, Any], config: Optional[RunnableConfig] = None + ) -> Dict[str, Any]: + if type(self)._acall == Chain._acall: + # If the chain does not implement async, fall back to default implementation + return await super().ainvoke(input, config) + + return await self.acall(input, **(config or {})) + memory: Optional[BaseMemory] = None """Optional memory object. Defaults to None. Memory is a class that gets called at the start diff --git a/libs/langchain/langchain/chat_models/anthropic.py b/libs/langchain/langchain/chat_models/anthropic.py index ce770c9494d..cf94096e9b9 100644 --- a/libs/langchain/langchain/chat_models/anthropic.py +++ b/libs/langchain/langchain/chat_models/anthropic.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -12,11 +12,13 @@ from langchain.schema import ( ) from langchain.schema.messages import ( AIMessage, + AIMessageChunk, BaseMessage, ChatMessage, HumanMessage, SystemMessage, ) +from langchain.schema.output import ChatGenerationChunk class ChatAnthropic(BaseChatModel, _AnthropicCommon): @@ -94,6 +96,44 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): text.rstrip() ) # trim off the trailing ' ' that might come from the "Assistant: " + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + prompt = self._convert_messages_to_prompt(messages) + params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs} + if stop: + params["stop_sequences"] = stop + + stream_resp = self.client.completions.create(**params, stream=True) + for data in stream_resp: + delta = data.completion + yield ChatGenerationChunk(message=AIMessageChunk(content=delta)) + if run_manager: + run_manager.on_llm_new_token(delta) + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + prompt = self._convert_messages_to_prompt(messages) + params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs} + if stop: + params["stop_sequences"] = stop + + stream_resp = await self.async_client.completions.create(**params, stream=True) + async for data in stream_resp: + delta = data.completion + yield ChatGenerationChunk(message=AIMessageChunk(content=delta)) + if run_manager: + await run_manager.on_llm_new_token(delta) + def _generate( self, messages: List[BaseMessage], @@ -101,22 +141,19 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - prompt = self._convert_messages_to_prompt(messages) - params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs} - if stop: - params["stop_sequences"] = stop - if self.streaming: completion = "" - stream_resp = self.client.completions.create(**params, stream=True) - for data in stream_resp: - delta = data.completion - completion += delta - if run_manager: - run_manager.on_llm_new_token( - delta, - ) + for chunk in self._stream(messages, stop, run_manager, **kwargs): + completion += chunk.text else: + prompt = self._convert_messages_to_prompt(messages) + params: Dict[str, Any] = { + "prompt": prompt, + **self._default_params, + **kwargs, + } + if stop: + params["stop_sequences"] = stop response = self.client.completions.create(**params) completion = response.completion message = AIMessage(content=completion) @@ -129,24 +166,19 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - prompt = self._convert_messages_to_prompt(messages) - params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs} - if stop: - params["stop_sequences"] = stop - if self.streaming: completion = "" - stream_resp = await self.async_client.completions.create( - **params, stream=True - ) - async for data in stream_resp: - delta = data.completion - completion += delta - if run_manager: - await run_manager.on_llm_new_token( - delta, - ) + async for chunk in self._astream(messages, stop, run_manager, **kwargs): + completion += chunk.text else: + prompt = self._convert_messages_to_prompt(messages) + params: Dict[str, Any] = { + "prompt": prompt, + **self._default_params, + **kwargs, + } + if stop: + params["stop_sequences"] = stop response = await self.async_client.completions.create(**params) completion = response.completion message = AIMessage(content=completion) diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index 36023dfd847..fe9ed5c59f7 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -3,7 +3,16 @@ import inspect import warnings from abc import ABC, abstractmethod from functools import partial -from typing import Any, Dict, List, Optional, Sequence +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + Sequence, + cast, +) from pydantic import Field, root_validator @@ -17,6 +26,8 @@ from langchain.callbacks.manager import ( Callbacks, ) from langchain.load.dump import dumpd, dumps +from langchain.prompts.base import StringPromptValue +from langchain.prompts.chat import ChatPromptValue from langchain.schema import ( ChatGeneration, ChatResult, @@ -24,17 +35,22 @@ from langchain.schema import ( PromptValue, RunInfo, ) -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage +from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput +from langchain.schema.messages import ( + AIMessage, + BaseMessage, + BaseMessageChunk, + HumanMessage, +) +from langchain.schema.output import ChatGenerationChunk +from langchain.schema.runnable import RunnableConfig def _get_verbosity() -> bool: return langchain.verbose -class BaseChatModel(BaseLanguageModel, ABC): - """Base class for chat models.""" - +class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): cache: Optional[bool] = None """Whether to cache the response.""" verbose: bool = Field(default_factory=_get_verbosity) @@ -64,6 +80,154 @@ class BaseChatModel(BaseLanguageModel, ABC): arbitrary_types_allowed = True + # --- Runnable methods --- + + def _convert_input(self, input: LanguageModelInput) -> PromptValue: + if isinstance(input, PromptValue): + return input + elif isinstance(input, str): + return StringPromptValue(text=input) + elif isinstance(input, list): + return ChatPromptValue(messages=input) + else: + raise ValueError( + f"Invalid input type {type(input)}. " + "Must be a PromptValue, str, or list of BaseMessages." + ) + + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + ) -> BaseMessageChunk: + return cast( + BaseMessageChunk, + cast( + ChatGeneration, + self.generate_prompt( + [self._convert_input(input)], stop=stop, **(config or {}) + ).generations[0][0], + ).message, + ) + + async def ainvoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + ) -> BaseMessageChunk: + if type(self)._agenerate == BaseChatModel._agenerate: + # model doesn't implement async generation, so use default implementation + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.invoke, input, config, stop=stop) + ) + + llm_result = await self.agenerate_prompt( + [self._convert_input(input)], stop=stop, **(config or {}) + ) + return cast( + BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message + ) + + def stream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> Iterator[BaseMessageChunk]: + if type(self)._stream == BaseChatModel._stream: + # model doesn't implement streaming, so use default implementation + yield self.invoke(input, config=config, stop=stop, **kwargs) + else: + config = config or {} + messages = self._convert_input(input).to_messages() + params = self._get_invocation_params(stop=stop, **kwargs) + options = {"stop": stop, **kwargs} + callback_manager = CallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + config.get("metadata"), + self.metadata, + ) + (run_manager,) = callback_manager.on_chat_model_start( + dumpd(self), [messages], invocation_params=params, options=options + ) + try: + message: Optional[BaseMessageChunk] = None + for chunk in self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ): + yield chunk.message + if message is None: + message = chunk.message + else: + message += chunk.message + assert message is not None + except (KeyboardInterrupt, Exception) as e: + run_manager.on_llm_error(e) + raise e + else: + run_manager.on_llm_end( + LLMResult(generations=[[ChatGeneration(message=message)]]), + ) + + async def astream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> AsyncIterator[BaseMessageChunk]: + if type(self)._astream == BaseChatModel._astream: + # model doesn't implement streaming, so use default implementation + yield self.invoke(input, config=config, stop=stop, **kwargs) + else: + config = config or {} + messages = self._convert_input(input).to_messages() + params = self._get_invocation_params(stop=stop, **kwargs) + options = {"stop": stop, **kwargs} + callback_manager = AsyncCallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + config.get("metadata"), + self.metadata, + ) + (run_manager,) = await callback_manager.on_chat_model_start( + dumpd(self), [messages], invocation_params=params, options=options + ) + try: + message: Optional[BaseMessageChunk] = None + async for chunk in self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ): + yield chunk.message + if message is None: + message = chunk.message + else: + message += chunk.message + assert message is not None + except (KeyboardInterrupt, Exception) as e: + await run_manager.on_llm_error(e) + raise e + else: + await run_manager.on_llm_end( + LLMResult(generations=[[ChatGeneration(message=message)]]), + ) + + # --- Custom methods --- + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: return {} @@ -334,7 +498,6 @@ class BaseChatModel(BaseLanguageModel, ABC): ) -> ChatResult: """Top Level call""" - @abstractmethod async def _agenerate( self, messages: List[BaseMessage], @@ -343,6 +506,25 @@ class BaseChatModel(BaseLanguageModel, ABC): **kwargs: Any, ) -> ChatResult: """Top Level call""" + raise NotImplementedError() + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + raise NotImplementedError() + + def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + raise NotImplementedError() def __call__( self, diff --git a/libs/langchain/langchain/chat_models/fake.py b/libs/langchain/langchain/chat_models/fake.py index 5f370c91140..97e69c993f4 100644 --- a/libs/langchain/langchain/chat_models/fake.py +++ b/libs/langchain/langchain/chat_models/fake.py @@ -25,7 +25,10 @@ class FakeListChatModel(SimpleChatModel): ) -> str: """First try to lookup in queries, else return 'foo' or 'bar'.""" response = self.responses[self.i] - self.i += 1 + if self.i < len(self.responses) - 1: + self.i += 1 + else: + self.i = 0 return response @property diff --git a/libs/langchain/langchain/chat_models/jinachat.py b/libs/langchain/langchain/chat_models/jinachat.py index 5c7d89b4023..30fee861261 100644 --- a/libs/langchain/langchain/chat_models/jinachat.py +++ b/libs/langchain/langchain/chat_models/jinachat.py @@ -4,8 +4,10 @@ from __future__ import annotations import logging from typing import ( Any, + AsyncIterator, Callable, Dict, + Iterator, List, Mapping, Optional, @@ -36,6 +38,14 @@ from langchain.schema import ( HumanMessage, SystemMessage, ) +from langchain.schema.messages import ( + AIMessageChunk, + BaseMessageChunk, + ChatMessageChunk, + HumanMessageChunk, + SystemMessageChunk, +) +from langchain.schema.output import ChatGenerationChunk from langchain.utils import get_from_dict_or_env, get_pydantic_field_names logger = logging.getLogger(__name__) @@ -75,6 +85,24 @@ async def acompletion_with_retry(llm: JinaChat, **kwargs: Any) -> Any: return await _completion_with_retry(**kwargs) +def _convert_delta_to_message_chunk( + _dict: Mapping[str, Any], default_class: type[BaseMessageChunk] +) -> BaseMessageChunk: + role = _dict.get("role") + content = _dict.get("content") or "" + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content) + elif role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk(content=content) + elif role == "system" or default_class == SystemMessageChunk: + return SystemMessageChunk(content=content) + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role) + else: + return default_class(content=content) + + def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: role = _dict["role"] if role == "user": @@ -258,6 +286,25 @@ class JinaChat(BaseChatModel): overall_token_usage[k] = v return {"token_usage": overall_token_usage} + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs, "stream": True} + + default_chunk_class = AIMessageChunk + for chunk in self.completion_with_retry(messages=message_dicts, **params): + delta = chunk["choices"][0]["delta"] + chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + default_chunk_class = chunk.__class__ + yield ChatGenerationChunk(message=chunk) + if run_manager: + run_manager.on_llm_new_token(chunk.content) + def _generate( self, messages: List[BaseMessage], @@ -265,27 +312,20 @@ class JinaChat(BaseChatModel): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: + if self.streaming: + generation: Optional[ChatGenerationChunk] = None + for chunk in self._stream( + messages=messages, stop=stop, run_manager=run_manager, **kwargs + ): + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return ChatResult(generations=[generation]) + message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} - if self.streaming: - inner_completion = "" - role = "assistant" - params["stream"] = True - for stream_resp in self.completion_with_retry( - messages=message_dicts, **params - ): - role = stream_resp["choices"][0]["delta"].get("role", role) - token = stream_resp["choices"][0]["delta"].get("content") or "" - inner_completion += token - if run_manager: - run_manager.on_llm_new_token(token) - message = _convert_dict_to_message( - { - "content": inner_completion, - "role": role, - } - ) - return ChatResult(generations=[ChatGeneration(message=message)]) response = self.completion_with_retry(messages=message_dicts, **params) return self._create_chat_result(response) @@ -309,6 +349,27 @@ class JinaChat(BaseChatModel): llm_output = {"token_usage": response["usage"]} return ChatResult(generations=generations, llm_output=llm_output) + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs, "stream": True} + + default_chunk_class = AIMessageChunk + async for chunk in await acompletion_with_retry( + self, messages=message_dicts, **params + ): + delta = chunk["choices"][0]["delta"] + chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + default_chunk_class = chunk.__class__ + yield ChatGenerationChunk(message=chunk) + if run_manager: + await run_manager.on_llm_new_token(chunk.content) + async def _agenerate( self, messages: List[BaseMessage], @@ -316,32 +377,22 @@ class JinaChat(BaseChatModel): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: + if self.streaming: + generation: Optional[ChatGenerationChunk] = None + async for chunk in self._astream( + messages=messages, stop=stop, run_manager=run_manager, **kwargs + ): + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return ChatResult(generations=[generation]) + message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} - if self.streaming: - inner_completion = "" - role = "assistant" - params["stream"] = True - async for stream_resp in await acompletion_with_retry( - self, messages=message_dicts, **params - ): - role = stream_resp["choices"][0]["delta"].get("role", role) - token = stream_resp["choices"][0]["delta"].get("content", "") - inner_completion += token or "" - if run_manager: - await run_manager.on_llm_new_token(token) - message = _convert_dict_to_message( - { - "content": inner_completion, - "role": role, - } - ) - return ChatResult(generations=[ChatGeneration(message=message)]) - else: - response = await acompletion_with_retry( - self, messages=message_dicts, **params - ) - return self._create_chat_result(response) + response = await acompletion_with_retry(self, messages=message_dicts, **params) + return self._create_chat_result(response) @property def _invocation_params(self) -> Mapping[str, Any]: diff --git a/libs/langchain/langchain/chat_models/openai.py b/libs/langchain/langchain/chat_models/openai.py index f1a1efd9773..815e3011bbf 100644 --- a/libs/langchain/langchain/chat_models/openai.py +++ b/libs/langchain/langchain/chat_models/openai.py @@ -6,8 +6,10 @@ import sys from typing import ( TYPE_CHECKING, Any, + AsyncIterator, Callable, Dict, + Iterator, List, Mapping, Optional, @@ -35,12 +37,19 @@ from langchain.schema import ( ) from langchain.schema.messages import ( AIMessage, + AIMessageChunk, BaseMessage, + BaseMessageChunk, ChatMessage, + ChatMessageChunk, FunctionMessage, + FunctionMessageChunk, HumanMessage, + HumanMessageChunk, SystemMessage, + SystemMessageChunk, ) +from langchain.schema.output import ChatGenerationChunk from langchain.utils import get_from_dict_or_env, get_pydantic_field_names if TYPE_CHECKING: @@ -95,6 +104,30 @@ async def acompletion_with_retry(llm: ChatOpenAI, **kwargs: Any) -> Any: return await _completion_with_retry(**kwargs) +def _convert_delta_to_message_chunk( + _dict: Mapping[str, Any], default_class: type[BaseMessageChunk] +) -> BaseMessageChunk: + role = _dict.get("role") + content = _dict.get("content") or "" + if _dict.get("function_call"): + additional_kwargs = {"function_call": dict(_dict["function_call"])} + else: + additional_kwargs = {} + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content) + elif role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + elif role == "system" or default_class == SystemMessageChunk: + return SystemMessageChunk(content=content) + elif role == "function" or default_class == FunctionMessageChunk: + return FunctionMessageChunk(content=content, name=_dict["name"]) + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role) + else: + return default_class(content=content) + + def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: role = _dict["role"] if role == "user": @@ -313,6 +346,27 @@ class ChatOpenAI(BaseChatModel): overall_token_usage[k] = v return {"token_usage": overall_token_usage, "model_name": self.model_name} + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs, "stream": True} + + default_chunk_class = AIMessageChunk + for chunk in self.completion_with_retry(messages=message_dicts, **params): + if len(chunk["choices"]) == 0: + continue + delta = chunk["choices"][0]["delta"] + chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + default_chunk_class = chunk.__class__ + yield ChatGenerationChunk(message=chunk) + if run_manager: + run_manager.on_llm_new_token(chunk.content) + def _generate( self, messages: List[BaseMessage], @@ -320,40 +374,20 @@ class ChatOpenAI(BaseChatModel): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: + if self.streaming: + generation: Optional[ChatGenerationChunk] = None + for chunk in self._stream( + messages=messages, stop=stop, run_manager=run_manager, **kwargs + ): + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return ChatResult(generations=[generation]) + message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} - if self.streaming: - inner_completion = "" - role = "assistant" - params["stream"] = True - function_call: Optional[dict] = None - for stream_resp in self.completion_with_retry( - messages=message_dicts, **params - ): - if len(stream_resp["choices"]) > 0: - role = stream_resp["choices"][0]["delta"].get("role", role) - token = stream_resp["choices"][0]["delta"].get("content") or "" - inner_completion += token - _function_call = stream_resp["choices"][0]["delta"].get( - "function_call" - ) - if _function_call: - if function_call is None: - function_call = _function_call - elif "arguments" in function_call: - function_call["arguments"] += _function_call["arguments"] - else: - function_call["arguments"] = _function_call["arguments"] - if run_manager: - run_manager.on_llm_new_token(token) - message = _convert_dict_to_message( - { - "content": inner_completion, - "role": role, - "function_call": function_call, - } - ) - return ChatResult(generations=[ChatGeneration(message=message)]) response = self.completion_with_retry(messages=message_dicts, **params) return self._create_chat_result(response) @@ -381,6 +415,29 @@ class ChatOpenAI(BaseChatModel): llm_output = {"token_usage": token_usage, "model_name": self.model_name} return ChatResult(generations=generations, llm_output=llm_output) + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs, "stream": True} + + default_chunk_class = AIMessageChunk + async for chunk in await acompletion_with_retry( + self, messages=message_dicts, **params + ): + if len(chunk["choices"]) == 0: + continue + delta = chunk["choices"][0]["delta"] + chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + default_chunk_class = chunk.__class__ + yield ChatGenerationChunk(message=chunk) + if run_manager: + await run_manager.on_llm_new_token(chunk.content) + async def _agenerate( self, messages: List[BaseMessage], @@ -388,45 +445,22 @@ class ChatOpenAI(BaseChatModel): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: + if self.streaming: + generation: Optional[ChatGenerationChunk] = None + async for chunk in self._astream( + messages=messages, stop=stop, run_manager=run_manager, **kwargs + ): + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return ChatResult(generations=[generation]) + message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} - if self.streaming: - inner_completion = "" - role = "assistant" - params["stream"] = True - function_call: Optional[dict] = None - async for stream_resp in await acompletion_with_retry( - self, messages=message_dicts, **params - ): - if len(stream_resp["choices"]) > 0: - role = stream_resp["choices"][0]["delta"].get("role", role) - token = stream_resp["choices"][0]["delta"].get("content", "") - inner_completion += token or "" - _function_call = stream_resp["choices"][0]["delta"].get( - "function_call" - ) - if _function_call: - if function_call is None: - function_call = _function_call - elif "arguments" in function_call: - function_call["arguments"] += _function_call["arguments"] - else: - function_call["arguments"] = _function_call["arguments"] - if run_manager: - await run_manager.on_llm_new_token(token) - message = _convert_dict_to_message( - { - "content": inner_completion, - "role": role, - "function_call": function_call, - } - ) - return ChatResult(generations=[ChatGeneration(message=message)]) - else: - response = await acompletion_with_retry( - self, messages=message_dicts, **params - ) - return self._create_chat_result(response) + response = await acompletion_with_retry(self, messages=message_dicts, **params) + return self._create_chat_result(response) @property def _identifying_params(self) -> Dict[str, Any]: diff --git a/libs/langchain/langchain/chat_models/vertexai.py b/libs/langchain/langchain/chat_models/vertexai.py index 84ea2065701..b50eac264a2 100644 --- a/libs/langchain/langchain/chat_models/vertexai.py +++ b/libs/langchain/langchain/chat_models/vertexai.py @@ -4,10 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional from pydantic import root_validator -from langchain.callbacks.manager import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.chat_models.base import BaseChatModel from langchain.llms.vertexai import _VertexAICommon, is_codey_model from langchain.schema import ( @@ -162,14 +159,3 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): response = chat.send_message(question.content) text = self._enforce_stop_words(response.text, stop) return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))]) - - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - raise NotImplementedError( - """Vertex AI doesn't support async requests at the moment.""" - ) diff --git a/libs/langchain/langchain/document_transformers/openai_functions.py b/libs/langchain/langchain/document_transformers/openai_functions.py index 96de42a2b95..8b2d11467f1 100644 --- a/libs/langchain/langchain/document_transformers/openai_functions.py +++ b/libs/langchain/langchain/document_transformers/openai_functions.py @@ -6,7 +6,8 @@ from pydantic import BaseModel from langchain.chains.llm import LLMChain from langchain.chains.openai_functions import create_tagging_chain from langchain.prompts import ChatPromptTemplate -from langchain.schema import BaseDocumentTransformer, BaseLanguageModel, Document +from langchain.schema import BaseDocumentTransformer, Document +from langchain.schema.language_model import BaseLanguageModel class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel): diff --git a/libs/langchain/langchain/llms/anthropic.py b/libs/langchain/langchain/llms/anthropic.py index a9d564fee4d..f32f581d1f7 100644 --- a/libs/langchain/langchain/llms/anthropic.py +++ b/libs/langchain/langchain/llms/anthropic.py @@ -1,18 +1,20 @@ import re import warnings -from typing import Any, Callable, Dict, Generator, List, Mapping, Optional +from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional -from pydantic import BaseModel, root_validator +from pydantic import root_validator from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import LLM +from langchain.schema.language_model import BaseLanguageModel +from langchain.schema.output import GenerationChunk from langchain.utils import check_package_version, get_from_dict_or_env -class _AnthropicCommon(BaseModel): +class _AnthropicCommon(BaseLanguageModel): client: Any = None #: :meta private: async_client: Any = None #: :meta private: model: str = "claude-2" @@ -193,24 +195,16 @@ class Anthropic(LLM, _AnthropicCommon): response = model(prompt) """ + if self.streaming: + completion = "" + for chunk in self._stream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + completion += chunk.text + return completion + stop = self._get_anthropic_stop(stop) params = {**self._default_params, **kwargs} - if self.streaming: - stream_resp = self.client.completions.create( - prompt=self._wrap_prompt(prompt), - stop_sequences=stop, - stream=True, - **params, - ) - current_completion = "" - for data in stream_resp: - delta = data.completion - current_completion += delta - if run_manager: - run_manager.on_llm_new_token( - delta, - ) - return current_completion response = self.client.completions.create( prompt=self._wrap_prompt(prompt), stop_sequences=stop, @@ -226,22 +220,17 @@ class Anthropic(LLM, _AnthropicCommon): **kwargs: Any, ) -> str: """Call out to Anthropic's completion endpoint asynchronously.""" + if self.streaming: + completion = "" + async for chunk in self._astream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + completion += chunk.text + return completion + stop = self._get_anthropic_stop(stop) params = {**self._default_params, **kwargs} - if self.streaming: - stream_resp = await self.async_client.completions.create( - prompt=self._wrap_prompt(prompt), - stop_sequences=stop, - stream=True, - **params, - ) - current_completion = "" - async for data in stream_resp: - delta = data.completion - current_completion += delta - if run_manager: - await run_manager.on_llm_new_token(delta) - return current_completion + response = await self.async_client.completions.create( prompt=self._wrap_prompt(prompt), stop_sequences=stop, @@ -249,23 +238,23 @@ class Anthropic(LLM, _AnthropicCommon): ) return response.completion - def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator: + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: r"""Call Anthropic completion_stream and return the resulting generator. - BETA: this is a beta feature while we figure out the right abstraction. - Once that happens, this interface could change. - Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. - Returns: A generator representing the stream of tokens from Anthropic. - Example: .. code-block:: python - prompt = "Write a poem about a stream." prompt = f"\n\nHuman: {prompt}\n\nAssistant:" generator = anthropic.stream(prompt) @@ -273,12 +262,49 @@ class Anthropic(LLM, _AnthropicCommon): yield token """ stop = self._get_anthropic_stop(stop) - return self.client.completions.create( + params = {**self._default_params, **kwargs} + + for token in self.client.completions.create( + prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params + ): + yield GenerationChunk(text=token.completion) + if run_manager: + run_manager.on_llm_new_token(token.completion) + + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + r"""Call Anthropic completion_stream and return the resulting generator. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + Returns: + A generator representing the stream of tokens from Anthropic. + Example: + .. code-block:: python + prompt = "Write a poem about a stream." + prompt = f"\n\nHuman: {prompt}\n\nAssistant:" + generator = anthropic.stream(prompt) + for token in generator: + yield token + """ + stop = self._get_anthropic_stop(stop) + params = {**self._default_params, **kwargs} + + async for token in await self.async_client.completions.create( prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, - **self._default_params, - ) + **params, + ): + yield GenerationChunk(text=token.completion) + if run_manager: + await run_manager.on_llm_new_token(token.completion) def get_num_tokens(self, text: str) -> int: """Calculate number of tokens.""" diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 91deb9c7942..10595e83191 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -7,11 +7,14 @@ import json import logging import warnings from abc import ABC, abstractmethod +from functools import partial from pathlib import Path from typing import ( Any, + AsyncIterator, Callable, Dict, + Iterator, List, Mapping, Optional, @@ -19,6 +22,7 @@ from typing import ( Tuple, Type, Union, + cast, ) import yaml @@ -42,14 +46,18 @@ from langchain.callbacks.manager import ( Callbacks, ) from langchain.load.dump import dumpd +from langchain.prompts.base import StringPromptValue +from langchain.prompts.chat import ChatPromptValue from langchain.schema import ( Generation, LLMResult, PromptValue, RunInfo, ) -from langchain.schema.language_model import BaseLanguageModel +from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string +from langchain.schema.output import GenerationChunk +from langchain.schema.runnable import RunnableConfig logger = logging.getLogger(__name__) @@ -115,7 +123,7 @@ def update_cache( return llm_output -class BaseLLM(BaseLanguageModel, ABC): +class BaseLLM(BaseLanguageModel[str], ABC): """Base LLM abstract interface. It should take in a prompt and return a string.""" @@ -157,6 +165,204 @@ class BaseLLM(BaseLanguageModel, ABC): else: return verbose + # --- Runnable methods --- + + def _convert_input(self, input: LanguageModelInput) -> PromptValue: + if isinstance(input, PromptValue): + return input + elif isinstance(input, str): + return StringPromptValue(text=input) + elif isinstance(input, list): + return ChatPromptValue(messages=input) + else: + raise ValueError( + f"Invalid input type {type(input)}. " + "Must be a PromptValue, str, or list of BaseMessages." + ) + + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + ) -> str: + return ( + self.generate_prompt( + [self._convert_input(input)], stop=stop, **(config or {}) + ) + .generations[0][0] + .text + ) + + async def ainvoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + ) -> str: + if type(self)._agenerate == BaseLLM._agenerate: + # model doesn't implement async invoke, so use default implementation + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.invoke, input, config, stop=stop) + ) + + llm_result = await self.agenerate_prompt( + [self._convert_input(input)], stop=stop, **(config or {}) + ) + return llm_result.generations[0][0].text + + def batch( + self, + inputs: List[LanguageModelInput], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + max_concurrency: Optional[int] = None, + ) -> List[str]: + config = self._get_config_list(config, len(inputs)) + + if max_concurrency is None: + llm_result = self.generate_prompt( + [self._convert_input(input) for input in inputs], + callbacks=[c.get("callbacks") for c in config], + tags=[c.get("tags") for c in config], + metadata=[c.get("metadata") for c in config], + ) + return [g[0].text for g in llm_result.generations] + else: + batches = [ + inputs[i : i + max_concurrency] + for i in range(0, len(inputs), max_concurrency) + ] + return [ + output + for batch in batches + for output in self.batch(batch, config=config) + ] + + async def abatch( + self, + inputs: List[LanguageModelInput], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + max_concurrency: Optional[int] = None, + ) -> List[str]: + if type(self)._agenerate == BaseLLM._agenerate: + # model doesn't implement async batch, so use default implementation + return await asyncio.get_running_loop().run_in_executor( + None, self.batch, inputs, config, max_concurrency + ) + + config = self._get_config_list(config, len(inputs)) + + if max_concurrency is None: + llm_result = await self.agenerate_prompt( + [self._convert_input(input) for input in inputs], + callbacks=[c.get("callbacks") for c in config], + tags=[c.get("tags") for c in config], + metadata=[c.get("metadata") for c in config], + ) + return [g[0].text for g in llm_result.generations] + else: + batches = [ + inputs[i : i + max_concurrency] + for i in range(0, len(inputs), max_concurrency) + ] + return [ + output + for batch in batches + for output in await self.abatch(batch, config=config) + ] + + def stream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + ) -> Iterator[str]: + if type(self)._stream == BaseLLM._stream: + # model doesn't implement streaming, so use default implementation + yield self.invoke(input, config=config, stop=stop) + else: + prompt = self._convert_input(input).to_string() + config = config or {} + params = self.dict() + params["stop"] = stop + options = {"stop": stop} + callback_manager = CallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + config.get("metadata"), + self.metadata, + ) + (run_manager,) = callback_manager.on_llm_start( + dumpd(self), [prompt], invocation_params=params, options=options + ) + try: + generation: Optional[GenerationChunk] = None + for chunk in self._stream(prompt, stop=stop, run_manager=run_manager): + yield chunk.text + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + except (KeyboardInterrupt, Exception) as e: + run_manager.on_llm_error(e) + raise e + else: + run_manager.on_llm_end(LLMResult(generations=[[generation]])) + + async def astream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + ) -> AsyncIterator[str]: + if type(self)._astream == BaseLLM._astream: + # model doesn't implement streaming, so use default implementation + yield await self.ainvoke(input, config=config, stop=stop) + else: + prompt = self._convert_input(input).to_string() + config = config or {} + params = self.dict() + params["stop"] = stop + options = {"stop": stop} + callback_manager = AsyncCallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + config.get("metadata"), + self.metadata, + ) + (run_manager,) = await callback_manager.on_llm_start( + dumpd(self), [prompt], invocation_params=params, options=options + ) + try: + generation: Optional[GenerationChunk] = None + async for chunk in self._astream( + prompt, stop=stop, run_manager=run_manager + ): + yield chunk.text + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + except (KeyboardInterrupt, Exception) as e: + await run_manager.on_llm_error(e) + raise e + else: + await run_manager.on_llm_end(LLMResult(generations=[[generation]])) + + # --- Custom methods --- + @abstractmethod def _generate( self, @@ -167,7 +373,6 @@ class BaseLLM(BaseLanguageModel, ABC): ) -> LLMResult: """Run the LLM on the given prompts.""" - @abstractmethod async def _agenerate( self, prompts: List[str], @@ -176,12 +381,31 @@ class BaseLLM(BaseLanguageModel, ABC): **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompts.""" + raise NotImplementedError() + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + raise NotImplementedError() + + def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + raise NotImplementedError() def generate_prompt( self, prompts: List[PromptValue], stop: Optional[List[str]] = None, - callbacks: Callbacks = None, + callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, **kwargs: Any, ) -> LLMResult: prompt_strings = [p.to_string() for p in prompts] @@ -191,7 +415,7 @@ class BaseLLM(BaseLanguageModel, ABC): self, prompts: List[PromptValue], stop: Optional[List[str]] = None, - callbacks: Callbacks = None, + callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, **kwargs: Any, ) -> LLMResult: prompt_strings = [p.to_string() for p in prompts] @@ -236,10 +460,10 @@ class BaseLLM(BaseLanguageModel, ABC): self, prompts: List[str], stop: Optional[List[str]] = None, - callbacks: Callbacks = None, + callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, *, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[Union[List[str], List[List[str]]]] = None, + metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" @@ -248,6 +472,50 @@ class BaseLLM(BaseLanguageModel, ABC): "Argument 'prompts' is expected to be of type List[str], received" f" argument of type {type(prompts)}." ) + # Create callback managers + if isinstance(callbacks, list) and ( + isinstance(callbacks[0], (list, BaseCallbackManager)) + or callbacks[0] is None + ): + # We've received a list of callbacks args to apply to each input + assert len(callbacks) == len(prompts) + assert tags is None or ( + isinstance(tags, list) and len(tags) == len(prompts) + ) + assert metadata is None or ( + isinstance(metadata, list) and len(metadata) == len(prompts) + ) + callbacks = cast(List[Callbacks], callbacks) + tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts))) + metadata_list = cast( + List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts)) + ) + callback_managers = [ + CallbackManager.configure( + callback, + self.callbacks, + self.verbose, + tag, + self.tags, + meta, + self.metadata, + ) + for callback, tag, meta in zip(callbacks, tags_list, metadata_list) + ] + else: + # We've received a single callbacks arg to apply to all inputs + callback_managers = [ + CallbackManager.configure( + cast(Callbacks, callbacks), + self.callbacks, + self.verbose, + cast(List[str], tags), + self.tags, + cast(Dict[str, Any], metadata), + self.metadata, + ) + ] * len(prompts) + params = self.dict() params["stop"] = stop options = {"stop": stop} @@ -258,15 +526,6 @@ class BaseLLM(BaseLanguageModel, ABC): missing_prompts, ) = get_prompts(params, prompts) disregard_cache = self.cache is not None and not self.cache - callback_manager = CallbackManager.configure( - callbacks, - self.callbacks, - self.verbose, - tags, - self.tags, - metadata, - self.metadata, - ) new_arg_supported = inspect.signature(self._generate).parameters.get( "run_manager" ) @@ -275,17 +534,26 @@ class BaseLLM(BaseLanguageModel, ABC): raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - run_managers = callback_manager.on_llm_start( - dumpd(self), prompts, invocation_params=params, options=options - ) + run_managers = [ + callback_manager.on_llm_start( + dumpd(self), [prompt], invocation_params=params, options=options + )[0] + for callback_manager, prompt in zip(callback_managers, prompts) + ] output = self._generate_helper( prompts, stop, run_managers, bool(new_arg_supported), **kwargs ) return output if len(missing_prompts) > 0: - run_managers = callback_manager.on_llm_start( - dumpd(self), missing_prompts, invocation_params=params, options=options - ) + run_managers = [ + callback_managers[idx].on_llm_start( + dumpd(self), + [prompts[idx]], + invocation_params=params, + options=options, + )[0] + for idx in missing_prompt_idxs + ] new_results = self._generate_helper( missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs ) @@ -346,13 +614,57 @@ class BaseLLM(BaseLanguageModel, ABC): self, prompts: List[str], stop: Optional[List[str]] = None, - callbacks: Callbacks = None, + callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, *, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[Union[List[str], List[List[str]]]] = None, + metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" + # Create callback managers + if isinstance(callbacks, list) and ( + isinstance(callbacks[0], (list, BaseCallbackManager)) + or callbacks[0] is None + ): + # We've received a list of callbacks args to apply to each input + assert len(callbacks) == len(prompts) + assert tags is None or ( + isinstance(tags, list) and len(tags) == len(prompts) + ) + assert metadata is None or ( + isinstance(metadata, list) and len(metadata) == len(prompts) + ) + callbacks = cast(List[Callbacks], callbacks) + tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts))) + metadata_list = cast( + List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts)) + ) + callback_managers = [ + AsyncCallbackManager.configure( + callback, + self.callbacks, + self.verbose, + tag, + self.tags, + meta, + self.metadata, + ) + for callback, tag, meta in zip(callbacks, tags_list, metadata_list) + ] + else: + # We've received a single callbacks arg to apply to all inputs + callback_managers = [ + AsyncCallbackManager.configure( + cast(Callbacks, callbacks), + self.callbacks, + self.verbose, + cast(List[str], tags), + self.tags, + cast(Dict[str, Any], metadata), + self.metadata, + ) + ] * len(prompts) + params = self.dict() params["stop"] = stop options = {"stop": stop} @@ -363,15 +675,6 @@ class BaseLLM(BaseLanguageModel, ABC): missing_prompts, ) = get_prompts(params, prompts) disregard_cache = self.cache is not None and not self.cache - callback_manager = AsyncCallbackManager.configure( - callbacks, - self.callbacks, - self.verbose, - tags, - self.tags, - metadata, - self.metadata, - ) new_arg_supported = inspect.signature(self._agenerate).parameters.get( "run_manager" ) @@ -380,17 +683,32 @@ class BaseLLM(BaseLanguageModel, ABC): raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - run_managers = await callback_manager.on_llm_start( - dumpd(self), prompts, invocation_params=params, options=options + run_managers = await asyncio.gather( + *[ + callback_manager.on_llm_start( + dumpd(self), [prompt], invocation_params=params, options=options + ) + for callback_manager, prompt in zip(callback_managers, prompts) + ] ) + run_managers = [r[0] for r in run_managers] output = await self._agenerate_helper( prompts, stop, run_managers, bool(new_arg_supported), **kwargs ) return output if len(missing_prompts) > 0: - run_managers = await callback_manager.on_llm_start( - dumpd(self), missing_prompts, invocation_params=params, options=options + run_managers = await asyncio.gather( + *[ + callback_managers[idx].on_llm_start( + dumpd(self), + [prompts[idx]], + invocation_params=params, + options=options, + ) + for idx in missing_prompt_idxs + ] ) + run_managers = [r[0] for r in run_managers] new_results = await self._agenerate_helper( missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs ) @@ -586,7 +904,7 @@ class LLM(BaseLLM): **kwargs: Any, ) -> str: """Run the LLM on the given prompt and input.""" - raise NotImplementedError("Async generation not implemented for this LLM.") + raise NotImplementedError() def _generate( self, @@ -615,6 +933,12 @@ class LLM(BaseLLM): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: + if type(self)._acall == LLM._acall: + # model doesn't implement async call, so use default implementation + return await asyncio.get_running_loop().run_in_executor( + None, partial(self._generate, prompts, stop, run_manager, **kwargs) + ) + """Run the LLM on the given prompt and input.""" generations = [] new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") diff --git a/libs/langchain/langchain/llms/fake.py b/libs/langchain/langchain/llms/fake.py index 4b139316a4a..c2cf63036d7 100644 --- a/libs/langchain/langchain/llms/fake.py +++ b/libs/langchain/langchain/llms/fake.py @@ -27,7 +27,10 @@ class FakeListLLM(LLM): ) -> str: """Return next response""" response = self.responses[self.i] - self.i += 1 + if self.i < len(self.responses) - 1: + self.i += 1 + else: + self.i = 0 return response async def _acall( @@ -39,7 +42,10 @@ class FakeListLLM(LLM): ) -> str: """Return next response""" response = self.responses[self.i] - self.i += 1 + if self.i < len(self.responses) - 1: + self.i += 1 + else: + self.i = 0 return response @property diff --git a/libs/langchain/langchain/llms/google_palm.py b/libs/langchain/langchain/llms/google_palm.py index 4b95c06369d..8f45950a29e 100644 --- a/libs/langchain/langchain/llms/google_palm.py +++ b/libs/langchain/langchain/llms/google_palm.py @@ -12,10 +12,7 @@ from tenacity import ( wait_exponential, ) -from langchain.callbacks.manager import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms import BaseLLM from langchain.schema import Generation, LLMResult from langchain.utils import get_from_dict_or_env @@ -161,15 +158,6 @@ class GooglePalm(BaseLLM, BaseModel): return LLMResult(generations=generations) - async def _agenerate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> LLMResult: - raise NotImplementedError() - @property def _llm_type(self) -> str: """Return type of llm.""" diff --git a/libs/langchain/langchain/llms/huggingface_text_gen_inference.py b/libs/langchain/langchain/llms/huggingface_text_gen_inference.py index 293d463add3..befe8134808 100644 --- a/libs/langchain/langchain/llms/huggingface_text_gen_inference.py +++ b/libs/langchain/langchain/llms/huggingface_text_gen_inference.py @@ -1,5 +1,4 @@ -from functools import partial -from typing import Any, Dict, List, Optional +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional from pydantic import Extra, Field, root_validator @@ -8,6 +7,7 @@ from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) from langchain.llms.base import LLM +from langchain.schema.output import GenerationChunk class HuggingFaceTextGenInference(LLM): @@ -69,7 +69,7 @@ class HuggingFaceTextGenInference(LLM): temperature = 0.01, repetition_penalty = 1.03, callbacks = callbacks, - stream = True + streaming = True ) print(llm("What is Deep Learning?")) @@ -87,7 +87,7 @@ class HuggingFaceTextGenInference(LLM): inference_server_url: str = "" timeout: int = 120 server_kwargs: Dict[str, Any] = Field(default_factory=dict) - stream: bool = False + streaming: bool = False client: Any async_client: Any @@ -154,37 +154,21 @@ class HuggingFaceTextGenInference(LLM): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: + if self.streaming: + completion = "" + for chunk in self._stream(prompt, stop, run_manager, **kwargs): + completion += chunk.text + return completion + invocation_params = self._invocation_params(stop, **kwargs) - if not self.stream: - res = self.client.generate(prompt, **invocation_params) - # remove stop sequences from the end of the generated text - for stop_seq in invocation_params["stop_sequences"]: - if stop_seq in res.generated_text: - res.generated_text = res.generated_text[ - : res.generated_text.index(stop_seq) - ] - text = res.generated_text - else: - text_callback = None - if run_manager: - text_callback = partial( - run_manager.on_llm_new_token, verbose=self.verbose - ) - text = "" - for res in self.client.generate_stream(prompt, **invocation_params): - token = res.token - is_stop = False - for stop_seq in invocation_params["stop_sequences"]: - if stop_seq in token.text: - is_stop = True - break - if is_stop: - break - if not token.special: - if text_callback: - text_callback(token.text) - text += token.text - return text + res = self.client.generate(prompt, **invocation_params) + # remove stop sequences from the end of the generated text + for stop_seq in invocation_params["stop_sequences"]: + if stop_seq in res.generated_text: + res.generated_text = res.generated_text[ + : res.generated_text.index(stop_seq) + ] + return res.generated_text async def _acall( self, @@ -193,39 +177,90 @@ class HuggingFaceTextGenInference(LLM): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: + if self.streaming: + completion = "" + async for chunk in self._astream(prompt, stop, run_manager, **kwargs): + completion += chunk.text + return completion + invocation_params = self._invocation_params(stop, **kwargs) - if not self.stream: - res = await self.async_client.generate( - prompt, - **invocation_params, - ) - # remove stop sequences from the end of the generated text + res = await self.async_client.generate(prompt, **invocation_params) + # remove stop sequences from the end of the generated text + for stop_seq in invocation_params["stop_sequences"]: + if stop_seq in res.generated_text: + res.generated_text = res.generated_text[ + : res.generated_text.index(stop_seq) + ] + return res.generated_text + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + invocation_params = self._invocation_params(stop, **kwargs) + + for res in self.client.generate_stream(prompt, **invocation_params): + # identify stop sequence in generated text, if any + stop_seq_found: Optional[str] = None for stop_seq in invocation_params["stop_sequences"]: - if stop_seq in res.generated_text: - res.generated_text = res.generated_text[ - : res.generated_text.index(stop_seq) - ] - text: str = res.generated_text - else: - text_callback = None - if run_manager: - text_callback = partial( - run_manager.on_llm_new_token, verbose=self.verbose - ) - text = "" - async for res in self.async_client.generate_stream( - prompt, **invocation_params - ): - token = res.token - is_stop = False - for stop_seq in invocation_params["stop_sequences"]: - if stop_seq in token.text: - is_stop = True - break - if is_stop: - break - if not token.special: - if text_callback: - await text_callback(token.text) - text += token.text - return text + if stop_seq in res.token.text: + stop_seq_found = stop_seq + + # identify text to yield + text: Optional[str] = None + if res.token.special: + text = None + elif stop_seq_found: + text = res.token.text[: res.token.text.index(stop_seq_found)] + else: + text = res.token.text + + # yield text, if any + if text: + chunk = GenerationChunk(text=text) + yield chunk + if run_manager: + run_manager.on_llm_new_token(chunk.text) + + # break if stop sequence found + if stop_seq_found: + break + + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + invocation_params = self._invocation_params(stop, **kwargs) + + async for res in self.async_client.generate_stream(prompt, **invocation_params): + # identify stop sequence in generated text, if any + stop_seq_found: Optional[str] = None + for stop_seq in invocation_params["stop_sequences"]: + if stop_seq in res.token.text: + stop_seq_found = stop_seq + + # identify text to yield + text: Optional[str] = None + if res.token.special: + text = None + elif stop_seq_found: + text = res.token.text[: res.token.text.index(stop_seq_found)] + else: + text = res.token.text + + # yield text, if any + if text: + chunk = GenerationChunk(text=text) + yield chunk + if run_manager: + await run_manager.on_llm_new_token(chunk.text) + + # break if stop sequence found + if stop_seq_found: + break diff --git a/libs/langchain/langchain/llms/llamacpp.py b/libs/langchain/langchain/llms/llamacpp.py index 9f849e98c57..be79076f162 100644 --- a/libs/langchain/langchain/llms/llamacpp.py +++ b/libs/langchain/langchain/llms/llamacpp.py @@ -1,10 +1,11 @@ import logging -from typing import Any, Dict, Generator, List, Optional +from typing import Any, Dict, Iterator, List, Optional from pydantic import Field, root_validator from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM +from langchain.schema.output import GenerationChunk logger = logging.getLogger(__name__) @@ -226,8 +227,10 @@ class LlamaCpp(LLM): # method that yields as they are generated # and return the combined strings from the first choices's text: combined_text_output = "" - for token in self.stream(prompt=prompt, stop=stop, run_manager=run_manager): - combined_text_output += token["choices"][0]["text"] + for chunk in self._stream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + combined_text_output += chunk.text return combined_text_output else: params = self._get_parameters(stop) @@ -235,17 +238,15 @@ class LlamaCpp(LLM): result = self.client(prompt=prompt, **params) return result["choices"][0]["text"] - def stream( + def _stream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, - ) -> Generator[Dict, None, None]: + **kwargs: Any, + ) -> Iterator[GenerationChunk]: """Yields results objects as they are generated in real time. - BETA: this is a beta feature while we figure out the right abstraction. - Once that happens, this interface could change. - It also calls the callback manager's on_llm_new_token event with similar parameters to the OpenAI LLM class method of the same name. @@ -274,16 +275,19 @@ class LlamaCpp(LLM): print(result["text"], end='', flush=True) """ - params = self._get_parameters(stop) + params = {**self._get_parameters(stop), **kwargs} result = self.client(prompt=prompt, stream=True, **params) - for chunk in result: - token = chunk["choices"][0]["text"] - log_probs = chunk["choices"][0].get("logprobs", None) + for part in result: + logprobs = part["choices"][0].get("logprobs", None) + chunk = GenerationChunk( + text=part["choices"][0]["text"], + generation_info={"logprobs": logprobs}, + ) + yield chunk if run_manager: run_manager.on_llm_new_token( - token=token, verbose=self.verbose, log_probs=log_probs + token=chunk.text, verbose=self.verbose, log_probs=logprobs ) - yield chunk def get_num_tokens(self, text: str) -> int: tokenized_text = self.client.tokenize(text.encode("utf-8")) diff --git a/libs/langchain/langchain/llms/openai.py b/libs/langchain/langchain/llms/openai.py index d86bab21411..2c664a1c874 100644 --- a/libs/langchain/langchain/llms/openai.py +++ b/libs/langchain/langchain/llms/openai.py @@ -6,10 +6,11 @@ import warnings from typing import ( AbstractSet, Any, + AsyncIterator, Callable, Collection, Dict, - Generator, + Iterator, List, Literal, Mapping, @@ -27,6 +28,7 @@ from langchain.callbacks.manager import ( ) from langchain.llms.base import BaseLLM, create_base_retry_decorator from langchain.schema import Generation, LLMResult +from langchain.schema.output import GenerationChunk from langchain.utils import get_from_dict_or_env, get_pydantic_field_names logger = logging.getLogger(__name__) @@ -44,6 +46,19 @@ def update_token_usage( token_usage[_key] += response["usage"][_key] +def _stream_response_to_generation_chunk( + stream_response: Dict[str, Any], +) -> GenerationChunk: + """Convert a stream response to a generation chunk.""" + return GenerationChunk( + text=stream_response["choices"][0]["text"], + generation_info=dict( + finish_reason=stream_response["choices"][0].get("finish_reason", None), + logprobs=stream_response["choices"][0].get("logprobs", None), + ), + ) + + def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None: """Update response from the stream response.""" response["choices"][0]["text"] += stream_response["choices"][0]["text"] @@ -268,6 +283,50 @@ class BaseOpenAI(BaseLLM): return {**normal_params, **self.model_kwargs} + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + params = {**self._invocation_params, **kwargs, "stream": True} + self.get_sub_prompts(params, [prompt], stop) # this mutate params + for stream_resp in completion_with_retry(self, prompt=prompt, **params): + chunk = _stream_response_to_generation_chunk(stream_resp) + yield chunk + if run_manager: + run_manager.on_llm_new_token( + chunk.text, + verbose=self.verbose, + logprobs=chunk.generation_info["logprobs"] + if chunk.generation_info + else None, + ) + + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + params = {**self._invocation_params, **kwargs, "stream": True} + self.get_sub_prompts(params, [prompt], stop) # this mutate params + async for stream_resp in await acompletion_with_retry( + self, prompt=prompt, **params + ): + chunk = _stream_response_to_generation_chunk(stream_resp) + yield chunk + if run_manager: + await run_manager.on_llm_new_token( + chunk.text, + verbose=self.verbose, + logprobs=chunk.generation_info["logprobs"] + if chunk.generation_info + else None, + ) + def _generate( self, prompts: List[str], @@ -302,24 +361,28 @@ class BaseOpenAI(BaseLLM): if self.streaming: if len(_prompts) > 1: raise ValueError("Cannot stream results with multiple prompts.") - params["stream"] = True - response = _streaming_response_template() - for stream_resp in completion_with_retry( - self, prompt=_prompts, **params - ): - if run_manager: - run_manager.on_llm_new_token( - stream_resp["choices"][0]["text"], - verbose=self.verbose, - logprobs=stream_resp["choices"][0]["logprobs"], - ) - _update_response(response, stream_resp) - choices.extend(response["choices"]) + + generation: Optional[GenerationChunk] = None + for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs): + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + choices.append( + { + "text": generation.text, + "finish_reason": generation.generation_info.get("finish_reason") + if generation.generation_info + else None, + "logprobs": generation.generation_info.get("logprobs") + if generation.generation_info + else None, + } + ) else: response = completion_with_retry(self, prompt=_prompts, **params) choices.extend(response["choices"]) - if not self.streaming: - # Can't update token usage if streaming update_token_usage(_keys, response, token_usage) return self.create_llm_result(choices, prompts, token_usage) @@ -343,24 +406,30 @@ class BaseOpenAI(BaseLLM): if self.streaming: if len(_prompts) > 1: raise ValueError("Cannot stream results with multiple prompts.") - params["stream"] = True - response = _streaming_response_template() - async for stream_resp in await acompletion_with_retry( - self, prompt=_prompts, **params + + generation: Optional[GenerationChunk] = None + async for chunk in self._astream( + _prompts[0], stop, run_manager, **kwargs ): - if run_manager: - await run_manager.on_llm_new_token( - stream_resp["choices"][0]["text"], - verbose=self.verbose, - logprobs=stream_resp["choices"][0]["logprobs"], - ) - _update_response(response, stream_resp) - choices.extend(response["choices"]) + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + choices.append( + { + "text": generation.text, + "finish_reason": generation.generation_info.get("finish_reason") + if generation.generation_info + else None, + "logprobs": generation.generation_info.get("logprobs") + if generation.generation_info + else None, + } + ) else: response = await acompletion_with_retry(self, prompt=_prompts, **params) choices.extend(response["choices"]) - if not self.streaming: - # Can't update token usage if streaming update_token_usage(_keys, response, token_usage) return self.create_llm_result(choices, prompts, token_usage) @@ -409,43 +478,6 @@ class BaseOpenAI(BaseLLM): llm_output = {"token_usage": token_usage, "model_name": self.model_name} return LLMResult(generations=generations, llm_output=llm_output) - def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator: - """Call OpenAI with streaming flag and return the resulting generator. - - BETA: this is a beta feature while we figure out the right abstraction. - Once that happens, this interface could change. - - Args: - prompt: The prompts to pass into the model. - stop: Optional list of stop words to use when generating. - - Returns: - A generator representing the stream of tokens from OpenAI. - - Example: - .. code-block:: python - - generator = openai.stream("Tell me a joke.") - for token in generator: - yield token - """ - params = self.prep_streaming_params(stop) - generator = self.client.create(prompt=prompt, **params) - - return generator - - def prep_streaming_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]: - """Prepare the params for streaming.""" - params = self._invocation_params - if "best_of" in params and params["best_of"] != 1: - raise ValueError("OpenAI only supports best_of == 1 for streaming") - if stop is not None: - if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") - params["stop"] = stop - params["stream"] = True - return params - @property def _invocation_params(self) -> Dict[str, Any]: """Get the parameters used to invoke the model.""" @@ -777,6 +809,38 @@ class OpenAIChat(BaseLLM): del params["max_tokens"] return messages, params + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + messages, params = self._get_chat_params([prompt], stop) + params = {**params, **kwargs, "stream": True} + for stream_resp in completion_with_retry(self, messages=messages, **params): + token = stream_resp["choices"][0]["delta"].get("content", "") + yield GenerationChunk(text=token) + if run_manager: + run_manager.on_llm_new_token(token) + + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + messages, params = self._get_chat_params([prompt], stop) + params = {**params, **kwargs, "stream": True} + async for stream_resp in await acompletion_with_retry( + self, messages=messages, **params + ): + token = stream_resp["choices"][0]["delta"].get("content", "") + yield GenerationChunk(text=token) + if run_manager: + await run_manager.on_llm_new_token(token) + def _generate( self, prompts: List[str], @@ -784,33 +848,29 @@ class OpenAIChat(BaseLLM): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: + if self.streaming: + generation: Optional[GenerationChunk] = None + for chunk in self._stream(prompts[0], stop, run_manager, **kwargs): + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return LLMResult(generations=[[generation]]) + messages, params = self._get_chat_params(prompts, stop) params = {**params, **kwargs} - if self.streaming: - response = "" - params["stream"] = True - for stream_resp in completion_with_retry(self, messages=messages, **params): - token = stream_resp["choices"][0]["delta"].get("content", "") - response += token - if run_manager: - run_manager.on_llm_new_token( - token, - ) - return LLMResult( - generations=[[Generation(text=response)]], - ) - else: - full_response = completion_with_retry(self, messages=messages, **params) - llm_output = { - "token_usage": full_response["usage"], - "model_name": self.model_name, - } - return LLMResult( - generations=[ - [Generation(text=full_response["choices"][0]["message"]["content"])] - ], - llm_output=llm_output, - ) + full_response = completion_with_retry(self, messages=messages, **params) + llm_output = { + "token_usage": full_response["usage"], + "model_name": self.model_name, + } + return LLMResult( + generations=[ + [Generation(text=full_response["choices"][0]["message"]["content"])] + ], + llm_output=llm_output, + ) async def _agenerate( self, @@ -819,37 +879,29 @@ class OpenAIChat(BaseLLM): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: + if self.streaming: + generation: Optional[GenerationChunk] = None + async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs): + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return LLMResult(generations=[[generation]]) + messages, params = self._get_chat_params(prompts, stop) params = {**params, **kwargs} - if self.streaming: - response = "" - params["stream"] = True - async for stream_resp in await acompletion_with_retry( - self, messages=messages, **params - ): - token = stream_resp["choices"][0]["delta"].get("content", "") - response += token - if run_manager: - await run_manager.on_llm_new_token( - token, - ) - return LLMResult( - generations=[[Generation(text=response)]], - ) - else: - full_response = await acompletion_with_retry( - self, messages=messages, **params - ) - llm_output = { - "token_usage": full_response["usage"], - "model_name": self.model_name, - } - return LLMResult( - generations=[ - [Generation(text=full_response["choices"][0]["message"]["content"])] - ], - llm_output=llm_output, - ) + full_response = await acompletion_with_retry(self, messages=messages, **params) + llm_output = { + "token_usage": full_response["usage"], + "model_name": self.model_name, + } + return LLMResult( + generations=[ + [Generation(text=full_response["choices"][0]["message"]["content"])] + ], + llm_output=llm_output, + ) @property def _identifying_params(self) -> Mapping[str, Any]: diff --git a/libs/langchain/langchain/llms/tongyi.py b/libs/langchain/langchain/llms/tongyi.py index 7753213aea1..57973861b98 100644 --- a/libs/langchain/langchain/llms/tongyi.py +++ b/libs/langchain/langchain/llms/tongyi.py @@ -13,10 +13,7 @@ from tenacity import ( wait_exponential, ) -from langchain.callbacks.manager import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.schema import Generation, LLMResult from langchain.utils import get_from_dict_or_env @@ -250,12 +247,3 @@ class Tongyi(LLM): ] ) return LLMResult(generations=generations) - - async def _agenerate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> LLMResult: - raise NotImplementedError() diff --git a/libs/langchain/langchain/output_parsers/list.py b/libs/langchain/langchain/output_parsers/list.py index c7f94648bf2..92850fc31c2 100644 --- a/libs/langchain/langchain/output_parsers/list.py +++ b/libs/langchain/langchain/output_parsers/list.py @@ -6,7 +6,7 @@ from typing import List from langchain.schema import BaseOutputParser -class ListOutputParser(BaseOutputParser): +class ListOutputParser(BaseOutputParser[List[str]]): """Parse the output of an LLM call to a list.""" @property diff --git a/libs/langchain/langchain/output_parsers/openai_functions.py b/libs/langchain/langchain/output_parsers/openai_functions.py index a6d455825be..c55801c9bda 100644 --- a/libs/langchain/langchain/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/output_parsers/openai_functions.py @@ -4,14 +4,14 @@ from typing import Any, Dict, List, Type, Union from pydantic import BaseModel, root_validator from langchain.schema import ( - BaseLLMOutputParser, ChatGeneration, Generation, OutputParserException, ) +from langchain.schema.output_parser import BaseGenerationOutputParser -class OutputFunctionsParser(BaseLLMOutputParser[Any]): +class OutputFunctionsParser(BaseGenerationOutputParser[Any]): """Parse an output that is one of sets of values.""" args_only: bool = True diff --git a/libs/langchain/langchain/prompts/base.py b/libs/langchain/langchain/prompts/base.py index d5426cca6cd..698c552f5c8 100644 --- a/libs/langchain/langchain/prompts/base.py +++ b/libs/langchain/langchain/prompts/base.py @@ -5,10 +5,10 @@ import warnings from abc import ABC from typing import Any, Callable, Dict, List, Set -from langchain.schema import BasePromptTemplate +from langchain.formatting import formatter from langchain.schema.messages import BaseMessage, HumanMessage from langchain.schema.prompt import PromptValue -from langchain.utils import formatter +from langchain.schema.prompt_template import BasePromptTemplate def jinja2_formatter(template: str, **kwargs: Any) -> str: diff --git a/libs/langchain/langchain/prompts/chat.py b/libs/langchain/langchain/prompts/chat.py index fe2c768c479..45e7e50a230 100644 --- a/libs/langchain/langchain/prompts/chat.py +++ b/libs/langchain/langchain/prompts/chat.py @@ -446,7 +446,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): for message in messages: if isinstance(message, BaseMessagePromptTemplate): input_vars.update(message.input_variables) - return cls(input_variables=list(input_vars), messages=messages) + return cls(input_variables=sorted(input_vars), messages=messages) def format(self, **kwargs: Any) -> str: """Format the chat template into a string. diff --git a/libs/langchain/langchain/retrievers/arxiv.py b/libs/langchain/langchain/retrievers/arxiv.py index 2510aadd79a..56019273b9d 100644 --- a/libs/langchain/langchain/retrievers/arxiv.py +++ b/libs/langchain/langchain/retrievers/arxiv.py @@ -1,9 +1,6 @@ from typing import List -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.schema import BaseRetriever, Document from langchain.utilities.arxiv import ArxivAPIWrapper @@ -20,8 +17,3 @@ class ArxivRetriever(BaseRetriever, ArxivAPIWrapper): self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: return self.load(query=query) - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError diff --git a/libs/langchain/langchain/retrievers/bm25.py b/libs/langchain/langchain/retrievers/bm25.py index 735ca9913a4..a5ef4f28496 100644 --- a/libs/langchain/langchain/retrievers/bm25.py +++ b/libs/langchain/langchain/retrievers/bm25.py @@ -7,10 +7,7 @@ from __future__ import annotations from typing import Any, Callable, Dict, Iterable, List, Optional -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.schema import BaseRetriever, Document @@ -108,8 +105,3 @@ class BM25Retriever(BaseRetriever): processed_query = self.preprocess_func(query) return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k) return return_docs - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError diff --git a/libs/langchain/langchain/retrievers/docarray.py b/libs/langchain/langchain/retrievers/docarray.py index 0ba247bf0c5..edd4e81f088 100644 --- a/libs/langchain/langchain/retrievers/docarray.py +++ b/libs/langchain/langchain/retrievers/docarray.py @@ -3,10 +3,7 @@ from typing import Any, Dict, List, Optional, Union import numpy as np -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.embeddings.base import Embeddings from langchain.schema import BaseRetriever, Document from langchain.vectorstores.utils import maximal_marginal_relevance @@ -208,11 +205,3 @@ class DocArrayRetriever(BaseRetriever): lc_doc.metadata[name] = value return lc_doc - - async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - ) -> List[Document]: - raise NotImplementedError diff --git a/libs/langchain/langchain/retrievers/elastic_search_bm25.py b/libs/langchain/langchain/retrievers/elastic_search_bm25.py index 3a76b36ef0f..52c4c97bc13 100644 --- a/libs/langchain/langchain/retrievers/elastic_search_bm25.py +++ b/libs/langchain/langchain/retrievers/elastic_search_bm25.py @@ -5,10 +5,7 @@ from __future__ import annotations import uuid from typing import Any, Iterable, List -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.docstore.document import Document from langchain.schema import BaseRetriever @@ -138,8 +135,3 @@ class ElasticSearchBM25Retriever(BaseRetriever): for r in res["hits"]["hits"]: docs.append(Document(page_content=r["_source"]["content"])) return docs - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError diff --git a/libs/langchain/langchain/retrievers/google_cloud_enterprise_search.py b/libs/langchain/langchain/retrievers/google_cloud_enterprise_search.py index 47adae5c621..c170d42086b 100644 --- a/libs/langchain/langchain/retrievers/google_cloud_enterprise_search.py +++ b/libs/langchain/langchain/retrievers/google_cloud_enterprise_search.py @@ -5,10 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence from pydantic import Extra, Field, root_validator -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.schema import BaseRetriever, Document from langchain.utils import get_from_dict_or_env @@ -184,8 +181,3 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever): documents = self._convert_search_response(response.results) return documents - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError diff --git a/libs/langchain/langchain/retrievers/kendra.py b/libs/langchain/langchain/retrievers/kendra.py index 8750d26f205..71e0ea1cf3b 100644 --- a/libs/langchain/langchain/retrievers/kendra.py +++ b/libs/langchain/langchain/retrievers/kendra.py @@ -4,10 +4,7 @@ from typing import Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel, Extra, root_validator -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.docstore.document import Document from langchain.schema import BaseRetriever @@ -411,11 +408,3 @@ class AmazonKendraRetriever(BaseRetriever): """ docs = self._kendra_query(query, self.top_k, self.attribute_filter) return docs - - async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - ) -> List[Document]: - raise NotImplementedError("Async version is not implemented for Kendra yet.") diff --git a/libs/langchain/langchain/retrievers/knn.py b/libs/langchain/langchain/retrievers/knn.py index 51a3effb221..d28408347f6 100644 --- a/libs/langchain/langchain/retrievers/knn.py +++ b/libs/langchain/langchain/retrievers/knn.py @@ -9,10 +9,7 @@ from typing import Any, List, Optional import numpy as np -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.embeddings.base import Embeddings from langchain.schema import BaseRetriever, Document @@ -82,8 +79,3 @@ class KNNRetriever(BaseRetriever): ) ] return top_k_results - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError diff --git a/libs/langchain/langchain/retrievers/llama_index.py b/libs/langchain/langchain/retrievers/llama_index.py index e393d121aa7..fe8e4f8b0fc 100644 --- a/libs/langchain/langchain/retrievers/llama_index.py +++ b/libs/langchain/langchain/retrievers/llama_index.py @@ -2,10 +2,7 @@ from typing import Any, Dict, List, cast from pydantic import Field -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.schema import BaseRetriever, Document @@ -42,11 +39,6 @@ class LlamaIndexRetriever(BaseRetriever): ) return docs - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError("LlamaIndexRetriever does not support async") - class LlamaIndexGraphRetriever(BaseRetriever): """Retriever for question-answering with sources over an LlamaIndex @@ -88,8 +80,3 @@ class LlamaIndexGraphRetriever(BaseRetriever): Document(page_content=source_node.source_text, metadata=metadata) ) return docs - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError("LlamaIndexGraphRetriever does not support async") diff --git a/libs/langchain/langchain/retrievers/metal.py b/libs/langchain/langchain/retrievers/metal.py index 5271a897b54..b4faaeab3df 100644 --- a/libs/langchain/langchain/retrievers/metal.py +++ b/libs/langchain/langchain/retrievers/metal.py @@ -2,10 +2,7 @@ from typing import Any, List, Optional from pydantic import root_validator -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.schema import BaseRetriever, Document @@ -43,8 +40,3 @@ class MetalRetriever(BaseRetriever): metadata = {k: v for k, v in r.items() if k != "text"} final_results.append(Document(page_content=r["text"], metadata=metadata)) return final_results - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError diff --git a/libs/langchain/langchain/retrievers/milvus.py b/libs/langchain/langchain/retrievers/milvus.py index 6541ce441f7..bc35e731842 100644 --- a/libs/langchain/langchain/retrievers/milvus.py +++ b/libs/langchain/langchain/retrievers/milvus.py @@ -4,10 +4,7 @@ from typing import Any, Dict, List, Optional from pydantic import root_validator -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.embeddings.base import Embeddings from langchain.schema import BaseRetriever, Document from langchain.vectorstores.milvus import Milvus @@ -63,15 +60,6 @@ class MilvusRetriever(BaseRetriever): query, run_manager=run_manager.get_child(), **kwargs ) - async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, - ) -> List[Document]: - raise NotImplementedError - def MilvusRetreiver(*args: Any, **kwargs: Any) -> MilvusRetriever: """Deprecated MilvusRetreiver. Please use MilvusRetriever ('i' before 'e') instead. diff --git a/libs/langchain/langchain/retrievers/multi_query.py b/libs/langchain/langchain/retrievers/multi_query.py index ca7d6eb3eb5..4d9520c0262 100644 --- a/libs/langchain/langchain/retrievers/multi_query.py +++ b/libs/langchain/langchain/retrievers/multi_query.py @@ -3,10 +3,7 @@ from typing import List from pydantic import BaseModel, Field -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.chains.llm import LLMChain from langchain.llms.base import BaseLLM from langchain.output_parsers.pydantic import PydanticOutputParser @@ -101,14 +98,6 @@ class MultiQueryRetriever(BaseRetriever): unique_documents = self.unique_union(documents) return unique_documents - async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - ) -> List[Document]: - raise NotImplementedError - def generate_queries( self, question: str, run_manager: CallbackManagerForRetrieverRun ) -> List[str]: diff --git a/libs/langchain/langchain/retrievers/pinecone_hybrid_search.py b/libs/langchain/langchain/retrievers/pinecone_hybrid_search.py index a6c998719b9..97d562421d3 100644 --- a/libs/langchain/langchain/retrievers/pinecone_hybrid_search.py +++ b/libs/langchain/langchain/retrievers/pinecone_hybrid_search.py @@ -5,10 +5,7 @@ from typing import Any, Dict, List, Optional from pydantic import Extra, root_validator -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.embeddings.base import Embeddings from langchain.schema import BaseRetriever, Document @@ -175,8 +172,3 @@ class PineconeHybridSearchRetriever(BaseRetriever): ) # return search results as json return final_result - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError diff --git a/libs/langchain/langchain/retrievers/pubmed.py b/libs/langchain/langchain/retrievers/pubmed.py index d49d581800a..d093bf23cb8 100644 --- a/libs/langchain/langchain/retrievers/pubmed.py +++ b/libs/langchain/langchain/retrievers/pubmed.py @@ -1,9 +1,6 @@ from typing import List -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.schema import BaseRetriever, Document from langchain.utilities.pupmed import PubMedAPIWrapper @@ -19,8 +16,3 @@ class PubMedRetriever(BaseRetriever, PubMedAPIWrapper): self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: return self.load_docs(query=query) - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError diff --git a/libs/langchain/langchain/retrievers/self_query/base.py b/libs/langchain/langchain/retrievers/self_query/base.py index 14f9d6f4e0b..15463751b4f 100644 --- a/libs/langchain/langchain/retrievers/self_query/base.py +++ b/libs/langchain/langchain/retrievers/self_query/base.py @@ -5,10 +5,7 @@ from typing import Any, Dict, List, Optional, Type, cast from pydantic import BaseModel, Field, root_validator from langchain import LLMChain -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.chains.query_constructor.base import load_query_constructor_chain from langchain.chains.query_constructor.ir import StructuredQuery, Visitor from langchain.chains.query_constructor.schema import AttributeInfo @@ -119,11 +116,6 @@ class SelfQueryRetriever(BaseRetriever, BaseModel): docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs) return docs - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError - @classmethod def from_llm( cls, diff --git a/libs/langchain/langchain/retrievers/svm.py b/libs/langchain/langchain/retrievers/svm.py index 96e34f160f5..3c65e974ebf 100644 --- a/libs/langchain/langchain/retrievers/svm.py +++ b/libs/langchain/langchain/retrievers/svm.py @@ -5,10 +5,7 @@ from typing import Any, Iterable, List, Optional import numpy as np -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.embeddings.base import Embeddings from langchain.schema import BaseRetriever, Document @@ -113,8 +110,3 @@ class SVMRetriever(BaseRetriever): ): top_k_results.append(Document(page_content=self.texts[row - 1])) return top_k_results - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError diff --git a/libs/langchain/langchain/retrievers/tfidf.py b/libs/langchain/langchain/retrievers/tfidf.py index 6be34df2606..1d910f18ecf 100644 --- a/libs/langchain/langchain/retrievers/tfidf.py +++ b/libs/langchain/langchain/retrievers/tfidf.py @@ -2,10 +2,7 @@ from __future__ import annotations from typing import Any, Dict, Iterable, List, Optional -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.schema import BaseRetriever, Document @@ -79,8 +76,3 @@ class TFIDFRetriever(BaseRetriever): ) # Op -- (n_docs,1) -- Cosine Sim with each doc return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]] return return_docs - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError diff --git a/libs/langchain/langchain/retrievers/time_weighted_retriever.py b/libs/langchain/langchain/retrievers/time_weighted_retriever.py index dd767e53b7e..b2aebfa913b 100644 --- a/libs/langchain/langchain/retrievers/time_weighted_retriever.py +++ b/libs/langchain/langchain/retrievers/time_weighted_retriever.py @@ -4,10 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple from pydantic import Field -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.schema import BaseRetriever, Document from langchain.vectorstores.base import VectorStore @@ -109,12 +106,6 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever): result.append(buffered_doc) return result - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - """Return documents that are relevant to the query.""" - raise NotImplementedError - def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: """Add documents to vectorstore.""" current_time = kwargs.get("current_time") diff --git a/libs/langchain/langchain/retrievers/vespa_retriever.py b/libs/langchain/langchain/retrievers/vespa_retriever.py index 50e0396a5eb..5580172efbc 100644 --- a/libs/langchain/langchain/retrievers/vespa_retriever.py +++ b/libs/langchain/langchain/retrievers/vespa_retriever.py @@ -3,10 +3,7 @@ from __future__ import annotations import json from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.schema import BaseRetriever, Document if TYPE_CHECKING: @@ -57,11 +54,6 @@ class VespaRetriever(BaseRetriever): body["query"] = query return self._query(body) - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError - def get_relevant_documents_with_filter( self, query: str, *, _filter: Optional[str] = None ) -> List[Document]: diff --git a/libs/langchain/langchain/retrievers/weaviate_hybrid_search.py b/libs/langchain/langchain/retrievers/weaviate_hybrid_search.py index 1df9e153473..2bd64ed3f17 100644 --- a/libs/langchain/langchain/retrievers/weaviate_hybrid_search.py +++ b/libs/langchain/langchain/retrievers/weaviate_hybrid_search.py @@ -5,10 +5,7 @@ from uuid import uuid4 from pydantic import root_validator -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.docstore.document import Document from langchain.schema import BaseRetriever @@ -118,8 +115,3 @@ class WeaviateHybridSearchRetriever(BaseRetriever): text = res.pop(self.text_key) docs.append(Document(page_content=text, metadata=res)) return docs - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError diff --git a/libs/langchain/langchain/retrievers/wikipedia.py b/libs/langchain/langchain/retrievers/wikipedia.py index d47775878ac..6b13f0d2ddc 100644 --- a/libs/langchain/langchain/retrievers/wikipedia.py +++ b/libs/langchain/langchain/retrievers/wikipedia.py @@ -1,9 +1,6 @@ from typing import List -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.schema import BaseRetriever, Document from langchain.utilities.wikipedia import WikipediaAPIWrapper @@ -19,8 +16,3 @@ class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper): self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: return self.load(query=query) - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - raise NotImplementedError diff --git a/libs/langchain/langchain/retrievers/zilliz.py b/libs/langchain/langchain/retrievers/zilliz.py index e40366d161e..e023bac7714 100644 --- a/libs/langchain/langchain/retrievers/zilliz.py +++ b/libs/langchain/langchain/retrievers/zilliz.py @@ -3,10 +3,7 @@ from typing import Any, Dict, List, Optional from pydantic import root_validator -from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.embeddings.base import Embeddings from langchain.schema import BaseRetriever, Document from langchain.vectorstores.zilliz import Zilliz @@ -67,15 +64,6 @@ class ZillizRetriever(BaseRetriever): query, run_manager=run_manager.get_child(), **kwargs ) - async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, - ) -> List[Document]: - raise NotImplementedError - def ZillizRetreiver(*args: Any, **kwargs: Any) -> ZillizRetriever: """Deprecated ZillizRetreiver. diff --git a/libs/langchain/langchain/schema/__init__.py b/libs/langchain/langchain/schema/__init__.py index 1e660248211..818ff4113d7 100644 --- a/libs/langchain/langchain/schema/__init__.py +++ b/libs/langchain/langchain/schema/__init__.py @@ -1,6 +1,5 @@ from langchain.schema.agent import AgentAction, AgentFinish from langchain.schema.document import BaseDocumentTransformer, Document -from langchain.schema.language_model import BaseLanguageModel from langchain.schema.memory import BaseChatMessageHistory, BaseMemory from langchain.schema.messages import ( AIMessage, @@ -67,6 +66,5 @@ __all__ = [ "BaseOutputParser", "BaseLLMOutputParser", "BasePromptTemplate", - "BaseLanguageModel", "format_document", ] diff --git a/libs/langchain/langchain/schema/language_model.py b/libs/langchain/langchain/schema/language_model.py index 19b4de1ef73..6a46165e43f 100644 --- a/libs/langchain/langchain/schema/language_model.py +++ b/libs/langchain/langchain/schema/language_model.py @@ -1,12 +1,22 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Set +from typing import ( + TYPE_CHECKING, + Any, + List, + Optional, + Sequence, + Set, + TypeVar, + Union, +) from langchain.load.serializable import Serializable from langchain.schema.messages import BaseMessage, get_buffer_string from langchain.schema.output import LLMResult from langchain.schema.prompt import PromptValue +from langchain.schema.runnable import Runnable from langchain.utils import get_pydantic_field_names if TYPE_CHECKING: @@ -32,7 +42,13 @@ def _get_token_ids_default_method(text: str) -> List[int]: return tokenizer.encode(text) -class BaseLanguageModel(Serializable, ABC): +LanguageModelInput = Union[PromptValue, str, List[BaseMessage]] +LanguageModelOutput = TypeVar("LanguageModelOutput") + + +class BaseLanguageModel( + Serializable, Runnable[LanguageModelInput, LanguageModelOutput], ABC +): """Abstract base class for interfacing with language models. All language model wrappers inherit from BaseLanguageModel. diff --git a/libs/langchain/langchain/schema/messages.py b/libs/langchain/langchain/schema/messages.py index a0cbf978d24..8895ed7e25e 100644 --- a/libs/langchain/langchain/schema/messages.py +++ b/libs/langchain/langchain/schema/messages.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import abstractmethod -from typing import List, Sequence +from typing import Any, Dict, List, Sequence from pydantic import Field @@ -78,6 +78,49 @@ class BaseMessage(Serializable): return True +class BaseMessageChunk(BaseMessage): + def _merge_kwargs_dict( + self, left: Dict[str, Any], right: Dict[str, Any] + ) -> Dict[str, Any]: + """Merge additional_kwargs from another BaseMessageChunk into this one.""" + merged = left.copy() + for k, v in right.items(): + if k not in merged: + merged[k] = v + elif type(merged[k]) != type(v): + raise ValueError( + f'additional_kwargs["{k}"] already exists in this message,' + " but with a different type." + ) + elif isinstance(merged[k], str): + merged[k] += v + elif isinstance(merged[k], dict): + merged[k] = self._merge_kwargs_dict(merged[k], v) + else: + raise ValueError( + f"Additional kwargs key {k} already exists in this message." + ) + return merged + + def __add__(self, other: Any) -> BaseMessageChunk: + if isinstance(other, BaseMessageChunk): + # If both are (subclasses of) BaseMessageChunk, + # concat into a single BaseMessageChunk + + return self.__class__( + content=self.content + other.content, + additional_kwargs=self._merge_kwargs_dict( + self.additional_kwargs, other.additional_kwargs + ), + ) + else: + raise TypeError( + 'unsupported operand type(s) for +: "' + f"{self.__class__.__name__}" + f'" and "{other.__class__.__name__}"' + ) + + class HumanMessage(BaseMessage): """A Message from a human.""" @@ -92,6 +135,10 @@ class HumanMessage(BaseMessage): return "human" +class HumanMessageChunk(HumanMessage, BaseMessageChunk): + pass + + class AIMessage(BaseMessage): """A Message from an AI.""" @@ -106,6 +153,10 @@ class AIMessage(BaseMessage): return "ai" +class AIMessageChunk(AIMessage, BaseMessageChunk): + pass + + class SystemMessage(BaseMessage): """A Message for priming AI behavior, usually passed in as the first of a sequence of input messages. @@ -117,6 +168,10 @@ class SystemMessage(BaseMessage): return "system" +class SystemMessageChunk(SystemMessage, BaseMessageChunk): + pass + + class FunctionMessage(BaseMessage): """A Message for passing the result of executing a function back to a model.""" @@ -129,6 +184,10 @@ class FunctionMessage(BaseMessage): return "function" +class FunctionMessageChunk(FunctionMessage, BaseMessageChunk): + pass + + class ChatMessage(BaseMessage): """A Message that can be assigned an arbitrary speaker (i.e. role).""" @@ -141,6 +200,10 @@ class ChatMessage(BaseMessage): return "chat" +class ChatMessageChunk(ChatMessage, BaseMessageChunk): + pass + + def _message_to_dict(message: BaseMessage) -> dict: return {"type": message.type, "data": message.dict()} diff --git a/libs/langchain/langchain/schema/output.py b/libs/langchain/langchain/schema/output.py index c085a495249..06d222ce889 100644 --- a/libs/langchain/langchain/schema/output.py +++ b/libs/langchain/langchain/schema/output.py @@ -7,7 +7,7 @@ from uuid import UUID from pydantic import BaseModel, root_validator from langchain.load.serializable import Serializable -from langchain.schema.messages import BaseMessage +from langchain.schema.messages import BaseMessage, BaseMessageChunk class Generation(Serializable): @@ -28,6 +28,24 @@ class Generation(Serializable): return True +class GenerationChunk(Generation): + def __add__(self, other: GenerationChunk) -> GenerationChunk: + if isinstance(other, GenerationChunk): + generation_info = ( + {**(self.generation_info or {}), **(other.generation_info or {})} + if self.generation_info is not None or other.generation_info is not None + else None + ) + return GenerationChunk( + text=self.text + other.text, + generation_info=generation_info, + ) + else: + raise TypeError( + f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" + ) + + class ChatGeneration(Generation): """A single chat generation output.""" @@ -43,6 +61,26 @@ class ChatGeneration(Generation): return values +class ChatGenerationChunk(ChatGeneration): + message: BaseMessageChunk + + def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk: + if isinstance(other, ChatGenerationChunk): + generation_info = ( + {**(self.generation_info or {}), **(other.generation_info or {})} + if self.generation_info is not None or other.generation_info is not None + else None + ) + return ChatGenerationChunk( + message=self.message + other.message, + generation_info=generation_info, + ) + else: + raise TypeError( + f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" + ) + + class RunInfo(BaseModel): """Class that contains metadata for a single execution of a Chain or model.""" diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index be76baf86f6..2bc1cd5d72d 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -1,16 +1,18 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, List, Optional, TypeVar +from typing import Any, Dict, Generic, List, Optional, TypeVar, Union from langchain.load.serializable import Serializable -from langchain.schema.output import Generation +from langchain.schema.messages import BaseMessage +from langchain.schema.output import ChatGeneration, Generation from langchain.schema.prompt import PromptValue +from langchain.schema.runnable import Runnable, RunnableConfig T = TypeVar("T") -class BaseLLMOutputParser(Serializable, ABC, Generic[T]): +class BaseLLMOutputParser(Serializable, Generic[T], ABC): """Abstract base class for parsing the outputs of a model.""" @abstractmethod @@ -26,7 +28,19 @@ class BaseLLMOutputParser(Serializable, ABC, Generic[T]): """ -class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]): +class BaseGenerationOutputParser( + BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T] +): + def invoke( + self, input: str | BaseMessage, config: RunnableConfig | None = None + ) -> T: + if isinstance(input, BaseMessage): + return self.parse_result([ChatGeneration(message=input)]) + else: + return self.parse_result([Generation(text=input)]) + + +class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]): """Base class to parse the output of an LLM call. Output parsers help structure language model responses. @@ -53,6 +67,14 @@ class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]): return "boolean_output_parser" """ # noqa: E501 + def invoke( + self, input: str | BaseMessage, config: RunnableConfig | None = None + ) -> T: + if isinstance(input, BaseMessage): + return self.parse_result([ChatGeneration(message=input)]) + else: + return self.parse_result([Generation(text=input)]) + def parse_result(self, result: List[Generation]) -> T: """Parse a list of candidate model Generations into a specific format. diff --git a/libs/langchain/langchain/schema/prompt_template.py b/libs/langchain/langchain/schema/prompt_template.py index 6f11d39b86d..b480ecc9464 100644 --- a/libs/langchain/langchain/schema/prompt_template.py +++ b/libs/langchain/langchain/schema/prompt_template.py @@ -12,9 +12,10 @@ from langchain.load.serializable import Serializable from langchain.schema.document import Document from langchain.schema.output_parser import BaseOutputParser from langchain.schema.prompt import PromptValue +from langchain.schema.runnable import Runnable, RunnableConfig -class BasePromptTemplate(Serializable, ABC): +class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC): """Base class for all prompt templates, returning a prompt.""" input_variables: List[str] @@ -34,6 +35,11 @@ class BasePromptTemplate(Serializable, ABC): arbitrary_types_allowed = True + def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue: + return self._call_with_config( + lambda inner_input: self.format_prompt(**inner_input), input, config + ) + @abstractmethod def format_prompt(self, **kwargs: Any) -> PromptValue: """Create Chat Messages.""" diff --git a/libs/langchain/langchain/schema/retriever.py b/libs/langchain/langchain/schema/retriever.py index b25ef0e692b..9df3e7a1389 100644 --- a/libs/langchain/langchain/schema/retriever.py +++ b/libs/langchain/langchain/schema/retriever.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional from langchain.load.dump import dumpd from langchain.load.serializable import Serializable from langchain.schema.document import Document +from langchain.schema.runnable import Runnable, RunnableConfig if TYPE_CHECKING: from langchain.callbacks.manager import ( @@ -17,7 +18,7 @@ if TYPE_CHECKING: ) -class BaseRetriever(Serializable, ABC): +class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): """Abstract base class for a Document retrieval system. A retrieval system is defined as something that can take string queries and return @@ -43,9 +44,6 @@ class BaseRetriever(Serializable, ABC): # Op -- (n_docs,1) -- Cosine Sim with each doc results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,)) return [self.docs[i] for i in results.argsort()[-self.k :][::-1]] - - async def aget_relevant_documents(self, query: str) -> List[Document]: - raise NotImplementedError """ # noqa: E501 class Config: @@ -106,6 +104,20 @@ class BaseRetriever(Serializable, ABC): len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0 ) + def invoke( + self, input: str, config: Optional[RunnableConfig] = None + ) -> List[Document]: + return self.get_relevant_documents(input, **(config or {})) + + async def ainvoke( + self, input: str, config: Optional[RunnableConfig] = None + ) -> List[Document]: + if type(self).aget_relevant_documents == BaseRetriever.aget_relevant_documents: + # If the retriever doesn't implement async, use default implementation + return await super().ainvoke(input, config) + + return await self.aget_relevant_documents(input, **(config or {})) + @abstractmethod def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun @@ -118,7 +130,6 @@ class BaseRetriever(Serializable, ABC): List of relevant documents """ - @abstractmethod async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: @@ -129,6 +140,7 @@ class BaseRetriever(Serializable, ABC): Returns: List of relevant documents """ + raise NotImplementedError() def get_relevant_documents( self, diff --git a/libs/langchain/langchain/schema/runnable.py b/libs/langchain/langchain/schema/runnable.py new file mode 100644 index 00000000000..f0483529706 --- /dev/null +++ b/libs/langchain/langchain/schema/runnable.py @@ -0,0 +1,705 @@ +from __future__ import annotations + +import asyncio +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from typing import ( + Any, + AsyncIterator, + Callable, + Coroutine, + Dict, + Generic, + Iterator, + List, + Optional, + TypedDict, + TypeVar, + Union, + cast, +) + +from pydantic import Field + +from langchain.callbacks.base import BaseCallbackManager, Callbacks +from langchain.load.dump import dumpd +from langchain.load.serializable import Serializable + + +async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any: + async with semaphore: + return await coro + + +async def _gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list: + if n is None: + return await asyncio.gather(*coros) + + semaphore = asyncio.Semaphore(n) + + return await asyncio.gather(*(_gated_coro(semaphore, c) for c in coros)) + + +class RunnableConfig(TypedDict, total=False): + tags: List[str] + """ + Tags for this call and any sub-calls (eg. a Chain calling an LLM). + You can use these to filter calls. + """ + + metadata: Dict[str, Any] + """ + Metadata for this call and any sub-calls (eg. a Chain calling an LLM). + Keys should be strings, values should be JSON-serializable. + """ + + callbacks: Callbacks + """ + Callbacks for this call and any sub-calls (eg. a Chain calling an LLM). + Tags are passed to all callbacks, metadata is passed to handle*Start callbacks. + """ + + +Input = TypeVar("Input") +# Output type should implement __concat__, as eg str, list, dict do +Output = TypeVar("Output") +Other = TypeVar("Other") + + +class Runnable(Generic[Input, Output], ABC): + def __or__( + self, + other: Union[ + Runnable[Any, Other], + Dict[str, Union[Runnable[Any, Other], Callable[[Any], Other]]], + ], + ) -> RunnableSequence[Input, Other]: + return RunnableSequence(first=self, last=_coerce_to_runnable(other)) + + def __ror__( + self, + other: Union[ + Runnable[Other, Any], + Dict[str, Union[Runnable[Other, Any], Callable[[Other], Any]]], + ], + ) -> RunnableSequence[Other, Output]: + return RunnableSequence(first=_coerce_to_runnable(other), last=self) + + @abstractmethod + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: + ... + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Output: + return await asyncio.get_running_loop().run_in_executor( + None, self.invoke, input, config + ) + + def batch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + max_concurrency: Optional[int] = None, + ) -> List[Output]: + configs = self._get_config_list(config, len(inputs)) + + with ThreadPoolExecutor(max_workers=max_concurrency) as executor: + return list(executor.map(self.invoke, inputs, configs)) + + async def abatch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + max_concurrency: Optional[int] = None, + ) -> List[Output]: + configs = self._get_config_list(config, len(inputs)) + coros = map(self.ainvoke, inputs, configs) + + return await _gather_with_concurrency(max_concurrency, *coros) + + def stream( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Iterator[Output]: + yield self.invoke(input, config) + + async def astream( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> AsyncIterator[Output]: + yield await self.ainvoke(input, config) + + def _get_config_list( + self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int + ) -> List[RunnableConfig]: + if isinstance(config, list) and len(config) != length: + raise ValueError( + f"config must be a list of the same length as inputs, " + f"but got {len(config)} configs for {length} inputs" + ) + + return ( + config + if isinstance(config, list) + else [config.copy() if config is not None else {} for _ in range(length)] + ) + + def _call_with_config( + self, + func: Callable[[Input], Output], + input: Input, + config: Optional[RunnableConfig], + ) -> Output: + from langchain.callbacks.manager import CallbackManager + + config = config or {} + callback_manager = CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) + run_manager = callback_manager.on_chain_start( + dumpd(self), input if isinstance(input, dict) else {"input": input} + ) + try: + output = func(input) + except Exception as e: + run_manager.on_chain_error(e) + raise + else: + run_manager.on_chain_end( + output if isinstance(output, dict) else {"output": output} + ) + return output + + +class RunnableSequence(Serializable, Runnable[Input, Output]): + first: Runnable[Input, Any] + middle: List[Runnable[Any, Any]] = Field(default_factory=list) + last: Runnable[Any, Output] + + @property + def steps(self) -> List[Runnable[Any, Any]]: + return [self.first] + self.middle + [self.last] + + @property + def lc_serializable(self) -> bool: + return True + + class Config: + arbitrary_types_allowed = True + + def __or__( + self, + other: Union[ + Runnable[Any, Other], + Dict[str, Union[Runnable[Any, Other], Callable[[Any], Other]]], + ], + ) -> RunnableSequence[Input, Other]: + if isinstance(other, RunnableSequence): + return RunnableSequence( + first=self.first, + middle=self.middle + [self.last] + other.middle, + last=other.last, + ) + else: + return RunnableSequence( + first=self.first, + middle=self.middle + [self.last], + last=_coerce_to_runnable(other), + ) + + def __ror__( + self, + other: Union[ + Runnable[Other, Any], + Dict[str, Union[Runnable[Other, Any], Callable[[Other], Any]]], + ], + ) -> RunnableSequence[Other, Output]: + if isinstance(other, RunnableSequence): + return RunnableSequence( + first=other.first, + middle=other.middle + [other.last] + self.middle, + last=self.last, + ) + else: + return RunnableSequence( + first=_coerce_to_runnable(other), + middle=[self.first] + self.middle, + last=self.last, + ) + + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: + from langchain.callbacks.manager import CallbackManager + + # setup callbacks + config = config or {} + callback_manager = CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + # start the root run + run_manager = callback_manager.on_chain_start( + dumpd(self), input if isinstance(input, dict) else {"input": input} + ) + + # invoke all steps in sequence + try: + for step in self.steps: + input = step.invoke( + input, + # mark each step as a child run + _patch_config(config, run_manager.get_child()), + ) + # finish the root run + except (KeyboardInterrupt, Exception) as e: + run_manager.on_chain_error(e) + raise + else: + run_manager.on_chain_end( + input if isinstance(input, dict) else {"output": input} + ) + return cast(Output, input) + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Output: + from langchain.callbacks.manager import AsyncCallbackManager + + # setup callbacks + config = config or {} + callback_manager = AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + # start the root run + run_manager = await callback_manager.on_chain_start( + dumpd(self), input if isinstance(input, dict) else {"input": input} + ) + + # invoke all steps in sequence + try: + for step in self.steps: + input = await step.ainvoke( + input, + # mark each step as a child run + _patch_config(config, run_manager.get_child()), + ) + # finish the root run + except (KeyboardInterrupt, Exception) as e: + await run_manager.on_chain_error(e) + raise + else: + await run_manager.on_chain_end( + input if isinstance(input, dict) else {"output": input} + ) + return cast(Output, input) + + def batch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + max_concurrency: Optional[int] = None, + ) -> List[Output]: + from langchain.callbacks.manager import CallbackManager + + # setup callbacks + configs = self._get_config_list(config, len(inputs)) + callback_managers = [ + CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + for config in configs + ] + # start the root runs, one per input + run_managers = [ + cm.on_chain_start( + dumpd(self), input if isinstance(input, dict) else {"input": input} + ) + for cm, input in zip(callback_managers, inputs) + ] + + # invoke + try: + for step in self.steps: + inputs = step.batch( + inputs, + [ + # each step a child run of the corresponding root run + _patch_config(config, rm.get_child()) + for rm, config in zip(run_managers, configs) + ], + max_concurrency=max_concurrency, + ) + # finish the root runs + except (KeyboardInterrupt, Exception) as e: + for rm in run_managers: + rm.on_chain_error(e) + raise + else: + for rm, input in zip(run_managers, inputs): + rm.on_chain_end(input if isinstance(input, dict) else {"output": input}) + return cast(List[Output], inputs) + + async def abatch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + max_concurrency: Optional[int] = None, + ) -> List[Output]: + from langchain.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForChainRun, + ) + + # setup callbacks + configs = self._get_config_list(config, len(inputs)) + callback_managers = [ + AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + for config in configs + ] + # start the root runs, one per input + run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( + *( + cm.on_chain_start( + dumpd(self), input if isinstance(input, dict) else {"input": input} + ) + for cm, input in zip(callback_managers, inputs) + ) + ) + + # invoke .batch() on each step + # this uses batching optimizations in Runnable subclasses, like LLM + try: + for step in self.steps: + inputs = await step.abatch( + inputs, + [ + # each step a child run of the corresponding root run + _patch_config(config, rm.get_child()) + for rm, config in zip(run_managers, configs) + ], + max_concurrency=max_concurrency, + ) + # finish the root runs + except (KeyboardInterrupt, Exception) as e: + await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers)) + raise + else: + await asyncio.gather( + *( + rm.on_chain_end( + input if isinstance(input, dict) else {"output": input} + ) + for rm, input in zip(run_managers, inputs) + ) + ) + return cast(List[Output], inputs) + + def stream( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Iterator[Output]: + from langchain.callbacks.manager import CallbackManager + + # setup callbacks + config = config or {} + callback_manager = CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + # start the root run + run_manager = callback_manager.on_chain_start( + dumpd(self), input if isinstance(input, dict) else {"input": input} + ) + + # invoke the first steps + try: + for step in [self.first] + self.middle: + input = step.invoke( + input, + # mark each step as a child run + _patch_config(config, run_manager.get_child()), + ) + except (KeyboardInterrupt, Exception) as e: + run_manager.on_chain_error(e) + raise + + # stream the last step + final: Union[Output, None] = None + final_supported = True + try: + for output in self.last.stream( + input, + # mark the last step as a child run + _patch_config(config, run_manager.get_child()), + ): + yield output + # Accumulate output if possible, otherwise disable accumulation + if final_supported: + if final is None: + final = output + else: + try: + final += output # type: ignore[operator] + except TypeError: + final = None + final_supported = False + pass + # finish the root run + except (KeyboardInterrupt, Exception) as e: + run_manager.on_chain_error(e) + raise + else: + run_manager.on_chain_end( + final if isinstance(final, dict) else {"output": final} + ) + + async def astream( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> AsyncIterator[Output]: + from langchain.callbacks.manager import AsyncCallbackManager + + # setup callbacks + config = config or {} + callback_manager = AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + # start the root run + run_manager = await callback_manager.on_chain_start( + dumpd(self), input if isinstance(input, dict) else {"input": input} + ) + + # invoke the first steps + try: + for step in [self.first] + self.middle: + input = await step.ainvoke( + input, + # mark each step as a child run + _patch_config(config, run_manager.get_child()), + ) + except (KeyboardInterrupt, Exception) as e: + await run_manager.on_chain_error(e) + raise + + # stream the last step + final: Union[Output, None] = None + final_supported = True + try: + async for output in self.last.astream( + input, + # mark the last step as a child run + _patch_config(config, run_manager.get_child()), + ): + yield output + # Accumulate output if possible, otherwise disable accumulation + if final_supported: + if final is None: + final = output + else: + try: + final += output # type: ignore[operator] + except TypeError: + final = None + final_supported = False + pass + # finish the root run + except (KeyboardInterrupt, Exception) as e: + await run_manager.on_chain_error(e) + raise + else: + await run_manager.on_chain_end( + final if isinstance(final, dict) else {"output": final} + ) + + +class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): + steps: Dict[str, Runnable[Input, Any]] + + @property + def lc_serializable(self) -> bool: + return True + + class Config: + arbitrary_types_allowed = True + + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Dict[str, Any]: + from langchain.callbacks.manager import CallbackManager + + # setup callbacks + config = config or {} + callback_manager = CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + # start the root run + run_manager = callback_manager.on_chain_start(dumpd(self), {"input": input}) + + # gather results from all steps + try: + # copy to avoid issues from the caller mutating the steps during invoke() + steps = self.steps.copy() + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit( + step.invoke, + input, + # mark each step as a child run + _patch_config(config, run_manager.get_child()), + ) + for step in steps.values() + ] + output = {key: future.result() for key, future in zip(steps, futures)} + # finish the root run + except (KeyboardInterrupt, Exception) as e: + run_manager.on_chain_error(e) + raise + else: + run_manager.on_chain_end(output) + return output + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Dict[str, Any]: + from langchain.callbacks.manager import AsyncCallbackManager + + # setup callbacks + config = config or {} + callback_manager = AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + # start the root run + run_manager = await callback_manager.on_chain_start( + dumpd(self), {"input": input} + ) + + # gather results from all steps + try: + # copy to avoid issues from the caller mutating the steps during invoke() + steps = self.steps.copy() + results = await asyncio.gather( + *( + step.ainvoke( + input, + # mark each step as a child run + _patch_config(config, run_manager.get_child()), + ) + for step in steps.values() + ) + ) + output = {key: value for key, value in zip(steps, results)} + # finish the root run + except (KeyboardInterrupt, Exception) as e: + await run_manager.on_chain_error(e) + raise + else: + await run_manager.on_chain_end(output) + return output + + +class RunnableLambda(Runnable[Input, Output]): + def __init__(self, func: Callable[[Input], Output]) -> None: + if callable(func): + self.func = func + else: + raise TypeError( + "Expected a callable type for `func`." + f"Instead got an unsupported type: {type(func)}" + ) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, RunnableLambda): + return self.func == other.func + else: + return False + + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: + return self._call_with_config(self.func, input, config) + + +class RunnablePassthrough(Serializable, Runnable[Input, Input]): + @property + def lc_serializable(self) -> bool: + return True + + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: + return self._call_with_config(lambda x: x, input, config) + + +def _patch_config( + config: RunnableConfig, callback_manager: BaseCallbackManager +) -> RunnableConfig: + config = config.copy() + config["callbacks"] = callback_manager + return config + + +def _coerce_to_runnable( + thing: Union[ + Runnable[Input, Output], + Callable[[Input], Output], + Dict[str, Union[Runnable[Input, Output], Callable[[Input], Output]]], + ] +) -> Runnable[Input, Output]: + if isinstance(thing, Runnable): + return thing + elif callable(thing): + return RunnableLambda(thing) + elif isinstance(thing, dict): + runnables = {key: _coerce_to_runnable(r) for key, r in thing.items()} + return cast(Runnable[Input, Output], RunnableMap(steps=runnables)) + else: + raise TypeError( + f"Expected a Runnable, callable or dict." + f"Instead got an unsupported type: {type(thing)}" + ) diff --git a/libs/langchain/tests/integration_tests/chat_models/test_openai.py b/libs/langchain/tests/integration_tests/chat_models/test_openai.py index af3157ed6d4..ff7332c88cb 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_openai.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_openai.py @@ -177,3 +177,68 @@ def test_chat_openai_extra_kwargs() -> None: # Test that "model" cannot be specified in kwargs with pytest.raises(ValueError): ChatOpenAI(model_kwargs={"model": "text-davinci-003"}) + + +def test_openai_streaming() -> None: + """Test streaming tokens from OpenAI.""" + llm = ChatOpenAI(max_tokens=10) + + for token in llm.stream("I'm Pickle Rick"): + assert isinstance(token.content, str) + + +@pytest.mark.asyncio +async def test_openai_astream() -> None: + """Test streaming tokens from OpenAI.""" + llm = ChatOpenAI(max_tokens=10) + + async for token in llm.astream("I'm Pickle Rick"): + assert isinstance(token.content, str) + + +@pytest.mark.asyncio +async def test_openai_abatch() -> None: + """Test streaming tokens from ChatOpenAI.""" + llm = ChatOpenAI(max_tokens=10) + + result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) + for token in result: + assert isinstance(token.content, str) + + +@pytest.mark.asyncio +async def test_openai_abatch_tags() -> None: + """Test batch tokens from ChatOpenAI.""" + llm = ChatOpenAI(max_tokens=10) + + result = await llm.abatch( + ["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]} + ) + for token in result: + assert isinstance(token.content, str) + + +def test_openai_batch() -> None: + """Test batch tokens from ChatOpenAI.""" + llm = ChatOpenAI(max_tokens=10) + + result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) + for token in result: + assert isinstance(token.content, str) + + +@pytest.mark.asyncio +async def test_openai_ainvoke() -> None: + """Test invoke tokens from ChatOpenAI.""" + llm = ChatOpenAI(max_tokens=10) + + result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) + assert isinstance(result.content, str) + + +def test_openai_invoke() -> None: + """Test invoke tokens from ChatOpenAI.""" + llm = ChatOpenAI(max_tokens=10) + + result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) + assert isinstance(result.content, str) diff --git a/libs/langchain/tests/integration_tests/llms/test_openai.py b/libs/langchain/tests/integration_tests/llms/test_openai.py index 5281b7563a4..0844faa6aa3 100644 --- a/libs/langchain/tests/integration_tests/llms/test_openai.py +++ b/libs/langchain/tests/integration_tests/llms/test_openai.py @@ -93,7 +93,64 @@ def test_openai_streaming() -> None: assert isinstance(generator, Generator) for token in generator: - assert isinstance(token["choices"][0]["text"], str) + assert isinstance(token, str) + + +@pytest.mark.asyncio +async def test_openai_astream() -> None: + """Test streaming tokens from OpenAI.""" + llm = OpenAI(max_tokens=10) + + async for token in llm.astream("I'm Pickle Rick"): + assert isinstance(token, str) + + +@pytest.mark.asyncio +async def test_openai_abatch() -> None: + """Test streaming tokens from OpenAI.""" + llm = OpenAI(max_tokens=10) + + result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) + for token in result: + assert isinstance(token, str) + + +@pytest.mark.asyncio +async def test_openai_abatch_tags() -> None: + """Test streaming tokens from OpenAI.""" + llm = OpenAI(max_tokens=10) + + result = await llm.abatch( + ["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]} + ) + for token in result: + assert isinstance(token, str) + + +def test_openai_batch() -> None: + """Test streaming tokens from OpenAI.""" + llm = OpenAI(max_tokens=10) + + result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) + for token in result: + assert isinstance(token, str) + + +@pytest.mark.asyncio +async def test_openai_ainvoke() -> None: + """Test streaming tokens from OpenAI.""" + llm = OpenAI(max_tokens=10) + + result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) + assert isinstance(result, str) + + +def test_openai_invoke() -> None: + """Test streaming tokens from OpenAI.""" + llm = OpenAI(max_tokens=10) + + result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) + assert isinstance(result, str) def test_openai_multiple_prompts() -> None: @@ -105,13 +162,6 @@ def test_openai_multiple_prompts() -> None: assert len(output.generations) == 2 -def test_openai_streaming_error() -> None: - """Test error handling in stream.""" - llm = OpenAI(best_of=2) - with pytest.raises(ValueError): - llm.stream("I'm Pickle Rick") - - def test_openai_streaming_best_of_error() -> None: """Test validation for streaming fails if best_of is not 1.""" with pytest.raises(ValueError): diff --git a/libs/langchain/tests/integration_tests/llms/test_promptlayer_openai.py b/libs/langchain/tests/integration_tests/llms/test_promptlayer_openai.py index b054e321028..643d9952dde 100644 --- a/libs/langchain/tests/integration_tests/llms/test_promptlayer_openai.py +++ b/libs/langchain/tests/integration_tests/llms/test_promptlayer_openai.py @@ -67,10 +67,3 @@ def test_promptlayer_openai_streaming() -> None: for token in generator: assert isinstance(token["choices"][0]["text"], str) - - -def test_promptlayer_openai_streaming_error() -> None: - """Test error handling in stream.""" - llm = PromptLayerOpenAI(best_of=2) - with pytest.raises(ValueError): - llm.stream("I'm Pickle Rick") diff --git a/libs/langchain/tests/unit_tests/schema/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/__snapshots__/test_runnable.ambr new file mode 100644 index 00000000000..fc93e69297f --- /dev/null +++ b/libs/langchain/tests/unit_tests/schema/__snapshots__/test_runnable.ambr @@ -0,0 +1,668 @@ +# serializer version: 1 +# name: test_prompt_with_chat_model + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "ChatPromptTemplate" + ], + "kwargs": { + "messages": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "SystemMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [], + "template": "You are a nice assistant.", + "template_format": "f-string" + } + } + } + }, + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "HumanMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [ + "question" + ], + "template": "{question}", + "template_format": "f-string" + } + } + } + } + ] + } + }, + "last": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "chat_models", + "fake", + "FakeListChatModel" + ] + } + } + } + ''' +# --- +# name: test_prompt_with_chat_model.1 + list([ + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string'}}}}]}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': AIMessage(content='foo', additional_kwargs={}, example=False)}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=1, child_execution_order=3, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string'}}}}]}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': ChatPromptValue(messages=[SystemMessage(content='You are a nice assistant.', additional_kwargs={}), HumanMessage(content='What is your name?', additional_kwargs={}, example=False)])}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=2, child_execution_order=2, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None, 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'foo'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=3, child_execution_order=3, child_runs=[])]), + ]) +# --- +# name: test_prompt_with_chat_model.2 + list([ + Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string'}}}}]}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': AIMessage(content='bar', additional_kwargs={}, example=False)}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=1, child_execution_order=3, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000005'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string'}}}}]}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': ChatPromptValue(messages=[SystemMessage(content='You are a nice assistant.', additional_kwargs={}), HumanMessage(content='What is your name?', additional_kwargs={}, example=False)])}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=[], execution_order=2, child_execution_order=2, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000007'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'bar', 'generation_info': None, 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'bar'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=[], execution_order=3, child_execution_order=3, child_runs=[])]), + Run(id=UUID('00000000-0000-4000-8000-000000000004'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string'}}}}]}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your favorite color?'}, outputs={'output': AIMessage(content='foo', additional_kwargs={}, example=False)}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=1, child_execution_order=3, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000006'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string'}}}}]}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your favorite color?'}, outputs={'output': ChatPromptValue(messages=[SystemMessage(content='You are a nice assistant.', additional_kwargs={}), HumanMessage(content='What is your favorite color?', additional_kwargs={}, example=False)])}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000004'), tags=[], execution_order=2, child_execution_order=2, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000008'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your favorite color?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None, 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'foo'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000004'), tags=[], execution_order=3, child_execution_order=3, child_runs=[])]), + ]) +# --- +# name: test_prompt_with_chat_model_and_parser + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "ChatPromptTemplate" + ], + "kwargs": { + "messages": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "SystemMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [], + "template": "You are a nice assistant.", + "template_format": "f-string" + } + } + } + }, + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "HumanMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [ + "question" + ], + "template": "{question}", + "template_format": "f-string" + } + } + } + } + ] + } + }, + "middle": [ + { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "chat_models", + "fake", + "FakeListChatModel" + ] + } + ], + "last": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "output_parsers", + "list", + "CommaSeparatedListOutputParser" + ], + "kwargs": {} + } + } + } + ''' +# --- +# name: test_prompt_with_chat_model_and_parser.1 + list([ + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string'}}}}]}}, 'middle': [{'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel']}], 'last': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': ['foo', 'bar']}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=1, child_execution_order=3, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string'}}}}]}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': ChatPromptValue(messages=[SystemMessage(content='You are a nice assistant.', additional_kwargs={}), HumanMessage(content='What is your name?', additional_kwargs={}, example=False)])}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=2, child_execution_order=2, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo, bar'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo, bar', 'generation_info': None, 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'foo, bar'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=3, child_execution_order=3, child_runs=[])]), + ]) +# --- +# name: test_prompt_with_llm + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "ChatPromptTemplate" + ], + "kwargs": { + "messages": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "SystemMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [], + "template": "You are a nice assistant.", + "template_format": "f-string" + } + } + } + }, + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "HumanMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [ + "question" + ], + "template": "{question}", + "template_format": "f-string" + } + } + } + } + ] + } + }, + "last": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "llms", + "fake", + "FakeListLLM" + ] + } + } + } + ''' +# --- +# name: test_prompt_with_llm.1 + list([ + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string'}}}}]}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=1, child_execution_order=3, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string'}}}}]}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': ChatPromptValue(messages=[SystemMessage(content='You are a nice assistant.', additional_kwargs={}), HumanMessage(content='What is your name?', additional_kwargs={}, example=False)])}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=2, child_execution_order=2, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=3, child_execution_order=3, child_runs=[])]), + ]) +# --- +# name: test_prompt_with_llm.2 + list([ + Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string'}}}}]}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': 'bar'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=1, child_execution_order=3, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000005'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string'}}}}]}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': ChatPromptValue(messages=[SystemMessage(content='You are a nice assistant.', additional_kwargs={}), HumanMessage(content='What is your name?', additional_kwargs={}, example=False)])}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=[], execution_order=2, child_execution_order=2, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000007'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'bar', 'generation_info': None}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=[], execution_order=3, child_execution_order=3, child_runs=[])]), + Run(id=UUID('00000000-0000-4000-8000-000000000004'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string'}}}}]}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your favorite color?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=1, child_execution_order=3, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000006'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string'}}}}]}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your favorite color?'}, outputs={'output': ChatPromptValue(messages=[SystemMessage(content='You are a nice assistant.', additional_kwargs={}), HumanMessage(content='What is your favorite color?', additional_kwargs={}, example=False)])}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000004'), tags=[], execution_order=2, child_execution_order=2, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000008'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your favorite color?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000004'), tags=[], execution_order=3, child_execution_order=3, child_runs=[])]), + ]) +# --- +# name: test_seq_dict_prompt_llm + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableMap" + ], + "kwargs": { + "steps": { + "question": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnablePassthrough" + ], + "kwargs": {} + }, + "last": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableLambda" + ] + } + } + }, + "documents": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableLambda" + ] + }, + "last": { + "lc": 1, + "type": "not_implemented", + "id": [ + "test_runnable", + "FakeRetriever" + ] + } + } + }, + "just_to_test_lambda": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableLambda" + ] + } + } + } + }, + "middle": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "ChatPromptTemplate" + ], + "kwargs": { + "messages": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "SystemMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [], + "template": "You are a nice assistant.", + "template_format": "f-string" + } + } + } + }, + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "HumanMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [ + "documents", + "question" + ], + "template": "Context:\n{documents}\n\nQuestion:\n{question}", + "template_format": "f-string" + } + } + } + } + ] + } + }, + { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "chat_models", + "fake", + "FakeListChatModel" + ] + } + ], + "last": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "output_parsers", + "list", + "CommaSeparatedListOutputParser" + ], + "kwargs": {} + } + } + } + ''' +# --- +# name: test_seq_dict_prompt_llm.1 + list([ + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableMap'], 'kwargs': {'steps': {'question': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnablePassthrough'], 'kwargs': {}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}}}, 'documents': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['test_runnable', 'FakeRetriever']}}}, 'just_to_test_lambda': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}}}}, 'middle': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['documents', 'question'], 'template': 'Context:\n{documents}\n\nQuestion:\n{question}', 'template_format': 'f-string'}}}}]}}, {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel']}], 'last': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': 'What is your name?'}, outputs={'output': ['foo', 'bar']}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=1, child_execution_order=11, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='RunnableMap', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableMap'], 'kwargs': {'steps': {'question': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnablePassthrough'], 'kwargs': {}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}}}, 'documents': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['test_runnable', 'FakeRetriever']}}}, 'just_to_test_lambda': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': 'What is your name?'}, outputs={'question': 'What is your name?', 'documents': [Document(page_content='foo', metadata={}), Document(page_content='bar', metadata={})], 'just_to_test_lambda': 'What is your name?'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=2, child_execution_order=9, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnablePassthrough'], 'kwargs': {}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': 'What is your name?'}, outputs={'output': 'What is your name?'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000001'), tags=[], execution_order=3, child_execution_order=5, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnablePassthrough', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnablePassthrough'], 'kwargs': {}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': 'What is your name?'}, outputs={'output': 'What is your name?'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000002'), tags=[], execution_order=4, child_execution_order=4, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000004'), name='RunnableLambda', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': 'What is your name?'}, outputs={'output': 'What is your name?'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000002'), tags=[], execution_order=5, child_execution_order=5, child_runs=[])]), Run(id=UUID('00000000-0000-4000-8000-000000000005'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['test_runnable', 'FakeRetriever']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': 'What is your name?'}, outputs={'output': [Document(page_content='foo', metadata={}), Document(page_content='bar', metadata={})]}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000001'), tags=[], execution_order=6, child_execution_order=8, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000006'), name='RunnableLambda', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': 'What is your name?'}, outputs={'output': 'What is your name?'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000005'), tags=[], execution_order=7, child_execution_order=7, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000007'), name='Retriever', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['test_runnable', 'FakeRetriever']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'query': 'What is your name?'}, outputs={'documents': [Document(page_content='foo', metadata={}), Document(page_content='bar', metadata={})]}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000005'), tags=[], execution_order=8, child_execution_order=8, child_runs=[])]), Run(id=UUID('00000000-0000-4000-8000-000000000008'), name='RunnableLambda', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': 'What is your name?'}, outputs={'output': 'What is your name?'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000001'), tags=[], execution_order=9, child_execution_order=9, child_runs=[])]), Run(id=UUID('00000000-0000-4000-8000-000000000009'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['documents', 'question'], 'template': 'Context:\n{documents}\n\nQuestion:\n{question}', 'template_format': 'f-string'}}}}]}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?', 'documents': [Document(page_content='foo', metadata={}), Document(page_content='bar', metadata={})], 'just_to_test_lambda': 'What is your name?'}, outputs={'output': ChatPromptValue(messages=[SystemMessage(content='You are a nice assistant.', additional_kwargs={}), HumanMessage(content="Context:\n[Document(page_content='foo', metadata={}), Document(page_content='bar', metadata={})]\n\nQuestion:\nWhat is your name?", additional_kwargs={}, example=False)])}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=10, child_execution_order=10, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000010'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo, bar'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ["System: You are a nice assistant.\nHuman: Context:\n[Document(page_content='foo', metadata={}), Document(page_content='bar', metadata={})]\n\nQuestion:\nWhat is your name?"]}, outputs={'generations': [[{'text': 'foo, bar', 'generation_info': None, 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'foo, bar'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=11, child_execution_order=11, child_runs=[])]), + ]) +# --- +# name: test_seq_prompt_dict + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "ChatPromptTemplate" + ], + "kwargs": { + "messages": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "SystemMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [], + "template": "You are a nice assistant.", + "template_format": "f-string" + } + } + } + }, + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "HumanMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [ + "question" + ], + "template": "{question}", + "template_format": "f-string" + } + } + } + } + ] + } + }, + "middle": [ + { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableLambda" + ] + } + ], + "last": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableMap" + ], + "kwargs": { + "steps": { + "chat": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "chat_models", + "fake", + "FakeListChatModel" + ] + }, + "llm": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "llms", + "fake", + "FakeListLLM" + ] + } + } + } + } + } + } + ''' +# --- +# name: test_seq_prompt_dict.1 + list([ + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string'}}}}]}}, 'middle': [{'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}], 'last': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableMap'], 'kwargs': {'steps': {'chat': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel']}, 'llm': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}}}}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'chat': AIMessage(content="i'm a chatbot", additional_kwargs={}, example=False), 'llm': "i'm a textbot"}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=1, child_execution_order=6, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string'}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string'}}}}]}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': ChatPromptValue(messages=[SystemMessage(content='You are a nice assistant.', additional_kwargs={}), HumanMessage(content='What is your name?', additional_kwargs={}, example=False)])}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=2, child_execution_order=2, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='RunnableLambda', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': ChatPromptValue(messages=[SystemMessage(content='You are a nice assistant.', additional_kwargs={}), HumanMessage(content='What is your name?', additional_kwargs={}, example=False)])}, outputs={'output': ChatPromptValue(messages=[SystemMessage(content='You are a nice assistant.', additional_kwargs={}), HumanMessage(content='What is your name?', additional_kwargs={}, example=False)])}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=3, child_execution_order=3, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnableMap', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableMap'], 'kwargs': {'steps': {'chat': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel']}, 'llm': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': ChatPromptValue(messages=[SystemMessage(content='You are a nice assistant.', additional_kwargs={}), HumanMessage(content='What is your name?', additional_kwargs={}, example=False)])}, outputs={'chat': AIMessage(content="i'm a chatbot", additional_kwargs={}, example=False), 'llm': "i'm a textbot"}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=4, child_execution_order=6, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000004'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ["i'm a chatbot"], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': "i'm a chatbot", 'generation_info': None, 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': "i'm a chatbot"}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=[], execution_order=5, child_execution_order=5, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000005'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ["i'm a textbot"], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': "i'm a textbot", 'generation_info': None}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=[], execution_order=6, child_execution_order=6, child_runs=[])])]), + ]) +# --- diff --git a/libs/langchain/tests/unit_tests/schema/test_messages.py b/libs/langchain/tests/unit_tests/schema/test_messages.py new file mode 100644 index 00000000000..25c1a2b072e --- /dev/null +++ b/libs/langchain/tests/unit_tests/schema/test_messages.py @@ -0,0 +1,38 @@ +from langchain.schema.messages import AIMessageChunk, HumanMessageChunk + + +def test_message_chunks() -> None: + assert AIMessageChunk(content="I am") + AIMessageChunk( + content=" indeed." + ) == AIMessageChunk( + content="I am indeed." + ), "MessageChunk + MessageChunk should be a MessageChunk" + + assert AIMessageChunk(content="I am") + HumanMessageChunk( + content=" indeed." + ) == AIMessageChunk( + content="I am indeed." + ), "MessageChunk + MessageChunk should be a MessageChunk of same class as the left side" # noqa: E501 + + assert AIMessageChunk( + content="", additional_kwargs={"foo": "bar"} + ) + AIMessageChunk(content="", additional_kwargs={"baz": "foo"}) == AIMessageChunk( + content="", additional_kwargs={"foo": "bar", "baz": "foo"} + ), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501 + + assert AIMessageChunk( + content="", additional_kwargs={"function_call": {"name": "web_search"}} + ) + AIMessageChunk( + content="", additional_kwargs={"function_call": {"arguments": "{\n"}} + ) + AIMessageChunk( + content="", + additional_kwargs={"function_call": {"arguments": ' "query": "turtles"\n}'}}, + ) == AIMessageChunk( + content="", + additional_kwargs={ + "function_call": { + "name": "web_search", + "arguments": '{\n "query": "turtles"\n}', + } + }, + ), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501 diff --git a/libs/langchain/tests/unit_tests/schema/test_output.py b/libs/langchain/tests/unit_tests/schema/test_output.py new file mode 100644 index 00000000000..2d19f9370f9 --- /dev/null +++ b/libs/langchain/tests/unit_tests/schema/test_output.py @@ -0,0 +1,52 @@ +from langchain.schema.messages import HumanMessageChunk +from langchain.schema.output import ChatGenerationChunk, GenerationChunk + + +def test_generation_chunk() -> None: + assert GenerationChunk(text="Hello, ") + GenerationChunk( + text="world!" + ) == GenerationChunk( + text="Hello, world!" + ), "GenerationChunk + GenerationChunk should be a GenerationChunk" + + assert GenerationChunk(text="Hello, ") + GenerationChunk( + text="world!", generation_info={"foo": "bar"} + ) == GenerationChunk( + text="Hello, world!", generation_info={"foo": "bar"} + ), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501 + + assert GenerationChunk(text="Hello, ") + GenerationChunk( + text="world!", generation_info={"foo": "bar"} + ) + GenerationChunk(text="!", generation_info={"baz": "foo"}) == GenerationChunk( + text="Hello, world!!", generation_info={"foo": "bar", "baz": "foo"} + ), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501 + + +def test_chat_generation_chunk() -> None: + assert ChatGenerationChunk( + message=HumanMessageChunk(content="Hello, ") + ) + ChatGenerationChunk( + message=HumanMessageChunk(content="world!") + ) == ChatGenerationChunk( + message=HumanMessageChunk(content="Hello, world!") + ), "ChatGenerationChunk + ChatGenerationChunk should be a ChatGenerationChunk" + + assert ChatGenerationChunk( + message=HumanMessageChunk(content="Hello, ") + ) + ChatGenerationChunk( + message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"} + ) == ChatGenerationChunk( + message=HumanMessageChunk(content="Hello, world!"), + generation_info={"foo": "bar"}, + ), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501 + + assert ChatGenerationChunk( + message=HumanMessageChunk(content="Hello, ") + ) + ChatGenerationChunk( + message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"} + ) + ChatGenerationChunk( + message=HumanMessageChunk(content="!"), generation_info={"baz": "foo"} + ) == ChatGenerationChunk( + message=HumanMessageChunk(content="Hello, world!!"), + generation_info={"foo": "bar", "baz": "foo"}, + ), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501 diff --git a/libs/langchain/tests/unit_tests/schema/test_runnable.py b/libs/langchain/tests/unit_tests/schema/test_runnable.py new file mode 100644 index 00000000000..83881a6bdf7 --- /dev/null +++ b/libs/langchain/tests/unit_tests/schema/test_runnable.py @@ -0,0 +1,547 @@ +from typing import Any, Dict, List, Optional +from uuid import UUID + +import pytest +from freezegun import freeze_time +from pytest_mock import MockerFixture +from syrupy import SnapshotAssertion + +from langchain.callbacks.manager import Callbacks +from langchain.callbacks.tracers.base import BaseTracer +from langchain.callbacks.tracers.schemas import Run +from langchain.chat_models.fake import FakeListChatModel +from langchain.llms.fake import FakeListLLM +from langchain.load.dump import dumps +from langchain.output_parsers.list import CommaSeparatedListOutputParser +from langchain.prompts.chat import ( + ChatPromptTemplate, + ChatPromptValue, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) +from langchain.schema.document import Document +from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage +from langchain.schema.retriever import BaseRetriever +from langchain.schema.runnable import ( + Runnable, + RunnableConfig, + RunnableLambda, + RunnableMap, + RunnablePassthrough, + RunnableSequence, +) + + +class FakeTracer(BaseTracer): + """Fake tracer that records LangChain execution.""" + + def __init__(self) -> None: + """Initialize the tracer.""" + super().__init__() + self.runs: List[Run] = [] + + def _persist_run(self, run: Run) -> None: + """Persist a run.""" + self.runs.append(run) + + +class FakeRunnable(Runnable[str, int]): + def invoke( + self, + input: str, + config: Optional[RunnableConfig] = None, + ) -> int: + return len(input) + + +class FakeRetriever(BaseRetriever): + def _get_relevant_documents( + self, + query: str, + *, + callbacks: Callbacks = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> List[Document]: + return [Document(page_content="foo"), Document(page_content="bar")] + + async def _aget_relevant_documents( + self, + query: str, + *, + callbacks: Callbacks = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> List[Document]: + return [Document(page_content="foo"), Document(page_content="bar")] + + +@pytest.fixture() +def fixed_uuids(mocker: MockerFixture) -> MockerFixture._Patcher: + """Note this mock only works with `import uuid; uuid.uuid4()`, + it does not work with `from uuid import uuid4; uuid4()`.""" + + # Disable tracing to avoid fixed UUIDs causing tracing errors. + mocker.patch.dict("os.environ", {"LANGCHAIN_TRACING_V2": "false"}) + + side_effect = ( + UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000) + ) + return mocker.patch("uuid.uuid4", side_effect=side_effect) + + +@pytest.mark.asyncio +async def test_default_method_implementations(mocker: MockerFixture) -> None: + fake = FakeRunnable() + spy = mocker.spy(fake, "invoke") + + assert fake.invoke("hello", dict(tags=["a-tag"])) == 5 + assert spy.call_args_list == [ + mocker.call("hello", dict(tags=["a-tag"])), + ] + spy.reset_mock() + + assert [*fake.stream("hello", dict(metadata={"key": "value"}))] == [5] + assert spy.call_args_list == [ + mocker.call("hello", dict(metadata={"key": "value"})), + ] + spy.reset_mock() + + assert fake.batch( + ["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})] + ) == [5, 7] + assert spy.call_args_list == [ + mocker.call("hello", dict(tags=["a-tag"])), + mocker.call("wooorld", dict(metadata={"key": "value"})), + ] + spy.reset_mock() + + assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7] + assert spy.call_args_list == [ + mocker.call("hello", dict(tags=["a-tag"])), + mocker.call("wooorld", dict(tags=["a-tag"])), + ] + spy.reset_mock() + + assert await fake.ainvoke("hello", config={"callbacks": []}) == 5 + assert spy.call_args_list == [ + mocker.call("hello", dict(callbacks=[])), + ] + spy.reset_mock() + + assert [part async for part in fake.astream("hello")] == [5] + assert spy.call_args_list == [ + mocker.call("hello", None), + ] + spy.reset_mock() + + assert await fake.abatch(["hello", "wooorld"], dict(metadata={"key": "value"})) == [ + 5, + 7, + ] + assert spy.call_args_list == [ + mocker.call("hello", dict(metadata={"key": "value"})), + mocker.call("wooorld", dict(metadata={"key": "value"})), + ] + + +@pytest.mark.asyncio +async def test_prompt() -> None: + prompt = ChatPromptTemplate.from_messages( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessagePromptTemplate.from_template("{question}"), + ] + ) + expected = ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ) + + assert prompt.invoke({"question": "What is your name?"}) == expected + + assert prompt.batch( + [ + {"question": "What is your name?"}, + {"question": "What is your favorite color?"}, + ] + ) == [ + expected, + ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your favorite color?"), + ] + ), + ] + + assert [*prompt.stream({"question": "What is your name?"})] == [expected] + + assert await prompt.ainvoke({"question": "What is your name?"}) == expected + + assert await prompt.abatch( + [ + {"question": "What is your name?"}, + {"question": "What is your favorite color?"}, + ] + ) == [ + expected, + ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your favorite color?"), + ] + ), + ] + + assert [ + part async for part in prompt.astream({"question": "What is your name?"}) + ] == [expected] + + +@pytest.mark.asyncio +@freeze_time("2023-01-01") +async def test_prompt_with_chat_model( + mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None +) -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + chat = FakeListChatModel(responses=["foo", "bar"]) + + chain = prompt | chat + + assert isinstance(chain, RunnableSequence) + assert chain.first == prompt + assert chain.middle == [] + assert chain.last == chat + assert dumps(chain, pretty=True) == snapshot + + # Test invoke + prompt_spy = mocker.spy(prompt.__class__, "invoke") + chat_spy = mocker.spy(chat.__class__, "invoke") + tracer = FakeTracer() + assert chain.invoke( + {"question": "What is your name?"}, dict(callbacks=[tracer]) + ) == AIMessage(content="foo") + assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} + assert chat_spy.call_args.args[1] == ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ) + assert tracer.runs == snapshot + mocker.stop(prompt_spy) + mocker.stop(chat_spy) + + # Test batch + prompt_spy = mocker.spy(prompt.__class__, "batch") + chat_spy = mocker.spy(chat.__class__, "batch") + tracer = FakeTracer() + assert chain.batch( + [ + {"question": "What is your name?"}, + {"question": "What is your favorite color?"}, + ], + dict(callbacks=[tracer]), + ) == [ + AIMessage(content="bar"), + AIMessage(content="foo"), + ] + assert prompt_spy.call_args.args[1] == [ + {"question": "What is your name?"}, + {"question": "What is your favorite color?"}, + ] + assert chat_spy.call_args.args[1] == [ + ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ), + ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your favorite color?"), + ] + ), + ] + assert tracer.runs == snapshot + mocker.stop(prompt_spy) + mocker.stop(chat_spy) + + # Test stream + prompt_spy = mocker.spy(prompt.__class__, "invoke") + chat_spy = mocker.spy(chat.__class__, "stream") + tracer = FakeTracer() + assert [ + *chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer])) + ] == [AIMessage(content="bar")] + assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} + assert chat_spy.call_args.args[1] == ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ) + + +@pytest.mark.asyncio +@freeze_time("2023-01-01") +async def test_prompt_with_llm( + mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None +) -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + llm = FakeListLLM(responses=["foo", "bar"]) + + chain = prompt | llm + + assert isinstance(chain, RunnableSequence) + assert chain.first == prompt + assert chain.middle == [] + assert chain.last == llm + assert dumps(chain, pretty=True) == snapshot + + # Test invoke + prompt_spy = mocker.spy(prompt.__class__, "ainvoke") + llm_spy = mocker.spy(llm.__class__, "ainvoke") + tracer = FakeTracer() + assert ( + await chain.ainvoke( + {"question": "What is your name?"}, dict(callbacks=[tracer]) + ) + == "foo" + ) + assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} + assert llm_spy.call_args.args[1] == ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ) + assert tracer.runs == snapshot + mocker.stop(prompt_spy) + mocker.stop(llm_spy) + + # Test batch + prompt_spy = mocker.spy(prompt.__class__, "abatch") + llm_spy = mocker.spy(llm.__class__, "abatch") + tracer = FakeTracer() + assert await chain.abatch( + [ + {"question": "What is your name?"}, + {"question": "What is your favorite color?"}, + ], + dict(callbacks=[tracer]), + ) == ["bar", "foo"] + assert prompt_spy.call_args.args[1] == [ + {"question": "What is your name?"}, + {"question": "What is your favorite color?"}, + ] + assert llm_spy.call_args.args[1] == [ + ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ), + ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your favorite color?"), + ] + ), + ] + assert tracer.runs == snapshot + mocker.stop(prompt_spy) + mocker.stop(llm_spy) + + # Test stream + prompt_spy = mocker.spy(prompt.__class__, "ainvoke") + llm_spy = mocker.spy(llm.__class__, "astream") + tracer = FakeTracer() + assert [ + token + async for token in chain.astream( + {"question": "What is your name?"}, dict(callbacks=[tracer]) + ) + ] == ["bar"] + assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} + assert llm_spy.call_args.args[1] == ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ) + + +@freeze_time("2023-01-01") +def test_prompt_with_chat_model_and_parser( + mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None +) -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + chat = FakeListChatModel(responses=["foo, bar"]) + parser = CommaSeparatedListOutputParser() + + chain = prompt | chat | parser + + assert isinstance(chain, RunnableSequence) + assert chain.first == prompt + assert chain.middle == [chat] + assert chain.last == parser + assert dumps(chain, pretty=True) == snapshot + + # Test invoke + prompt_spy = mocker.spy(prompt.__class__, "invoke") + chat_spy = mocker.spy(chat.__class__, "invoke") + parser_spy = mocker.spy(parser.__class__, "invoke") + tracer = FakeTracer() + assert chain.invoke( + {"question": "What is your name?"}, dict(callbacks=[tracer]) + ) == ["foo", "bar"] + assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} + assert chat_spy.call_args.args[1] == ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ) + assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar") + assert tracer.runs == snapshot + + +@freeze_time("2023-01-01") +def test_seq_dict_prompt_llm( + mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None +) -> None: + passthrough = mocker.Mock(side_effect=lambda x: x) + + retriever = FakeRetriever() + + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + """Context: +{documents} + +Question: +{question}""" + ) + + chat = FakeListChatModel(responses=["foo, bar"]) + + parser = CommaSeparatedListOutputParser() + + chain = ( + { + "question": RunnablePassthrough[str]() | passthrough, + "documents": passthrough | retriever, + "just_to_test_lambda": passthrough, + } + | prompt + | chat + | parser + ) + + assert isinstance(chain, RunnableSequence) + assert isinstance(chain.first, RunnableMap) + assert chain.middle == [prompt, chat] + assert chain.last == parser + assert dumps(chain, pretty=True) == snapshot + + # Test invoke + prompt_spy = mocker.spy(prompt.__class__, "invoke") + chat_spy = mocker.spy(chat.__class__, "invoke") + parser_spy = mocker.spy(parser.__class__, "invoke") + tracer = FakeTracer() + assert chain.invoke("What is your name?", dict(callbacks=[tracer])) == [ + "foo", + "bar", + ] + assert prompt_spy.call_args.args[1] == { + "documents": [Document(page_content="foo"), Document(page_content="bar")], + "question": "What is your name?", + "just_to_test_lambda": "What is your name?", + } + assert chat_spy.call_args.args[1] == ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage( + content="""Context: +[Document(page_content='foo', metadata={}), Document(page_content='bar', metadata={})] + +Question: +What is your name?""" + ), + ] + ) + assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar") + assert tracer.runs == snapshot + + +@freeze_time("2023-01-01") +def test_seq_prompt_dict( + mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None +) -> None: + passthrough = mocker.Mock(side_effect=lambda x: x) + + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + + chat = FakeListChatModel(responses=["i'm a chatbot"]) + + llm = FakeListLLM(responses=["i'm a textbot"]) + + chain = ( + prompt + | passthrough + | { # type: ignore + "chat": chat, + "llm": llm, + } + ) + + assert isinstance(chain, RunnableSequence) + assert chain.first == prompt + assert chain.middle == [RunnableLambda(passthrough)] + assert isinstance(chain.last, RunnableMap) + assert dumps(chain, pretty=True) == snapshot + + # Test invoke + prompt_spy = mocker.spy(prompt.__class__, "invoke") + chat_spy = mocker.spy(chat.__class__, "invoke") + llm_spy = mocker.spy(llm.__class__, "invoke") + tracer = FakeTracer() + assert chain.invoke( + {"question": "What is your name?"}, dict(callbacks=[tracer]) + ) == { + "chat": AIMessage(content="i'm a chatbot"), + "llm": "i'm a textbot", + } + assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} + assert chat_spy.call_args.args[1] == ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ) + assert llm_spy.call_args.args[1] == ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ) + assert tracer.runs == snapshot