Runnable single protocol (#7800)

Objects implementing Runnable: BasePromptTemplate, LLM, ChatModel,
Chain, Retriever, OutputParser

- [x] Implement Runnable in base Retriever
- [x] Raise TypeError in operator methods for unsupported things 
- [x] Implement dict which calls values in parallel and outputs dict
with results
- [x] Merge in `+` for prompts
- [x] Confirm precedence order for operators, ideal would be `+` `|`,
https://docs.python.org/3/reference/expressions.html#operator-precedence
- [x] Add support for openai functions, ie. Chat Models must return
messages
- [x] Implement BaseMessageChunk return type for BaseChatModel, a
subclass of BaseMessage which implements __add__ to return
BaseMessageChunk, concatenating all str args
- [x] Update implementation of stream/astream for llm and chat models to
use new `_stream`, `_astream` optional methods, with default
implementation in base class `raise NotImplementedError` use
https://stackoverflow.com/a/59762827 to see if it is implemented in base
class
- [x] Delete the IteratorCallbackHandler (leave the async one because
people using)
- [x] Make BaseLLMOutputParser implement Runnable, accepting either str
or BaseMessage
---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Nuno Campos 2023-07-26 20:16:46 +01:00 committed by GitHub
parent 04a4d3e312
commit a612800ef0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
60 changed files with 3564 additions and 762 deletions

View File

@ -9,10 +9,7 @@ from typing import (
Tuple, Tuple,
) )
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForLLMRun
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.schema import ( from langchain.schema import (
ChatGeneration, ChatGeneration,
@ -116,15 +113,6 @@ class ChatLlamaAPI(BaseChatModel):
generations.append(gen) generations.append(gen)
return ChatResult(generations=generations) return ChatResult(generations=generations)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
raise NotImplementedError
@property @property
def _client_params(self) -> Mapping[str, Any]: def _client_params(self) -> Mapping[str, Any]:
"""Get the parameters used for the client.""" """Get the parameters used for the client."""

View File

@ -1,13 +1,14 @@
"""Base callback handler that can be used to handle callbacks in langchain.""" """Base callback handler that can be used to handle callbacks in langchain."""
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from uuid import UUID from uuid import UUID
from langchain.schema.agent import AgentAction, AgentFinish if TYPE_CHECKING:
from langchain.schema.document import Document from langchain.schema.agent import AgentAction, AgentFinish
from langchain.schema.messages import BaseMessage from langchain.schema.document import Document
from langchain.schema.output import LLMResult from langchain.schema.messages import BaseMessage
from langchain.schema.output import LLMResult
class RetrieverManagerMixin: class RetrieverManagerMixin:
@ -543,3 +544,6 @@ class BaseCallbackManager(CallbackManagerMixin):
for key in keys: for key in keys:
self.metadata.pop(key) self.metadata.pop(key)
self.inheritable_metadata.pop(key) self.inheritable_metadata.pop(key)
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]

View File

@ -4,6 +4,7 @@ import asyncio
import functools import functools
import logging import logging
import os import os
import uuid
from contextlib import asynccontextmanager, contextmanager from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar from contextvars import ContextVar
from typing import ( from typing import (
@ -20,12 +21,13 @@ from typing import (
Union, Union,
cast, cast,
) )
from uuid import UUID, uuid4 from uuid import UUID
import langchain import langchain
from langchain.callbacks.base import ( from langchain.callbacks.base import (
BaseCallbackHandler, BaseCallbackHandler,
BaseCallbackManager, BaseCallbackManager,
Callbacks,
ChainManagerMixin, ChainManagerMixin,
LLMManagerMixin, LLMManagerMixin,
RetrieverManagerMixin, RetrieverManagerMixin,
@ -50,7 +52,6 @@ if TYPE_CHECKING:
from langsmith import Client as LangSmithClient from langsmith import Client as LangSmithClient
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar( openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
"openai_callback", default=None "openai_callback", default=None
@ -437,7 +438,7 @@ class BaseRunManager(RunManagerMixin):
BaseRunManager: The noop manager. BaseRunManager: The noop manager.
""" """
return cls( return cls(
run_id=uuid4(), run_id=uuid.uuid4(),
handlers=[], handlers=[],
inheritable_handlers=[], inheritable_handlers=[],
tags=[], tags=[],
@ -1024,7 +1025,7 @@ class CallbackManager(BaseCallbackManager):
""" """
managers = [] managers = []
for prompt in prompts: for prompt in prompts:
run_id_ = uuid4() run_id_ = uuid.uuid4()
_handle_event( _handle_event(
self.handlers, self.handlers,
"on_llm_start", "on_llm_start",
@ -1073,7 +1074,7 @@ class CallbackManager(BaseCallbackManager):
managers = [] managers = []
for message_list in messages: for message_list in messages:
run_id_ = uuid4() run_id_ = uuid.uuid4()
_handle_event( _handle_event(
self.handlers, self.handlers,
"on_chat_model_start", "on_chat_model_start",
@ -1120,7 +1121,7 @@ class CallbackManager(BaseCallbackManager):
CallbackManagerForChainRun: The callback manager for the chain run. CallbackManagerForChainRun: The callback manager for the chain run.
""" """
if run_id is None: if run_id is None:
run_id = uuid4() run_id = uuid.uuid4()
_handle_event( _handle_event(
self.handlers, self.handlers,
@ -1166,7 +1167,7 @@ class CallbackManager(BaseCallbackManager):
CallbackManagerForToolRun: The callback manager for the tool run. CallbackManagerForToolRun: The callback manager for the tool run.
""" """
if run_id is None: if run_id is None:
run_id = uuid4() run_id = uuid.uuid4()
_handle_event( _handle_event(
self.handlers, self.handlers,
@ -1202,7 +1203,7 @@ class CallbackManager(BaseCallbackManager):
) -> CallbackManagerForRetrieverRun: ) -> CallbackManagerForRetrieverRun:
"""Run when retriever starts running.""" """Run when retriever starts running."""
if run_id is None: if run_id is None:
run_id = uuid4() run_id = uuid.uuid4()
_handle_event( _handle_event(
self.handlers, self.handlers,
@ -1302,7 +1303,7 @@ class AsyncCallbackManager(BaseCallbackManager):
managers = [] managers = []
for prompt in prompts: for prompt in prompts:
run_id_ = uuid4() run_id_ = uuid.uuid4()
tasks.append( tasks.append(
_ahandle_event( _ahandle_event(
@ -1341,7 +1342,7 @@ class AsyncCallbackManager(BaseCallbackManager):
serialized: Dict[str, Any], serialized: Dict[str, Any],
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> List[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running. """Run when LLM starts running.
Args: Args:
@ -1358,7 +1359,7 @@ class AsyncCallbackManager(BaseCallbackManager):
managers = [] managers = []
for message_list in messages: for message_list in messages:
run_id_ = uuid4() run_id_ = uuid.uuid4()
tasks.append( tasks.append(
_ahandle_event( _ahandle_event(
@ -1410,7 +1411,7 @@ class AsyncCallbackManager(BaseCallbackManager):
for the chain run. for the chain run.
""" """
if run_id is None: if run_id is None:
run_id = uuid4() run_id = uuid.uuid4()
await _ahandle_event( await _ahandle_event(
self.handlers, self.handlers,
@ -1458,7 +1459,7 @@ class AsyncCallbackManager(BaseCallbackManager):
for the tool run. for the tool run.
""" """
if run_id is None: if run_id is None:
run_id = uuid4() run_id = uuid.uuid4()
await _ahandle_event( await _ahandle_event(
self.handlers, self.handlers,
@ -1494,7 +1495,7 @@ class AsyncCallbackManager(BaseCallbackManager):
) -> AsyncCallbackManagerForRetrieverRun: ) -> AsyncCallbackManagerForRetrieverRun:
"""Run when retriever starts running.""" """Run when retriever starts running."""
if run_id is None: if run_id is None:
run_id = uuid4() run_id = uuid.uuid4()
await _ahandle_event( await _ahandle_event(
self.handlers, self.handlers,

View File

@ -4,7 +4,7 @@ import asyncio
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast
from langchain.callbacks.base import AsyncCallbackHandler from langchain.callbacks.base import AsyncCallbackHandler
from langchain.schema import LLMResult from langchain.schema.output import LLMResult
# TODO If used by two LLM runs in parallel this won't work as expected # TODO If used by two LLM runs in parallel this won't work as expected

View File

@ -22,6 +22,7 @@ from langchain.callbacks.manager import (
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema import RUN_KEY, BaseMemory, RunInfo from langchain.schema import RUN_KEY, BaseMemory, RunInfo
from langchain.schema.runnable import Runnable, RunnableConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,7 +31,7 @@ def _get_verbosity() -> bool:
return langchain.verbose return langchain.verbose
class Chain(Serializable, ABC): class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
"""Abstract base class for creating structured sequences of calls to components. """Abstract base class for creating structured sequences of calls to components.
Chains should be used to encode a sequence of calls to components like Chains should be used to encode a sequence of calls to components like
@ -53,6 +54,20 @@ class Chain(Serializable, ABC):
chains and cannot return as rich of an output as `__call__`. chains and cannot return as rich of an output as `__call__`.
""" """
def invoke(
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
return self(input, **(config or {}))
async def ainvoke(
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
if type(self)._acall == Chain._acall:
# If the chain does not implement async, fall back to default implementation
return await super().ainvoke(input, config)
return await self.acall(input, **(config or {}))
memory: Optional[BaseMemory] = None memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None. """Optional memory object. Defaults to None.
Memory is a class that gets called at the start Memory is a class that gets called at the start

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -12,11 +12,13 @@ from langchain.schema import (
) )
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
AIMessageChunk,
BaseMessage, BaseMessage,
ChatMessage, ChatMessage,
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
) )
from langchain.schema.output import ChatGenerationChunk
class ChatAnthropic(BaseChatModel, _AnthropicCommon): class ChatAnthropic(BaseChatModel, _AnthropicCommon):
@ -94,6 +96,44 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
text.rstrip() text.rstrip()
) # trim off the trailing ' ' that might come from the "Assistant: " ) # trim off the trailing ' ' that might come from the "Assistant: "
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop:
params["stop_sequences"] = stop
stream_resp = self.client.completions.create(**params, stream=True)
for data in stream_resp:
delta = data.completion
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
run_manager.on_llm_new_token(delta)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop:
params["stop_sequences"] = stop
stream_resp = await self.async_client.completions.create(**params, stream=True)
async for data in stream_resp:
delta = data.completion
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
await run_manager.on_llm_new_token(delta)
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -101,22 +141,19 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop:
params["stop_sequences"] = stop
if self.streaming: if self.streaming:
completion = "" completion = ""
stream_resp = self.client.completions.create(**params, stream=True) for chunk in self._stream(messages, stop, run_manager, **kwargs):
for data in stream_resp: completion += chunk.text
delta = data.completion
completion += delta
if run_manager:
run_manager.on_llm_new_token(
delta,
)
else: else:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {
"prompt": prompt,
**self._default_params,
**kwargs,
}
if stop:
params["stop_sequences"] = stop
response = self.client.completions.create(**params) response = self.client.completions.create(**params)
completion = response.completion completion = response.completion
message = AIMessage(content=completion) message = AIMessage(content=completion)
@ -129,24 +166,19 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop:
params["stop_sequences"] = stop
if self.streaming: if self.streaming:
completion = "" completion = ""
stream_resp = await self.async_client.completions.create( async for chunk in self._astream(messages, stop, run_manager, **kwargs):
**params, stream=True completion += chunk.text
)
async for data in stream_resp:
delta = data.completion
completion += delta
if run_manager:
await run_manager.on_llm_new_token(
delta,
)
else: else:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {
"prompt": prompt,
**self._default_params,
**kwargs,
}
if stop:
params["stop_sequences"] = stop
response = await self.async_client.completions.create(**params) response = await self.async_client.completions.create(**params)
completion = response.completion completion = response.completion
message = AIMessage(content=completion) message = AIMessage(content=completion)

View File

@ -3,7 +3,16 @@ import inspect
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import partial from functools import partial
from typing import Any, Dict, List, Optional, Sequence from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
cast,
)
from pydantic import Field, root_validator from pydantic import Field, root_validator
@ -17,6 +26,8 @@ from langchain.callbacks.manager import (
Callbacks, Callbacks,
) )
from langchain.load.dump import dumpd, dumps from langchain.load.dump import dumpd, dumps
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
from langchain.schema import ( from langchain.schema import (
ChatGeneration, ChatGeneration,
ChatResult, ChatResult,
@ -24,17 +35,22 @@ from langchain.schema import (
PromptValue, PromptValue,
RunInfo, RunInfo,
) )
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage from langchain.schema.messages import (
AIMessage,
BaseMessage,
BaseMessageChunk,
HumanMessage,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig
def _get_verbosity() -> bool: def _get_verbosity() -> bool:
return langchain.verbose return langchain.verbose
class BaseChatModel(BaseLanguageModel, ABC): class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
"""Base class for chat models."""
cache: Optional[bool] = None cache: Optional[bool] = None
"""Whether to cache the response.""" """Whether to cache the response."""
verbose: bool = Field(default_factory=_get_verbosity) verbose: bool = Field(default_factory=_get_verbosity)
@ -64,6 +80,154 @@ class BaseChatModel(BaseLanguageModel, ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
# --- Runnable methods ---
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
elif isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, list):
return ChatPromptValue(messages=input)
else:
raise ValueError(
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
) -> BaseMessageChunk:
return cast(
BaseMessageChunk,
cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)], stop=stop, **(config or {})
).generations[0][0],
).message,
)
async def ainvoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
) -> BaseMessageChunk:
if type(self)._agenerate == BaseChatModel._agenerate:
# model doesn't implement async generation, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, stop=stop)
)
llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **(config or {})
)
return cast(
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
)
def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[BaseMessageChunk]:
if type(self)._stream == BaseChatModel._stream:
# model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop, **kwargs)
else:
config = config or {}
messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
callback_manager = CallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = callback_manager.on_chat_model_start(
dumpd(self), [messages], invocation_params=params, options=options
)
try:
message: Optional[BaseMessageChunk] = None
for chunk in self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.message
if message is None:
message = chunk.message
else:
message += chunk.message
assert message is not None
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
raise e
else:
run_manager.on_llm_end(
LLMResult(generations=[[ChatGeneration(message=message)]]),
)
async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]:
if type(self)._astream == BaseChatModel._astream:
# model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop, **kwargs)
else:
config = config or {}
messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = await callback_manager.on_chat_model_start(
dumpd(self), [messages], invocation_params=params, options=options
)
try:
message: Optional[BaseMessageChunk] = None
async for chunk in self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.message
if message is None:
message = chunk.message
else:
message += chunk.message
assert message is not None
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
raise e
else:
await run_manager.on_llm_end(
LLMResult(generations=[[ChatGeneration(message=message)]]),
)
# --- Custom methods ---
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return {} return {}
@ -334,7 +498,6 @@ class BaseChatModel(BaseLanguageModel, ABC):
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call"""
@abstractmethod
async def _agenerate( async def _agenerate(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -343,6 +506,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call"""
raise NotImplementedError()
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
raise NotImplementedError()
def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
raise NotImplementedError()
def __call__( def __call__(
self, self,

View File

@ -25,7 +25,10 @@ class FakeListChatModel(SimpleChatModel):
) -> str: ) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'.""" """First try to lookup in queries, else return 'foo' or 'bar'."""
response = self.responses[self.i] response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1 self.i += 1
else:
self.i = 0
return response return response
@property @property

View File

@ -4,8 +4,10 @@ from __future__ import annotations
import logging import logging
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Callable, Callable,
Dict, Dict,
Iterator,
List, List,
Mapping, Mapping,
Optional, Optional,
@ -36,6 +38,14 @@ from langchain.schema import (
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
) )
from langchain.schema.messages import (
AIMessageChunk,
BaseMessageChunk,
ChatMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -75,6 +85,24 @@ async def acompletion_with_retry(llm: JinaChat, **kwargs: Any) -> Any:
return await _completion_with_retry(**kwargs) return await _completion_with_retry(**kwargs)
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"] role = _dict["role"]
if role == "user": if role == "user":
@ -258,6 +286,25 @@ class JinaChat(BaseChatModel):
overall_token_usage[k] = v overall_token_usage[k] = v
return {"token_usage": overall_token_usage} return {"token_usage": overall_token_usage}
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(messages=message_dicts, **params):
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(chunk.content)
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -265,27 +312,20 @@ class JinaChat(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
if self.streaming:
generation: Optional[ChatGenerationChunk] = None
for chunk in self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
if self.streaming:
inner_completion = ""
role = "assistant"
params["stream"] = True
for stream_resp in self.completion_with_retry(
messages=message_dicts, **params
):
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content") or ""
inner_completion += token
if run_manager:
run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{
"content": inner_completion,
"role": role,
}
)
return ChatResult(generations=[ChatGeneration(message=message)])
response = self.completion_with_retry(messages=message_dicts, **params) response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response) return self._create_chat_result(response)
@ -309,6 +349,27 @@ class JinaChat(BaseChatModel):
llm_output = {"token_usage": response["usage"]} llm_output = {"token_usage": response["usage"]}
return ChatResult(generations=generations, llm_output=llm_output) return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, **params
):
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
if run_manager:
await run_manager.on_llm_new_token(chunk.content)
async def _agenerate( async def _agenerate(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -316,31 +377,21 @@ class JinaChat(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
if self.streaming:
generation: Optional[ChatGenerationChunk] = None
async for chunk in self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
if self.streaming: response = await acompletion_with_retry(self, messages=message_dicts, **params)
inner_completion = ""
role = "assistant"
params["stream"] = True
async for stream_resp in await acompletion_with_retry(
self, messages=message_dicts, **params
):
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content", "")
inner_completion += token or ""
if run_manager:
await run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{
"content": inner_completion,
"role": role,
}
)
return ChatResult(generations=[ChatGeneration(message=message)])
else:
response = await acompletion_with_retry(
self, messages=message_dicts, **params
)
return self._create_chat_result(response) return self._create_chat_result(response)
@property @property

View File

@ -6,8 +6,10 @@ import sys
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncIterator,
Callable, Callable,
Dict, Dict,
Iterator,
List, List,
Mapping, Mapping,
Optional, Optional,
@ -35,12 +37,19 @@ from langchain.schema import (
) )
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
AIMessageChunk,
BaseMessage, BaseMessage,
BaseMessageChunk,
ChatMessage, ChatMessage,
ChatMessageChunk,
FunctionMessage, FunctionMessage,
FunctionMessageChunk,
HumanMessage, HumanMessage,
HumanMessageChunk,
SystemMessage, SystemMessage,
SystemMessageChunk,
) )
from langchain.schema.output import ChatGenerationChunk
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
if TYPE_CHECKING: if TYPE_CHECKING:
@ -95,6 +104,30 @@ async def acompletion_with_retry(llm: ChatOpenAI, **kwargs: Any) -> Any:
return await _completion_with_retry(**kwargs) return await _completion_with_retry(**kwargs)
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
if _dict.get("function_call"):
additional_kwargs = {"function_call": dict(_dict["function_call"])}
else:
additional_kwargs = {}
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"])
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"] role = _dict["role"]
if role == "user": if role == "user":
@ -313,6 +346,27 @@ class ChatOpenAI(BaseChatModel):
overall_token_usage[k] = v overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model_name} return {"token_usage": overall_token_usage, "model_name": self.model_name}
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(messages=message_dicts, **params):
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(chunk.content)
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -320,40 +374,20 @@ class ChatOpenAI(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
if self.streaming:
generation: Optional[ChatGenerationChunk] = None
for chunk in self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
if self.streaming:
inner_completion = ""
role = "assistant"
params["stream"] = True
function_call: Optional[dict] = None
for stream_resp in self.completion_with_retry(
messages=message_dicts, **params
):
if len(stream_resp["choices"]) > 0:
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content") or ""
inner_completion += token
_function_call = stream_resp["choices"][0]["delta"].get(
"function_call"
)
if _function_call:
if function_call is None:
function_call = _function_call
elif "arguments" in function_call:
function_call["arguments"] += _function_call["arguments"]
else:
function_call["arguments"] = _function_call["arguments"]
if run_manager:
run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{
"content": inner_completion,
"role": role,
"function_call": function_call,
}
)
return ChatResult(generations=[ChatGeneration(message=message)])
response = self.completion_with_retry(messages=message_dicts, **params) response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response) return self._create_chat_result(response)
@ -381,6 +415,29 @@ class ChatOpenAI(BaseChatModel):
llm_output = {"token_usage": token_usage, "model_name": self.model_name} llm_output = {"token_usage": token_usage, "model_name": self.model_name}
return ChatResult(generations=generations, llm_output=llm_output) return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, **params
):
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
if run_manager:
await run_manager.on_llm_new_token(chunk.content)
async def _agenerate( async def _agenerate(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -388,44 +445,21 @@ class ChatOpenAI(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
if self.streaming:
generation: Optional[ChatGenerationChunk] = None
async for chunk in self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
if self.streaming: response = await acompletion_with_retry(self, messages=message_dicts, **params)
inner_completion = ""
role = "assistant"
params["stream"] = True
function_call: Optional[dict] = None
async for stream_resp in await acompletion_with_retry(
self, messages=message_dicts, **params
):
if len(stream_resp["choices"]) > 0:
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content", "")
inner_completion += token or ""
_function_call = stream_resp["choices"][0]["delta"].get(
"function_call"
)
if _function_call:
if function_call is None:
function_call = _function_call
elif "arguments" in function_call:
function_call["arguments"] += _function_call["arguments"]
else:
function_call["arguments"] = _function_call["arguments"]
if run_manager:
await run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{
"content": inner_completion,
"role": role,
"function_call": function_call,
}
)
return ChatResult(generations=[ChatGeneration(message=message)])
else:
response = await acompletion_with_retry(
self, messages=message_dicts, **params
)
return self._create_chat_result(response) return self._create_chat_result(response)
@property @property

View File

@ -4,10 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
from pydantic import root_validator from pydantic import root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForLLMRun
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.llms.vertexai import _VertexAICommon, is_codey_model from langchain.llms.vertexai import _VertexAICommon, is_codey_model
from langchain.schema import ( from langchain.schema import (
@ -162,14 +159,3 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
response = chat.send_message(question.content) response = chat.send_message(question.content)
text = self._enforce_stop_words(response.text, stop) text = self._enforce_stop_words(response.text, stop)
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))]) return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
raise NotImplementedError(
"""Vertex AI doesn't support async requests at the moment."""
)

View File

@ -6,7 +6,8 @@ from pydantic import BaseModel
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions import create_tagging_chain from langchain.chains.openai_functions import create_tagging_chain
from langchain.prompts import ChatPromptTemplate from langchain.prompts import ChatPromptTemplate
from langchain.schema import BaseDocumentTransformer, BaseLanguageModel, Document from langchain.schema import BaseDocumentTransformer, Document
from langchain.schema.language_model import BaseLanguageModel
class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel): class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel):

