mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
Runnable single protocol (#7800)
Objects implementing Runnable: BasePromptTemplate, LLM, ChatModel, Chain, Retriever, OutputParser - [x] Implement Runnable in base Retriever - [x] Raise TypeError in operator methods for unsupported things - [x] Implement dict which calls values in parallel and outputs dict with results - [x] Merge in `+` for prompts - [x] Confirm precedence order for operators, ideal would be `+` `|`, https://docs.python.org/3/reference/expressions.html#operator-precedence - [x] Add support for openai functions, ie. Chat Models must return messages - [x] Implement BaseMessageChunk return type for BaseChatModel, a subclass of BaseMessage which implements __add__ to return BaseMessageChunk, concatenating all str args - [x] Update implementation of stream/astream for llm and chat models to use new `_stream`, `_astream` optional methods, with default implementation in base class `raise NotImplementedError` use https://stackoverflow.com/a/59762827 to see if it is implemented in base class - [x] Delete the IteratorCallbackHandler (leave the async one because people using) - [x] Make BaseLLMOutputParser implement Runnable, accepting either str or BaseMessage --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
04a4d3e312
commit
a612800ef0
@ -9,10 +9,7 @@ from typing import (
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
@ -116,15 +113,6 @@ class ChatLlamaAPI(BaseChatModel):
|
||||
generations.append(gen)
|
||||
return ChatResult(generations=generations)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _client_params(self) -> Mapping[str, Any]:
|
||||
"""Get the parameters used for the client."""
|
||||
|
@ -1,13 +1,14 @@
|
||||
"""Base callback handler that can be used to handle callbacks in langchain."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.schema.agent import AgentAction, AgentFinish
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.output import LLMResult
|
||||
if TYPE_CHECKING:
|
||||
from langchain.schema.agent import AgentAction, AgentFinish
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.output import LLMResult
|
||||
|
||||
|
||||
class RetrieverManagerMixin:
|
||||
@ -543,3 +544,6 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
for key in keys:
|
||||
self.metadata.pop(key)
|
||||
self.inheritable_metadata.pop(key)
|
||||
|
||||
|
||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||
|
@ -4,6 +4,7 @@ import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import (
|
||||
@ -20,12 +21,13 @@ from typing import (
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import UUID
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
Callbacks,
|
||||
ChainManagerMixin,
|
||||
LLMManagerMixin,
|
||||
RetrieverManagerMixin,
|
||||
@ -50,7 +52,6 @@ if TYPE_CHECKING:
|
||||
from langsmith import Client as LangSmithClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||
|
||||
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
|
||||
"openai_callback", default=None
|
||||
@ -437,7 +438,7 @@ class BaseRunManager(RunManagerMixin):
|
||||
BaseRunManager: The noop manager.
|
||||
"""
|
||||
return cls(
|
||||
run_id=uuid4(),
|
||||
run_id=uuid.uuid4(),
|
||||
handlers=[],
|
||||
inheritable_handlers=[],
|
||||
tags=[],
|
||||
@ -1024,7 +1025,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
"""
|
||||
managers = []
|
||||
for prompt in prompts:
|
||||
run_id_ = uuid4()
|
||||
run_id_ = uuid.uuid4()
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
@ -1073,7 +1074,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
managers = []
|
||||
for message_list in messages:
|
||||
run_id_ = uuid4()
|
||||
run_id_ = uuid.uuid4()
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_chat_model_start",
|
||||
@ -1120,7 +1121,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
CallbackManagerForChainRun: The callback manager for the chain run.
|
||||
"""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
run_id = uuid.uuid4()
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
@ -1166,7 +1167,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
CallbackManagerForToolRun: The callback manager for the tool run.
|
||||
"""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
run_id = uuid.uuid4()
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
@ -1202,7 +1203,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
) -> CallbackManagerForRetrieverRun:
|
||||
"""Run when retriever starts running."""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
run_id = uuid.uuid4()
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
@ -1302,7 +1303,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
managers = []
|
||||
|
||||
for prompt in prompts:
|
||||
run_id_ = uuid4()
|
||||
run_id_ = uuid.uuid4()
|
||||
|
||||
tasks.append(
|
||||
_ahandle_event(
|
||||
@ -1341,7 +1342,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
) -> List[AsyncCallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running.
|
||||
|
||||
Args:
|
||||
@ -1358,7 +1359,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
managers = []
|
||||
|
||||
for message_list in messages:
|
||||
run_id_ = uuid4()
|
||||
run_id_ = uuid.uuid4()
|
||||
|
||||
tasks.append(
|
||||
_ahandle_event(
|
||||
@ -1410,7 +1411,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
for the chain run.
|
||||
"""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
run_id = uuid.uuid4()
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
@ -1458,7 +1459,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
for the tool run.
|
||||
"""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
run_id = uuid.uuid4()
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
@ -1494,7 +1495,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
) -> AsyncCallbackManagerForRetrieverRun:
|
||||
"""Run when retriever starts running."""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
run_id = uuid.uuid4()
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
|
@ -4,7 +4,7 @@ import asyncio
|
||||
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema.output import LLMResult
|
||||
|
||||
# TODO If used by two LLM runs in parallel this won't work as expected
|
||||
|
||||
|
@ -22,6 +22,7 @@ from langchain.callbacks.manager import (
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -30,7 +31,7 @@ def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
class Chain(Serializable, ABC):
|
||||
class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
"""Abstract base class for creating structured sequences of calls to components.
|
||||
|
||||
Chains should be used to encode a sequence of calls to components like
|
||||
@ -53,6 +54,20 @@ class Chain(Serializable, ABC):
|
||||
chains and cannot return as rich of an output as `__call__`.
|
||||
"""
|
||||
|
||||
def invoke(
|
||||
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
return self(input, **(config or {}))
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
if type(self)._acall == Chain._acall:
|
||||
# If the chain does not implement async, fall back to default implementation
|
||||
return await super().ainvoke(input, config)
|
||||
|
||||
return await self.acall(input, **(config or {}))
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
"""Optional memory object. Defaults to None.
|
||||
Memory is a class that gets called at the start
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -12,11 +12,13 @@ from langchain.schema import (
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
|
||||
|
||||
class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
@ -94,6 +96,44 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
text.rstrip()
|
||||
) # trim off the trailing ' ' that might come from the "Assistant: "
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
stream_resp = self.client.completions.create(**params, stream=True)
|
||||
for data in stream_resp:
|
||||
delta = data.completion
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(delta)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
stream_resp = await self.async_client.completions.create(**params, stream=True)
|
||||
async for data in stream_resp:
|
||||
delta = data.completion
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(delta)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -101,22 +141,19 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
stream_resp = self.client.completions.create(**params, stream=True)
|
||||
for data in stream_resp:
|
||||
delta = data.completion
|
||||
completion += delta
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
delta,
|
||||
)
|
||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
else:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
response = self.client.completions.create(**params)
|
||||
completion = response.completion
|
||||
message = AIMessage(content=completion)
|
||||
@ -129,24 +166,19 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
stream_resp = await self.async_client.completions.create(
|
||||
**params, stream=True
|
||||
)
|
||||
async for data in stream_resp:
|
||||
delta = data.completion
|
||||
completion += delta
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
delta,
|
||||
)
|
||||
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
else:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
response = await self.async_client.completions.create(**params)
|
||||
completion = response.completion
|
||||
message = AIMessage(content=completion)
|
||||
|
@ -3,7 +3,16 @@ import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
@ -17,6 +26,8 @@ from langchain.callbacks.manager import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.load.dump import dumpd, dumps
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValue
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
@ -24,17 +35,22 @@ from langchain.schema import (
|
||||
PromptValue,
|
||||
RunInfo,
|
||||
)
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.schema.runnable import RunnableConfig
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel, ABC):
|
||||
"""Base class for chat models."""
|
||||
|
||||
class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
cache: Optional[bool] = None
|
||||
"""Whether to cache the response."""
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
@ -64,6 +80,154 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
# --- Runnable methods ---
|
||||
|
||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||
if isinstance(input, PromptValue):
|
||||
return input
|
||||
elif isinstance(input, str):
|
||||
return StringPromptValue(text=input)
|
||||
elif isinstance(input, list):
|
||||
return ChatPromptValue(messages=input)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(input)}. "
|
||||
"Must be a PromptValue, str, or list of BaseMessages."
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> BaseMessageChunk:
|
||||
return cast(
|
||||
BaseMessageChunk,
|
||||
cast(
|
||||
ChatGeneration,
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {})
|
||||
).generations[0][0],
|
||||
).message,
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> BaseMessageChunk:
|
||||
if type(self)._agenerate == BaseChatModel._agenerate:
|
||||
# model doesn't implement async generation, so use default implementation
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.invoke, input, config, stop=stop)
|
||||
)
|
||||
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {})
|
||||
)
|
||||
return cast(
|
||||
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||
)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[BaseMessageChunk]:
|
||||
if type(self)._stream == BaseChatModel._stream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
else:
|
||||
config = config or {}
|
||||
messages = self._convert_input(input).to_messages()
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
config.get("metadata"),
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = callback_manager.on_chat_model_start(
|
||||
dumpd(self), [messages], invocation_params=params, options=options
|
||||
)
|
||||
try:
|
||||
message: Optional[BaseMessageChunk] = None
|
||||
for chunk in self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk.message
|
||||
if message is None:
|
||||
message = chunk.message
|
||||
else:
|
||||
message += chunk.message
|
||||
assert message is not None
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_llm_end(
|
||||
LLMResult(generations=[[ChatGeneration(message=message)]]),
|
||||
)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[BaseMessageChunk]:
|
||||
if type(self)._astream == BaseChatModel._astream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
else:
|
||||
config = config or {}
|
||||
messages = self._convert_input(input).to_messages()
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
config.get("metadata"),
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||
dumpd(self), [messages], invocation_params=params, options=options
|
||||
)
|
||||
try:
|
||||
message: Optional[BaseMessageChunk] = None
|
||||
async for chunk in self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk.message
|
||||
if message is None:
|
||||
message = chunk.message
|
||||
else:
|
||||
message += chunk.message
|
||||
assert message is not None
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_llm_end(
|
||||
LLMResult(generations=[[ChatGeneration(message=message)]]),
|
||||
)
|
||||
|
||||
# --- Custom methods ---
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
return {}
|
||||
|
||||
@ -334,7 +498,6 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
|
||||
@abstractmethod
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -343,6 +506,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -25,7 +25,10 @@ class FakeListChatModel(SimpleChatModel):
|
||||
) -> str:
|
||||
"""First try to lookup in queries, else return 'foo' or 'bar'."""
|
||||
response = self.responses[self.i]
|
||||
self.i += 1
|
||||
if self.i < len(self.responses) - 1:
|
||||
self.i += 1
|
||||
else:
|
||||
self.i = 0
|
||||
return response
|
||||
|
||||
@property
|
||||
|
@ -4,8 +4,10 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
@ -36,6 +38,14 @@ from langchain.schema import (
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessageChunk,
|
||||
BaseMessageChunk,
|
||||
ChatMessageChunk,
|
||||
HumanMessageChunk,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -75,6 +85,24 @@ async def acompletion_with_retry(llm: JinaChat, **kwargs: Any) -> Any:
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role")
|
||||
content = _dict.get("content") or ""
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content)
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
@ -258,6 +286,25 @@ class JinaChat(BaseChatModel):
|
||||
overall_token_usage[k] = v
|
||||
return {"token_usage": overall_token_usage}
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
for chunk in self.completion_with_retry(messages=message_dicts, **params):
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.content)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -265,27 +312,20 @@ class JinaChat(BaseChatModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
for chunk in self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
role = "assistant"
|
||||
params["stream"] = True
|
||||
for stream_resp in self.completion_with_retry(
|
||||
messages=message_dicts, **params
|
||||
):
|
||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
||||
token = stream_resp["choices"][0]["delta"].get("content") or ""
|
||||
inner_completion += token
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token)
|
||||
message = _convert_dict_to_message(
|
||||
{
|
||||
"content": inner_completion,
|
||||
"role": role,
|
||||
}
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@ -309,6 +349,27 @@ class JinaChat(BaseChatModel):
|
||||
llm_output = {"token_usage": response["usage"]}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
async for chunk in await acompletion_with_retry(
|
||||
self, messages=message_dicts, **params
|
||||
):
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.content)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -316,32 +377,22 @@ class JinaChat(BaseChatModel):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
async for chunk in self._astream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
role = "assistant"
|
||||
params["stream"] = True
|
||||
async for stream_resp in await acompletion_with_retry(
|
||||
self, messages=message_dicts, **params
|
||||
):
|
||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
inner_completion += token or ""
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(token)
|
||||
message = _convert_dict_to_message(
|
||||
{
|
||||
"content": inner_completion,
|
||||
"role": role,
|
||||
}
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
else:
|
||||
response = await acompletion_with_retry(
|
||||
self, messages=message_dicts, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
response = await acompletion_with_retry(self, messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Mapping[str, Any]:
|
||||
|
@ -6,8 +6,10 @@ import sys
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
@ -35,12 +37,19 @@ from langchain.schema import (
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -95,6 +104,30 @@ async def acompletion_with_retry(llm: ChatOpenAI, **kwargs: Any) -> Any:
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role")
|
||||
content = _dict.get("content") or ""
|
||||
if _dict.get("function_call"):
|
||||
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
||||
else:
|
||||
additional_kwargs = {}
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content)
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
@ -313,6 +346,27 @@ class ChatOpenAI(BaseChatModel):
|
||||
overall_token_usage[k] = v
|
||||
return {"token_usage": overall_token_usage, "model_name": self.model_name}
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
for chunk in self.completion_with_retry(messages=message_dicts, **params):
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.content)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -320,40 +374,20 @@ class ChatOpenAI(BaseChatModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
for chunk in self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
role = "assistant"
|
||||
params["stream"] = True
|
||||
function_call: Optional[dict] = None
|
||||
for stream_resp in self.completion_with_retry(
|
||||
messages=message_dicts, **params
|
||||
):
|
||||
if len(stream_resp["choices"]) > 0:
|
||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
||||
token = stream_resp["choices"][0]["delta"].get("content") or ""
|
||||
inner_completion += token
|
||||
_function_call = stream_resp["choices"][0]["delta"].get(
|
||||
"function_call"
|
||||
)
|
||||
if _function_call:
|
||||
if function_call is None:
|
||||
function_call = _function_call
|
||||
elif "arguments" in function_call:
|
||||
function_call["arguments"] += _function_call["arguments"]
|
||||
else:
|
||||
function_call["arguments"] = _function_call["arguments"]
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token)
|
||||
message = _convert_dict_to_message(
|
||||
{
|
||||
"content": inner_completion,
|
||||
"role": role,
|
||||
"function_call": function_call,
|
||||
}
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@ -381,6 +415,29 @@ class ChatOpenAI(BaseChatModel):
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
async for chunk in await acompletion_with_retry(
|
||||
self, messages=message_dicts, **params
|
||||
):
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.content)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -388,45 +445,22 @@ class ChatOpenAI(BaseChatModel):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
async for chunk in self._astream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
role = "assistant"
|
||||
params["stream"] = True
|
||||
function_call: Optional[dict] = None
|
||||
async for stream_resp in await acompletion_with_retry(
|
||||
self, messages=message_dicts, **params
|
||||
):
|
||||
if len(stream_resp["choices"]) > 0:
|
||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
inner_completion += token or ""
|
||||
_function_call = stream_resp["choices"][0]["delta"].get(
|
||||
"function_call"
|
||||
)
|
||||
if _function_call:
|
||||
if function_call is None:
|
||||
function_call = _function_call
|
||||
elif "arguments" in function_call:
|
||||
function_call["arguments"] += _function_call["arguments"]
|
||||
else:
|
||||
function_call["arguments"] = _function_call["arguments"]
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(token)
|
||||
message = _convert_dict_to_message(
|
||||
{
|
||||
"content": inner_completion,
|
||||
"role": role,
|
||||
"function_call": function_call,
|
||||
}
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
else:
|
||||
response = await acompletion_with_retry(
|
||||
self, messages=message_dicts, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
response = await acompletion_with_retry(self, messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
|
@ -4,10 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
|
||||
from langchain.schema import (
|
||||
@ -162,14 +159,3 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
response = chat.send_message(question.content)
|
||||
text = self._enforce_stop_words(response.text, stop)
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise NotImplementedError(
|
||||
"""Vertex AI doesn't support async requests at the moment."""
|
||||
)
|
||||
|
@ -6,7 +6,8 @@ from pydantic import BaseModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.openai_functions import create_tagging_chain
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema import BaseDocumentTransformer, BaseLanguageModel, Document
|
||||
from langchain.schema import BaseDocumentTransformer, Document
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel):
|
||||
|
@ -1,18 +1,20 @@
|
||||
import re
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional
|
||||
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.utils import check_package_version, get_from_dict_or_env
|
||||
|
||||
|
||||
class _AnthropicCommon(BaseModel):
|
||||
class _AnthropicCommon(BaseLanguageModel):
|
||||
client: Any = None #: :meta private:
|
||||
async_client: Any = None #: :meta private:
|
||||
model: str = "claude-2"
|
||||
@ -193,24 +195,16 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
response = model(prompt)
|
||||
|
||||
"""
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
|
||||
stop = self._get_anthropic_stop(stop)
|
||||
params = {**self._default_params, **kwargs}
|
||||
if self.streaming:
|
||||
stream_resp = self.client.completions.create(
|
||||
prompt=self._wrap_prompt(prompt),
|
||||
stop_sequences=stop,
|
||||
stream=True,
|
||||
**params,
|
||||
)
|
||||
current_completion = ""
|
||||
for data in stream_resp:
|
||||
delta = data.completion
|
||||
current_completion += delta
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
delta,
|
||||
)
|
||||
return current_completion
|
||||
response = self.client.completions.create(
|
||||
prompt=self._wrap_prompt(prompt),
|
||||
stop_sequences=stop,
|
||||
@ -226,22 +220,17 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Anthropic's completion endpoint asynchronously."""
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
async for chunk in self._astream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
|
||||
stop = self._get_anthropic_stop(stop)
|
||||
params = {**self._default_params, **kwargs}
|
||||
if self.streaming:
|
||||
stream_resp = await self.async_client.completions.create(
|
||||
prompt=self._wrap_prompt(prompt),
|
||||
stop_sequences=stop,
|
||||
stream=True,
|
||||
**params,
|
||||
)
|
||||
current_completion = ""
|
||||
async for data in stream_resp:
|
||||
delta = data.completion
|
||||
current_completion += delta
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(delta)
|
||||
return current_completion
|
||||
|
||||
response = await self.async_client.completions.create(
|
||||
prompt=self._wrap_prompt(prompt),
|
||||
stop_sequences=stop,
|
||||
@ -249,23 +238,23 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
)
|
||||
return response.completion
|
||||
|
||||
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
r"""Call Anthropic completion_stream and return the resulting generator.
|
||||
|
||||
BETA: this is a beta feature while we figure out the right abstraction.
|
||||
Once that happens, this interface could change.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
A generator representing the stream of tokens from Anthropic.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
prompt = "Write a poem about a stream."
|
||||
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
|
||||
generator = anthropic.stream(prompt)
|
||||
@ -273,12 +262,49 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
yield token
|
||||
"""
|
||||
stop = self._get_anthropic_stop(stop)
|
||||
return self.client.completions.create(
|
||||
params = {**self._default_params, **kwargs}
|
||||
|
||||
for token in self.client.completions.create(
|
||||
prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params
|
||||
):
|
||||
yield GenerationChunk(text=token.completion)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token.completion)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
r"""Call Anthropic completion_stream and return the resulting generator.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
A generator representing the stream of tokens from Anthropic.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
prompt = "Write a poem about a stream."
|
||||
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
|
||||
generator = anthropic.stream(prompt)
|
||||
for token in generator:
|
||||
yield token
|
||||
"""
|
||||
stop = self._get_anthropic_stop(stop)
|
||||
params = {**self._default_params, **kwargs}
|
||||
|
||||
async for token in await self.async_client.completions.create(
|
||||
prompt=self._wrap_prompt(prompt),
|
||||
stop_sequences=stop,
|
||||
stream=True,
|
||||
**self._default_params,
|
||||
)
|
||||
**params,
|
||||
):
|
||||
yield GenerationChunk(text=token.completion)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(token.completion)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Calculate number of tokens."""
|
||||
|
@ -7,11 +7,14 @@ import json
|
||||
import logging
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
@ -19,6 +22,7 @@ from typing import (
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import yaml
|
||||
@ -42,14 +46,18 @@ from langchain.callbacks.manager import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValue
|
||||
from langchain.schema import (
|
||||
Generation,
|
||||
LLMResult,
|
||||
PromptValue,
|
||||
RunInfo,
|
||||
)
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
|
||||
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.schema.runnable import RunnableConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -115,7 +123,7 @@ def update_cache(
|
||||
return llm_output
|
||||
|
||||
|
||||
class BaseLLM(BaseLanguageModel, ABC):
|
||||
class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
"""Base LLM abstract interface.
|
||||
|
||||
It should take in a prompt and return a string."""
|
||||
@ -157,6 +165,204 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
else:
|
||||
return verbose
|
||||
|
||||
# --- Runnable methods ---
|
||||
|
||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||
if isinstance(input, PromptValue):
|
||||
return input
|
||||
elif isinstance(input, str):
|
||||
return StringPromptValue(text=input)
|
||||
elif isinstance(input, list):
|
||||
return ChatPromptValue(messages=input)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(input)}. "
|
||||
"Must be a PromptValue, str, or list of BaseMessages."
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
return (
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {})
|
||||
)
|
||||
.generations[0][0]
|
||||
.text
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
if type(self)._agenerate == BaseLLM._agenerate:
|
||||
# model doesn't implement async invoke, so use default implementation
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.invoke, input, config, stop=stop)
|
||||
)
|
||||
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {})
|
||||
)
|
||||
return llm_result.generations[0][0].text
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
config = self._get_config_list(config, len(inputs))
|
||||
|
||||
if max_concurrency is None:
|
||||
llm_result = self.generate_prompt(
|
||||
[self._convert_input(input) for input in inputs],
|
||||
callbacks=[c.get("callbacks") for c in config],
|
||||
tags=[c.get("tags") for c in config],
|
||||
metadata=[c.get("metadata") for c in config],
|
||||
)
|
||||
return [g[0].text for g in llm_result.generations]
|
||||
else:
|
||||
batches = [
|
||||
inputs[i : i + max_concurrency]
|
||||
for i in range(0, len(inputs), max_concurrency)
|
||||
]
|
||||
return [
|
||||
output
|
||||
for batch in batches
|
||||
for output in self.batch(batch, config=config)
|
||||
]
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
if type(self)._agenerate == BaseLLM._agenerate:
|
||||
# model doesn't implement async batch, so use default implementation
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.batch, inputs, config, max_concurrency
|
||||
)
|
||||
|
||||
config = self._get_config_list(config, len(inputs))
|
||||
|
||||
if max_concurrency is None:
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input) for input in inputs],
|
||||
callbacks=[c.get("callbacks") for c in config],
|
||||
tags=[c.get("tags") for c in config],
|
||||
metadata=[c.get("metadata") for c in config],
|
||||
)
|
||||
return [g[0].text for g in llm_result.generations]
|
||||
else:
|
||||
batches = [
|
||||
inputs[i : i + max_concurrency]
|
||||
for i in range(0, len(inputs), max_concurrency)
|
||||
]
|
||||
return [
|
||||
output
|
||||
for batch in batches
|
||||
for output in await self.abatch(batch, config=config)
|
||||
]
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> Iterator[str]:
|
||||
if type(self)._stream == BaseLLM._stream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield self.invoke(input, config=config, stop=stop)
|
||||
else:
|
||||
prompt = self._convert_input(input).to_string()
|
||||
config = config or {}
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
config.get("metadata"),
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = callback_manager.on_llm_start(
|
||||
dumpd(self), [prompt], invocation_params=params, options=options
|
||||
)
|
||||
try:
|
||||
generation: Optional[GenerationChunk] = None
|
||||
for chunk in self._stream(prompt, stop=stop, run_manager=run_manager):
|
||||
yield chunk.text
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_llm_end(LLMResult(generations=[[generation]]))
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> AsyncIterator[str]:
|
||||
if type(self)._astream == BaseLLM._astream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield await self.ainvoke(input, config=config, stop=stop)
|
||||
else:
|
||||
prompt = self._convert_input(input).to_string()
|
||||
config = config or {}
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
config.get("metadata"),
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_llm_start(
|
||||
dumpd(self), [prompt], invocation_params=params, options=options
|
||||
)
|
||||
try:
|
||||
generation: Optional[GenerationChunk] = None
|
||||
async for chunk in self._astream(
|
||||
prompt, stop=stop, run_manager=run_manager
|
||||
):
|
||||
yield chunk.text
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_llm_end(LLMResult(generations=[[generation]]))
|
||||
|
||||
# --- Custom methods ---
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self,
|
||||
@ -167,7 +373,6 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompts."""
|
||||
|
||||
@abstractmethod
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -176,12 +381,31 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompts."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def generate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
@ -191,7 +415,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
@ -236,10 +460,10 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[Union[List[str], List[List[str]]]] = None,
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
@ -248,6 +472,50 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
"Argument 'prompts' is expected to be of type List[str], received"
|
||||
f" argument of type {type(prompts)}."
|
||||
)
|
||||
# Create callback managers
|
||||
if isinstance(callbacks, list) and (
|
||||
isinstance(callbacks[0], (list, BaseCallbackManager))
|
||||
or callbacks[0] is None
|
||||
):
|
||||
# We've received a list of callbacks args to apply to each input
|
||||
assert len(callbacks) == len(prompts)
|
||||
assert tags is None or (
|
||||
isinstance(tags, list) and len(tags) == len(prompts)
|
||||
)
|
||||
assert metadata is None or (
|
||||
isinstance(metadata, list) and len(metadata) == len(prompts)
|
||||
)
|
||||
callbacks = cast(List[Callbacks], callbacks)
|
||||
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
|
||||
metadata_list = cast(
|
||||
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
|
||||
)
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
callback,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tag,
|
||||
self.tags,
|
||||
meta,
|
||||
self.metadata,
|
||||
)
|
||||
for callback, tag, meta in zip(callbacks, tags_list, metadata_list)
|
||||
]
|
||||
else:
|
||||
# We've received a single callbacks arg to apply to all inputs
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
cast(Callbacks, callbacks),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
cast(List[str], tags),
|
||||
self.tags,
|
||||
cast(Dict[str, Any], metadata),
|
||||
self.metadata,
|
||||
)
|
||||
] * len(prompts)
|
||||
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
@ -258,15 +526,6 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
missing_prompts,
|
||||
) = get_prompts(params, prompts)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
@ -275,17 +534,26 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
run_managers = callback_manager.on_llm_start(
|
||||
dumpd(self), prompts, invocation_params=params, options=options
|
||||
)
|
||||
run_managers = [
|
||||
callback_manager.on_llm_start(
|
||||
dumpd(self), [prompt], invocation_params=params, options=options
|
||||
)[0]
|
||||
for callback_manager, prompt in zip(callback_managers, prompts)
|
||||
]
|
||||
output = self._generate_helper(
|
||||
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
||||
)
|
||||
return output
|
||||
if len(missing_prompts) > 0:
|
||||
run_managers = callback_manager.on_llm_start(
|
||||
dumpd(self), missing_prompts, invocation_params=params, options=options
|
||||
)
|
||||
run_managers = [
|
||||
callback_managers[idx].on_llm_start(
|
||||
dumpd(self),
|
||||
[prompts[idx]],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
)[0]
|
||||
for idx in missing_prompt_idxs
|
||||
]
|
||||
new_results = self._generate_helper(
|
||||
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
||||
)
|
||||
@ -346,13 +614,57 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[Union[List[str], List[List[str]]]] = None,
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
# Create callback managers
|
||||
if isinstance(callbacks, list) and (
|
||||
isinstance(callbacks[0], (list, BaseCallbackManager))
|
||||
or callbacks[0] is None
|
||||
):
|
||||
# We've received a list of callbacks args to apply to each input
|
||||
assert len(callbacks) == len(prompts)
|
||||
assert tags is None or (
|
||||
isinstance(tags, list) and len(tags) == len(prompts)
|
||||
)
|
||||
assert metadata is None or (
|
||||
isinstance(metadata, list) and len(metadata) == len(prompts)
|
||||
)
|
||||
callbacks = cast(List[Callbacks], callbacks)
|
||||
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
|
||||
metadata_list = cast(
|
||||
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
|
||||
)
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
callback,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tag,
|
||||
self.tags,
|
||||
meta,
|
||||
self.metadata,
|
||||
)
|
||||
for callback, tag, meta in zip(callbacks, tags_list, metadata_list)
|
||||
]
|
||||
else:
|
||||
# We've received a single callbacks arg to apply to all inputs
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
cast(Callbacks, callbacks),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
cast(List[str], tags),
|
||||
self.tags,
|
||||
cast(Dict[str, Any], metadata),
|
||||
self.metadata,
|
||||
)
|
||||
] * len(prompts)
|
||||
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
@ -363,15 +675,6 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
missing_prompts,
|
||||
) = get_prompts(params, prompts)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
@ -380,17 +683,32 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
run_managers = await callback_manager.on_llm_start(
|
||||
dumpd(self), prompts, invocation_params=params, options=options
|
||||
run_managers = await asyncio.gather(
|
||||
*[
|
||||
callback_manager.on_llm_start(
|
||||
dumpd(self), [prompt], invocation_params=params, options=options
|
||||
)
|
||||
for callback_manager, prompt in zip(callback_managers, prompts)
|
||||
]
|
||||
)
|
||||
run_managers = [r[0] for r in run_managers]
|
||||
output = await self._agenerate_helper(
|
||||
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
||||
)
|
||||
return output
|
||||
if len(missing_prompts) > 0:
|
||||
run_managers = await callback_manager.on_llm_start(
|
||||
dumpd(self), missing_prompts, invocation_params=params, options=options
|
||||
run_managers = await asyncio.gather(
|
||||
*[
|
||||
callback_managers[idx].on_llm_start(
|
||||
dumpd(self),
|
||||
[prompts[idx]],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
)
|
||||
for idx in missing_prompt_idxs
|
||||
]
|
||||
)
|
||||
run_managers = [r[0] for r in run_managers]
|
||||
new_results = await self._agenerate_helper(
|
||||
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
||||
)
|
||||
@ -586,7 +904,7 @@ class LLM(BaseLLM):
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
raise NotImplementedError("Async generation not implemented for this LLM.")
|
||||
raise NotImplementedError()
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@ -615,6 +933,12 @@ class LLM(BaseLLM):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
if type(self)._acall == LLM._acall:
|
||||
# model doesn't implement async call, so use default implementation
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._generate, prompts, stop, run_manager, **kwargs)
|
||||
)
|
||||
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
generations = []
|
||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||
|
@ -27,7 +27,10 @@ class FakeListLLM(LLM):
|
||||
) -> str:
|
||||
"""Return next response"""
|
||||
response = self.responses[self.i]
|
||||
self.i += 1
|
||||
if self.i < len(self.responses) - 1:
|
||||
self.i += 1
|
||||
else:
|
||||
self.i = 0
|
||||
return response
|
||||
|
||||
async def _acall(
|
||||
@ -39,7 +42,10 @@ class FakeListLLM(LLM):
|
||||
) -> str:
|
||||
"""Return next response"""
|
||||
response = self.responses[self.i]
|
||||
self.i += 1
|
||||
if self.i < len(self.responses) - 1:
|
||||
self.i += 1
|
||||
else:
|
||||
self.i = 0
|
||||
return response
|
||||
|
||||
@property
|
||||
|
@ -12,10 +12,7 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms import BaseLLM
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@ -161,15 +158,6 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
|
@ -1,5 +1,4 @@
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
@ -8,6 +7,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.schema.output import GenerationChunk
|
||||
|
||||
|
||||
class HuggingFaceTextGenInference(LLM):
|
||||
@ -69,7 +69,7 @@ class HuggingFaceTextGenInference(LLM):
|
||||
temperature = 0.01,
|
||||
repetition_penalty = 1.03,
|
||||
callbacks = callbacks,
|
||||
stream = True
|
||||
streaming = True
|
||||
)
|
||||
print(llm("What is Deep Learning?"))
|
||||
|
||||
@ -87,7 +87,7 @@ class HuggingFaceTextGenInference(LLM):
|
||||
inference_server_url: str = ""
|
||||
timeout: int = 120
|
||||
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
stream: bool = False
|
||||
streaming: bool = False
|
||||
client: Any
|
||||
async_client: Any
|
||||
|
||||
@ -154,37 +154,21 @@ class HuggingFaceTextGenInference(LLM):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
if not self.stream:
|
||||
res = self.client.generate(prompt, **invocation_params)
|
||||
# remove stop sequences from the end of the generated text
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if stop_seq in res.generated_text:
|
||||
res.generated_text = res.generated_text[
|
||||
: res.generated_text.index(stop_seq)
|
||||
]
|
||||
text = res.generated_text
|
||||
else:
|
||||
text_callback = None
|
||||
if run_manager:
|
||||
text_callback = partial(
|
||||
run_manager.on_llm_new_token, verbose=self.verbose
|
||||
)
|
||||
text = ""
|
||||
for res in self.client.generate_stream(prompt, **invocation_params):
|
||||
token = res.token
|
||||
is_stop = False
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if stop_seq in token.text:
|
||||
is_stop = True
|
||||
break
|
||||
if is_stop:
|
||||
break
|
||||
if not token.special:
|
||||
if text_callback:
|
||||
text_callback(token.text)
|
||||
text += token.text
|
||||
return text
|
||||
res = self.client.generate(prompt, **invocation_params)
|
||||
# remove stop sequences from the end of the generated text
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if stop_seq in res.generated_text:
|
||||
res.generated_text = res.generated_text[
|
||||
: res.generated_text.index(stop_seq)
|
||||
]
|
||||
return res.generated_text
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
@ -193,39 +177,90 @@ class HuggingFaceTextGenInference(LLM):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
if not self.stream:
|
||||
res = await self.async_client.generate(
|
||||
prompt,
|
||||
**invocation_params,
|
||||
)
|
||||
# remove stop sequences from the end of the generated text
|
||||
res = await self.async_client.generate(prompt, **invocation_params)
|
||||
# remove stop sequences from the end of the generated text
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if stop_seq in res.generated_text:
|
||||
res.generated_text = res.generated_text[
|
||||
: res.generated_text.index(stop_seq)
|
||||
]
|
||||
return res.generated_text
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
|
||||
for res in self.client.generate_stream(prompt, **invocation_params):
|
||||
# identify stop sequence in generated text, if any
|
||||
stop_seq_found: Optional[str] = None
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if stop_seq in res.generated_text:
|
||||
res.generated_text = res.generated_text[
|
||||
: res.generated_text.index(stop_seq)
|
||||
]
|
||||
text: str = res.generated_text
|
||||
else:
|
||||
text_callback = None
|
||||
if run_manager:
|
||||
text_callback = partial(
|
||||
run_manager.on_llm_new_token, verbose=self.verbose
|
||||
)
|
||||
text = ""
|
||||
async for res in self.async_client.generate_stream(
|
||||
prompt, **invocation_params
|
||||
):
|
||||
token = res.token
|
||||
is_stop = False
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if stop_seq in token.text:
|
||||
is_stop = True
|
||||
break
|
||||
if is_stop:
|
||||
break
|
||||
if not token.special:
|
||||
if text_callback:
|
||||
await text_callback(token.text)
|
||||
text += token.text
|
||||
return text
|
||||
if stop_seq in res.token.text:
|
||||
stop_seq_found = stop_seq
|
||||
|
||||
# identify text to yield
|
||||
text: Optional[str] = None
|
||||
if res.token.special:
|
||||
text = None
|
||||
elif stop_seq_found:
|
||||
text = res.token.text[: res.token.text.index(stop_seq_found)]
|
||||
else:
|
||||
text = res.token.text
|
||||
|
||||
# yield text, if any
|
||||
if text:
|
||||
chunk = GenerationChunk(text=text)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text)
|
||||
|
||||
# break if stop sequence found
|
||||
if stop_seq_found:
|
||||
break
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
|
||||
async for res in self.async_client.generate_stream(prompt, **invocation_params):
|
||||
# identify stop sequence in generated text, if any
|
||||
stop_seq_found: Optional[str] = None
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if stop_seq in res.token.text:
|
||||
stop_seq_found = stop_seq
|
||||
|
||||
# identify text to yield
|
||||
text: Optional[str] = None
|
||||
if res.token.special:
|
||||
text = None
|
||||
elif stop_seq_found:
|
||||
text = res.token.text[: res.token.text.index(stop_seq_found)]
|
||||
else:
|
||||
text = res.token.text
|
||||
|
||||
# yield text, if any
|
||||
if text:
|
||||
chunk = GenerationChunk(text=text)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.text)
|
||||
|
||||
# break if stop sequence found
|
||||
if stop_seq_found:
|
||||
break
|
||||
|
@ -1,10 +1,11 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.schema.output import GenerationChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -226,8 +227,10 @@ class LlamaCpp(LLM):
|
||||
# method that yields as they are generated
|
||||
# and return the combined strings from the first choices's text:
|
||||
combined_text_output = ""
|
||||
for token in self.stream(prompt=prompt, stop=stop, run_manager=run_manager):
|
||||
combined_text_output += token["choices"][0]["text"]
|
||||
for chunk in self._stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
combined_text_output += chunk.text
|
||||
return combined_text_output
|
||||
else:
|
||||
params = self._get_parameters(stop)
|
||||
@ -235,17 +238,15 @@ class LlamaCpp(LLM):
|
||||
result = self.client(prompt=prompt, **params)
|
||||
return result["choices"][0]["text"]
|
||||
|
||||
def stream(
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> Generator[Dict, None, None]:
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""Yields results objects as they are generated in real time.
|
||||
|
||||
BETA: this is a beta feature while we figure out the right abstraction.
|
||||
Once that happens, this interface could change.
|
||||
|
||||
It also calls the callback manager's on_llm_new_token event with
|
||||
similar parameters to the OpenAI LLM class method of the same name.
|
||||
|
||||
@ -274,16 +275,19 @@ class LlamaCpp(LLM):
|
||||
print(result["text"], end='', flush=True)
|
||||
|
||||
"""
|
||||
params = self._get_parameters(stop)
|
||||
params = {**self._get_parameters(stop), **kwargs}
|
||||
result = self.client(prompt=prompt, stream=True, **params)
|
||||
for chunk in result:
|
||||
token = chunk["choices"][0]["text"]
|
||||
log_probs = chunk["choices"][0].get("logprobs", None)
|
||||
for part in result:
|
||||
logprobs = part["choices"][0].get("logprobs", None)
|
||||
chunk = GenerationChunk(
|
||||
text=part["choices"][0]["text"],
|
||||
generation_info={"logprobs": logprobs},
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
token=token, verbose=self.verbose, log_probs=log_probs
|
||||
token=chunk.text, verbose=self.verbose, log_probs=logprobs
|
||||
)
|
||||
yield chunk
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenized_text = self.client.tokenize(text.encode("utf-8"))
|
||||
|
@ -6,10 +6,11 @@ import warnings
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
@ -27,6 +28,7 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -44,6 +46,19 @@ def update_token_usage(
|
||||
token_usage[_key] += response["usage"][_key]
|
||||
|
||||
|
||||
def _stream_response_to_generation_chunk(
|
||||
stream_response: Dict[str, Any],
|
||||
) -> GenerationChunk:
|
||||
"""Convert a stream response to a generation chunk."""
|
||||
return GenerationChunk(
|
||||
text=stream_response["choices"][0]["text"],
|
||||
generation_info=dict(
|
||||
finish_reason=stream_response["choices"][0].get("finish_reason", None),
|
||||
logprobs=stream_response["choices"][0].get("logprobs", None),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
|
||||
"""Update response from the stream response."""
|
||||
response["choices"][0]["text"] += stream_response["choices"][0]["text"]
|
||||
@ -268,6 +283,50 @@ class BaseOpenAI(BaseLLM):
|
||||
|
||||
return {**normal_params, **self.model_kwargs}
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = {**self._invocation_params, **kwargs, "stream": True}
|
||||
self.get_sub_prompts(params, [prompt], stop) # this mutate params
|
||||
for stream_resp in completion_with_retry(self, prompt=prompt, **params):
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
logprobs=chunk.generation_info["logprobs"]
|
||||
if chunk.generation_info
|
||||
else None,
|
||||
)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
params = {**self._invocation_params, **kwargs, "stream": True}
|
||||
self.get_sub_prompts(params, [prompt], stop) # this mutate params
|
||||
async for stream_resp in await acompletion_with_retry(
|
||||
self, prompt=prompt, **params
|
||||
):
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
logprobs=chunk.generation_info["logprobs"]
|
||||
if chunk.generation_info
|
||||
else None,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -302,24 +361,28 @@ class BaseOpenAI(BaseLLM):
|
||||
if self.streaming:
|
||||
if len(_prompts) > 1:
|
||||
raise ValueError("Cannot stream results with multiple prompts.")
|
||||
params["stream"] = True
|
||||
response = _streaming_response_template()
|
||||
for stream_resp in completion_with_retry(
|
||||
self, prompt=_prompts, **params
|
||||
):
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
stream_resp["choices"][0]["text"],
|
||||
verbose=self.verbose,
|
||||
logprobs=stream_resp["choices"][0]["logprobs"],
|
||||
)
|
||||
_update_response(response, stream_resp)
|
||||
choices.extend(response["choices"])
|
||||
|
||||
generation: Optional[GenerationChunk] = None
|
||||
for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
choices.append(
|
||||
{
|
||||
"text": generation.text,
|
||||
"finish_reason": generation.generation_info.get("finish_reason")
|
||||
if generation.generation_info
|
||||
else None,
|
||||
"logprobs": generation.generation_info.get("logprobs")
|
||||
if generation.generation_info
|
||||
else None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = completion_with_retry(self, prompt=_prompts, **params)
|
||||
choices.extend(response["choices"])
|
||||
if not self.streaming:
|
||||
# Can't update token usage if streaming
|
||||
update_token_usage(_keys, response, token_usage)
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
|
||||
@ -343,24 +406,30 @@ class BaseOpenAI(BaseLLM):
|
||||
if self.streaming:
|
||||
if len(_prompts) > 1:
|
||||
raise ValueError("Cannot stream results with multiple prompts.")
|
||||
params["stream"] = True
|
||||
response = _streaming_response_template()
|
||||
async for stream_resp in await acompletion_with_retry(
|
||||
self, prompt=_prompts, **params
|
||||
|
||||
generation: Optional[GenerationChunk] = None
|
||||
async for chunk in self._astream(
|
||||
_prompts[0], stop, run_manager, **kwargs
|
||||
):
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
stream_resp["choices"][0]["text"],
|
||||
verbose=self.verbose,
|
||||
logprobs=stream_resp["choices"][0]["logprobs"],
|
||||
)
|
||||
_update_response(response, stream_resp)
|
||||
choices.extend(response["choices"])
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
choices.append(
|
||||
{
|
||||
"text": generation.text,
|
||||
"finish_reason": generation.generation_info.get("finish_reason")
|
||||
if generation.generation_info
|
||||
else None,
|
||||
"logprobs": generation.generation_info.get("logprobs")
|
||||
if generation.generation_info
|
||||
else None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await acompletion_with_retry(self, prompt=_prompts, **params)
|
||||
choices.extend(response["choices"])
|
||||
if not self.streaming:
|
||||
# Can't update token usage if streaming
|
||||
update_token_usage(_keys, response, token_usage)
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
|
||||
@ -409,43 +478,6 @@ class BaseOpenAI(BaseLLM):
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
|
||||
"""Call OpenAI with streaming flag and return the resulting generator.
|
||||
|
||||
BETA: this is a beta feature while we figure out the right abstraction.
|
||||
Once that happens, this interface could change.
|
||||
|
||||
Args:
|
||||
prompt: The prompts to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
A generator representing the stream of tokens from OpenAI.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
generator = openai.stream("Tell me a joke.")
|
||||
for token in generator:
|
||||
yield token
|
||||
"""
|
||||
params = self.prep_streaming_params(stop)
|
||||
generator = self.client.create(prompt=prompt, **params)
|
||||
|
||||
return generator
|
||||
|
||||
def prep_streaming_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""Prepare the params for streaming."""
|
||||
params = self._invocation_params
|
||||
if "best_of" in params and params["best_of"] != 1:
|
||||
raise ValueError("OpenAI only supports best_of == 1 for streaming")
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
params["stream"] = True
|
||||
return params
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
@ -777,6 +809,38 @@ class OpenAIChat(BaseLLM):
|
||||
del params["max_tokens"]
|
||||
return messages, params
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
messages, params = self._get_chat_params([prompt], stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
for stream_resp in completion_with_retry(self, messages=messages, **params):
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
yield GenerationChunk(text=token)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
messages, params = self._get_chat_params([prompt], stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
async for stream_resp in await acompletion_with_retry(
|
||||
self, messages=messages, **params
|
||||
):
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
yield GenerationChunk(text=token)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(token)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -784,33 +848,29 @@ class OpenAIChat(BaseLLM):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
if self.streaming:
|
||||
generation: Optional[GenerationChunk] = None
|
||||
for chunk in self._stream(prompts[0], stop, run_manager, **kwargs):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return LLMResult(generations=[[generation]])
|
||||
|
||||
messages, params = self._get_chat_params(prompts, stop)
|
||||
params = {**params, **kwargs}
|
||||
if self.streaming:
|
||||
response = ""
|
||||
params["stream"] = True
|
||||
for stream_resp in completion_with_retry(self, messages=messages, **params):
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
response += token
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
token,
|
||||
)
|
||||
return LLMResult(
|
||||
generations=[[Generation(text=response)]],
|
||||
)
|
||||
else:
|
||||
full_response = completion_with_retry(self, messages=messages, **params)
|
||||
llm_output = {
|
||||
"token_usage": full_response["usage"],
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
return LLMResult(
|
||||
generations=[
|
||||
[Generation(text=full_response["choices"][0]["message"]["content"])]
|
||||
],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
full_response = completion_with_retry(self, messages=messages, **params)
|
||||
llm_output = {
|
||||
"token_usage": full_response["usage"],
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
return LLMResult(
|
||||
generations=[
|
||||
[Generation(text=full_response["choices"][0]["message"]["content"])]
|
||||
],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
@ -819,37 +879,29 @@ class OpenAIChat(BaseLLM):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
if self.streaming:
|
||||
generation: Optional[GenerationChunk] = None
|
||||
async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return LLMResult(generations=[[generation]])
|
||||
|
||||
messages, params = self._get_chat_params(prompts, stop)
|
||||
params = {**params, **kwargs}
|
||||
if self.streaming:
|
||||
response = ""
|
||||
params["stream"] = True
|
||||
async for stream_resp in await acompletion_with_retry(
|
||||
self, messages=messages, **params
|
||||
):
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
response += token
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
token,
|
||||
)
|
||||
return LLMResult(
|
||||
generations=[[Generation(text=response)]],
|
||||
)
|
||||
else:
|
||||
full_response = await acompletion_with_retry(
|
||||
self, messages=messages, **params
|
||||
)
|
||||
llm_output = {
|
||||
"token_usage": full_response["usage"],
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
return LLMResult(
|
||||
generations=[
|
||||
[Generation(text=full_response["choices"][0]["message"]["content"])]
|
||||
],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
full_response = await acompletion_with_retry(self, messages=messages, **params)
|
||||
llm_output = {
|
||||
"token_usage": full_response["usage"],
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
return LLMResult(
|
||||
generations=[
|
||||
[Generation(text=full_response["choices"][0]["message"]["content"])]
|
||||
],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
|
@ -13,10 +13,7 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@ -250,12 +247,3 @@ class Tongyi(LLM):
|
||||
]
|
||||
)
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
raise NotImplementedError()
|
||||
|
@ -6,7 +6,7 @@ from typing import List
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
|
||||
class ListOutputParser(BaseOutputParser):
|
||||
class ListOutputParser(BaseOutputParser[List[str]]):
|
||||
"""Parse the output of an LLM call to a list."""
|
||||
|
||||
@property
|
||||
|
@ -4,14 +4,14 @@ from typing import Any, Dict, List, Type, Union
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.schema import (
|
||||
BaseLLMOutputParser,
|
||||
ChatGeneration,
|
||||
Generation,
|
||||
OutputParserException,
|
||||
)
|
||||
from langchain.schema.output_parser import BaseGenerationOutputParser
|
||||
|
||||
|
||||
class OutputFunctionsParser(BaseLLMOutputParser[Any]):
|
||||
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
"""Parse an output that is one of sets of values."""
|
||||
|
||||
args_only: bool = True
|
||||
|
@ -5,10 +5,10 @@ import warnings
|
||||
from abc import ABC
|
||||
from typing import Any, Callable, Dict, List, Set
|
||||
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.formatting import formatter
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage
|
||||
from langchain.schema.prompt import PromptValue
|
||||
from langchain.utils import formatter
|
||||
from langchain.schema.prompt_template import BasePromptTemplate
|
||||
|
||||
|
||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
|
@ -446,7 +446,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
for message in messages:
|
||||
if isinstance(message, BaseMessagePromptTemplate):
|
||||
input_vars.update(message.input_variables)
|
||||
return cls(input_variables=list(input_vars), messages=messages)
|
||||
return cls(input_variables=sorted(input_vars), messages=messages)
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the chat template into a string.
|
||||
|
@ -1,9 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.utilities.arxiv import ArxivAPIWrapper
|
||||
|
||||
@ -20,8 +17,3 @@ class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
return self.load(query=query)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -7,10 +7,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
@ -108,8 +105,3 @@ class BM25Retriever(BaseRetriever):
|
||||
processed_query = self.preprocess_func(query)
|
||||
return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k)
|
||||
return return_docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -3,10 +3,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
@ -208,11 +205,3 @@ class DocArrayRetriever(BaseRetriever):
|
||||
lc_doc.metadata[name] = value
|
||||
|
||||
return lc_doc
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -5,10 +5,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from typing import Any, Iterable, List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
@ -138,8 +135,3 @@ class ElasticSearchBM25Retriever(BaseRetriever):
|
||||
for r in res["hits"]["hits"]:
|
||||
docs.append(Document(page_content=r["_source"]["content"]))
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -5,10 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@ -184,8 +181,3 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever):
|
||||
documents = self._convert_search_response(response.results)
|
||||
|
||||
return documents
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -4,10 +4,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
@ -411,11 +408,3 @@ class AmazonKendraRetriever(BaseRetriever):
|
||||
"""
|
||||
docs = self._kendra_query(query, self.top_k, self.attribute_filter)
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("Async version is not implemented for Kendra yet.")
|
||||
|
@ -9,10 +9,7 @@ from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
@ -82,8 +79,3 @@ class KNNRetriever(BaseRetriever):
|
||||
)
|
||||
]
|
||||
return top_k_results
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -2,10 +2,7 @@ from typing import Any, Dict, List, cast
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
@ -42,11 +39,6 @@ class LlamaIndexRetriever(BaseRetriever):
|
||||
)
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("LlamaIndexRetriever does not support async")
|
||||
|
||||
|
||||
class LlamaIndexGraphRetriever(BaseRetriever):
|
||||
"""Retriever for question-answering with sources over an LlamaIndex
|
||||
@ -88,8 +80,3 @@ class LlamaIndexGraphRetriever(BaseRetriever):
|
||||
Document(page_content=source_node.source_text, metadata=metadata)
|
||||
)
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("LlamaIndexGraphRetriever does not support async")
|
||||
|
@ -2,10 +2,7 @@ from typing import Any, List, Optional
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
@ -43,8 +40,3 @@ class MetalRetriever(BaseRetriever):
|
||||
metadata = {k: v for k, v in r.items() if k != "text"}
|
||||
final_results.append(Document(page_content=r["text"], metadata=metadata))
|
||||
return final_results
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -4,10 +4,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores.milvus import Milvus
|
||||
@ -63,15 +60,6 @@ class MilvusRetriever(BaseRetriever):
|
||||
query, run_manager=run_manager.get_child(), **kwargs
|
||||
)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def MilvusRetreiver(*args: Any, **kwargs: Any) -> MilvusRetriever:
|
||||
"""Deprecated MilvusRetreiver. Please use MilvusRetriever ('i' before 'e') instead.
|
||||
|
@ -3,10 +3,7 @@ from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.output_parsers.pydantic import PydanticOutputParser
|
||||
@ -101,14 +98,6 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
unique_documents = self.unique_union(documents)
|
||||
return unique_documents
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_queries(
|
||||
self, question: str, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[str]:
|
||||
|
@ -5,10 +5,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
@ -175,8 +172,3 @@ class PineconeHybridSearchRetriever(BaseRetriever):
|
||||
)
|
||||
# return search results as json
|
||||
return final_result
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -1,9 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.utilities.pupmed import PubMedAPIWrapper
|
||||
|
||||
@ -19,8 +16,3 @@ class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
return self.load_docs(query=query)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -5,10 +5,7 @@ from typing import Any, Dict, List, Optional, Type, cast
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.chains.query_constructor.base import load_query_constructor_chain
|
||||
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
|
||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||
@ -119,11 +116,6 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs)
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
|
@ -5,10 +5,7 @@ from typing import Any, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
@ -113,8 +110,3 @@ class SVMRetriever(BaseRetriever):
|
||||
):
|
||||
top_k_results.append(Document(page_content=self.texts[row - 1]))
|
||||
return top_k_results
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -2,10 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
@ -79,8 +76,3 @@ class TFIDFRetriever(BaseRetriever):
|
||||
) # Op -- (n_docs,1) -- Cosine Sim with each doc
|
||||
return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
||||
return return_docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -4,10 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
@ -109,12 +106,6 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever):
|
||||
result.append(buffered_doc)
|
||||
return result
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Return documents that are relevant to the query."""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
"""Add documents to vectorstore."""
|
||||
current_time = kwargs.get("current_time")
|
||||
|
@ -3,10 +3,7 @@ from __future__ import annotations
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -57,11 +54,6 @@ class VespaRetriever(BaseRetriever):
|
||||
body["query"] = query
|
||||
return self._query(body)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_relevant_documents_with_filter(
|
||||
self, query: str, *, _filter: Optional[str] = None
|
||||
) -> List[Document]:
|
||||
|
@ -5,10 +5,7 @@ from uuid import uuid4
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
@ -118,8 +115,3 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||
text = res.pop(self.text_key)
|
||||
docs.append(Document(page_content=text, metadata=res))
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -1,9 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.utilities.wikipedia import WikipediaAPIWrapper
|
||||
|
||||
@ -19,8 +16,3 @@ class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
return self.load(query=query)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -3,10 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores.zilliz import Zilliz
|
||||
@ -67,15 +64,6 @@ class ZillizRetriever(BaseRetriever):
|
||||
query, run_manager=run_manager.get_child(), **kwargs
|
||||
)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def ZillizRetreiver(*args: Any, **kwargs: Any) -> ZillizRetriever:
|
||||
"""Deprecated ZillizRetreiver.
|
||||
|
@ -1,6 +1,5 @@
|
||||
from langchain.schema.agent import AgentAction, AgentFinish
|
||||
from langchain.schema.document import BaseDocumentTransformer, Document
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.memory import BaseChatMessageHistory, BaseMemory
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
@ -67,6 +66,5 @@ __all__ = [
|
||||
"BaseOutputParser",
|
||||
"BaseLLMOutputParser",
|
||||
"BasePromptTemplate",
|
||||
"BaseLanguageModel",
|
||||
"format_document",
|
||||
]
|
||||
|
@ -1,12 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Set
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||
from langchain.schema.output import LLMResult
|
||||
from langchain.schema.prompt import PromptValue
|
||||
from langchain.schema.runnable import Runnable
|
||||
from langchain.utils import get_pydantic_field_names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -32,7 +42,13 @@ def _get_token_ids_default_method(text: str) -> List[int]:
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class BaseLanguageModel(Serializable, ABC):
|
||||
LanguageModelInput = Union[PromptValue, str, List[BaseMessage]]
|
||||
LanguageModelOutput = TypeVar("LanguageModelOutput")
|
||||
|
||||
|
||||
class BaseLanguageModel(
|
||||
Serializable, Runnable[LanguageModelInput, LanguageModelOutput], ABC
|
||||
):
|
||||
"""Abstract base class for interfacing with language models.
|
||||
|
||||
All language model wrappers inherit from BaseLanguageModel.
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import List, Sequence
|
||||
from typing import Any, Dict, List, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@ -78,6 +78,49 @@ class BaseMessage(Serializable):
|
||||
return True
|
||||
|
||||
|
||||
class BaseMessageChunk(BaseMessage):
|
||||
def _merge_kwargs_dict(
|
||||
self, left: Dict[str, Any], right: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Merge additional_kwargs from another BaseMessageChunk into this one."""
|
||||
merged = left.copy()
|
||||
for k, v in right.items():
|
||||
if k not in merged:
|
||||
merged[k] = v
|
||||
elif type(merged[k]) != type(v):
|
||||
raise ValueError(
|
||||
f'additional_kwargs["{k}"] already exists in this message,'
|
||||
" but with a different type."
|
||||
)
|
||||
elif isinstance(merged[k], str):
|
||||
merged[k] += v
|
||||
elif isinstance(merged[k], dict):
|
||||
merged[k] = self._merge_kwargs_dict(merged[k], v)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Additional kwargs key {k} already exists in this message."
|
||||
)
|
||||
return merged
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk:
|
||||
if isinstance(other, BaseMessageChunk):
|
||||
# If both are (subclasses of) BaseMessageChunk,
|
||||
# concat into a single BaseMessageChunk
|
||||
|
||||
return self.__class__(
|
||||
content=self.content + other.content,
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
'unsupported operand type(s) for +: "'
|
||||
f"{self.__class__.__name__}"
|
||||
f'" and "{other.__class__.__name__}"'
|
||||
)
|
||||
|
||||
|
||||
class HumanMessage(BaseMessage):
|
||||
"""A Message from a human."""
|
||||
|
||||
@ -92,6 +135,10 @@ class HumanMessage(BaseMessage):
|
||||
return "human"
|
||||
|
||||
|
||||
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
||||
pass
|
||||
|
||||
|
||||
class AIMessage(BaseMessage):
|
||||
"""A Message from an AI."""
|
||||
|
||||
@ -106,6 +153,10 @@ class AIMessage(BaseMessage):
|
||||
return "ai"
|
||||
|
||||
|
||||
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
pass
|
||||
|
||||
|
||||
class SystemMessage(BaseMessage):
|
||||
"""A Message for priming AI behavior, usually passed in as the first of a sequence
|
||||
of input messages.
|
||||
@ -117,6 +168,10 @@ class SystemMessage(BaseMessage):
|
||||
return "system"
|
||||
|
||||
|
||||
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
||||
pass
|
||||
|
||||
|
||||
class FunctionMessage(BaseMessage):
|
||||
"""A Message for passing the result of executing a function back to a model."""
|
||||
|
||||
@ -129,6 +184,10 @@ class FunctionMessage(BaseMessage):
|
||||
return "function"
|
||||
|
||||
|
||||
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||
pass
|
||||
|
||||
|
||||
class ChatMessage(BaseMessage):
|
||||
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
|
||||
|
||||
@ -141,6 +200,10 @@ class ChatMessage(BaseMessage):
|
||||
return "chat"
|
||||
|
||||
|
||||
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||
pass
|
||||
|
||||
|
||||
def _message_to_dict(message: BaseMessage) -> dict:
|
||||
return {"type": message.type, "data": message.dict()}
|
||||
|
||||
|
@ -7,7 +7,7 @@ from uuid import UUID
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import BaseMessage, BaseMessageChunk
|
||||
|
||||
|
||||
class Generation(Serializable):
|
||||
@ -28,6 +28,24 @@ class Generation(Serializable):
|
||||
return True
|
||||
|
||||
|
||||
class GenerationChunk(Generation):
|
||||
def __add__(self, other: GenerationChunk) -> GenerationChunk:
|
||||
if isinstance(other, GenerationChunk):
|
||||
generation_info = (
|
||||
{**(self.generation_info or {}), **(other.generation_info or {})}
|
||||
if self.generation_info is not None or other.generation_info is not None
|
||||
else None
|
||||
)
|
||||
return GenerationChunk(
|
||||
text=self.text + other.text,
|
||||
generation_info=generation_info,
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||
)
|
||||
|
||||
|
||||
class ChatGeneration(Generation):
|
||||
"""A single chat generation output."""
|
||||
|
||||
@ -43,6 +61,26 @@ class ChatGeneration(Generation):
|
||||
return values
|
||||
|
||||
|
||||
class ChatGenerationChunk(ChatGeneration):
|
||||
message: BaseMessageChunk
|
||||
|
||||
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
|
||||
if isinstance(other, ChatGenerationChunk):
|
||||
generation_info = (
|
||||
{**(self.generation_info or {}), **(other.generation_info or {})}
|
||||
if self.generation_info is not None or other.generation_info is not None
|
||||
else None
|
||||
)
|
||||
return ChatGenerationChunk(
|
||||
message=self.message + other.message,
|
||||
generation_info=generation_info,
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||
)
|
||||
|
||||
|
||||
class RunInfo(BaseModel):
|
||||
"""Class that contains metadata for a single execution of a Chain or model."""
|
||||
|
||||
|
@ -1,16 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, List, Optional, TypeVar
|
||||
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.output import Generation
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.output import ChatGeneration, Generation
|
||||
from langchain.schema.prompt import PromptValue
|
||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
|
||||
class BaseLLMOutputParser(Serializable, Generic[T], ABC):
|
||||
"""Abstract base class for parsing the outputs of a model."""
|
||||
|
||||
@abstractmethod
|
||||
@ -26,7 +28,19 @@ class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
|
||||
"""
|
||||
|
||||
|
||||
class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
|
||||
class BaseGenerationOutputParser(
|
||||
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||
):
|
||||
def invoke(
|
||||
self, input: str | BaseMessage, config: RunnableConfig | None = None
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return self.parse_result([ChatGeneration(message=input)])
|
||||
else:
|
||||
return self.parse_result([Generation(text=input)])
|
||||
|
||||
|
||||
class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]):
|
||||
"""Base class to parse the output of an LLM call.
|
||||
|
||||
Output parsers help structure language model responses.
|
||||
@ -53,6 +67,14 @@ class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
|
||||
return "boolean_output_parser"
|
||||
""" # noqa: E501
|
||||
|
||||
def invoke(
|
||||
self, input: str | BaseMessage, config: RunnableConfig | None = None
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return self.parse_result([ChatGeneration(message=input)])
|
||||
else:
|
||||
return self.parse_result([Generation(text=input)])
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
|
@ -12,9 +12,10 @@ from langchain.load.serializable import Serializable
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.output_parser import BaseOutputParser
|
||||
from langchain.schema.prompt import PromptValue
|
||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||
|
||||
|
||||
class BasePromptTemplate(Serializable, ABC):
|
||||
class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
|
||||
"""Base class for all prompt templates, returning a prompt."""
|
||||
|
||||
input_variables: List[str]
|
||||
@ -34,6 +35,11 @@ class BasePromptTemplate(Serializable, ABC):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue:
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.format_prompt(**inner_input), input, config
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""Create Chat Messages."""
|
||||
|
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import (
|
||||
@ -17,7 +18,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
class BaseRetriever(Serializable, ABC):
|
||||
class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
"""Abstract base class for a Document retrieval system.
|
||||
|
||||
A retrieval system is defined as something that can take string queries and return
|
||||
@ -43,9 +44,6 @@ class BaseRetriever(Serializable, ABC):
|
||||
# Op -- (n_docs,1) -- Cosine Sim with each doc
|
||||
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
|
||||
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
""" # noqa: E501
|
||||
|
||||
class Config:
|
||||
@ -106,6 +104,20 @@ class BaseRetriever(Serializable, ABC):
|
||||
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None
|
||||
) -> List[Document]:
|
||||
return self.get_relevant_documents(input, **(config or {}))
|
||||
|
||||
async def ainvoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None
|
||||
) -> List[Document]:
|
||||
if type(self).aget_relevant_documents == BaseRetriever.aget_relevant_documents:
|
||||
# If the retriever doesn't implement async, use default implementation
|
||||
return await super().ainvoke(input, config)
|
||||
|
||||
return await self.aget_relevant_documents(input, **(config or {}))
|
||||
|
||||
@abstractmethod
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
@ -118,7 +130,6 @@ class BaseRetriever(Serializable, ABC):
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
@ -129,6 +140,7 @@ class BaseRetriever(Serializable, ABC):
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
|
705
libs/langchain/langchain/schema/runnable.py
Normal file
705
libs/langchain/langchain/schema/runnable.py
Normal file
@ -0,0 +1,705 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
|
||||
|
||||
async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||
async with semaphore:
|
||||
return await coro
|
||||
|
||||
|
||||
async def _gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
|
||||
if n is None:
|
||||
return await asyncio.gather(*coros)
|
||||
|
||||
semaphore = asyncio.Semaphore(n)
|
||||
|
||||
return await asyncio.gather(*(_gated_coro(semaphore, c) for c in coros))
|
||||
|
||||
|
||||
class RunnableConfig(TypedDict, total=False):
|
||||
tags: List[str]
|
||||
"""
|
||||
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
You can use these to filter calls.
|
||||
"""
|
||||
|
||||
metadata: Dict[str, Any]
|
||||
"""
|
||||
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
Keys should be strings, values should be JSON-serializable.
|
||||
"""
|
||||
|
||||
callbacks: Callbacks
|
||||
"""
|
||||
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
|
||||
"""
|
||||
|
||||
|
||||
Input = TypeVar("Input")
|
||||
# Output type should implement __concat__, as eg str, list, dict do
|
||||
Output = TypeVar("Output")
|
||||
Other = TypeVar("Other")
|
||||
|
||||
|
||||
class Runnable(Generic[Input, Output], ABC):
|
||||
def __or__(
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Any, Other],
|
||||
Dict[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||
],
|
||||
) -> RunnableSequence[Input, Other]:
|
||||
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
|
||||
|
||||
def __ror__(
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Other, Any],
|
||||
Dict[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||
],
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
...
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.invoke, input, config
|
||||
)
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
return list(executor.map(self.invoke, inputs, configs))
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
coros = map(self.ainvoke, inputs, configs)
|
||||
|
||||
return await _gather_with_concurrency(max_concurrency, *coros)
|
||||
|
||||
def stream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
yield self.invoke(input, config)
|
||||
|
||||
async def astream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
yield await self.ainvoke(input, config)
|
||||
|
||||
def _get_config_list(
|
||||
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
|
||||
) -> List[RunnableConfig]:
|
||||
if isinstance(config, list) and len(config) != length:
|
||||
raise ValueError(
|
||||
f"config must be a list of the same length as inputs, "
|
||||
f"but got {len(config)} configs for {length} inputs"
|
||||
)
|
||||
|
||||
return (
|
||||
config
|
||||
if isinstance(config, list)
|
||||
else [config.copy() if config is not None else {} for _ in range(length)]
|
||||
)
|
||||
|
||||
def _call_with_config(
|
||||
self,
|
||||
func: Callable[[Input], Output],
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig],
|
||||
) -> Output:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
config = config or {}
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
try:
|
||||
output = func(input)
|
||||
except Exception as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(
|
||||
output if isinstance(output, dict) else {"output": output}
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
first: Runnable[Input, Any]
|
||||
middle: List[Runnable[Any, Any]] = Field(default_factory=list)
|
||||
last: Runnable[Any, Output]
|
||||
|
||||
@property
|
||||
def steps(self) -> List[Runnable[Any, Any]]:
|
||||
return [self.first] + self.middle + [self.last]
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __or__(
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Any, Other],
|
||||
Dict[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||
],
|
||||
) -> RunnableSequence[Input, Other]:
|
||||
if isinstance(other, RunnableSequence):
|
||||
return RunnableSequence(
|
||||
first=self.first,
|
||||
middle=self.middle + [self.last] + other.middle,
|
||||
last=other.last,
|
||||
)
|
||||
else:
|
||||
return RunnableSequence(
|
||||
first=self.first,
|
||||
middle=self.middle + [self.last],
|
||||
last=_coerce_to_runnable(other),
|
||||
)
|
||||
|
||||
def __ror__(
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Other, Any],
|
||||
Dict[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||
],
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
if isinstance(other, RunnableSequence):
|
||||
return RunnableSequence(
|
||||
first=other.first,
|
||||
middle=other.middle + [other.last] + self.middle,
|
||||
last=self.last,
|
||||
)
|
||||
else:
|
||||
return RunnableSequence(
|
||||
first=_coerce_to_runnable(other),
|
||||
middle=[self.first] + self.middle,
|
||||
last=self.last,
|
||||
)
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
for step in self.steps:
|
||||
input = step.invoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(
|
||||
input if isinstance(input, dict) else {"output": input}
|
||||
)
|
||||
return cast(Output, input)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
for step in self.steps:
|
||||
input = await step.ainvoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(
|
||||
input if isinstance(input, dict) else {"output": input}
|
||||
)
|
||||
return cast(Output, input)
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
for config in configs
|
||||
]
|
||||
# start the root runs, one per input
|
||||
run_managers = [
|
||||
cm.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
for cm, input in zip(callback_managers, inputs)
|
||||
]
|
||||
|
||||
# invoke
|
||||
try:
|
||||
for step in self.steps:
|
||||
inputs = step.batch(
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
_patch_config(config, rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
# finish the root runs
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
for rm in run_managers:
|
||||
rm.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
for rm, input in zip(run_managers, inputs):
|
||||
rm.on_chain_end(input if isinstance(input, dict) else {"output": input})
|
||||
return cast(List[Output], inputs)
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
)
|
||||
|
||||
# setup callbacks
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
for config in configs
|
||||
]
|
||||
# start the root runs, one per input
|
||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
*(
|
||||
cm.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
for cm, input in zip(callback_managers, inputs)
|
||||
)
|
||||
)
|
||||
|
||||
# invoke .batch() on each step
|
||||
# this uses batching optimizations in Runnable subclasses, like LLM
|
||||
try:
|
||||
for step in self.steps:
|
||||
inputs = await step.abatch(
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
_patch_config(config, rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
# finish the root runs
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
||||
raise
|
||||
else:
|
||||
await asyncio.gather(
|
||||
*(
|
||||
rm.on_chain_end(
|
||||
input if isinstance(input, dict) else {"output": input}
|
||||
)
|
||||
for rm, input in zip(run_managers, inputs)
|
||||
)
|
||||
)
|
||||
return cast(List[Output], inputs)
|
||||
|
||||
def stream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
|
||||
# invoke the first steps
|
||||
try:
|
||||
for step in [self.first] + self.middle:
|
||||
input = step.invoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
|
||||
# stream the last step
|
||||
final: Union[Output, None] = None
|
||||
final_supported = True
|
||||
try:
|
||||
for output in self.last.stream(
|
||||
input,
|
||||
# mark the last step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
):
|
||||
yield output
|
||||
# Accumulate output if possible, otherwise disable accumulation
|
||||
if final_supported:
|
||||
if final is None:
|
||||
final = output
|
||||
else:
|
||||
try:
|
||||
final += output # type: ignore[operator]
|
||||
except TypeError:
|
||||
final = None
|
||||
final_supported = False
|
||||
pass
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(
|
||||
final if isinstance(final, dict) else {"output": final}
|
||||
)
|
||||
|
||||
async def astream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
|
||||
# invoke the first steps
|
||||
try:
|
||||
for step in [self.first] + self.middle:
|
||||
input = await step.ainvoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
|
||||
# stream the last step
|
||||
final: Union[Output, None] = None
|
||||
final_supported = True
|
||||
try:
|
||||
async for output in self.last.astream(
|
||||
input,
|
||||
# mark the last step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
):
|
||||
yield output
|
||||
# Accumulate output if possible, otherwise disable accumulation
|
||||
if final_supported:
|
||||
if final is None:
|
||||
final = output
|
||||
else:
|
||||
try:
|
||||
final += output # type: ignore[operator]
|
||||
except TypeError:
|
||||
final = None
|
||||
final_supported = False
|
||||
pass
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(
|
||||
final if isinstance(final, dict) else {"output": final}
|
||||
)
|
||||
|
||||
|
||||
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
steps: Dict[str, Runnable[Input, Any]]
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(dumpd(self), {"input": input})
|
||||
|
||||
# gather results from all steps
|
||||
try:
|
||||
# copy to avoid issues from the caller mutating the steps during invoke()
|
||||
steps = self.steps.copy()
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
step.invoke,
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
for step in steps.values()
|
||||
]
|
||||
output = {key: future.result() for key, future in zip(steps, futures)}
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(output)
|
||||
return output
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), {"input": input}
|
||||
)
|
||||
|
||||
# gather results from all steps
|
||||
try:
|
||||
# copy to avoid issues from the caller mutating the steps during invoke()
|
||||
steps = self.steps.copy()
|
||||
results = await asyncio.gather(
|
||||
*(
|
||||
step.ainvoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
for step in steps.values()
|
||||
)
|
||||
)
|
||||
output = {key: value for key, value in zip(steps, results)}
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(output)
|
||||
return output
|
||||
|
||||
|
||||
class RunnableLambda(Runnable[Input, Output]):
|
||||
def __init__(self, func: Callable[[Input], Output]) -> None:
|
||||
if callable(func):
|
||||
self.func = func
|
||||
else:
|
||||
raise TypeError(
|
||||
"Expected a callable type for `func`."
|
||||
f"Instead got an unsupported type: {type(func)}"
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, RunnableLambda):
|
||||
return self.func == other.func
|
||||
else:
|
||||
return False
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
return self._call_with_config(self.func, input, config)
|
||||
|
||||
|
||||
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||
return self._call_with_config(lambda x: x, input, config)
|
||||
|
||||
|
||||
def _patch_config(
|
||||
config: RunnableConfig, callback_manager: BaseCallbackManager
|
||||
) -> RunnableConfig:
|
||||
config = config.copy()
|
||||
config["callbacks"] = callback_manager
|
||||
return config
|
||||
|
||||
|
||||
def _coerce_to_runnable(
|
||||
thing: Union[
|
||||
Runnable[Input, Output],
|
||||
Callable[[Input], Output],
|
||||
Dict[str, Union[Runnable[Input, Output], Callable[[Input], Output]]],
|
||||
]
|
||||
) -> Runnable[Input, Output]:
|
||||
if isinstance(thing, Runnable):
|
||||
return thing
|
||||
elif callable(thing):
|
||||
return RunnableLambda(thing)
|
||||
elif isinstance(thing, dict):
|
||||
runnables = {key: _coerce_to_runnable(r) for key, r in thing.items()}
|
||||
return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected a Runnable, callable or dict."
|
||||
f"Instead got an unsupported type: {type(thing)}"
|
||||
)
|
@ -177,3 +177,68 @@ def test_chat_openai_extra_kwargs() -> None:
|
||||
# Test that "model" cannot be specified in kwargs
|
||||
with pytest.raises(ValueError):
|
||||
ChatOpenAI(model_kwargs={"model": "text-davinci-003"})
|
||||
|
||||
|
||||
def test_openai_streaming() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_abatch() -> None:
|
||||
"""Test streaming tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_abatch_tags() -> None:
|
||||
"""Test batch tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
def test_openai_batch() -> None:
|
||||
"""Test batch tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_ainvoke() -> None:
|
||||
"""Test invoke tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def test_openai_invoke() -> None:
|
||||
"""Test invoke tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
||||
|
@ -93,7 +93,64 @@ def test_openai_streaming() -> None:
|
||||
assert isinstance(generator, Generator)
|
||||
|
||||
for token in generator:
|
||||
assert isinstance(token["choices"][0]["text"], str)
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_abatch() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_abatch_tags() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_openai_batch() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_ainvoke() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_openai_invoke() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_openai_multiple_prompts() -> None:
|
||||
@ -105,13 +162,6 @@ def test_openai_multiple_prompts() -> None:
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
def test_openai_streaming_error() -> None:
|
||||
"""Test error handling in stream."""
|
||||
llm = OpenAI(best_of=2)
|
||||
with pytest.raises(ValueError):
|
||||
llm.stream("I'm Pickle Rick")
|
||||
|
||||
|
||||
def test_openai_streaming_best_of_error() -> None:
|
||||
"""Test validation for streaming fails if best_of is not 1."""
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -67,10 +67,3 @@ def test_promptlayer_openai_streaming() -> None:
|
||||
|
||||
for token in generator:
|
||||
assert isinstance(token["choices"][0]["text"], str)
|
||||
|
||||
|
||||
def test_promptlayer_openai_streaming_error() -> None:
|
||||
"""Test error handling in stream."""
|
||||
llm = PromptLayerOpenAI(best_of=2)
|
||||
with pytest.raises(ValueError):
|
||||
llm.stream("I'm Pickle Rick")
|
||||
|
File diff suppressed because one or more lines are too long
38
libs/langchain/tests/unit_tests/schema/test_messages.py
Normal file
38
libs/langchain/tests/unit_tests/schema/test_messages.py
Normal file
@ -0,0 +1,38 @@
|
||||
from langchain.schema.messages import AIMessageChunk, HumanMessageChunk
|
||||
|
||||
|
||||
def test_message_chunks() -> None:
|
||||
assert AIMessageChunk(content="I am") + AIMessageChunk(
|
||||
content=" indeed."
|
||||
) == AIMessageChunk(
|
||||
content="I am indeed."
|
||||
), "MessageChunk + MessageChunk should be a MessageChunk"
|
||||
|
||||
assert AIMessageChunk(content="I am") + HumanMessageChunk(
|
||||
content=" indeed."
|
||||
) == AIMessageChunk(
|
||||
content="I am indeed."
|
||||
), "MessageChunk + MessageChunk should be a MessageChunk of same class as the left side" # noqa: E501
|
||||
|
||||
assert AIMessageChunk(
|
||||
content="", additional_kwargs={"foo": "bar"}
|
||||
) + AIMessageChunk(content="", additional_kwargs={"baz": "foo"}) == AIMessageChunk(
|
||||
content="", additional_kwargs={"foo": "bar", "baz": "foo"}
|
||||
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
|
||||
|
||||
assert AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"name": "web_search"}}
|
||||
) + AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": "{\n"}}
|
||||
) + AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={"function_call": {"arguments": ' "query": "turtles"\n}'}},
|
||||
) == AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "web_search",
|
||||
"arguments": '{\n "query": "turtles"\n}',
|
||||
}
|
||||
},
|
||||
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
|
52
libs/langchain/tests/unit_tests/schema/test_output.py
Normal file
52
libs/langchain/tests/unit_tests/schema/test_output.py
Normal file
@ -0,0 +1,52 @@
|
||||
from langchain.schema.messages import HumanMessageChunk
|
||||
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
|
||||
|
||||
|
||||
def test_generation_chunk() -> None:
|
||||
assert GenerationChunk(text="Hello, ") + GenerationChunk(
|
||||
text="world!"
|
||||
) == GenerationChunk(
|
||||
text="Hello, world!"
|
||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk"
|
||||
|
||||
assert GenerationChunk(text="Hello, ") + GenerationChunk(
|
||||
text="world!", generation_info={"foo": "bar"}
|
||||
) == GenerationChunk(
|
||||
text="Hello, world!", generation_info={"foo": "bar"}
|
||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
||||
|
||||
assert GenerationChunk(text="Hello, ") + GenerationChunk(
|
||||
text="world!", generation_info={"foo": "bar"}
|
||||
) + GenerationChunk(text="!", generation_info={"baz": "foo"}) == GenerationChunk(
|
||||
text="Hello, world!!", generation_info={"foo": "bar", "baz": "foo"}
|
||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
||||
|
||||
|
||||
def test_chat_generation_chunk() -> None:
|
||||
assert ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="Hello, ")
|
||||
) + ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="world!")
|
||||
) == ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="Hello, world!")
|
||||
), "ChatGenerationChunk + ChatGenerationChunk should be a ChatGenerationChunk"
|
||||
|
||||
assert ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="Hello, ")
|
||||
) + ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"}
|
||||
) == ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="Hello, world!"),
|
||||
generation_info={"foo": "bar"},
|
||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
||||
|
||||
assert ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="Hello, ")
|
||||
) + ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"}
|
||||
) + ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="!"), generation_info={"baz": "foo"}
|
||||
) == ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="Hello, world!!"),
|
||||
generation_info={"foo": "bar", "baz": "foo"},
|
||||
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
547
libs/langchain/tests/unit_tests/schema/test_runnable.py
Normal file
547
libs/langchain/tests/unit_tests/schema/test_runnable.py
Normal file
@ -0,0 +1,547 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from pytest_mock import MockerFixture
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.chat_models.fake import FakeListChatModel
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.output_parsers.list import CommaSeparatedListOutputParser
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
ChatPromptValue,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.schema.runnable import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
RunnableMap,
|
||||
RunnablePassthrough,
|
||||
RunnableSequence,
|
||||
)
|
||||
|
||||
|
||||
class FakeTracer(BaseTracer):
|
||||
"""Fake tracer that records LangChain execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the tracer."""
|
||||
super().__init__()
|
||||
self.runs: List[Run] = []
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
self.runs.append(run)
|
||||
|
||||
|
||||
class FakeRunnable(Runnable[str, int]):
|
||||
def invoke(
|
||||
self,
|
||||
input: str,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> int:
|
||||
return len(input)
|
||||
|
||||
|
||||
class FakeRetriever(BaseRetriever):
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return [Document(page_content="foo"), Document(page_content="bar")]
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return [Document(page_content="foo"), Document(page_content="bar")]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fixed_uuids(mocker: MockerFixture) -> MockerFixture._Patcher:
|
||||
"""Note this mock only works with `import uuid; uuid.uuid4()`,
|
||||
it does not work with `from uuid import uuid4; uuid4()`."""
|
||||
|
||||
# Disable tracing to avoid fixed UUIDs causing tracing errors.
|
||||
mocker.patch.dict("os.environ", {"LANGCHAIN_TRACING_V2": "false"})
|
||||
|
||||
side_effect = (
|
||||
UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)
|
||||
)
|
||||
return mocker.patch("uuid.uuid4", side_effect=side_effect)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
fake = FakeRunnable()
|
||||
spy = mocker.spy(fake, "invoke")
|
||||
|
||||
assert fake.invoke("hello", dict(tags=["a-tag"])) == 5
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(tags=["a-tag"])),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
assert [*fake.stream("hello", dict(metadata={"key": "value"}))] == [5]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(metadata={"key": "value"})),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
assert fake.batch(
|
||||
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
|
||||
) == [5, 7]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(tags=["a-tag"])),
|
||||
mocker.call("wooorld", dict(metadata={"key": "value"})),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(tags=["a-tag"])),
|
||||
mocker.call("wooorld", dict(tags=["a-tag"])),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
assert await fake.ainvoke("hello", config={"callbacks": []}) == 5
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(callbacks=[])),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
assert [part async for part in fake.astream("hello")] == [5]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", None),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
assert await fake.abatch(["hello", "wooorld"], dict(metadata={"key": "value"})) == [
|
||||
5,
|
||||
7,
|
||||
]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(metadata={"key": "value"})),
|
||||
mocker.call("wooorld", dict(metadata={"key": "value"})),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt() -> None:
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessagePromptTemplate.from_template("{question}"),
|
||||
]
|
||||
)
|
||||
expected = ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
|
||||
assert prompt.invoke({"question": "What is your name?"}) == expected
|
||||
|
||||
assert prompt.batch(
|
||||
[
|
||||
{"question": "What is your name?"},
|
||||
{"question": "What is your favorite color?"},
|
||||
]
|
||||
) == [
|
||||
expected,
|
||||
ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your favorite color?"),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
assert [*prompt.stream({"question": "What is your name?"})] == [expected]
|
||||
|
||||
assert await prompt.ainvoke({"question": "What is your name?"}) == expected
|
||||
|
||||
assert await prompt.abatch(
|
||||
[
|
||||
{"question": "What is your name?"},
|
||||
{"question": "What is your favorite color?"},
|
||||
]
|
||||
) == [
|
||||
expected,
|
||||
ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your favorite color?"),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
assert [
|
||||
part async for part in prompt.astream({"question": "What is your name?"})
|
||||
] == [expected]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_prompt_with_chat_model(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
chat = FakeListChatModel(responses=["foo", "bar"])
|
||||
|
||||
chain = prompt | chat
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == []
|
||||
assert chain.last == chat
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
# Test invoke
|
||||
prompt_spy = mocker.spy(prompt.__class__, "invoke")
|
||||
chat_spy = mocker.spy(chat.__class__, "invoke")
|
||||
tracer = FakeTracer()
|
||||
assert chain.invoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == AIMessage(content="foo")
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert tracer.runs == snapshot
|
||||
mocker.stop(prompt_spy)
|
||||
mocker.stop(chat_spy)
|
||||
|
||||
# Test batch
|
||||
prompt_spy = mocker.spy(prompt.__class__, "batch")
|
||||
chat_spy = mocker.spy(chat.__class__, "batch")
|
||||
tracer = FakeTracer()
|
||||
assert chain.batch(
|
||||
[
|
||||
{"question": "What is your name?"},
|
||||
{"question": "What is your favorite color?"},
|
||||
],
|
||||
dict(callbacks=[tracer]),
|
||||
) == [
|
||||
AIMessage(content="bar"),
|
||||
AIMessage(content="foo"),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == [
|
||||
{"question": "What is your name?"},
|
||||
{"question": "What is your favorite color?"},
|
||||
]
|
||||
assert chat_spy.call_args.args[1] == [
|
||||
ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
),
|
||||
ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your favorite color?"),
|
||||
]
|
||||
),
|
||||
]
|
||||
assert tracer.runs == snapshot
|
||||
mocker.stop(prompt_spy)
|
||||
mocker.stop(chat_spy)
|
||||
|
||||
# Test stream
|
||||
prompt_spy = mocker.spy(prompt.__class__, "invoke")
|
||||
chat_spy = mocker.spy(chat.__class__, "stream")
|
||||
tracer = FakeTracer()
|
||||
assert [
|
||||
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
|
||||
] == [AIMessage(content="bar")]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_prompt_with_llm(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
llm = FakeListLLM(responses=["foo", "bar"])
|
||||
|
||||
chain = prompt | llm
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == []
|
||||
assert chain.last == llm
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
# Test invoke
|
||||
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||
llm_spy = mocker.spy(llm.__class__, "ainvoke")
|
||||
tracer = FakeTracer()
|
||||
assert (
|
||||
await chain.ainvoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
)
|
||||
== "foo"
|
||||
)
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert llm_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert tracer.runs == snapshot
|
||||
mocker.stop(prompt_spy)
|
||||
mocker.stop(llm_spy)
|
||||
|
||||
# Test batch
|
||||
prompt_spy = mocker.spy(prompt.__class__, "abatch")
|
||||
llm_spy = mocker.spy(llm.__class__, "abatch")
|
||||
tracer = FakeTracer()
|
||||
assert await chain.abatch(
|
||||
[
|
||||
{"question": "What is your name?"},
|
||||
{"question": "What is your favorite color?"},
|
||||
],
|
||||
dict(callbacks=[tracer]),
|
||||
) == ["bar", "foo"]
|
||||
assert prompt_spy.call_args.args[1] == [
|
||||
{"question": "What is your name?"},
|
||||
{"question": "What is your favorite color?"},
|
||||
]
|
||||
assert llm_spy.call_args.args[1] == [
|
||||
ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
),
|
||||
ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your favorite color?"),
|
||||
]
|
||||
),
|
||||
]
|
||||
assert tracer.runs == snapshot
|
||||
mocker.stop(prompt_spy)
|
||||
mocker.stop(llm_spy)
|
||||
|
||||
# Test stream
|
||||
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||
llm_spy = mocker.spy(llm.__class__, "astream")
|
||||
tracer = FakeTracer()
|
||||
assert [
|
||||
token
|
||||
async for token in chain.astream(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
)
|
||||
] == ["bar"]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert llm_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_prompt_with_chat_model_and_parser(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
chat = FakeListChatModel(responses=["foo, bar"])
|
||||
parser = CommaSeparatedListOutputParser()
|
||||
|
||||
chain = prompt | chat | parser
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == [chat]
|
||||
assert chain.last == parser
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
# Test invoke
|
||||
prompt_spy = mocker.spy(prompt.__class__, "invoke")
|
||||
chat_spy = mocker.spy(chat.__class__, "invoke")
|
||||
parser_spy = mocker.spy(parser.__class__, "invoke")
|
||||
tracer = FakeTracer()
|
||||
assert chain.invoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == ["foo", "bar"]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar")
|
||||
assert tracer.runs == snapshot
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_seq_dict_prompt_llm(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||
) -> None:
|
||||
passthrough = mocker.Mock(side_effect=lambda x: x)
|
||||
|
||||
retriever = FakeRetriever()
|
||||
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ """Context:
|
||||
{documents}
|
||||
|
||||
Question:
|
||||
{question}"""
|
||||
)
|
||||
|
||||
chat = FakeListChatModel(responses=["foo, bar"])
|
||||
|
||||
parser = CommaSeparatedListOutputParser()
|
||||
|
||||
chain = (
|
||||
{
|
||||
"question": RunnablePassthrough[str]() | passthrough,
|
||||
"documents": passthrough | retriever,
|
||||
"just_to_test_lambda": passthrough,
|
||||
}
|
||||
| prompt
|
||||
| chat
|
||||
| parser
|
||||
)
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert isinstance(chain.first, RunnableMap)
|
||||
assert chain.middle == [prompt, chat]
|
||||
assert chain.last == parser
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
# Test invoke
|
||||
prompt_spy = mocker.spy(prompt.__class__, "invoke")
|
||||
chat_spy = mocker.spy(chat.__class__, "invoke")
|
||||
parser_spy = mocker.spy(parser.__class__, "invoke")
|
||||
tracer = FakeTracer()
|
||||
assert chain.invoke("What is your name?", dict(callbacks=[tracer])) == [
|
||||
"foo",
|
||||
"bar",
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == {
|
||||
"documents": [Document(page_content="foo"), Document(page_content="bar")],
|
||||
"question": "What is your name?",
|
||||
"just_to_test_lambda": "What is your name?",
|
||||
}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(
|
||||
content="""Context:
|
||||
[Document(page_content='foo', metadata={}), Document(page_content='bar', metadata={})]
|
||||
|
||||
Question:
|
||||
What is your name?"""
|
||||
),
|
||||
]
|
||||
)
|
||||
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar")
|
||||
assert tracer.runs == snapshot
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_seq_prompt_dict(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||
) -> None:
|
||||
passthrough = mocker.Mock(side_effect=lambda x: x)
|
||||
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
|
||||
chat = FakeListChatModel(responses=["i'm a chatbot"])
|
||||
|
||||
llm = FakeListLLM(responses=["i'm a textbot"])
|
||||
|
||||
chain = (
|
||||
prompt
|
||||
| passthrough
|
||||
| { # type: ignore
|
||||
"chat": chat,
|
||||
"llm": llm,
|
||||
}
|
||||
)
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == [RunnableLambda(passthrough)]
|
||||
assert isinstance(chain.last, RunnableMap)
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
# Test invoke
|
||||
prompt_spy = mocker.spy(prompt.__class__, "invoke")
|
||||
chat_spy = mocker.spy(chat.__class__, "invoke")
|
||||
llm_spy = mocker.spy(llm.__class__, "invoke")
|
||||
tracer = FakeTracer()
|
||||
assert chain.invoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == {
|
||||
"chat": AIMessage(content="i'm a chatbot"),
|
||||
"llm": "i'm a textbot",
|
||||
}
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert llm_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert tracer.runs == snapshot
|
Loading…
Reference in New Issue
Block a user