View File

@ -1,18 +1,20 @@
import re import re
import warnings import warnings
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional
from pydantic import BaseModel, root_validator from pydantic import root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output import GenerationChunk
from langchain.utils import check_package_version, get_from_dict_or_env from langchain.utils import check_package_version, get_from_dict_or_env
class _AnthropicCommon(BaseModel): class _AnthropicCommon(BaseLanguageModel):
client: Any = None #: :meta private: client: Any = None #: :meta private:
async_client: Any = None #: :meta private: async_client: Any = None #: :meta private:
model: str = "claude-2" model: str = "claude-2"
@ -193,24 +195,16 @@ class Anthropic(LLM, _AnthropicCommon):
response = model(prompt) response = model(prompt)
""" """
if self.streaming:
completion = ""
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
completion += chunk.text
return completion
stop = self._get_anthropic_stop(stop) stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs} params = {**self._default_params, **kwargs}
if self.streaming:
stream_resp = self.client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
stream=True,
**params,
)
current_completion = ""
for data in stream_resp:
delta = data.completion
current_completion += delta
if run_manager:
run_manager.on_llm_new_token(
delta,
)
return current_completion
response = self.client.completions.create( response = self.client.completions.create(
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
@ -226,22 +220,17 @@ class Anthropic(LLM, _AnthropicCommon):
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Call out to Anthropic's completion endpoint asynchronously.""" """Call out to Anthropic's completion endpoint asynchronously."""
if self.streaming:
completion = ""
async for chunk in self._astream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
completion += chunk.text
return completion
stop = self._get_anthropic_stop(stop) stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs} params = {**self._default_params, **kwargs}
if self.streaming:
stream_resp = await self.async_client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
stream=True,
**params,
)
current_completion = ""
async for data in stream_resp:
delta = data.completion
current_completion += delta
if run_manager:
await run_manager.on_llm_new_token(delta)
return current_completion
response = await self.async_client.completions.create( response = await self.async_client.completions.create(
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
@ -249,23 +238,23 @@ class Anthropic(LLM, _AnthropicCommon):
) )
return response.completion return response.completion
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator: def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
r"""Call Anthropic completion_stream and return the resulting generator. r"""Call Anthropic completion_stream and return the resulting generator.
BETA: this is a beta feature while we figure out the right abstraction.
Once that happens, this interface could change.
Args: Args:
prompt: The prompt to pass into the model. prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating. stop: Optional list of stop words to use when generating.
Returns: Returns:
A generator representing the stream of tokens from Anthropic. A generator representing the stream of tokens from Anthropic.
Example: Example:
.. code-block:: python .. code-block:: python
prompt = "Write a poem about a stream." prompt = "Write a poem about a stream."
prompt = f"\n\nHuman: {prompt}\n\nAssistant:" prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
generator = anthropic.stream(prompt) generator = anthropic.stream(prompt)
@ -273,12 +262,49 @@ class Anthropic(LLM, _AnthropicCommon):
yield token yield token
""" """
stop = self._get_anthropic_stop(stop) stop = self._get_anthropic_stop(stop)
return self.client.completions.create( params = {**self._default_params, **kwargs}
for token in self.client.completions.create(
prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params
):
yield GenerationChunk(text=token.completion)
if run_manager:
run_manager.on_llm_new_token(token.completion)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
r"""Call Anthropic completion_stream and return the resulting generator.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens from Anthropic.
Example:
.. code-block:: python
prompt = "Write a poem about a stream."
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
generator = anthropic.stream(prompt)
for token in generator:
yield token
"""
stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
async for token in await self.async_client.completions.create(
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
stream=True, stream=True,
**self._default_params, **params,
) ):
yield GenerationChunk(text=token.completion)
if run_manager:
await run_manager.on_llm_new_token(token.completion)
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens.""" """Calculate number of tokens."""

View File

@ -7,11 +7,14 @@ import json
import logging import logging
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Callable, Callable,
Dict, Dict,
Iterator,
List, List,
Mapping, Mapping,
Optional, Optional,
@ -19,6 +22,7 @@ from typing import (
Tuple, Tuple,
Type, Type,
Union, Union,
cast,
) )
import yaml import yaml
@ -42,14 +46,18 @@ from langchain.callbacks.manager import (
Callbacks, Callbacks,
) )
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
from langchain.schema import ( from langchain.schema import (
Generation, Generation,
LLMResult, LLMResult,
PromptValue, PromptValue,
RunInfo, RunInfo,
) )
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
from langchain.schema.output import GenerationChunk
from langchain.schema.runnable import RunnableConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -115,7 +123,7 @@ def update_cache(
return llm_output return llm_output
class BaseLLM(BaseLanguageModel, ABC): class BaseLLM(BaseLanguageModel[str], ABC):
"""Base LLM abstract interface. """Base LLM abstract interface.
It should take in a prompt and return a string.""" It should take in a prompt and return a string."""
@ -157,6 +165,204 @@ class BaseLLM(BaseLanguageModel, ABC):
else: else:
return verbose return verbose
# --- Runnable methods ---
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
elif isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, list):
return ChatPromptValue(messages=input)
else:
raise ValueError(
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
) -> str:
return (
self.generate_prompt(
[self._convert_input(input)], stop=stop, **(config or {})
)
.generations[0][0]
.text
)
async def ainvoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
) -> str:
if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async invoke, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, stop=stop)
)
llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **(config or {})
)
return llm_result.generations[0][0].text
def batch(
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
max_concurrency: Optional[int] = None,
) -> List[str]:
config = self._get_config_list(config, len(inputs))
if max_concurrency is None:
llm_result = self.generate_prompt(
[self._convert_input(input) for input in inputs],
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
)
return [g[0].text for g in llm_result.generations]
else:
batches = [
inputs[i : i + max_concurrency]
for i in range(0, len(inputs), max_concurrency)
]
return [
output
for batch in batches
for output in self.batch(batch, config=config)
]
async def abatch(
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
max_concurrency: Optional[int] = None,
) -> List[str]:
if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async batch, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, self.batch, inputs, config, max_concurrency
)
config = self._get_config_list(config, len(inputs))
if max_concurrency is None:
llm_result = await self.agenerate_prompt(
[self._convert_input(input) for input in inputs],
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
)
return [g[0].text for g in llm_result.generations]
else:
batches = [
inputs[i : i + max_concurrency]
for i in range(0, len(inputs), max_concurrency)
]
return [
output
for batch in batches
for output in await self.abatch(batch, config=config)
]
def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
) -> Iterator[str]:
if type(self)._stream == BaseLLM._stream:
# model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop)
else:
prompt = self._convert_input(input).to_string()
config = config or {}
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
callback_manager = CallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = callback_manager.on_llm_start(
dumpd(self), [prompt], invocation_params=params, options=options
)
try:
generation: Optional[GenerationChunk] = None
for chunk in self._stream(prompt, stop=stop, run_manager=run_manager):
yield chunk.text
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
raise e
else:
run_manager.on_llm_end(LLMResult(generations=[[generation]]))
async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
) -> AsyncIterator[str]:
if type(self)._astream == BaseLLM._astream:
# model doesn't implement streaming, so use default implementation
yield await self.ainvoke(input, config=config, stop=stop)
else:
prompt = self._convert_input(input).to_string()
config = config or {}
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = await callback_manager.on_llm_start(
dumpd(self), [prompt], invocation_params=params, options=options
)
try:
generation: Optional[GenerationChunk] = None
async for chunk in self._astream(
prompt, stop=stop, run_manager=run_manager
):
yield chunk.text
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
raise e
else:
await run_manager.on_llm_end(LLMResult(generations=[[generation]]))
# --- Custom methods ---
@abstractmethod @abstractmethod
def _generate( def _generate(
self, self,
@ -167,7 +373,6 @@ class BaseLLM(BaseLanguageModel, ABC):
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompts.""" """Run the LLM on the given prompts."""
@abstractmethod
async def _agenerate( async def _agenerate(
self, self,
prompts: List[str], prompts: List[str],
@ -176,12 +381,31 @@ class BaseLLM(BaseLanguageModel, ABC):
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompts.""" """Run the LLM on the given prompts."""
raise NotImplementedError()
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
raise NotImplementedError()
def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
raise NotImplementedError()
def generate_prompt( def generate_prompt(
self, self,
prompts: List[PromptValue], prompts: List[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts] prompt_strings = [p.to_string() for p in prompts]
@ -191,7 +415,7 @@ class BaseLLM(BaseLanguageModel, ABC):
self, self,
prompts: List[PromptValue], prompts: List[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts] prompt_strings = [p.to_string() for p in prompts]
@ -236,10 +460,10 @@ class BaseLLM(BaseLanguageModel, ABC):
self, self,
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
*, *,
tags: Optional[List[str]] = None, tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
@ -248,6 +472,50 @@ class BaseLLM(BaseLanguageModel, ABC):
"Argument 'prompts' is expected to be of type List[str], received" "Argument 'prompts' is expected to be of type List[str], received"
f" argument of type {type(prompts)}." f" argument of type {type(prompts)}."
) )
# Create callback managers
if isinstance(callbacks, list) and (
isinstance(callbacks[0], (list, BaseCallbackManager))
or callbacks[0] is None
):
# We've received a list of callbacks args to apply to each input
assert len(callbacks) == len(prompts)
assert tags is None or (
isinstance(tags, list) and len(tags) == len(prompts)
)
assert metadata is None or (
isinstance(metadata, list) and len(metadata) == len(prompts)
)
callbacks = cast(List[Callbacks], callbacks)
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
metadata_list = cast(
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
)
callback_managers = [
CallbackManager.configure(
callback,
self.callbacks,
self.verbose,
tag,
self.tags,
meta,
self.metadata,
)
for callback, tag, meta in zip(callbacks, tags_list, metadata_list)
]
else:
# We've received a single callbacks arg to apply to all inputs
callback_managers = [
CallbackManager.configure(
cast(Callbacks, callbacks),
self.callbacks,
self.verbose,
cast(List[str], tags),
self.tags,
cast(Dict[str, Any], metadata),
self.metadata,
)
] * len(prompts)
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
options = {"stop": stop} options = {"stop": stop}
@ -258,15 +526,6 @@ class BaseLLM(BaseLanguageModel, ABC):
missing_prompts, missing_prompts,
) = get_prompts(params, prompts) ) = get_prompts(params, prompts)
disregard_cache = self.cache is not None and not self.cache disregard_cache = self.cache is not None and not self.cache
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._generate).parameters.get( new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager" "run_manager"
) )
@ -275,17 +534,26 @@ class BaseLLM(BaseLanguageModel, ABC):
raise ValueError( raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`." "Asked to cache, but no cache found at `langchain.cache`."
) )
run_managers = callback_manager.on_llm_start( run_managers = [
dumpd(self), prompts, invocation_params=params, options=options callback_manager.on_llm_start(
) dumpd(self), [prompt], invocation_params=params, options=options
)[0]
for callback_manager, prompt in zip(callback_managers, prompts)
]
output = self._generate_helper( output = self._generate_helper(
prompts, stop, run_managers, bool(new_arg_supported), **kwargs prompts, stop, run_managers, bool(new_arg_supported), **kwargs
) )
return output return output
if len(missing_prompts) > 0: if len(missing_prompts) > 0:
run_managers = callback_manager.on_llm_start( run_managers = [
dumpd(self), missing_prompts, invocation_params=params, options=options callback_managers[idx].on_llm_start(
) dumpd(self),
[prompts[idx]],
invocation_params=params,
options=options,
)[0]
for idx in missing_prompt_idxs
]
new_results = self._generate_helper( new_results = self._generate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
) )
@ -346,13 +614,57 @@ class BaseLLM(BaseLanguageModel, ABC):
self, self,
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
*, *,
tags: Optional[List[str]] = None, tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
# Create callback managers
if isinstance(callbacks, list) and (
isinstance(callbacks[0], (list, BaseCallbackManager))
or callbacks[0] is None
):
# We've received a list of callbacks args to apply to each input
assert len(callbacks) == len(prompts)
assert tags is None or (
isinstance(tags, list) and len(tags) == len(prompts)
)
assert metadata is None or (
isinstance(metadata, list) and len(metadata) == len(prompts)
)
callbacks = cast(List[Callbacks], callbacks)
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
metadata_list = cast(
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
)
callback_managers = [
AsyncCallbackManager.configure(
callback,
self.callbacks,
self.verbose,
tag,
self.tags,
meta,
self.metadata,
)
for callback, tag, meta in zip(callbacks, tags_list, metadata_list)
]
else:
# We've received a single callbacks arg to apply to all inputs
callback_managers = [
AsyncCallbackManager.configure(
cast(Callbacks, callbacks),
self.callbacks,
self.verbose,
cast(List[str], tags),
self.tags,
cast(Dict[str, Any], metadata),
self.metadata,
)
] * len(prompts)
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
options = {"stop": stop} options = {"stop": stop}
@ -363,15 +675,6 @@ class BaseLLM(BaseLanguageModel, ABC):
missing_prompts, missing_prompts,
) = get_prompts(params, prompts) ) = get_prompts(params, prompts)
disregard_cache = self.cache is not None and not self.cache disregard_cache = self.cache is not None and not self.cache
callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._agenerate).parameters.get( new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager" "run_manager"
) )
@ -380,17 +683,32 @@ class BaseLLM(BaseLanguageModel, ABC):
raise ValueError( raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`." "Asked to cache, but no cache found at `langchain.cache`."
) )
run_managers = await callback_manager.on_llm_start( run_managers = await asyncio.gather(
dumpd(self), prompts, invocation_params=params, options=options *[
callback_manager.on_llm_start(
dumpd(self), [prompt], invocation_params=params, options=options
) )
for callback_manager, prompt in zip(callback_managers, prompts)
]
)
run_managers = [r[0] for r in run_managers]
output = await self._agenerate_helper( output = await self._agenerate_helper(
prompts, stop, run_managers, bool(new_arg_supported), **kwargs prompts, stop, run_managers, bool(new_arg_supported), **kwargs
) )
return output return output
if len(missing_prompts) > 0: if len(missing_prompts) > 0:
run_managers = await callback_manager.on_llm_start( run_managers = await asyncio.gather(
dumpd(self), missing_prompts, invocation_params=params, options=options *[
callback_managers[idx].on_llm_start(
dumpd(self),
[prompts[idx]],
invocation_params=params,
options=options,
) )
for idx in missing_prompt_idxs
]
)
run_managers = [r[0] for r in run_managers]
new_results = await self._agenerate_helper( new_results = await self._agenerate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
) )
@ -586,7 +904,7 @@ class LLM(BaseLLM):
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
raise NotImplementedError("Async generation not implemented for this LLM.") raise NotImplementedError()
def _generate( def _generate(
self, self,
@ -615,6 +933,12 @@ class LLM(BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
if type(self)._acall == LLM._acall:
# model doesn't implement async call, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, prompts, stop, run_manager, **kwargs)
)
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
generations = [] generations = []
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")

View File

@ -27,7 +27,10 @@ class FakeListLLM(LLM):
) -> str: ) -> str:
"""Return next response""" """Return next response"""
response = self.responses[self.i] response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1 self.i += 1
else:
self.i = 0
return response return response
async def _acall( async def _acall(
@ -39,7 +42,10 @@ class FakeListLLM(LLM):
) -> str: ) -> str:
"""Return next response""" """Return next response"""
response = self.responses[self.i] response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1 self.i += 1
else:
self.i = 0
return response return response
@property @property

View File

@ -12,10 +12,7 @@ from tenacity import (
wait_exponential, wait_exponential,
) )
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForLLMRun
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms import BaseLLM from langchain.llms import BaseLLM
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
@ -161,15 +158,6 @@ class GooglePalm(BaseLLM, BaseModel):
return LLMResult(generations=generations) return LLMResult(generations=generations)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
raise NotImplementedError()
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of llm.""" """Return type of llm."""

View File

@ -1,5 +1,4 @@
from functools import partial from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from typing import Any, Dict, List, Optional
from pydantic import Extra, Field, root_validator from pydantic import Extra, Field, root_validator
@ -8,6 +7,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.schema.output import GenerationChunk
class HuggingFaceTextGenInference(LLM): class HuggingFaceTextGenInference(LLM):
@ -69,7 +69,7 @@ class HuggingFaceTextGenInference(LLM):
temperature = 0.01, temperature = 0.01,
repetition_penalty = 1.03, repetition_penalty = 1.03,
callbacks = callbacks, callbacks = callbacks,
stream = True streaming = True
) )
print(llm("What is Deep Learning?")) print(llm("What is Deep Learning?"))
@ -87,7 +87,7 @@ class HuggingFaceTextGenInference(LLM):
inference_server_url: str = "" inference_server_url: str = ""
timeout: int = 120 timeout: int = 120
server_kwargs: Dict[str, Any] = Field(default_factory=dict) server_kwargs: Dict[str, Any] = Field(default_factory=dict)
stream: bool = False streaming: bool = False
client: Any client: Any
async_client: Any async_client: Any
@ -154,8 +154,13 @@ class HuggingFaceTextGenInference(LLM):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
if self.streaming:
completion = ""
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
completion += chunk.text
return completion
invocation_params = self._invocation_params(stop, **kwargs) invocation_params = self._invocation_params(stop, **kwargs)
if not self.stream:
res = self.client.generate(prompt, **invocation_params) res = self.client.generate(prompt, **invocation_params)
# remove stop sequences from the end of the generated text # remove stop sequences from the end of the generated text
for stop_seq in invocation_params["stop_sequences"]: for stop_seq in invocation_params["stop_sequences"]:
@ -163,28 +168,7 @@ class HuggingFaceTextGenInference(LLM):
res.generated_text = res.generated_text[ res.generated_text = res.generated_text[
: res.generated_text.index(stop_seq) : res.generated_text.index(stop_seq)
] ]
text = res.generated_text return res.generated_text
else:
text_callback = None
if run_manager:
text_callback = partial(
run_manager.on_llm_new_token, verbose=self.verbose
)
text = ""
for res in self.client.generate_stream(prompt, **invocation_params):
token = res.token
is_stop = False
for stop_seq in invocation_params["stop_sequences"]:
if stop_seq in token.text:
is_stop = True
break
if is_stop:
break
if not token.special:
if text_callback:
text_callback(token.text)
text += token.text
return text
async def _acall( async def _acall(
self, self,
@ -193,39 +177,90 @@ class HuggingFaceTextGenInference(LLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
if self.streaming:
completion = ""
async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
completion += chunk.text
return completion
invocation_params = self._invocation_params(stop, **kwargs) invocation_params = self._invocation_params(stop, **kwargs)
if not self.stream: res = await self.async_client.generate(prompt, **invocation_params)
res = await self.async_client.generate(
prompt,
**invocation_params,
)
# remove stop sequences from the end of the generated text # remove stop sequences from the end of the generated text
for stop_seq in invocation_params["stop_sequences"]: for stop_seq in invocation_params["stop_sequences"]:
if stop_seq in res.generated_text: if stop_seq in res.generated_text:
res.generated_text = res.generated_text[ res.generated_text = res.generated_text[
: res.generated_text.index(stop_seq) : res.generated_text.index(stop_seq)
] ]
text: str = res.generated_text return res.generated_text
else:
text_callback = None def _stream(
if run_manager: self,
text_callback = partial( prompt: str,
run_manager.on_llm_new_token, verbose=self.verbose stop: Optional[List[str]] = None,
) run_manager: Optional[CallbackManagerForLLMRun] = None,
text = "" **kwargs: Any,
async for res in self.async_client.generate_stream( ) -> Iterator[GenerationChunk]:
prompt, **invocation_params invocation_params = self._invocation_params(stop, **kwargs)
):
token = res.token for res in self.client.generate_stream(prompt, **invocation_params):
is_stop = False # identify stop sequence in generated text, if any
stop_seq_found: Optional[str] = None
for stop_seq in invocation_params["stop_sequences"]: for stop_seq in invocation_params["stop_sequences"]:
if stop_seq in token.text: if stop_seq in res.token.text:
is_stop = True stop_seq_found = stop_seq
# identify text to yield
text: Optional[str] = None
if res.token.special:
text = None
elif stop_seq_found:
text = res.token.text[: res.token.text.index(stop_seq_found)]
else:
text = res.token.text
# yield text, if any
if text:
chunk = GenerationChunk(text=text)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text)
# break if stop sequence found
if stop_seq_found:
break break
if is_stop:
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
invocation_params = self._invocation_params(stop, **kwargs)
async for res in self.async_client.generate_stream(prompt, **invocation_params):
# identify stop sequence in generated text, if any
stop_seq_found: Optional[str] = None
for stop_seq in invocation_params["stop_sequences"]:
if stop_seq in res.token.text:
stop_seq_found = stop_seq
# identify text to yield
text: Optional[str] = None
if res.token.special:
text = None
elif stop_seq_found:
text = res.token.text[: res.token.text.index(stop_seq_found)]
else:
text = res.token.text
# yield text, if any
if text:
chunk = GenerationChunk(text=text)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text)
# break if stop sequence found
if stop_seq_found:
break break
if not token.special:
if text_callback:
await text_callback(token.text)
text += token.text
return text

View File

@ -1,10 +1,11 @@
import logging import logging
from typing import Any, Dict, Generator, List, Optional from typing import Any, Dict, Iterator, List, Optional
from pydantic import Field, root_validator from pydantic import Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.schema.output import GenerationChunk
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -226,8 +227,10 @@ class LlamaCpp(LLM):
# method that yields as they are generated # method that yields as they are generated
# and return the combined strings from the first choices's text: # and return the combined strings from the first choices's text:
combined_text_output = "" combined_text_output = ""
for token in self.stream(prompt=prompt, stop=stop, run_manager=run_manager): for chunk in self._stream(
combined_text_output += token["choices"][0]["text"] prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
combined_text_output += chunk.text
return combined_text_output return combined_text_output
else: else:
params = self._get_parameters(stop) params = self._get_parameters(stop)
@ -235,17 +238,15 @@ class LlamaCpp(LLM):
result = self.client(prompt=prompt, **params) result = self.client(prompt=prompt, **params)
return result["choices"][0]["text"] return result["choices"][0]["text"]
def stream( def _stream(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> Generator[Dict, None, None]: **kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Yields results objects as they are generated in real time. """Yields results objects as they are generated in real time.
BETA: this is a beta feature while we figure out the right abstraction.
Once that happens, this interface could change.
It also calls the callback manager's on_llm_new_token event with It also calls the callback manager's on_llm_new_token event with
similar parameters to the OpenAI LLM class method of the same name. similar parameters to the OpenAI LLM class method of the same name.
@ -274,16 +275,19 @@ class LlamaCpp(LLM):
print(result["text"], end='', flush=True) print(result["text"], end='', flush=True)
""" """
params = self._get_parameters(stop) params = {**self._get_parameters(stop), **kwargs}
result = self.client(prompt=prompt, stream=True, **params) result = self.client(prompt=prompt, stream=True, **params)
for chunk in result: for part in result:
token = chunk["choices"][0]["text"] logprobs = part["choices"][0].get("logprobs", None)
log_probs = chunk["choices"][0].get("logprobs", None) chunk = GenerationChunk(
if run_manager: text=part["choices"][0]["text"],
run_manager.on_llm_new_token( generation_info={"logprobs": logprobs},
token=token, verbose=self.verbose, log_probs=log_probs
) )
yield chunk yield chunk
if run_manager:
run_manager.on_llm_new_token(
token=chunk.text, verbose=self.verbose, log_probs=logprobs
)
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
tokenized_text = self.client.tokenize(text.encode("utf-8")) tokenized_text = self.client.tokenize(text.encode("utf-8"))

View File

@ -6,10 +6,11 @@ import warnings
from typing import ( from typing import (
AbstractSet, AbstractSet,
Any, Any,
AsyncIterator,
Callable, Callable,
Collection, Collection,
Dict, Dict,
Generator, Iterator,
List, List,
Literal, Literal,
Mapping, Mapping,
@ -27,6 +28,7 @@ from langchain.callbacks.manager import (
) )
from langchain.llms.base import BaseLLM, create_base_retry_decorator from langchain.llms.base import BaseLLM, create_base_retry_decorator
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
from langchain.schema.output import GenerationChunk
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,6 +46,19 @@ def update_token_usage(
token_usage[_key] += response["usage"][_key] token_usage[_key] += response["usage"][_key]
def _stream_response_to_generation_chunk(
stream_response: Dict[str, Any],
) -> GenerationChunk:
"""Convert a stream response to a generation chunk."""
return GenerationChunk(
text=stream_response["choices"][0]["text"],
generation_info=dict(
finish_reason=stream_response["choices"][0].get("finish_reason", None),
logprobs=stream_response["choices"][0].get("logprobs", None),
),
)
def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None: def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
"""Update response from the stream response.""" """Update response from the stream response."""
response["choices"][0]["text"] += stream_response["choices"][0]["text"] response["choices"][0]["text"] += stream_response["choices"][0]["text"]
@ -268,6 +283,50 @@ class BaseOpenAI(BaseLLM):
return {**normal_params, **self.model_kwargs} return {**normal_params, **self.model_kwargs}
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutate params
for stream_resp in completion_with_retry(self, prompt=prompt, **params):
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info
else None,
)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutate params
async for stream_resp in await acompletion_with_retry(
self, prompt=prompt, **params
):
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info
else None,
)
def _generate( def _generate(
self, self,
prompts: List[str], prompts: List[str],
@ -302,24 +361,28 @@ class BaseOpenAI(BaseLLM):
if self.streaming: if self.streaming:
if len(_prompts) > 1: if len(_prompts) > 1:
raise ValueError("Cannot stream results with multiple prompts.") raise ValueError("Cannot stream results with multiple prompts.")
params["stream"] = True
response = _streaming_response_template() generation: Optional[GenerationChunk] = None
for stream_resp in completion_with_retry( for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs):
self, prompt=_prompts, **params if generation is None:
): generation = chunk
if run_manager: else:
run_manager.on_llm_new_token( generation += chunk
stream_resp["choices"][0]["text"], assert generation is not None
verbose=self.verbose, choices.append(
logprobs=stream_resp["choices"][0]["logprobs"], {
"text": generation.text,
"finish_reason": generation.generation_info.get("finish_reason")
if generation.generation_info
else None,
"logprobs": generation.generation_info.get("logprobs")
if generation.generation_info
else None,
}
) )
_update_response(response, stream_resp)
choices.extend(response["choices"])
else: else:
response = completion_with_retry(self, prompt=_prompts, **params) response = completion_with_retry(self, prompt=_prompts, **params)
choices.extend(response["choices"]) choices.extend(response["choices"])
if not self.streaming:
# Can't update token usage if streaming
update_token_usage(_keys, response, token_usage) update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage) return self.create_llm_result(choices, prompts, token_usage)
@ -343,24 +406,30 @@ class BaseOpenAI(BaseLLM):
if self.streaming: if self.streaming:
if len(_prompts) > 1: if len(_prompts) > 1:
raise ValueError("Cannot stream results with multiple prompts.") raise ValueError("Cannot stream results with multiple prompts.")
params["stream"] = True
response = _streaming_response_template() generation: Optional[GenerationChunk] = None
async for stream_resp in await acompletion_with_retry( async for chunk in self._astream(
self, prompt=_prompts, **params _prompts[0], stop, run_manager, **kwargs
): ):
if run_manager: if generation is None:
await run_manager.on_llm_new_token( generation = chunk
stream_resp["choices"][0]["text"], else:
verbose=self.verbose, generation += chunk
logprobs=stream_resp["choices"][0]["logprobs"], assert generation is not None
choices.append(
{
"text": generation.text,
"finish_reason": generation.generation_info.get("finish_reason")
if generation.generation_info
else None,
"logprobs": generation.generation_info.get("logprobs")
if generation.generation_info
else None,
}
) )
_update_response(response, stream_resp)
choices.extend(response["choices"])
else: else:
response = await acompletion_with_retry(self, prompt=_prompts, **params) response = await acompletion_with_retry(self, prompt=_prompts, **params)
choices.extend(response["choices"]) choices.extend(response["choices"])
if not self.streaming:
# Can't update token usage if streaming
update_token_usage(_keys, response, token_usage) update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage) return self.create_llm_result(choices, prompts, token_usage)
@ -409,43 +478,6 @@ class BaseOpenAI(BaseLLM):
llm_output = {"token_usage": token_usage, "model_name": self.model_name} llm_output = {"token_usage": token_usage, "model_name": self.model_name}
return LLMResult(generations=generations, llm_output=llm_output) return LLMResult(generations=generations, llm_output=llm_output)
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
"""Call OpenAI with streaming flag and return the resulting generator.
BETA: this is a beta feature while we figure out the right abstraction.
Once that happens, this interface could change.
Args:
prompt: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens from OpenAI.
Example:
.. code-block:: python
generator = openai.stream("Tell me a joke.")
for token in generator:
yield token
"""
params = self.prep_streaming_params(stop)
generator = self.client.create(prompt=prompt, **params)
return generator
def prep_streaming_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
"""Prepare the params for streaming."""
params = self._invocation_params
if "best_of" in params and params["best_of"] != 1:
raise ValueError("OpenAI only supports best_of == 1 for streaming")
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
params["stream"] = True
return params
@property @property
def _invocation_params(self) -> Dict[str, Any]: def _invocation_params(self) -> Dict[str, Any]:
"""Get the parameters used to invoke the model.""" """Get the parameters used to invoke the model."""
@ -777,6 +809,38 @@ class OpenAIChat(BaseLLM):
del params["max_tokens"] del params["max_tokens"]
return messages, params return messages, params
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
messages, params = self._get_chat_params([prompt], stop)
params = {**params, **kwargs, "stream": True}
for stream_resp in completion_with_retry(self, messages=messages, **params):
token = stream_resp["choices"][0]["delta"].get("content", "")
yield GenerationChunk(text=token)
if run_manager:
run_manager.on_llm_new_token(token)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
messages, params = self._get_chat_params([prompt], stop)
params = {**params, **kwargs, "stream": True}
async for stream_resp in await acompletion_with_retry(
self, messages=messages, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "")
yield GenerationChunk(text=token)
if run_manager:
await run_manager.on_llm_new_token(token)
def _generate( def _generate(
self, self,
prompts: List[str], prompts: List[str],
@ -784,22 +848,18 @@ class OpenAIChat(BaseLLM):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
if self.streaming:
generation: Optional[GenerationChunk] = None
for chunk in self._stream(prompts[0], stop, run_manager, **kwargs):
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return LLMResult(generations=[[generation]])
messages, params = self._get_chat_params(prompts, stop) messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
if self.streaming:
response = ""
params["stream"] = True
for stream_resp in completion_with_retry(self, messages=messages, **params):
token = stream_resp["choices"][0]["delta"].get("content", "")
response += token
if run_manager:
run_manager.on_llm_new_token(
token,
)
return LLMResult(
generations=[[Generation(text=response)]],
)
else:
full_response = completion_with_retry(self, messages=messages, **params) full_response = completion_with_retry(self, messages=messages, **params)
llm_output = { llm_output = {
"token_usage": full_response["usage"], "token_usage": full_response["usage"],
@ -819,27 +879,19 @@ class OpenAIChat(BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
if self.streaming:
generation: Optional[GenerationChunk] = None
async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs):
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return LLMResult(generations=[[generation]])
messages, params = self._get_chat_params(prompts, stop) messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
if self.streaming: full_response = await acompletion_with_retry(self, messages=messages, **params)
response = ""
params["stream"] = True
async for stream_resp in await acompletion_with_retry(
self, messages=messages, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "")
response += token
if run_manager:
await run_manager.on_llm_new_token(
token,
)
return LLMResult(
generations=[[Generation(text=response)]],
)
else:
full_response = await acompletion_with_retry(
self, messages=messages, **params
)
llm_output = { llm_output = {
"token_usage": full_response["usage"], "token_usage": full_response["usage"],
"model_name": self.model_name, "model_name": self.model_name,

View File

@ -13,10 +13,7 @@ from tenacity import (
wait_exponential, wait_exponential,
) )
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForLLMRun
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
@ -250,12 +247,3 @@ class Tongyi(LLM):
] ]
) )
return LLMResult(generations=generations) return LLMResult(generations=generations)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
raise NotImplementedError()

View File

@ -6,7 +6,7 @@ from typing import List
from langchain.schema import BaseOutputParser from langchain.schema import BaseOutputParser
class ListOutputParser(BaseOutputParser): class ListOutputParser(BaseOutputParser[List[str]]):
"""Parse the output of an LLM call to a list.""" """Parse the output of an LLM call to a list."""
@property @property

View File

@ -4,14 +4,14 @@ from typing import Any, Dict, List, Type, Union
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
from langchain.schema import ( from langchain.schema import (
BaseLLMOutputParser,
ChatGeneration, ChatGeneration,
Generation, Generation,
OutputParserException, OutputParserException,
) )
from langchain.schema.output_parser import BaseGenerationOutputParser
class OutputFunctionsParser(BaseLLMOutputParser[Any]): class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
"""Parse an output that is one of sets of values.""" """Parse an output that is one of sets of values."""
args_only: bool = True args_only: bool = True

View File

@ -5,10 +5,10 @@ import warnings
from abc import ABC from abc import ABC
from typing import Any, Callable, Dict, List, Set from typing import Any, Callable, Dict, List, Set
from langchain.schema import BasePromptTemplate from langchain.formatting import formatter
from langchain.schema.messages import BaseMessage, HumanMessage from langchain.schema.messages import BaseMessage, HumanMessage
from langchain.schema.prompt import PromptValue from langchain.schema.prompt import PromptValue
from langchain.utils import formatter from langchain.schema.prompt_template import BasePromptTemplate
def jinja2_formatter(template: str, **kwargs: Any) -> str: def jinja2_formatter(template: str, **kwargs: Any) -> str:

View File

@ -446,7 +446,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
for message in messages: for message in messages:
if isinstance(message, BaseMessagePromptTemplate): if isinstance(message, BaseMessagePromptTemplate):
input_vars.update(message.input_variables) input_vars.update(message.input_variables)
return cls(input_variables=list(input_vars), messages=messages) return cls(input_variables=sorted(input_vars), messages=messages)
def format(self, **kwargs: Any) -> str: def format(self, **kwargs: Any) -> str:
"""Format the chat template into a string. """Format the chat template into a string.

View File

@ -1,9 +1,6 @@
from typing import List from typing import List
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
from langchain.utilities.arxiv import ArxivAPIWrapper from langchain.utilities.arxiv import ArxivAPIWrapper
@ -20,8 +17,3 @@ class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]: ) -> List[Document]:
return self.load(query=query) return self.load(query=query)
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

View File

@ -7,10 +7,7 @@ from __future__ import annotations
from typing import Any, Callable, Dict, Iterable, List, Optional from typing import Any, Callable, Dict, Iterable, List, Optional
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
@ -108,8 +105,3 @@ class BM25Retriever(BaseRetriever):
processed_query = self.preprocess_func(query) processed_query = self.preprocess_func(query)
return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k) return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k)
return return_docs return return_docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

View File

@ -3,10 +3,7 @@ from typing import Any, Dict, List, Optional, Union
import numpy as np import numpy as np
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.utils import maximal_marginal_relevance from langchain.vectorstores.utils import maximal_marginal_relevance
@ -208,11 +205,3 @@ class DocArrayRetriever(BaseRetriever):
lc_doc.metadata[name] = value lc_doc.metadata[name] = value
return lc_doc return lc_doc
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
raise NotImplementedError

View File

@ -5,10 +5,7 @@ from __future__ import annotations
import uuid import uuid
from typing import Any, Iterable, List from typing import Any, Iterable, List
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.schema import BaseRetriever from langchain.schema import BaseRetriever
@ -138,8 +135,3 @@ class ElasticSearchBM25Retriever(BaseRetriever):
for r in res["hits"]["hits"]: for r in res["hits"]["hits"]:
docs.append(Document(page_content=r["_source"]["content"])) docs.append(Document(page_content=r["_source"]["content"]))
return docs return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

View File

@ -5,10 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
from pydantic import Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
@ -184,8 +181,3 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever):
documents = self._convert_search_response(response.results) documents = self._convert_search_response(response.results)
return documents return documents
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

View File

@ -4,10 +4,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.schema import BaseRetriever from langchain.schema import BaseRetriever
@ -411,11 +408,3 @@ class AmazonKendraRetriever(BaseRetriever):
""" """
docs = self._kendra_query(query, self.top_k, self.attribute_filter) docs = self._kendra_query(query, self.top_k, self.attribute_filter)
return docs return docs
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
raise NotImplementedError("Async version is not implemented for Kendra yet.")

View File

@ -9,10 +9,7 @@ from typing import Any, List, Optional
import numpy as np import numpy as np
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
@ -82,8 +79,3 @@ class KNNRetriever(BaseRetriever):
) )
] ]
return top_k_results return top_k_results
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

View File

@ -2,10 +2,7 @@ from typing import Any, Dict, List, cast
from pydantic import Field from pydantic import Field
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
@ -42,11 +39,6 @@ class LlamaIndexRetriever(BaseRetriever):
) )
return docs return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError("LlamaIndexRetriever does not support async")
class LlamaIndexGraphRetriever(BaseRetriever): class LlamaIndexGraphRetriever(BaseRetriever):
"""Retriever for question-answering with sources over an LlamaIndex """Retriever for question-answering with sources over an LlamaIndex
@ -88,8 +80,3 @@ class LlamaIndexGraphRetriever(BaseRetriever):
Document(page_content=source_node.source_text, metadata=metadata) Document(page_content=source_node.source_text, metadata=metadata)
) )
return docs return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError("LlamaIndexGraphRetriever does not support async")

View File

@ -2,10 +2,7 @@ from typing import Any, List, Optional
from pydantic import root_validator from pydantic import root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
@ -43,8 +40,3 @@ class MetalRetriever(BaseRetriever):
metadata = {k: v for k, v in r.items() if k != "text"} metadata = {k: v for k, v in r.items() if k != "text"}
final_results.append(Document(page_content=r["text"], metadata=metadata)) final_results.append(Document(page_content=r["text"], metadata=metadata))
return final_results return final_results
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

View File

@ -4,10 +4,7 @@ from typing import Any, Dict, List, Optional
from pydantic import root_validator from pydantic import root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.milvus import Milvus from langchain.vectorstores.milvus import Milvus
@ -63,15 +60,6 @@ class MilvusRetriever(BaseRetriever):
query, run_manager=run_manager.get_child(), **kwargs query, run_manager=run_manager.get_child(), **kwargs
) )
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError
def MilvusRetreiver(*args: Any, **kwargs: Any) -> MilvusRetriever: def MilvusRetreiver(*args: Any, **kwargs: Any) -> MilvusRetriever:
"""Deprecated MilvusRetreiver. Please use MilvusRetriever ('i' before 'e') instead. """Deprecated MilvusRetreiver. Please use MilvusRetriever ('i' before 'e') instead.

View File

@ -3,10 +3,7 @@ from typing import List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.output_parsers.pydantic import PydanticOutputParser from langchain.output_parsers.pydantic import PydanticOutputParser
@ -101,14 +98,6 @@ class MultiQueryRetriever(BaseRetriever):
unique_documents = self.unique_union(documents) unique_documents = self.unique_union(documents)
return unique_documents return unique_documents
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
raise NotImplementedError
def generate_queries( def generate_queries(
self, question: str, run_manager: CallbackManagerForRetrieverRun self, question: str, run_manager: CallbackManagerForRetrieverRun
) -> List[str]: ) -> List[str]:

View File

@ -5,10 +5,7 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator from pydantic import Extra, root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
@ -175,8 +172,3 @@ class PineconeHybridSearchRetriever(BaseRetriever):
) )
# return search results as json # return search results as json
return final_result return final_result
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

View File

@ -1,9 +1,6 @@
from typing import List from typing import List
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
from langchain.utilities.pupmed import PubMedAPIWrapper from langchain.utilities.pupmed import PubMedAPIWrapper
@ -19,8 +16,3 @@ class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]: ) -> List[Document]:
return self.load_docs(query=query) return self.load_docs(query=query)
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

View File

@ -5,10 +5,7 @@ from typing import Any, Dict, List, Optional, Type, cast
from pydantic import BaseModel, Field, root_validator from pydantic import BaseModel, Field, root_validator
from langchain import LLMChain from langchain import LLMChain
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chains.query_constructor.base import load_query_constructor_chain from langchain.chains.query_constructor.base import load_query_constructor_chain
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
from langchain.chains.query_constructor.schema import AttributeInfo from langchain.chains.query_constructor.schema import AttributeInfo
@ -119,11 +116,6 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs) docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs)
return docs return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,

View File

@ -5,10 +5,7 @@ from typing import Any, Iterable, List, Optional
import numpy as np import numpy as np
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
@ -113,8 +110,3 @@ class SVMRetriever(BaseRetriever):
): ):
top_k_results.append(Document(page_content=self.texts[row - 1])) top_k_results.append(Document(page_content=self.texts[row - 1]))
return top_k_results return top_k_results
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

View File

@ -2,10 +2,7 @@ from __future__ import annotations
from typing import Any, Dict, Iterable, List, Optional from typing import Any, Dict, Iterable, List, Optional
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
@ -79,8 +76,3 @@ class TFIDFRetriever(BaseRetriever):
) # Op -- (n_docs,1) -- Cosine Sim with each doc ) # Op -- (n_docs,1) -- Cosine Sim with each doc
return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]] return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
return return_docs return return_docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

View File

@ -4,10 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
from pydantic import Field from pydantic import Field
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
@ -109,12 +106,6 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever):
result.append(buffered_doc) result.append(buffered_doc)
return result return result
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""Return documents that are relevant to the query."""
raise NotImplementedError
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Add documents to vectorstore.""" """Add documents to vectorstore."""
current_time = kwargs.get("current_time") current_time = kwargs.get("current_time")

View File

@ -3,10 +3,7 @@ from __future__ import annotations
import json import json
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
if TYPE_CHECKING: if TYPE_CHECKING:
@ -57,11 +54,6 @@ class VespaRetriever(BaseRetriever):
body["query"] = query body["query"] = query
return self._query(body) return self._query(body)
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError
def get_relevant_documents_with_filter( def get_relevant_documents_with_filter(
self, query: str, *, _filter: Optional[str] = None self, query: str, *, _filter: Optional[str] = None
) -> List[Document]: ) -> List[Document]:

View File

@ -5,10 +5,7 @@ from uuid import uuid4
from pydantic import root_validator from pydantic import root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.schema import BaseRetriever from langchain.schema import BaseRetriever
@ -118,8 +115,3 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
text = res.pop(self.text_key) text = res.pop(self.text_key)
docs.append(Document(page_content=text, metadata=res)) docs.append(Document(page_content=text, metadata=res))
return docs return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

View File

@ -1,9 +1,6 @@
from typing import List from typing import List
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
from langchain.utilities.wikipedia import WikipediaAPIWrapper from langchain.utilities.wikipedia import WikipediaAPIWrapper
@ -19,8 +16,3 @@ class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]: ) -> List[Document]:
return self.load(query=query) return self.load(query=query)
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

View File

@ -3,10 +3,7 @@ from typing import Any, Dict, List, Optional
from pydantic import root_validator from pydantic import root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import CallbackManagerForRetrieverRun
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.zilliz import Zilliz from langchain.vectorstores.zilliz import Zilliz
@ -67,15 +64,6 @@ class ZillizRetriever(BaseRetriever):
query, run_manager=run_manager.get_child(), **kwargs query, run_manager=run_manager.get_child(), **kwargs
) )
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError
def ZillizRetreiver(*args: Any, **kwargs: Any) -> ZillizRetriever: def ZillizRetreiver(*args: Any, **kwargs: Any) -> ZillizRetriever:
"""Deprecated ZillizRetreiver. """Deprecated ZillizRetreiver.

View File

@ -1,6 +1,5 @@
from langchain.schema.agent import AgentAction, AgentFinish from langchain.schema.agent import AgentAction, AgentFinish
from langchain.schema.document import BaseDocumentTransformer, Document from langchain.schema.document import BaseDocumentTransformer, Document
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.memory import BaseChatMessageHistory, BaseMemory from langchain.schema.memory import BaseChatMessageHistory, BaseMemory
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
@ -67,6 +66,5 @@ __all__ = [
"BaseOutputParser", "BaseOutputParser",
"BaseLLMOutputParser", "BaseLLMOutputParser",
"BasePromptTemplate", "BasePromptTemplate",
"BaseLanguageModel",
"format_document", "format_document",
] ]

View File

@ -1,12 +1,22 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Set from typing import (
TYPE_CHECKING,
Any,
List,
Optional,
Sequence,
Set,
TypeVar,
Union,
)
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage, get_buffer_string from langchain.schema.messages import BaseMessage, get_buffer_string
from langchain.schema.output import LLMResult from langchain.schema.output import LLMResult
from langchain.schema.prompt import PromptValue from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable
from langchain.utils import get_pydantic_field_names from langchain.utils import get_pydantic_field_names
if TYPE_CHECKING: if TYPE_CHECKING:
@ -32,7 +42,13 @@ def _get_token_ids_default_method(text: str) -> List[int]:
return tokenizer.encode(text) return tokenizer.encode(text)
class BaseLanguageModel(Serializable, ABC): LanguageModelInput = Union[PromptValue, str, List[BaseMessage]]
LanguageModelOutput = TypeVar("LanguageModelOutput")
class BaseLanguageModel(
Serializable, Runnable[LanguageModelInput, LanguageModelOutput], ABC
):
"""Abstract base class for interfacing with language models. """Abstract base class for interfacing with language models.
All language model wrappers inherit from BaseLanguageModel. All language model wrappers inherit from BaseLanguageModel.

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
from typing import List, Sequence from typing import Any, Dict, List, Sequence
from pydantic import Field from pydantic import Field
@ -78,6 +78,49 @@ class BaseMessage(Serializable):
return True return True
class BaseMessageChunk(BaseMessage):
def _merge_kwargs_dict(
self, left: Dict[str, Any], right: Dict[str, Any]
) -> Dict[str, Any]:
"""Merge additional_kwargs from another BaseMessageChunk into this one."""
merged = left.copy()
for k, v in right.items():
if k not in merged:
merged[k] = v
elif type(merged[k]) != type(v):
raise ValueError(
f'additional_kwargs["{k}"] already exists in this message,'
" but with a different type."
)
elif isinstance(merged[k], str):
merged[k] += v
elif isinstance(merged[k], dict):
merged[k] = self._merge_kwargs_dict(merged[k], v)
else:
raise ValueError(
f"Additional kwargs key {k} already exists in this message."
)
return merged
def __add__(self, other: Any) -> BaseMessageChunk:
if isinstance(other, BaseMessageChunk):
# If both are (subclasses of) BaseMessageChunk,
# concat into a single BaseMessageChunk
return self.__class__(
content=self.content + other.content,
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
else:
raise TypeError(
'unsupported operand type(s) for +: "'
f"{self.__class__.__name__}"
f'" and "{other.__class__.__name__}"'
)
class HumanMessage(BaseMessage): class HumanMessage(BaseMessage):
"""A Message from a human.""" """A Message from a human."""
@ -92,6 +135,10 @@ class HumanMessage(BaseMessage):
return "human" return "human"
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
pass
class AIMessage(BaseMessage): class AIMessage(BaseMessage):
"""A Message from an AI.""" """A Message from an AI."""
@ -106,6 +153,10 @@ class AIMessage(BaseMessage):
return "ai" return "ai"
class AIMessageChunk(AIMessage, BaseMessageChunk):
pass
class SystemMessage(BaseMessage): class SystemMessage(BaseMessage):
"""A Message for priming AI behavior, usually passed in as the first of a sequence """A Message for priming AI behavior, usually passed in as the first of a sequence
of input messages. of input messages.
@ -117,6 +168,10 @@ class SystemMessage(BaseMessage):
return "system" return "system"
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
pass
class FunctionMessage(BaseMessage): class FunctionMessage(BaseMessage):
"""A Message for passing the result of executing a function back to a model.""" """A Message for passing the result of executing a function back to a model."""
@ -129,6 +184,10 @@ class FunctionMessage(BaseMessage):
return "function" return "function"
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
pass
class ChatMessage(BaseMessage): class ChatMessage(BaseMessage):
"""A Message that can be assigned an arbitrary speaker (i.e. role).""" """A Message that can be assigned an arbitrary speaker (i.e. role)."""
@ -141,6 +200,10 @@ class ChatMessage(BaseMessage):
return "chat" return "chat"
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
pass
def _message_to_dict(message: BaseMessage) -> dict: def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()} return {"type": message.type, "data": message.dict()}

View File

@ -7,7 +7,7 @@ from uuid import UUID
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage from langchain.schema.messages import BaseMessage, BaseMessageChunk
class Generation(Serializable): class Generation(Serializable):
@ -28,6 +28,24 @@ class Generation(Serializable):
return True return True
class GenerationChunk(Generation):
def __add__(self, other: GenerationChunk) -> GenerationChunk:
if isinstance(other, GenerationChunk):
generation_info = (
{**(self.generation_info or {}), **(other.generation_info or {})}
if self.generation_info is not None or other.generation_info is not None
else None
)
return GenerationChunk(
text=self.text + other.text,
generation_info=generation_info,
)
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)
class ChatGeneration(Generation): class ChatGeneration(Generation):
"""A single chat generation output.""" """A single chat generation output."""
@ -43,6 +61,26 @@ class ChatGeneration(Generation):
return values return values
class ChatGenerationChunk(ChatGeneration):
message: BaseMessageChunk
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk):
generation_info = (
{**(self.generation_info or {}), **(other.generation_info or {})}
if self.generation_info is not None or other.generation_info is not None
else None
)
return ChatGenerationChunk(
message=self.message + other.message,
generation_info=generation_info,
)
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)
class RunInfo(BaseModel): class RunInfo(BaseModel):
"""Class that contains metadata for a single execution of a Chain or model.""" """Class that contains metadata for a single execution of a Chain or model."""

View File

@ -1,16 +1,18 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, TypeVar from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.output import Generation from langchain.schema.messages import BaseMessage
from langchain.schema.output import ChatGeneration, Generation
from langchain.schema.prompt import PromptValue from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable, RunnableConfig
T = TypeVar("T") T = TypeVar("T")
class BaseLLMOutputParser(Serializable, ABC, Generic[T]): class BaseLLMOutputParser(Serializable, Generic[T], ABC):
"""Abstract base class for parsing the outputs of a model.""" """Abstract base class for parsing the outputs of a model."""
@abstractmethod @abstractmethod
@ -26,7 +28,19 @@ class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
""" """
class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]): class BaseGenerationOutputParser(
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
):
def invoke(
self, input: str | BaseMessage, config: RunnableConfig | None = None
) -> T:
if isinstance(input, BaseMessage):
return self.parse_result([ChatGeneration(message=input)])
else:
return self.parse_result([Generation(text=input)])
class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]):
"""Base class to parse the output of an LLM call. """Base class to parse the output of an LLM call.
Output parsers help structure language model responses. Output parsers help structure language model responses.
@ -53,6 +67,14 @@ class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
return "boolean_output_parser" return "boolean_output_parser"
""" # noqa: E501 """ # noqa: E501
def invoke(
self, input: str | BaseMessage, config: RunnableConfig | None = None
) -> T:
if isinstance(input, BaseMessage):
return self.parse_result([ChatGeneration(message=input)])
else:
return self.parse_result([Generation(text=input)])
def parse_result(self, result: List[Generation]) -> T: def parse_result(self, result: List[Generation]) -> T:
"""Parse a list of candidate model Generations into a specific format. """Parse a list of candidate model Generations into a specific format.

View File

@ -12,9 +12,10 @@ from langchain.load.serializable import Serializable
from langchain.schema.document import Document from langchain.schema.document import Document
from langchain.schema.output_parser import BaseOutputParser from langchain.schema.output_parser import BaseOutputParser
from langchain.schema.prompt import PromptValue from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable, RunnableConfig
class BasePromptTemplate(Serializable, ABC): class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
"""Base class for all prompt templates, returning a prompt.""" """Base class for all prompt templates, returning a prompt."""
input_variables: List[str] input_variables: List[str]
@ -34,6 +35,11 @@ class BasePromptTemplate(Serializable, ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue:
return self._call_with_config(
lambda inner_input: self.format_prompt(**inner_input), input, config
)
@abstractmethod @abstractmethod
def format_prompt(self, **kwargs: Any) -> PromptValue: def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages.""" """Create Chat Messages."""

View File

@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.document import Document from langchain.schema.document import Document
from langchain.schema.runnable import Runnable, RunnableConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
@ -17,7 +18,7 @@ if TYPE_CHECKING:
) )
class BaseRetriever(Serializable, ABC): class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
"""Abstract base class for a Document retrieval system. """Abstract base class for a Document retrieval system.
A retrieval system is defined as something that can take string queries and return A retrieval system is defined as something that can take string queries and return
@ -43,9 +44,6 @@ class BaseRetriever(Serializable, ABC):
# Op -- (n_docs,1) -- Cosine Sim with each doc # Op -- (n_docs,1) -- Cosine Sim with each doc
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,)) results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]] return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError
""" # noqa: E501 """ # noqa: E501
class Config: class Config:
@ -106,6 +104,20 @@ class BaseRetriever(Serializable, ABC):
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0 len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
) )
def invoke(
self, input: str, config: Optional[RunnableConfig] = None
) -> List[Document]:
return self.get_relevant_documents(input, **(config or {}))
async def ainvoke(
self, input: str, config: Optional[RunnableConfig] = None
) -> List[Document]:
if type(self).aget_relevant_documents == BaseRetriever.aget_relevant_documents:
# If the retriever doesn't implement async, use default implementation
return await super().ainvoke(input, config)
return await self.aget_relevant_documents(input, **(config or {}))
@abstractmethod @abstractmethod
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
@ -118,7 +130,6 @@ class BaseRetriever(Serializable, ABC):
List of relevant documents List of relevant documents
""" """
@abstractmethod
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]: ) -> List[Document]:
@ -129,6 +140,7 @@ class BaseRetriever(Serializable, ABC):
Returns: Returns:
List of relevant documents List of relevant documents
""" """
raise NotImplementedError()
def get_relevant_documents( def get_relevant_documents(
self, self,

View File

@ -0,0 +1,705 @@
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
Dict,
Generic,
Iterator,
List,
Optional,
TypedDict,
TypeVar,
Union,
cast,
)
from pydantic import Field
from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
async with semaphore:
return await coro
async def _gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
if n is None:
return await asyncio.gather(*coros)
semaphore = asyncio.Semaphore(n)
return await asyncio.gather(*(_gated_coro(semaphore, c) for c in coros))
class RunnableConfig(TypedDict, total=False):
tags: List[str]
"""
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
You can use these to filter calls.
"""
metadata: Dict[str, Any]
"""
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
Keys should be strings, values should be JSON-serializable.
"""
callbacks: Callbacks
"""
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
"""
Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output")
Other = TypeVar("Other")
class Runnable(Generic[Input, Output], ABC):
def __or__(
self,
other: Union[
Runnable[Any, Other],
Dict[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
],
) -> RunnableSequence[Input, Other]:
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
def __ror__(
self,
other: Union[
Runnable[Other, Any],
Dict[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
],
) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
@abstractmethod
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
...
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
return await asyncio.get_running_loop().run_in_executor(
None, self.invoke, input, config
)
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
configs = self._get_config_list(config, len(inputs))
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list(executor.map(self.invoke, inputs, configs))
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
configs = self._get_config_list(config, len(inputs))
coros = map(self.ainvoke, inputs, configs)
return await _gather_with_concurrency(max_concurrency, *coros)
def stream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
yield self.invoke(input, config)
async def astream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
yield await self.ainvoke(input, config)
def _get_config_list(
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
) -> List[RunnableConfig]:
if isinstance(config, list) and len(config) != length:
raise ValueError(
f"config must be a list of the same length as inputs, "
f"but got {len(config)} configs for {length} inputs"
)
return (
config
if isinstance(config, list)
else [config.copy() if config is not None else {} for _ in range(length)]
)
def _call_with_config(
self,
func: Callable[[Input], Output],
input: Input,
config: Optional[RunnableConfig],
) -> Output:
from langchain.callbacks.manager import CallbackManager
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
try:
output = func(input)
except Exception as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(
output if isinstance(output, dict) else {"output": output}
)
return output
class RunnableSequence(Serializable, Runnable[Input, Output]):
first: Runnable[Input, Any]
middle: List[Runnable[Any, Any]] = Field(default_factory=list)
last: Runnable[Any, Output]
@property
def steps(self) -> List[Runnable[Any, Any]]:
return [self.first] + self.middle + [self.last]
@property
def lc_serializable(self) -> bool:
return True
class Config:
arbitrary_types_allowed = True
def __or__(
self,
other: Union[
Runnable[Any, Other],
Dict[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
],
) -> RunnableSequence[Input, Other]:
if isinstance(other, RunnableSequence):
return RunnableSequence(
first=self.first,
middle=self.middle + [self.last] + other.middle,
last=other.last,
)
else:
return RunnableSequence(
first=self.first,
middle=self.middle + [self.last],
last=_coerce_to_runnable(other),
)
def __ror__(
self,
other: Union[
Runnable[Other, Any],
Dict[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
],
) -> RunnableSequence[Other, Output]:
if isinstance(other, RunnableSequence):
return RunnableSequence(
first=other.first,
middle=other.middle + [other.last] + self.middle,
last=self.last,
)
else:
return RunnableSequence(
first=_coerce_to_runnable(other),
middle=[self.first] + self.middle,
last=self.last,
)
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
# invoke all steps in sequence
try:
for step in self.steps:
input = step.invoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
# finish the root run
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(
input if isinstance(input, dict) else {"output": input}
)
return cast(Output, input)
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
# invoke all steps in sequence
try:
for step in self.steps:
input = await step.ainvoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
# finish the root run
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(
input if isinstance(input, dict) else {"output": input}
)
return cast(Output, input)
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
configs = self._get_config_list(config, len(inputs))
callback_managers = [
CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers = [
cm.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
for cm, input in zip(callback_managers, inputs)
]
# invoke
try:
for step in self.steps:
inputs = step.batch(
inputs,
[
# each step a child run of the corresponding root run
_patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
# finish the root runs
except (KeyboardInterrupt, Exception) as e:
for rm in run_managers:
rm.on_chain_error(e)
raise
else:
for rm, input in zip(run_managers, inputs):
rm.on_chain_end(input if isinstance(input, dict) else {"output": input})
return cast(List[Output], inputs)
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
)
# setup callbacks
configs = self._get_config_list(config, len(inputs))
callback_managers = [
AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
for cm, input in zip(callback_managers, inputs)
)
)
# invoke .batch() on each step
# this uses batching optimizations in Runnable subclasses, like LLM
try:
for step in self.steps:
inputs = await step.abatch(
inputs,
[
# each step a child run of the corresponding root run
_patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
# finish the root runs
except (KeyboardInterrupt, Exception) as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
raise
else:
await asyncio.gather(
*(
rm.on_chain_end(
input if isinstance(input, dict) else {"output": input}
)
for rm, input in zip(run_managers, inputs)
)
)
return cast(List[Output], inputs)
def stream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
# invoke the first steps
try:
for step in [self.first] + self.middle:
input = step.invoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise
# stream the last step
final: Union[Output, None] = None
final_supported = True
try:
for output in self.last.stream(
input,
# mark the last step as a child run
_patch_config(config, run_manager.get_child()),
):
yield output
# Accumulate output if possible, otherwise disable accumulation
if final_supported:
if final is None:
final = output
else:
try:
final += output # type: ignore[operator]
except TypeError:
final = None
final_supported = False
pass
# finish the root run
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(
final if isinstance(final, dict) else {"output": final}
)
async def astream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
# invoke the first steps
try:
for step in [self.first] + self.middle:
input = await step.ainvoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise
# stream the last step
final: Union[Output, None] = None
final_supported = True
try:
async for output in self.last.astream(
input,
# mark the last step as a child run
_patch_config(config, run_manager.get_child()),
):
yield output
# Accumulate output if possible, otherwise disable accumulation
if final_supported:
if final is None:
final = output
else:
try:
final += output # type: ignore[operator]
except TypeError:
final = None
final_supported = False
pass
# finish the root run
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(
final if isinstance(final, dict) else {"output": final}
)
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
steps: Dict[str, Runnable[Input, Any]]
@property
def lc_serializable(self) -> bool:
return True
class Config:
arbitrary_types_allowed = True
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), {"input": input})
# gather results from all steps
try:
# copy to avoid issues from the caller mutating the steps during invoke()
steps = self.steps.copy()
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(
step.invoke,
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
for step in steps.values()
]
output = {key: future.result() for key, future in zip(steps, futures)}
# finish the root run
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(output)
return output
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), {"input": input}
)
# gather results from all steps
try:
# copy to avoid issues from the caller mutating the steps during invoke()
steps = self.steps.copy()
results = await asyncio.gather(
*(
step.ainvoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
for step in steps.values()
)
)
output = {key: value for key, value in zip(steps, results)}
# finish the root run
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(output)
return output
class RunnableLambda(Runnable[Input, Output]):
def __init__(self, func: Callable[[Input], Output]) -> None:
if callable(func):
self.func = func
else:
raise TypeError(
"Expected a callable type for `func`."
f"Instead got an unsupported type: {type(func)}"
)
def __eq__(self, other: Any) -> bool:
if isinstance(other, RunnableLambda):
return self.func == other.func
else:
return False
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
return self._call_with_config(self.func, input, config)
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
@property
def lc_serializable(self) -> bool:
return True
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
return self._call_with_config(lambda x: x, input, config)
def _patch_config(
config: RunnableConfig, callback_manager: BaseCallbackManager
) -> RunnableConfig:
config = config.copy()
config["callbacks"] = callback_manager
return config
def _coerce_to_runnable(
thing: Union[
Runnable[Input, Output],
Callable[[Input], Output],
Dict[str, Union[Runnable[Input, Output], Callable[[Input], Output]]],
]
) -> Runnable[Input, Output]:
if isinstance(thing, Runnable):
return thing
elif callable(thing):
return RunnableLambda(thing)
elif isinstance(thing, dict):
runnables = {key: _coerce_to_runnable(r) for key, r in thing.items()}
return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
else:
raise TypeError(
f"Expected a Runnable, callable or dict."
f"Instead got an unsupported type: {type(thing)}"
)

View File

@ -177,3 +177,68 @@ def test_chat_openai_extra_kwargs() -> None:
# Test that "model" cannot be specified in kwargs # Test that "model" cannot be specified in kwargs
with pytest.raises(ValueError): with pytest.raises(ValueError):
ChatOpenAI(model_kwargs={"model": "text-davinci-003"}) ChatOpenAI(model_kwargs={"model": "text-davinci-003"})
def test_openai_streaming() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatOpenAI(max_tokens=10)
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
@pytest.mark.asyncio
async def test_openai_astream() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatOpenAI(max_tokens=10)
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token.content, str)
@pytest.mark.asyncio
async def test_openai_abatch() -> None:
"""Test streaming tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
@pytest.mark.asyncio
async def test_openai_abatch_tags() -> None:
"""Test batch tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token.content, str)
def test_openai_batch() -> None:
"""Test batch tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
@pytest.mark.asyncio
async def test_openai_ainvoke() -> None:
"""Test invoke tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
def test_openai_invoke() -> None:
"""Test invoke tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)

View File

@ -93,7 +93,64 @@ def test_openai_streaming() -> None:
assert isinstance(generator, Generator) assert isinstance(generator, Generator)
for token in generator: for token in generator:
assert isinstance(token["choices"][0]["text"], str) assert isinstance(token, str)
@pytest.mark.asyncio
async def test_openai_astream() -> None:
"""Test streaming tokens from OpenAI."""
llm = OpenAI(max_tokens=10)
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token, str)
@pytest.mark.asyncio
async def test_openai_abatch() -> None:
"""Test streaming tokens from OpenAI."""
llm = OpenAI(max_tokens=10)
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)
@pytest.mark.asyncio
async def test_openai_abatch_tags() -> None:
"""Test streaming tokens from OpenAI."""
llm = OpenAI(max_tokens=10)
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token, str)
def test_openai_batch() -> None:
"""Test streaming tokens from OpenAI."""
llm = OpenAI(max_tokens=10)
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)
@pytest.mark.asyncio
async def test_openai_ainvoke() -> None:
"""Test streaming tokens from OpenAI."""
llm = OpenAI(max_tokens=10)
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result, str)
def test_openai_invoke() -> None:
"""Test streaming tokens from OpenAI."""
llm = OpenAI(max_tokens=10)
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result, str)
def test_openai_multiple_prompts() -> None: def test_openai_multiple_prompts() -> None:
@ -105,13 +162,6 @@ def test_openai_multiple_prompts() -> None:
assert len(output.generations) == 2 assert len(output.generations) == 2
def test_openai_streaming_error() -> None:
"""Test error handling in stream."""
llm = OpenAI(best_of=2)
with pytest.raises(ValueError):
llm.stream("I'm Pickle Rick")
def test_openai_streaming_best_of_error() -> None: def test_openai_streaming_best_of_error() -> None:
"""Test validation for streaming fails if best_of is not 1.""" """Test validation for streaming fails if best_of is not 1."""
with pytest.raises(ValueError): with pytest.raises(ValueError):

View File

@ -67,10 +67,3 @@ def test_promptlayer_openai_streaming() -> None:
for token in generator: for token in generator:
assert isinstance(token["choices"][0]["text"], str) assert isinstance(token["choices"][0]["text"], str)
def test_promptlayer_openai_streaming_error() -> None:
"""Test error handling in stream."""
llm = PromptLayerOpenAI(best_of=2)
with pytest.raises(ValueError):
llm.stream("I'm Pickle Rick")

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,38 @@
from langchain.schema.messages import AIMessageChunk, HumanMessageChunk
def test_message_chunks() -> None:
assert AIMessageChunk(content="I am") + AIMessageChunk(
content=" indeed."
) == AIMessageChunk(
content="I am indeed."
), "MessageChunk + MessageChunk should be a MessageChunk"
assert AIMessageChunk(content="I am") + HumanMessageChunk(
content=" indeed."
) == AIMessageChunk(
content="I am indeed."
), "MessageChunk + MessageChunk should be a MessageChunk of same class as the left side" # noqa: E501
assert AIMessageChunk(
content="", additional_kwargs={"foo": "bar"}
) + AIMessageChunk(content="", additional_kwargs={"baz": "foo"}) == AIMessageChunk(
content="", additional_kwargs={"foo": "bar", "baz": "foo"}
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
assert AIMessageChunk(
content="", additional_kwargs={"function_call": {"name": "web_search"}}
) + AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": "{\n"}}
) + AIMessageChunk(
content="",
additional_kwargs={"function_call": {"arguments": ' "query": "turtles"\n}'}},
) == AIMessageChunk(
content="",
additional_kwargs={
"function_call": {
"name": "web_search",
"arguments": '{\n "query": "turtles"\n}',
}
},
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501

View File

@ -0,0 +1,52 @@
from langchain.schema.messages import HumanMessageChunk
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
def test_generation_chunk() -> None:
assert GenerationChunk(text="Hello, ") + GenerationChunk(
text="world!"
) == GenerationChunk(
text="Hello, world!"
), "GenerationChunk + GenerationChunk should be a GenerationChunk"
assert GenerationChunk(text="Hello, ") + GenerationChunk(
text="world!", generation_info={"foo": "bar"}
) == GenerationChunk(
text="Hello, world!", generation_info={"foo": "bar"}
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
assert GenerationChunk(text="Hello, ") + GenerationChunk(
text="world!", generation_info={"foo": "bar"}
) + GenerationChunk(text="!", generation_info={"baz": "foo"}) == GenerationChunk(
text="Hello, world!!", generation_info={"foo": "bar", "baz": "foo"}
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
def test_chat_generation_chunk() -> None:
assert ChatGenerationChunk(
message=HumanMessageChunk(content="Hello, ")
) + ChatGenerationChunk(
message=HumanMessageChunk(content="world!")
) == ChatGenerationChunk(
message=HumanMessageChunk(content="Hello, world!")
), "ChatGenerationChunk + ChatGenerationChunk should be a ChatGenerationChunk"
assert ChatGenerationChunk(
message=HumanMessageChunk(content="Hello, ")
) + ChatGenerationChunk(
message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"}
) == ChatGenerationChunk(
message=HumanMessageChunk(content="Hello, world!"),
generation_info={"foo": "bar"},
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
assert ChatGenerationChunk(
message=HumanMessageChunk(content="Hello, ")
) + ChatGenerationChunk(
message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"}
) + ChatGenerationChunk(
message=HumanMessageChunk(content="!"), generation_info={"baz": "foo"}
) == ChatGenerationChunk(
message=HumanMessageChunk(content="Hello, world!!"),
generation_info={"foo": "bar", "baz": "foo"},
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501

View File

@ -0,0 +1,547 @@
from typing import Any, Dict, List, Optional
from uuid import UUID
import pytest
from freezegun import freeze_time
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from langchain.callbacks.manager import Callbacks
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
from langchain.chat_models.fake import FakeListChatModel
from langchain.llms.fake import FakeListLLM
from langchain.load.dump import dumps
from langchain.output_parsers.list import CommaSeparatedListOutputParser
from langchain.prompts.chat import (
ChatPromptTemplate,
ChatPromptValue,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema.document import Document
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import (
Runnable,
RunnableConfig,
RunnableLambda,
RunnableMap,
RunnablePassthrough,
RunnableSequence,
)
class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution."""
def __init__(self) -> None:
"""Initialize the tracer."""
super().__init__()
self.runs: List[Run] = []
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
self.runs.append(run)
class FakeRunnable(Runnable[str, int]):
def invoke(
self,
input: str,
config: Optional[RunnableConfig] = None,
) -> int:
return len(input)
class FakeRetriever(BaseRetriever):
def _get_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
return [Document(page_content="foo"), Document(page_content="bar")]
async def _aget_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
return [Document(page_content="foo"), Document(page_content="bar")]
@pytest.fixture()
def fixed_uuids(mocker: MockerFixture) -> MockerFixture._Patcher:
"""Note this mock only works with `import uuid; uuid.uuid4()`,
it does not work with `from uuid import uuid4; uuid4()`."""
# Disable tracing to avoid fixed UUIDs causing tracing errors.
mocker.patch.dict("os.environ", {"LANGCHAIN_TRACING_V2": "false"})
side_effect = (
UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)
)
return mocker.patch("uuid.uuid4", side_effect=side_effect)
@pytest.mark.asyncio
async def test_default_method_implementations(mocker: MockerFixture) -> None:
fake = FakeRunnable()
spy = mocker.spy(fake, "invoke")
assert fake.invoke("hello", dict(tags=["a-tag"])) == 5
assert spy.call_args_list == [
mocker.call("hello", dict(tags=["a-tag"])),
]
spy.reset_mock()
assert [*fake.stream("hello", dict(metadata={"key": "value"}))] == [5]
assert spy.call_args_list == [
mocker.call("hello", dict(metadata={"key": "value"})),
]
spy.reset_mock()
assert fake.batch(
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
) == [5, 7]
assert spy.call_args_list == [
mocker.call("hello", dict(tags=["a-tag"])),
mocker.call("wooorld", dict(metadata={"key": "value"})),
]
spy.reset_mock()
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
assert spy.call_args_list == [
mocker.call("hello", dict(tags=["a-tag"])),
mocker.call("wooorld", dict(tags=["a-tag"])),
]
spy.reset_mock()
assert await fake.ainvoke("hello", config={"callbacks": []}) == 5
assert spy.call_args_list == [
mocker.call("hello", dict(callbacks=[])),
]
spy.reset_mock()
assert [part async for part in fake.astream("hello")] == [5]
assert spy.call_args_list == [
mocker.call("hello", None),
]
spy.reset_mock()
assert await fake.abatch(["hello", "wooorld"], dict(metadata={"key": "value"})) == [
5,
7,
]
assert spy.call_args_list == [
mocker.call("hello", dict(metadata={"key": "value"})),
mocker.call("wooorld", dict(metadata={"key": "value"})),
]
@pytest.mark.asyncio
async def test_prompt() -> None:
prompt = ChatPromptTemplate.from_messages(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessagePromptTemplate.from_template("{question}"),
]
)
expected = ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert prompt.invoke({"question": "What is your name?"}) == expected
assert prompt.batch(
[
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
]
) == [
expected,
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your favorite color?"),
]
),
]
assert [*prompt.stream({"question": "What is your name?"})] == [expected]
assert await prompt.ainvoke({"question": "What is your name?"}) == expected
assert await prompt.abatch(
[
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
]
) == [
expected,
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your favorite color?"),
]
),
]
assert [
part async for part in prompt.astream({"question": "What is your name?"})
] == [expected]
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_prompt_with_chat_model(
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat = FakeListChatModel(responses=["foo", "bar"])
chain = prompt | chat
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == []
assert chain.last == chat
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "invoke")
chat_spy = mocker.spy(chat.__class__, "invoke")
tracer = FakeTracer()
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == AIMessage(content="foo")
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert tracer.runs == snapshot
mocker.stop(prompt_spy)
mocker.stop(chat_spy)
# Test batch
prompt_spy = mocker.spy(prompt.__class__, "batch")
chat_spy = mocker.spy(chat.__class__, "batch")
tracer = FakeTracer()
assert chain.batch(
[
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
],
dict(callbacks=[tracer]),
) == [
AIMessage(content="bar"),
AIMessage(content="foo"),
]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
]
assert chat_spy.call_args.args[1] == [
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
),
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your favorite color?"),
]
),
]
assert tracer.runs == snapshot
mocker.stop(prompt_spy)
mocker.stop(chat_spy)
# Test stream
prompt_spy = mocker.spy(prompt.__class__, "invoke")
chat_spy = mocker.spy(chat.__class__, "stream")
tracer = FakeTracer()
assert [
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
] == [AIMessage(content="bar")]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_prompt_with_llm(
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
llm = FakeListLLM(responses=["foo", "bar"])
chain = prompt | llm
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == []
assert chain.last == llm
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
llm_spy = mocker.spy(llm.__class__, "ainvoke")
tracer = FakeTracer()
assert (
await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
== "foo"
)
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert llm_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert tracer.runs == snapshot
mocker.stop(prompt_spy)
mocker.stop(llm_spy)
# Test batch
prompt_spy = mocker.spy(prompt.__class__, "abatch")
llm_spy = mocker.spy(llm.__class__, "abatch")
tracer = FakeTracer()
assert await chain.abatch(
[
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
],
dict(callbacks=[tracer]),
) == ["bar", "foo"]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
]
assert llm_spy.call_args.args[1] == [
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
),
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your favorite color?"),
]
),
]
assert tracer.runs == snapshot
mocker.stop(prompt_spy)
mocker.stop(llm_spy)
# Test stream
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
llm_spy = mocker.spy(llm.__class__, "astream")
tracer = FakeTracer()
assert [
token
async for token in chain.astream(
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
] == ["bar"]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert llm_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
@freeze_time("2023-01-01")
def test_prompt_with_chat_model_and_parser(
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat = FakeListChatModel(responses=["foo, bar"])
parser = CommaSeparatedListOutputParser()
chain = prompt | chat | parser
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == [chat]
assert chain.last == parser
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "invoke")
chat_spy = mocker.spy(chat.__class__, "invoke")
parser_spy = mocker.spy(parser.__class__, "invoke")
tracer = FakeTracer()
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == ["foo", "bar"]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar")
assert tracer.runs == snapshot
@freeze_time("2023-01-01")
def test_seq_dict_prompt_llm(
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
) -> None:
passthrough = mocker.Mock(side_effect=lambda x: x)
retriever = FakeRetriever()
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ """Context:
{documents}
Question:
{question}"""
)
chat = FakeListChatModel(responses=["foo, bar"])
parser = CommaSeparatedListOutputParser()
chain = (
{
"question": RunnablePassthrough[str]() | passthrough,
"documents": passthrough | retriever,
"just_to_test_lambda": passthrough,
}
| prompt
| chat
| parser
)
assert isinstance(chain, RunnableSequence)
assert isinstance(chain.first, RunnableMap)
assert chain.middle == [prompt, chat]
assert chain.last == parser
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "invoke")
chat_spy = mocker.spy(chat.__class__, "invoke")
parser_spy = mocker.spy(parser.__class__, "invoke")
tracer = FakeTracer()
assert chain.invoke("What is your name?", dict(callbacks=[tracer])) == [
"foo",
"bar",
]
assert prompt_spy.call_args.args[1] == {
"documents": [Document(page_content="foo"), Document(page_content="bar")],
"question": "What is your name?",
"just_to_test_lambda": "What is your name?",
}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(
content="""Context:
[Document(page_content='foo', metadata={}), Document(page_content='bar', metadata={})]
Question:
What is your name?"""
),
]
)
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar")
assert tracer.runs == snapshot
@freeze_time("2023-01-01")
def test_seq_prompt_dict(
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
) -> None:
passthrough = mocker.Mock(side_effect=lambda x: x)
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat = FakeListChatModel(responses=["i'm a chatbot"])
llm = FakeListLLM(responses=["i'm a textbot"])
chain = (
prompt
| passthrough
| { # type: ignore
"chat": chat,
"llm": llm,
}
)
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == [RunnableLambda(passthrough)]
assert isinstance(chain.last, RunnableMap)
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "invoke")
chat_spy = mocker.spy(chat.__class__, "invoke")
llm_spy = mocker.spy(llm.__class__, "invoke")
tracer = FakeTracer()
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == {
"chat": AIMessage(content="i'm a chatbot"),
"llm": "i'm a textbot",
}
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert llm_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert tracer.runs == snapshot