mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
Runnable single protocol (#7800)
Objects implementing Runnable: BasePromptTemplate, LLM, ChatModel, Chain, Retriever, OutputParser - [x] Implement Runnable in base Retriever - [x] Raise TypeError in operator methods for unsupported things - [x] Implement dict which calls values in parallel and outputs dict with results - [x] Merge in `+` for prompts - [x] Confirm precedence order for operators, ideal would be `+` `|`, https://docs.python.org/3/reference/expressions.html#operator-precedence - [x] Add support for openai functions, ie. Chat Models must return messages - [x] Implement BaseMessageChunk return type for BaseChatModel, a subclass of BaseMessage which implements __add__ to return BaseMessageChunk, concatenating all str args - [x] Update implementation of stream/astream for llm and chat models to use new `_stream`, `_astream` optional methods, with default implementation in base class `raise NotImplementedError` use https://stackoverflow.com/a/59762827 to see if it is implemented in base class - [x] Delete the IteratorCallbackHandler (leave the async one because people using) - [x] Make BaseLLMOutputParser implement Runnable, accepting either str or BaseMessage --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
04a4d3e312
commit
a612800ef0
@ -9,10 +9,7 @@ from typing import (
|
|||||||
Tuple,
|
Tuple,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
AsyncCallbackManagerForLLMRun,
|
|
||||||
CallbackManagerForLLMRun,
|
|
||||||
)
|
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
@ -116,15 +113,6 @@ class ChatLlamaAPI(BaseChatModel):
|
|||||||
generations.append(gen)
|
generations.append(gen)
|
||||||
return ChatResult(generations=generations)
|
return ChatResult(generations=generations)
|
||||||
|
|
||||||
async def _agenerate(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatResult:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _client_params(self) -> Mapping[str, Any]:
|
def _client_params(self) -> Mapping[str, Any]:
|
||||||
"""Get the parameters used for the client."""
|
"""Get the parameters used for the client."""
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
"""Base callback handler that can be used to handle callbacks in langchain."""
|
"""Base callback handler that can be used to handle callbacks in langchain."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain.schema.agent import AgentAction, AgentFinish
|
if TYPE_CHECKING:
|
||||||
from langchain.schema.document import Document
|
from langchain.schema.agent import AgentAction, AgentFinish
|
||||||
from langchain.schema.messages import BaseMessage
|
from langchain.schema.document import Document
|
||||||
from langchain.schema.output import LLMResult
|
from langchain.schema.messages import BaseMessage
|
||||||
|
from langchain.schema.output import LLMResult
|
||||||
|
|
||||||
|
|
||||||
class RetrieverManagerMixin:
|
class RetrieverManagerMixin:
|
||||||
@ -543,3 +544,6 @@ class BaseCallbackManager(CallbackManagerMixin):
|
|||||||
for key in keys:
|
for key in keys:
|
||||||
self.metadata.pop(key)
|
self.metadata.pop(key)
|
||||||
self.inheritable_metadata.pop(key)
|
self.inheritable_metadata.pop(key)
|
||||||
|
|
||||||
|
|
||||||
|
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||||
|
@ -4,6 +4,7 @@ import asyncio
|
|||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
from contextlib import asynccontextmanager, contextmanager
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -20,12 +21,13 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.callbacks.base import (
|
from langchain.callbacks.base import (
|
||||||
BaseCallbackHandler,
|
BaseCallbackHandler,
|
||||||
BaseCallbackManager,
|
BaseCallbackManager,
|
||||||
|
Callbacks,
|
||||||
ChainManagerMixin,
|
ChainManagerMixin,
|
||||||
LLMManagerMixin,
|
LLMManagerMixin,
|
||||||
RetrieverManagerMixin,
|
RetrieverManagerMixin,
|
||||||
@ -50,7 +52,6 @@ if TYPE_CHECKING:
|
|||||||
from langsmith import Client as LangSmithClient
|
from langsmith import Client as LangSmithClient
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
|
||||||
|
|
||||||
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
|
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
|
||||||
"openai_callback", default=None
|
"openai_callback", default=None
|
||||||
@ -437,7 +438,7 @@ class BaseRunManager(RunManagerMixin):
|
|||||||
BaseRunManager: The noop manager.
|
BaseRunManager: The noop manager.
|
||||||
"""
|
"""
|
||||||
return cls(
|
return cls(
|
||||||
run_id=uuid4(),
|
run_id=uuid.uuid4(),
|
||||||
handlers=[],
|
handlers=[],
|
||||||
inheritable_handlers=[],
|
inheritable_handlers=[],
|
||||||
tags=[],
|
tags=[],
|
||||||
@ -1024,7 +1025,7 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
"""
|
"""
|
||||||
managers = []
|
managers = []
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
run_id_ = uuid4()
|
run_id_ = uuid.uuid4()
|
||||||
_handle_event(
|
_handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_llm_start",
|
"on_llm_start",
|
||||||
@ -1073,7 +1074,7 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
|
|
||||||
managers = []
|
managers = []
|
||||||
for message_list in messages:
|
for message_list in messages:
|
||||||
run_id_ = uuid4()
|
run_id_ = uuid.uuid4()
|
||||||
_handle_event(
|
_handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_chat_model_start",
|
"on_chat_model_start",
|
||||||
@ -1120,7 +1121,7 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
CallbackManagerForChainRun: The callback manager for the chain run.
|
CallbackManagerForChainRun: The callback manager for the chain run.
|
||||||
"""
|
"""
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = uuid4()
|
run_id = uuid.uuid4()
|
||||||
|
|
||||||
_handle_event(
|
_handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
@ -1166,7 +1167,7 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
CallbackManagerForToolRun: The callback manager for the tool run.
|
CallbackManagerForToolRun: The callback manager for the tool run.
|
||||||
"""
|
"""
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = uuid4()
|
run_id = uuid.uuid4()
|
||||||
|
|
||||||
_handle_event(
|
_handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
@ -1202,7 +1203,7 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
) -> CallbackManagerForRetrieverRun:
|
) -> CallbackManagerForRetrieverRun:
|
||||||
"""Run when retriever starts running."""
|
"""Run when retriever starts running."""
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = uuid4()
|
run_id = uuid.uuid4()
|
||||||
|
|
||||||
_handle_event(
|
_handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
@ -1302,7 +1303,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
managers = []
|
managers = []
|
||||||
|
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
run_id_ = uuid4()
|
run_id_ = uuid.uuid4()
|
||||||
|
|
||||||
tasks.append(
|
tasks.append(
|
||||||
_ahandle_event(
|
_ahandle_event(
|
||||||
@ -1341,7 +1342,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
messages: List[List[BaseMessage]],
|
messages: List[List[BaseMessage]],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> List[AsyncCallbackManagerForLLMRun]:
|
||||||
"""Run when LLM starts running.
|
"""Run when LLM starts running.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1358,7 +1359,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
managers = []
|
managers = []
|
||||||
|
|
||||||
for message_list in messages:
|
for message_list in messages:
|
||||||
run_id_ = uuid4()
|
run_id_ = uuid.uuid4()
|
||||||
|
|
||||||
tasks.append(
|
tasks.append(
|
||||||
_ahandle_event(
|
_ahandle_event(
|
||||||
@ -1410,7 +1411,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
for the chain run.
|
for the chain run.
|
||||||
"""
|
"""
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = uuid4()
|
run_id = uuid.uuid4()
|
||||||
|
|
||||||
await _ahandle_event(
|
await _ahandle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
@ -1458,7 +1459,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
for the tool run.
|
for the tool run.
|
||||||
"""
|
"""
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = uuid4()
|
run_id = uuid.uuid4()
|
||||||
|
|
||||||
await _ahandle_event(
|
await _ahandle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
@ -1494,7 +1495,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
) -> AsyncCallbackManagerForRetrieverRun:
|
) -> AsyncCallbackManagerForRetrieverRun:
|
||||||
"""Run when retriever starts running."""
|
"""Run when retriever starts running."""
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = uuid4()
|
run_id = uuid.uuid4()
|
||||||
|
|
||||||
await _ahandle_event(
|
await _ahandle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
|
@ -4,7 +4,7 @@ import asyncio
|
|||||||
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast
|
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast
|
||||||
|
|
||||||
from langchain.callbacks.base import AsyncCallbackHandler
|
from langchain.callbacks.base import AsyncCallbackHandler
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema.output import LLMResult
|
||||||
|
|
||||||
# TODO If used by two LLM runs in parallel this won't work as expected
|
# TODO If used by two LLM runs in parallel this won't work as expected
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ from langchain.callbacks.manager import (
|
|||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||||
|
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -30,7 +31,7 @@ def _get_verbosity() -> bool:
|
|||||||
return langchain.verbose
|
return langchain.verbose
|
||||||
|
|
||||||
|
|
||||||
class Chain(Serializable, ABC):
|
class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||||
"""Abstract base class for creating structured sequences of calls to components.
|
"""Abstract base class for creating structured sequences of calls to components.
|
||||||
|
|
||||||
Chains should be used to encode a sequence of calls to components like
|
Chains should be used to encode a sequence of calls to components like
|
||||||
@ -53,6 +54,20 @@ class Chain(Serializable, ABC):
|
|||||||
chains and cannot return as rich of an output as `__call__`.
|
chains and cannot return as rich of an output as `__call__`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return self(input, **(config or {}))
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
if type(self)._acall == Chain._acall:
|
||||||
|
# If the chain does not implement async, fall back to default implementation
|
||||||
|
return await super().ainvoke(input, config)
|
||||||
|
|
||||||
|
return await self.acall(input, **(config or {}))
|
||||||
|
|
||||||
memory: Optional[BaseMemory] = None
|
memory: Optional[BaseMemory] = None
|
||||||
"""Optional memory object. Defaults to None.
|
"""Optional memory object. Defaults to None.
|
||||||
Memory is a class that gets called at the start
|
Memory is a class that gets called at the start
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
@ -12,11 +12,13 @@ from langchain.schema import (
|
|||||||
)
|
)
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
|
from langchain.schema.output import ChatGenerationChunk
|
||||||
|
|
||||||
|
|
||||||
class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||||
@ -94,6 +96,44 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
|||||||
text.rstrip()
|
text.rstrip()
|
||||||
) # trim off the trailing ' ' that might come from the "Assistant: "
|
) # trim off the trailing ' ' that might come from the "Assistant: "
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
prompt = self._convert_messages_to_prompt(messages)
|
||||||
|
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
|
||||||
|
if stop:
|
||||||
|
params["stop_sequences"] = stop
|
||||||
|
|
||||||
|
stream_resp = self.client.completions.create(**params, stream=True)
|
||||||
|
for data in stream_resp:
|
||||||
|
delta = data.completion
|
||||||
|
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(delta)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
prompt = self._convert_messages_to_prompt(messages)
|
||||||
|
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
|
||||||
|
if stop:
|
||||||
|
params["stop_sequences"] = stop
|
||||||
|
|
||||||
|
stream_resp = await self.async_client.completions.create(**params, stream=True)
|
||||||
|
async for data in stream_resp:
|
||||||
|
delta = data.completion
|
||||||
|
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(delta)
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@ -101,22 +141,19 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
prompt = self._convert_messages_to_prompt(messages)
|
|
||||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
|
|
||||||
if stop:
|
|
||||||
params["stop_sequences"] = stop
|
|
||||||
|
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
completion = ""
|
completion = ""
|
||||||
stream_resp = self.client.completions.create(**params, stream=True)
|
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||||
for data in stream_resp:
|
completion += chunk.text
|
||||||
delta = data.completion
|
|
||||||
completion += delta
|
|
||||||
if run_manager:
|
|
||||||
run_manager.on_llm_new_token(
|
|
||||||
delta,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
prompt = self._convert_messages_to_prompt(messages)
|
||||||
|
params: Dict[str, Any] = {
|
||||||
|
"prompt": prompt,
|
||||||
|
**self._default_params,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
if stop:
|
||||||
|
params["stop_sequences"] = stop
|
||||||
response = self.client.completions.create(**params)
|
response = self.client.completions.create(**params)
|
||||||
completion = response.completion
|
completion = response.completion
|
||||||
message = AIMessage(content=completion)
|
message = AIMessage(content=completion)
|
||||||
@ -129,24 +166,19 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
prompt = self._convert_messages_to_prompt(messages)
|
|
||||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
|
|
||||||
if stop:
|
|
||||||
params["stop_sequences"] = stop
|
|
||||||
|
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
completion = ""
|
completion = ""
|
||||||
stream_resp = await self.async_client.completions.create(
|
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
|
||||||
**params, stream=True
|
completion += chunk.text
|
||||||
)
|
|
||||||
async for data in stream_resp:
|
|
||||||
delta = data.completion
|
|
||||||
completion += delta
|
|
||||||
if run_manager:
|
|
||||||
await run_manager.on_llm_new_token(
|
|
||||||
delta,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
prompt = self._convert_messages_to_prompt(messages)
|
||||||
|
params: Dict[str, Any] = {
|
||||||
|
"prompt": prompt,
|
||||||
|
**self._default_params,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
if stop:
|
||||||
|
params["stop_sequences"] = stop
|
||||||
response = await self.async_client.completions.create(**params)
|
response = await self.async_client.completions.create(**params)
|
||||||
completion = response.completion
|
completion = response.completion
|
||||||
message = AIMessage(content=completion)
|
message = AIMessage(content=completion)
|
||||||
|
@ -3,7 +3,16 @@ import inspect
|
|||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional, Sequence
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
from pydantic import Field, root_validator
|
from pydantic import Field, root_validator
|
||||||
|
|
||||||
@ -17,6 +26,8 @@ from langchain.callbacks.manager import (
|
|||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain.load.dump import dumpd, dumps
|
from langchain.load.dump import dumpd, dumps
|
||||||
|
from langchain.prompts.base import StringPromptValue
|
||||||
|
from langchain.prompts.chat import ChatPromptValue
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
ChatResult,
|
ChatResult,
|
||||||
@ -24,17 +35,22 @@ from langchain.schema import (
|
|||||||
PromptValue,
|
PromptValue,
|
||||||
RunInfo,
|
RunInfo,
|
||||||
)
|
)
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
|
||||||
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
from langchain.schema.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
|
HumanMessage,
|
||||||
|
)
|
||||||
|
from langchain.schema.output import ChatGenerationChunk
|
||||||
|
from langchain.schema.runnable import RunnableConfig
|
||||||
|
|
||||||
|
|
||||||
def _get_verbosity() -> bool:
|
def _get_verbosity() -> bool:
|
||||||
return langchain.verbose
|
return langchain.verbose
|
||||||
|
|
||||||
|
|
||||||
class BaseChatModel(BaseLanguageModel, ABC):
|
class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||||
"""Base class for chat models."""
|
|
||||||
|
|
||||||
cache: Optional[bool] = None
|
cache: Optional[bool] = None
|
||||||
"""Whether to cache the response."""
|
"""Whether to cache the response."""
|
||||||
verbose: bool = Field(default_factory=_get_verbosity)
|
verbose: bool = Field(default_factory=_get_verbosity)
|
||||||
@ -64,6 +80,154 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
# --- Runnable methods ---
|
||||||
|
|
||||||
|
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||||
|
if isinstance(input, PromptValue):
|
||||||
|
return input
|
||||||
|
elif isinstance(input, str):
|
||||||
|
return StringPromptValue(text=input)
|
||||||
|
elif isinstance(input, list):
|
||||||
|
return ChatPromptValue(messages=input)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid input type {type(input)}. "
|
||||||
|
"Must be a PromptValue, str, or list of BaseMessages."
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self,
|
||||||
|
input: LanguageModelInput,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
*,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
) -> BaseMessageChunk:
|
||||||
|
return cast(
|
||||||
|
BaseMessageChunk,
|
||||||
|
cast(
|
||||||
|
ChatGeneration,
|
||||||
|
self.generate_prompt(
|
||||||
|
[self._convert_input(input)], stop=stop, **(config or {})
|
||||||
|
).generations[0][0],
|
||||||
|
).message,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self,
|
||||||
|
input: LanguageModelInput,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
*,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
) -> BaseMessageChunk:
|
||||||
|
if type(self)._agenerate == BaseChatModel._agenerate:
|
||||||
|
# model doesn't implement async generation, so use default implementation
|
||||||
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, partial(self.invoke, input, config, stop=stop)
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_result = await self.agenerate_prompt(
|
||||||
|
[self._convert_input(input)], stop=stop, **(config or {})
|
||||||
|
)
|
||||||
|
return cast(
|
||||||
|
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||||
|
)
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self,
|
||||||
|
input: LanguageModelInput,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
*,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[BaseMessageChunk]:
|
||||||
|
if type(self)._stream == BaseChatModel._stream:
|
||||||
|
# model doesn't implement streaming, so use default implementation
|
||||||
|
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||||
|
else:
|
||||||
|
config = config or {}
|
||||||
|
messages = self._convert_input(input).to_messages()
|
||||||
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||||
|
options = {"stop": stop, **kwargs}
|
||||||
|
callback_manager = CallbackManager.configure(
|
||||||
|
config.get("callbacks"),
|
||||||
|
self.callbacks,
|
||||||
|
self.verbose,
|
||||||
|
config.get("tags"),
|
||||||
|
self.tags,
|
||||||
|
config.get("metadata"),
|
||||||
|
self.metadata,
|
||||||
|
)
|
||||||
|
(run_manager,) = callback_manager.on_chat_model_start(
|
||||||
|
dumpd(self), [messages], invocation_params=params, options=options
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
message: Optional[BaseMessageChunk] = None
|
||||||
|
for chunk in self._stream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
yield chunk.message
|
||||||
|
if message is None:
|
||||||
|
message = chunk.message
|
||||||
|
else:
|
||||||
|
message += chunk.message
|
||||||
|
assert message is not None
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
run_manager.on_llm_error(e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
run_manager.on_llm_end(
|
||||||
|
LLMResult(generations=[[ChatGeneration(message=message)]]),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self,
|
||||||
|
input: LanguageModelInput,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
*,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[BaseMessageChunk]:
|
||||||
|
if type(self)._astream == BaseChatModel._astream:
|
||||||
|
# model doesn't implement streaming, so use default implementation
|
||||||
|
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||||
|
else:
|
||||||
|
config = config or {}
|
||||||
|
messages = self._convert_input(input).to_messages()
|
||||||
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||||
|
options = {"stop": stop, **kwargs}
|
||||||
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
|
config.get("callbacks"),
|
||||||
|
self.callbacks,
|
||||||
|
self.verbose,
|
||||||
|
config.get("tags"),
|
||||||
|
self.tags,
|
||||||
|
config.get("metadata"),
|
||||||
|
self.metadata,
|
||||||
|
)
|
||||||
|
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||||
|
dumpd(self), [messages], invocation_params=params, options=options
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
message: Optional[BaseMessageChunk] = None
|
||||||
|
async for chunk in self._astream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
yield chunk.message
|
||||||
|
if message is None:
|
||||||
|
message = chunk.message
|
||||||
|
else:
|
||||||
|
message += chunk.message
|
||||||
|
assert message is not None
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
await run_manager.on_llm_error(e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
await run_manager.on_llm_end(
|
||||||
|
LLMResult(generations=[[ChatGeneration(message=message)]]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Custom methods ---
|
||||||
|
|
||||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@ -334,7 +498,6 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@ -343,6 +506,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -25,7 +25,10 @@ class FakeListChatModel(SimpleChatModel):
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""First try to lookup in queries, else return 'foo' or 'bar'."""
|
"""First try to lookup in queries, else return 'foo' or 'bar'."""
|
||||||
response = self.responses[self.i]
|
response = self.responses[self.i]
|
||||||
|
if self.i < len(self.responses) - 1:
|
||||||
self.i += 1
|
self.i += 1
|
||||||
|
else:
|
||||||
|
self.i = 0
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -4,8 +4,10 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
@ -36,6 +38,14 @@ from langchain.schema import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
|
from langchain.schema.messages import (
|
||||||
|
AIMessageChunk,
|
||||||
|
BaseMessageChunk,
|
||||||
|
ChatMessageChunk,
|
||||||
|
HumanMessageChunk,
|
||||||
|
SystemMessageChunk,
|
||||||
|
)
|
||||||
|
from langchain.schema.output import ChatGenerationChunk
|
||||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -75,6 +85,24 @@ async def acompletion_with_retry(llm: JinaChat, **kwargs: Any) -> Any:
|
|||||||
return await _completion_with_retry(**kwargs)
|
return await _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_delta_to_message_chunk(
|
||||||
|
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||||
|
) -> BaseMessageChunk:
|
||||||
|
role = _dict.get("role")
|
||||||
|
content = _dict.get("content") or ""
|
||||||
|
|
||||||
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
|
return HumanMessageChunk(content=content)
|
||||||
|
elif role == "assistant" or default_class == AIMessageChunk:
|
||||||
|
return AIMessageChunk(content=content)
|
||||||
|
elif role == "system" or default_class == SystemMessageChunk:
|
||||||
|
return SystemMessageChunk(content=content)
|
||||||
|
elif role or default_class == ChatMessageChunk:
|
||||||
|
return ChatMessageChunk(content=content, role=role)
|
||||||
|
else:
|
||||||
|
return default_class(content=content)
|
||||||
|
|
||||||
|
|
||||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||||
role = _dict["role"]
|
role = _dict["role"]
|
||||||
if role == "user":
|
if role == "user":
|
||||||
@ -258,6 +286,25 @@ class JinaChat(BaseChatModel):
|
|||||||
overall_token_usage[k] = v
|
overall_token_usage[k] = v
|
||||||
return {"token_usage": overall_token_usage}
|
return {"token_usage": overall_token_usage}
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
for chunk in self.completion_with_retry(messages=message_dicts, **params):
|
||||||
|
delta = chunk["choices"][0]["delta"]
|
||||||
|
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
yield ChatGenerationChunk(message=chunk)
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(chunk.content)
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@ -265,27 +312,20 @@ class JinaChat(BaseChatModel):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
|
if self.streaming:
|
||||||
|
generation: Optional[ChatGenerationChunk] = None
|
||||||
|
for chunk in self._stream(
|
||||||
|
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
if generation is None:
|
||||||
|
generation = chunk
|
||||||
|
else:
|
||||||
|
generation += chunk
|
||||||
|
assert generation is not None
|
||||||
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
if self.streaming:
|
|
||||||
inner_completion = ""
|
|
||||||
role = "assistant"
|
|
||||||
params["stream"] = True
|
|
||||||
for stream_resp in self.completion_with_retry(
|
|
||||||
messages=message_dicts, **params
|
|
||||||
):
|
|
||||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
|
||||||
token = stream_resp["choices"][0]["delta"].get("content") or ""
|
|
||||||
inner_completion += token
|
|
||||||
if run_manager:
|
|
||||||
run_manager.on_llm_new_token(token)
|
|
||||||
message = _convert_dict_to_message(
|
|
||||||
{
|
|
||||||
"content": inner_completion,
|
|
||||||
"role": role,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
|
||||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||||
return self._create_chat_result(response)
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
@ -309,6 +349,27 @@ class JinaChat(BaseChatModel):
|
|||||||
llm_output = {"token_usage": response["usage"]}
|
llm_output = {"token_usage": response["usage"]}
|
||||||
return ChatResult(generations=generations, llm_output=llm_output)
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
async for chunk in await acompletion_with_retry(
|
||||||
|
self, messages=message_dicts, **params
|
||||||
|
):
|
||||||
|
delta = chunk["choices"][0]["delta"]
|
||||||
|
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
yield ChatGenerationChunk(message=chunk)
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(chunk.content)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@ -316,31 +377,21 @@ class JinaChat(BaseChatModel):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
|
if self.streaming:
|
||||||
|
generation: Optional[ChatGenerationChunk] = None
|
||||||
|
async for chunk in self._astream(
|
||||||
|
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
if generation is None:
|
||||||
|
generation = chunk
|
||||||
|
else:
|
||||||
|
generation += chunk
|
||||||
|
assert generation is not None
|
||||||
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
if self.streaming:
|
response = await acompletion_with_retry(self, messages=message_dicts, **params)
|
||||||
inner_completion = ""
|
|
||||||
role = "assistant"
|
|
||||||
params["stream"] = True
|
|
||||||
async for stream_resp in await acompletion_with_retry(
|
|
||||||
self, messages=message_dicts, **params
|
|
||||||
):
|
|
||||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
|
||||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
|
||||||
inner_completion += token or ""
|
|
||||||
if run_manager:
|
|
||||||
await run_manager.on_llm_new_token(token)
|
|
||||||
message = _convert_dict_to_message(
|
|
||||||
{
|
|
||||||
"content": inner_completion,
|
|
||||||
"role": role,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
|
||||||
else:
|
|
||||||
response = await acompletion_with_retry(
|
|
||||||
self, messages=message_dicts, **params
|
|
||||||
)
|
|
||||||
return self._create_chat_result(response)
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -6,8 +6,10 @@ import sys
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
@ -35,12 +37,19 @@ from langchain.schema import (
|
|||||||
)
|
)
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
|
ChatMessageChunk,
|
||||||
FunctionMessage,
|
FunctionMessage,
|
||||||
|
FunctionMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
|
HumanMessageChunk,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
|
SystemMessageChunk,
|
||||||
)
|
)
|
||||||
|
from langchain.schema.output import ChatGenerationChunk
|
||||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -95,6 +104,30 @@ async def acompletion_with_retry(llm: ChatOpenAI, **kwargs: Any) -> Any:
|
|||||||
return await _completion_with_retry(**kwargs)
|
return await _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_delta_to_message_chunk(
|
||||||
|
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||||
|
) -> BaseMessageChunk:
|
||||||
|
role = _dict.get("role")
|
||||||
|
content = _dict.get("content") or ""
|
||||||
|
if _dict.get("function_call"):
|
||||||
|
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
||||||
|
else:
|
||||||
|
additional_kwargs = {}
|
||||||
|
|
||||||
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
|
return HumanMessageChunk(content=content)
|
||||||
|
elif role == "assistant" or default_class == AIMessageChunk:
|
||||||
|
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||||
|
elif role == "system" or default_class == SystemMessageChunk:
|
||||||
|
return SystemMessageChunk(content=content)
|
||||||
|
elif role == "function" or default_class == FunctionMessageChunk:
|
||||||
|
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||||
|
elif role or default_class == ChatMessageChunk:
|
||||||
|
return ChatMessageChunk(content=content, role=role)
|
||||||
|
else:
|
||||||
|
return default_class(content=content)
|
||||||
|
|
||||||
|
|
||||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||||
role = _dict["role"]
|
role = _dict["role"]
|
||||||
if role == "user":
|
if role == "user":
|
||||||
@ -313,6 +346,27 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
overall_token_usage[k] = v
|
overall_token_usage[k] = v
|
||||||
return {"token_usage": overall_token_usage, "model_name": self.model_name}
|
return {"token_usage": overall_token_usage, "model_name": self.model_name}
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
for chunk in self.completion_with_retry(messages=message_dicts, **params):
|
||||||
|
if len(chunk["choices"]) == 0:
|
||||||
|
continue
|
||||||
|
delta = chunk["choices"][0]["delta"]
|
||||||
|
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
yield ChatGenerationChunk(message=chunk)
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(chunk.content)
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@ -320,40 +374,20 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
|
if self.streaming:
|
||||||
|
generation: Optional[ChatGenerationChunk] = None
|
||||||
|
for chunk in self._stream(
|
||||||
|
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
if generation is None:
|
||||||
|
generation = chunk
|
||||||
|
else:
|
||||||
|
generation += chunk
|
||||||
|
assert generation is not None
|
||||||
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
if self.streaming:
|
|
||||||
inner_completion = ""
|
|
||||||
role = "assistant"
|
|
||||||
params["stream"] = True
|
|
||||||
function_call: Optional[dict] = None
|
|
||||||
for stream_resp in self.completion_with_retry(
|
|
||||||
messages=message_dicts, **params
|
|
||||||
):
|
|
||||||
if len(stream_resp["choices"]) > 0:
|
|
||||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
|
||||||
token = stream_resp["choices"][0]["delta"].get("content") or ""
|
|
||||||
inner_completion += token
|
|
||||||
_function_call = stream_resp["choices"][0]["delta"].get(
|
|
||||||
"function_call"
|
|
||||||
)
|
|
||||||
if _function_call:
|
|
||||||
if function_call is None:
|
|
||||||
function_call = _function_call
|
|
||||||
elif "arguments" in function_call:
|
|
||||||
function_call["arguments"] += _function_call["arguments"]
|
|
||||||
else:
|
|
||||||
function_call["arguments"] = _function_call["arguments"]
|
|
||||||
if run_manager:
|
|
||||||
run_manager.on_llm_new_token(token)
|
|
||||||
message = _convert_dict_to_message(
|
|
||||||
{
|
|
||||||
"content": inner_completion,
|
|
||||||
"role": role,
|
|
||||||
"function_call": function_call,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
|
||||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||||
return self._create_chat_result(response)
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
@ -381,6 +415,29 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
||||||
return ChatResult(generations=generations, llm_output=llm_output)
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
async for chunk in await acompletion_with_retry(
|
||||||
|
self, messages=message_dicts, **params
|
||||||
|
):
|
||||||
|
if len(chunk["choices"]) == 0:
|
||||||
|
continue
|
||||||
|
delta = chunk["choices"][0]["delta"]
|
||||||
|
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
yield ChatGenerationChunk(message=chunk)
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(chunk.content)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@ -388,44 +445,21 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
|
if self.streaming:
|
||||||
|
generation: Optional[ChatGenerationChunk] = None
|
||||||
|
async for chunk in self._astream(
|
||||||
|
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
if generation is None:
|
||||||
|
generation = chunk
|
||||||
|
else:
|
||||||
|
generation += chunk
|
||||||
|
assert generation is not None
|
||||||
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
if self.streaming:
|
response = await acompletion_with_retry(self, messages=message_dicts, **params)
|
||||||
inner_completion = ""
|
|
||||||
role = "assistant"
|
|
||||||
params["stream"] = True
|
|
||||||
function_call: Optional[dict] = None
|
|
||||||
async for stream_resp in await acompletion_with_retry(
|
|
||||||
self, messages=message_dicts, **params
|
|
||||||
):
|
|
||||||
if len(stream_resp["choices"]) > 0:
|
|
||||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
|
||||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
|
||||||
inner_completion += token or ""
|
|
||||||
_function_call = stream_resp["choices"][0]["delta"].get(
|
|
||||||
"function_call"
|
|
||||||
)
|
|
||||||
if _function_call:
|
|
||||||
if function_call is None:
|
|
||||||
function_call = _function_call
|
|
||||||
elif "arguments" in function_call:
|
|
||||||
function_call["arguments"] += _function_call["arguments"]
|
|
||||||
else:
|
|
||||||
function_call["arguments"] = _function_call["arguments"]
|
|
||||||
if run_manager:
|
|
||||||
await run_manager.on_llm_new_token(token)
|
|
||||||
message = _convert_dict_to_message(
|
|
||||||
{
|
|
||||||
"content": inner_completion,
|
|
||||||
"role": role,
|
|
||||||
"function_call": function_call,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
|
||||||
else:
|
|
||||||
response = await acompletion_with_retry(
|
|
||||||
self, messages=message_dicts, **params
|
|
||||||
)
|
|
||||||
return self._create_chat_result(response)
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -4,10 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
AsyncCallbackManagerForLLMRun,
|
|
||||||
CallbackManagerForLLMRun,
|
|
||||||
)
|
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
|
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
@ -162,14 +159,3 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
response = chat.send_message(question.content)
|
response = chat.send_message(question.content)
|
||||||
text = self._enforce_stop_words(response.text, stop)
|
text = self._enforce_stop_words(response.text, stop)
|
||||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
||||||
|
|
||||||
async def _agenerate(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatResult:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"""Vertex AI doesn't support async requests at the moment."""
|
|
||||||
)
|
|
||||||
|
@ -6,7 +6,8 @@ from pydantic import BaseModel
|
|||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.openai_functions import create_tagging_chain
|
from langchain.chains.openai_functions import create_tagging_chain
|
||||||
from langchain.prompts import ChatPromptTemplate
|
from langchain.prompts import ChatPromptTemplate
|
||||||
from langchain.schema import BaseDocumentTransformer, BaseLanguageModel, Document
|
from langchain.schema import BaseDocumentTransformer, Document
|
||||||
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
|
||||||
|
|
||||||
class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel):
|
class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel):
|
||||||
|
@ -1,18 +1,20 @@
|
|||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional
|
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
from langchain.schema.output import GenerationChunk
|
||||||
from langchain.utils import check_package_version, get_from_dict_or_env
|
from langchain.utils import check_package_version, get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
class _AnthropicCommon(BaseModel):
|
class _AnthropicCommon(BaseLanguageModel):
|
||||||
client: Any = None #: :meta private:
|
client: Any = None #: :meta private:
|
||||||
async_client: Any = None #: :meta private:
|
async_client: Any = None #: :meta private:
|
||||||
model: str = "claude-2"
|
model: str = "claude-2"
|
||||||
@ -193,24 +195,16 @@ class Anthropic(LLM, _AnthropicCommon):
|
|||||||
response = model(prompt)
|
response = model(prompt)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if self.streaming:
|
||||||
|
completion = ""
|
||||||
|
for chunk in self._stream(
|
||||||
|
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
completion += chunk.text
|
||||||
|
return completion
|
||||||
|
|
||||||
stop = self._get_anthropic_stop(stop)
|
stop = self._get_anthropic_stop(stop)
|
||||||
params = {**self._default_params, **kwargs}
|
params = {**self._default_params, **kwargs}
|
||||||
if self.streaming:
|
|
||||||
stream_resp = self.client.completions.create(
|
|
||||||
prompt=self._wrap_prompt(prompt),
|
|
||||||
stop_sequences=stop,
|
|
||||||
stream=True,
|
|
||||||
**params,
|
|
||||||
)
|
|
||||||
current_completion = ""
|
|
||||||
for data in stream_resp:
|
|
||||||
delta = data.completion
|
|
||||||
current_completion += delta
|
|
||||||
if run_manager:
|
|
||||||
run_manager.on_llm_new_token(
|
|
||||||
delta,
|
|
||||||
)
|
|
||||||
return current_completion
|
|
||||||
response = self.client.completions.create(
|
response = self.client.completions.create(
|
||||||
prompt=self._wrap_prompt(prompt),
|
prompt=self._wrap_prompt(prompt),
|
||||||
stop_sequences=stop,
|
stop_sequences=stop,
|
||||||
@ -226,22 +220,17 @@ class Anthropic(LLM, _AnthropicCommon):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to Anthropic's completion endpoint asynchronously."""
|
"""Call out to Anthropic's completion endpoint asynchronously."""
|
||||||
|
if self.streaming:
|
||||||
|
completion = ""
|
||||||
|
async for chunk in self._astream(
|
||||||
|
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
completion += chunk.text
|
||||||
|
return completion
|
||||||
|
|
||||||
stop = self._get_anthropic_stop(stop)
|
stop = self._get_anthropic_stop(stop)
|
||||||
params = {**self._default_params, **kwargs}
|
params = {**self._default_params, **kwargs}
|
||||||
if self.streaming:
|
|
||||||
stream_resp = await self.async_client.completions.create(
|
|
||||||
prompt=self._wrap_prompt(prompt),
|
|
||||||
stop_sequences=stop,
|
|
||||||
stream=True,
|
|
||||||
**params,
|
|
||||||
)
|
|
||||||
current_completion = ""
|
|
||||||
async for data in stream_resp:
|
|
||||||
delta = data.completion
|
|
||||||
current_completion += delta
|
|
||||||
if run_manager:
|
|
||||||
await run_manager.on_llm_new_token(delta)
|
|
||||||
return current_completion
|
|
||||||
response = await self.async_client.completions.create(
|
response = await self.async_client.completions.create(
|
||||||
prompt=self._wrap_prompt(prompt),
|
prompt=self._wrap_prompt(prompt),
|
||||||
stop_sequences=stop,
|
stop_sequences=stop,
|
||||||
@ -249,23 +238,23 @@ class Anthropic(LLM, _AnthropicCommon):
|
|||||||
)
|
)
|
||||||
return response.completion
|
return response.completion
|
||||||
|
|
||||||
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
|
def _stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
r"""Call Anthropic completion_stream and return the resulting generator.
|
r"""Call Anthropic completion_stream and return the resulting generator.
|
||||||
|
|
||||||
BETA: this is a beta feature while we figure out the right abstraction.
|
|
||||||
Once that happens, this interface could change.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: The prompt to pass into the model.
|
prompt: The prompt to pass into the model.
|
||||||
stop: Optional list of stop words to use when generating.
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A generator representing the stream of tokens from Anthropic.
|
A generator representing the stream of tokens from Anthropic.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
|
||||||
prompt = "Write a poem about a stream."
|
prompt = "Write a poem about a stream."
|
||||||
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
|
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
|
||||||
generator = anthropic.stream(prompt)
|
generator = anthropic.stream(prompt)
|
||||||
@ -273,12 +262,49 @@ class Anthropic(LLM, _AnthropicCommon):
|
|||||||
yield token
|
yield token
|
||||||
"""
|
"""
|
||||||
stop = self._get_anthropic_stop(stop)
|
stop = self._get_anthropic_stop(stop)
|
||||||
return self.client.completions.create(
|
params = {**self._default_params, **kwargs}
|
||||||
|
|
||||||
|
for token in self.client.completions.create(
|
||||||
|
prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params
|
||||||
|
):
|
||||||
|
yield GenerationChunk(text=token.completion)
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(token.completion)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[GenerationChunk]:
|
||||||
|
r"""Call Anthropic completion_stream and return the resulting generator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
Returns:
|
||||||
|
A generator representing the stream of tokens from Anthropic.
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
prompt = "Write a poem about a stream."
|
||||||
|
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
|
||||||
|
generator = anthropic.stream(prompt)
|
||||||
|
for token in generator:
|
||||||
|
yield token
|
||||||
|
"""
|
||||||
|
stop = self._get_anthropic_stop(stop)
|
||||||
|
params = {**self._default_params, **kwargs}
|
||||||
|
|
||||||
|
async for token in await self.async_client.completions.create(
|
||||||
prompt=self._wrap_prompt(prompt),
|
prompt=self._wrap_prompt(prompt),
|
||||||
stop_sequences=stop,
|
stop_sequences=stop,
|
||||||
stream=True,
|
stream=True,
|
||||||
**self._default_params,
|
**params,
|
||||||
)
|
):
|
||||||
|
yield GenerationChunk(text=token.completion)
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(token.completion)
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
"""Calculate number of tokens."""
|
"""Calculate number of tokens."""
|
||||||
|
@ -7,11 +7,14 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
@ -19,6 +22,7 @@ from typing import (
|
|||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
@ -42,14 +46,18 @@ from langchain.callbacks.manager import (
|
|||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
|
from langchain.prompts.base import StringPromptValue
|
||||||
|
from langchain.prompts.chat import ChatPromptValue
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
Generation,
|
Generation,
|
||||||
LLMResult,
|
LLMResult,
|
||||||
PromptValue,
|
PromptValue,
|
||||||
RunInfo,
|
RunInfo,
|
||||||
)
|
)
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
|
||||||
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
|
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
|
||||||
|
from langchain.schema.output import GenerationChunk
|
||||||
|
from langchain.schema.runnable import RunnableConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -115,7 +123,7 @@ def update_cache(
|
|||||||
return llm_output
|
return llm_output
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM(BaseLanguageModel, ABC):
|
class BaseLLM(BaseLanguageModel[str], ABC):
|
||||||
"""Base LLM abstract interface.
|
"""Base LLM abstract interface.
|
||||||
|
|
||||||
It should take in a prompt and return a string."""
|
It should take in a prompt and return a string."""
|
||||||
@ -157,6 +165,204 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
else:
|
else:
|
||||||
return verbose
|
return verbose
|
||||||
|
|
||||||
|
# --- Runnable methods ---
|
||||||
|
|
||||||
|
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||||
|
if isinstance(input, PromptValue):
|
||||||
|
return input
|
||||||
|
elif isinstance(input, str):
|
||||||
|
return StringPromptValue(text=input)
|
||||||
|
elif isinstance(input, list):
|
||||||
|
return ChatPromptValue(messages=input)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid input type {type(input)}. "
|
||||||
|
"Must be a PromptValue, str, or list of BaseMessages."
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self,
|
||||||
|
input: LanguageModelInput,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
*,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
) -> str:
|
||||||
|
return (
|
||||||
|
self.generate_prompt(
|
||||||
|
[self._convert_input(input)], stop=stop, **(config or {})
|
||||||
|
)
|
||||||
|
.generations[0][0]
|
||||||
|
.text
|
||||||
|
)
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self,
|
||||||
|
input: LanguageModelInput,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
*,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
) -> str:
|
||||||
|
if type(self)._agenerate == BaseLLM._agenerate:
|
||||||
|
# model doesn't implement async invoke, so use default implementation
|
||||||
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, partial(self.invoke, input, config, stop=stop)
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_result = await self.agenerate_prompt(
|
||||||
|
[self._convert_input(input)], stop=stop, **(config or {})
|
||||||
|
)
|
||||||
|
return llm_result.generations[0][0].text
|
||||||
|
|
||||||
|
def batch(
|
||||||
|
self,
|
||||||
|
inputs: List[LanguageModelInput],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
max_concurrency: Optional[int] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
config = self._get_config_list(config, len(inputs))
|
||||||
|
|
||||||
|
if max_concurrency is None:
|
||||||
|
llm_result = self.generate_prompt(
|
||||||
|
[self._convert_input(input) for input in inputs],
|
||||||
|
callbacks=[c.get("callbacks") for c in config],
|
||||||
|
tags=[c.get("tags") for c in config],
|
||||||
|
metadata=[c.get("metadata") for c in config],
|
||||||
|
)
|
||||||
|
return [g[0].text for g in llm_result.generations]
|
||||||
|
else:
|
||||||
|
batches = [
|
||||||
|
inputs[i : i + max_concurrency]
|
||||||
|
for i in range(0, len(inputs), max_concurrency)
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
output
|
||||||
|
for batch in batches
|
||||||
|
for output in self.batch(batch, config=config)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def abatch(
|
||||||
|
self,
|
||||||
|
inputs: List[LanguageModelInput],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
max_concurrency: Optional[int] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
if type(self)._agenerate == BaseLLM._agenerate:
|
||||||
|
# model doesn't implement async batch, so use default implementation
|
||||||
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, self.batch, inputs, config, max_concurrency
|
||||||
|
)
|
||||||
|
|
||||||
|
config = self._get_config_list(config, len(inputs))
|
||||||
|
|
||||||
|
if max_concurrency is None:
|
||||||
|
llm_result = await self.agenerate_prompt(
|
||||||
|
[self._convert_input(input) for input in inputs],
|
||||||
|
callbacks=[c.get("callbacks") for c in config],
|
||||||
|
tags=[c.get("tags") for c in config],
|
||||||
|
metadata=[c.get("metadata") for c in config],
|
||||||
|
)
|
||||||
|
return [g[0].text for g in llm_result.generations]
|
||||||
|
else:
|
||||||
|
batches = [
|
||||||
|
inputs[i : i + max_concurrency]
|
||||||
|
for i in range(0, len(inputs), max_concurrency)
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
output
|
||||||
|
for batch in batches
|
||||||
|
for output in await self.abatch(batch, config=config)
|
||||||
|
]
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self,
|
||||||
|
input: LanguageModelInput,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
*,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
) -> Iterator[str]:
|
||||||
|
if type(self)._stream == BaseLLM._stream:
|
||||||
|
# model doesn't implement streaming, so use default implementation
|
||||||
|
yield self.invoke(input, config=config, stop=stop)
|
||||||
|
else:
|
||||||
|
prompt = self._convert_input(input).to_string()
|
||||||
|
config = config or {}
|
||||||
|
params = self.dict()
|
||||||
|
params["stop"] = stop
|
||||||
|
options = {"stop": stop}
|
||||||
|
callback_manager = CallbackManager.configure(
|
||||||
|
config.get("callbacks"),
|
||||||
|
self.callbacks,
|
||||||
|
self.verbose,
|
||||||
|
config.get("tags"),
|
||||||
|
self.tags,
|
||||||
|
config.get("metadata"),
|
||||||
|
self.metadata,
|
||||||
|
)
|
||||||
|
(run_manager,) = callback_manager.on_llm_start(
|
||||||
|
dumpd(self), [prompt], invocation_params=params, options=options
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
generation: Optional[GenerationChunk] = None
|
||||||
|
for chunk in self._stream(prompt, stop=stop, run_manager=run_manager):
|
||||||
|
yield chunk.text
|
||||||
|
if generation is None:
|
||||||
|
generation = chunk
|
||||||
|
else:
|
||||||
|
generation += chunk
|
||||||
|
assert generation is not None
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
run_manager.on_llm_error(e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
run_manager.on_llm_end(LLMResult(generations=[[generation]]))
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self,
|
||||||
|
input: LanguageModelInput,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
*,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
if type(self)._astream == BaseLLM._astream:
|
||||||
|
# model doesn't implement streaming, so use default implementation
|
||||||
|
yield await self.ainvoke(input, config=config, stop=stop)
|
||||||
|
else:
|
||||||
|
prompt = self._convert_input(input).to_string()
|
||||||
|
config = config or {}
|
||||||
|
params = self.dict()
|
||||||
|
params["stop"] = stop
|
||||||
|
options = {"stop": stop}
|
||||||
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
|
config.get("callbacks"),
|
||||||
|
self.callbacks,
|
||||||
|
self.verbose,
|
||||||
|
config.get("tags"),
|
||||||
|
self.tags,
|
||||||
|
config.get("metadata"),
|
||||||
|
self.metadata,
|
||||||
|
)
|
||||||
|
(run_manager,) = await callback_manager.on_llm_start(
|
||||||
|
dumpd(self), [prompt], invocation_params=params, options=options
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
generation: Optional[GenerationChunk] = None
|
||||||
|
async for chunk in self._astream(
|
||||||
|
prompt, stop=stop, run_manager=run_manager
|
||||||
|
):
|
||||||
|
yield chunk.text
|
||||||
|
if generation is None:
|
||||||
|
generation = chunk
|
||||||
|
else:
|
||||||
|
generation += chunk
|
||||||
|
assert generation is not None
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
await run_manager.on_llm_error(e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
await run_manager.on_llm_end(LLMResult(generations=[[generation]]))
|
||||||
|
|
||||||
|
# --- Custom methods ---
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
@ -167,7 +373,6 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompts."""
|
"""Run the LLM on the given prompts."""
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
@ -176,12 +381,31 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompts."""
|
"""Run the LLM on the given prompts."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _astream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[GenerationChunk]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def generate_prompt(
|
def generate_prompt(
|
||||||
self,
|
self,
|
||||||
prompts: List[PromptValue],
|
prompts: List[PromptValue],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
prompt_strings = [p.to_string() for p in prompts]
|
prompt_strings = [p.to_string() for p in prompts]
|
||||||
@ -191,7 +415,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
self,
|
self,
|
||||||
prompts: List[PromptValue],
|
prompts: List[PromptValue],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
prompt_strings = [p.to_string() for p in prompts]
|
prompt_strings = [p.to_string() for p in prompts]
|
||||||
@ -236,10 +460,10 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||||
*,
|
*,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[Union[List[str], List[List[str]]]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
@ -248,6 +472,50 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
"Argument 'prompts' is expected to be of type List[str], received"
|
"Argument 'prompts' is expected to be of type List[str], received"
|
||||||
f" argument of type {type(prompts)}."
|
f" argument of type {type(prompts)}."
|
||||||
)
|
)
|
||||||
|
# Create callback managers
|
||||||
|
if isinstance(callbacks, list) and (
|
||||||
|
isinstance(callbacks[0], (list, BaseCallbackManager))
|
||||||
|
or callbacks[0] is None
|
||||||
|
):
|
||||||
|
# We've received a list of callbacks args to apply to each input
|
||||||
|
assert len(callbacks) == len(prompts)
|
||||||
|
assert tags is None or (
|
||||||
|
isinstance(tags, list) and len(tags) == len(prompts)
|
||||||
|
)
|
||||||
|
assert metadata is None or (
|
||||||
|
isinstance(metadata, list) and len(metadata) == len(prompts)
|
||||||
|
)
|
||||||
|
callbacks = cast(List[Callbacks], callbacks)
|
||||||
|
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
|
||||||
|
metadata_list = cast(
|
||||||
|
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
|
||||||
|
)
|
||||||
|
callback_managers = [
|
||||||
|
CallbackManager.configure(
|
||||||
|
callback,
|
||||||
|
self.callbacks,
|
||||||
|
self.verbose,
|
||||||
|
tag,
|
||||||
|
self.tags,
|
||||||
|
meta,
|
||||||
|
self.metadata,
|
||||||
|
)
|
||||||
|
for callback, tag, meta in zip(callbacks, tags_list, metadata_list)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# We've received a single callbacks arg to apply to all inputs
|
||||||
|
callback_managers = [
|
||||||
|
CallbackManager.configure(
|
||||||
|
cast(Callbacks, callbacks),
|
||||||
|
self.callbacks,
|
||||||
|
self.verbose,
|
||||||
|
cast(List[str], tags),
|
||||||
|
self.tags,
|
||||||
|
cast(Dict[str, Any], metadata),
|
||||||
|
self.metadata,
|
||||||
|
)
|
||||||
|
] * len(prompts)
|
||||||
|
|
||||||
params = self.dict()
|
params = self.dict()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
options = {"stop": stop}
|
options = {"stop": stop}
|
||||||
@ -258,15 +526,6 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
missing_prompts,
|
missing_prompts,
|
||||||
) = get_prompts(params, prompts)
|
) = get_prompts(params, prompts)
|
||||||
disregard_cache = self.cache is not None and not self.cache
|
disregard_cache = self.cache is not None and not self.cache
|
||||||
callback_manager = CallbackManager.configure(
|
|
||||||
callbacks,
|
|
||||||
self.callbacks,
|
|
||||||
self.verbose,
|
|
||||||
tags,
|
|
||||||
self.tags,
|
|
||||||
metadata,
|
|
||||||
self.metadata,
|
|
||||||
)
|
|
||||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||||
"run_manager"
|
"run_manager"
|
||||||
)
|
)
|
||||||
@ -275,17 +534,26 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Asked to cache, but no cache found at `langchain.cache`."
|
"Asked to cache, but no cache found at `langchain.cache`."
|
||||||
)
|
)
|
||||||
run_managers = callback_manager.on_llm_start(
|
run_managers = [
|
||||||
dumpd(self), prompts, invocation_params=params, options=options
|
callback_manager.on_llm_start(
|
||||||
)
|
dumpd(self), [prompt], invocation_params=params, options=options
|
||||||
|
)[0]
|
||||||
|
for callback_manager, prompt in zip(callback_managers, prompts)
|
||||||
|
]
|
||||||
output = self._generate_helper(
|
output = self._generate_helper(
|
||||||
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
if len(missing_prompts) > 0:
|
if len(missing_prompts) > 0:
|
||||||
run_managers = callback_manager.on_llm_start(
|
run_managers = [
|
||||||
dumpd(self), missing_prompts, invocation_params=params, options=options
|
callback_managers[idx].on_llm_start(
|
||||||
)
|
dumpd(self),
|
||||||
|
[prompts[idx]],
|
||||||
|
invocation_params=params,
|
||||||
|
options=options,
|
||||||
|
)[0]
|
||||||
|
for idx in missing_prompt_idxs
|
||||||
|
]
|
||||||
new_results = self._generate_helper(
|
new_results = self._generate_helper(
|
||||||
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
||||||
)
|
)
|
||||||
@ -346,13 +614,57 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||||
*,
|
*,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[Union[List[str], List[List[str]]]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
|
# Create callback managers
|
||||||
|
if isinstance(callbacks, list) and (
|
||||||
|
isinstance(callbacks[0], (list, BaseCallbackManager))
|
||||||
|
or callbacks[0] is None
|
||||||
|
):
|
||||||
|
# We've received a list of callbacks args to apply to each input
|
||||||
|
assert len(callbacks) == len(prompts)
|
||||||
|
assert tags is None or (
|
||||||
|
isinstance(tags, list) and len(tags) == len(prompts)
|
||||||
|
)
|
||||||
|
assert metadata is None or (
|
||||||
|
isinstance(metadata, list) and len(metadata) == len(prompts)
|
||||||
|
)
|
||||||
|
callbacks = cast(List[Callbacks], callbacks)
|
||||||
|
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
|
||||||
|
metadata_list = cast(
|
||||||
|
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
|
||||||
|
)
|
||||||
|
callback_managers = [
|
||||||
|
AsyncCallbackManager.configure(
|
||||||
|
callback,
|
||||||
|
self.callbacks,
|
||||||
|
self.verbose,
|
||||||
|
tag,
|
||||||
|
self.tags,
|
||||||
|
meta,
|
||||||
|
self.metadata,
|
||||||
|
)
|
||||||
|
for callback, tag, meta in zip(callbacks, tags_list, metadata_list)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# We've received a single callbacks arg to apply to all inputs
|
||||||
|
callback_managers = [
|
||||||
|
AsyncCallbackManager.configure(
|
||||||
|
cast(Callbacks, callbacks),
|
||||||
|
self.callbacks,
|
||||||
|
self.verbose,
|
||||||
|
cast(List[str], tags),
|
||||||
|
self.tags,
|
||||||
|
cast(Dict[str, Any], metadata),
|
||||||
|
self.metadata,
|
||||||
|
)
|
||||||
|
] * len(prompts)
|
||||||
|
|
||||||
params = self.dict()
|
params = self.dict()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
options = {"stop": stop}
|
options = {"stop": stop}
|
||||||
@ -363,15 +675,6 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
missing_prompts,
|
missing_prompts,
|
||||||
) = get_prompts(params, prompts)
|
) = get_prompts(params, prompts)
|
||||||
disregard_cache = self.cache is not None and not self.cache
|
disregard_cache = self.cache is not None and not self.cache
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
|
||||||
callbacks,
|
|
||||||
self.callbacks,
|
|
||||||
self.verbose,
|
|
||||||
tags,
|
|
||||||
self.tags,
|
|
||||||
metadata,
|
|
||||||
self.metadata,
|
|
||||||
)
|
|
||||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||||
"run_manager"
|
"run_manager"
|
||||||
)
|
)
|
||||||
@ -380,17 +683,32 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Asked to cache, but no cache found at `langchain.cache`."
|
"Asked to cache, but no cache found at `langchain.cache`."
|
||||||
)
|
)
|
||||||
run_managers = await callback_manager.on_llm_start(
|
run_managers = await asyncio.gather(
|
||||||
dumpd(self), prompts, invocation_params=params, options=options
|
*[
|
||||||
|
callback_manager.on_llm_start(
|
||||||
|
dumpd(self), [prompt], invocation_params=params, options=options
|
||||||
)
|
)
|
||||||
|
for callback_manager, prompt in zip(callback_managers, prompts)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
run_managers = [r[0] for r in run_managers]
|
||||||
output = await self._agenerate_helper(
|
output = await self._agenerate_helper(
|
||||||
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
if len(missing_prompts) > 0:
|
if len(missing_prompts) > 0:
|
||||||
run_managers = await callback_manager.on_llm_start(
|
run_managers = await asyncio.gather(
|
||||||
dumpd(self), missing_prompts, invocation_params=params, options=options
|
*[
|
||||||
|
callback_managers[idx].on_llm_start(
|
||||||
|
dumpd(self),
|
||||||
|
[prompts[idx]],
|
||||||
|
invocation_params=params,
|
||||||
|
options=options,
|
||||||
)
|
)
|
||||||
|
for idx in missing_prompt_idxs
|
||||||
|
]
|
||||||
|
)
|
||||||
|
run_managers = [r[0] for r in run_managers]
|
||||||
new_results = await self._agenerate_helper(
|
new_results = await self._agenerate_helper(
|
||||||
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
||||||
)
|
)
|
||||||
@ -586,7 +904,7 @@ class LLM(BaseLLM):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
raise NotImplementedError("Async generation not implemented for this LLM.")
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
@ -615,6 +933,12 @@ class LLM(BaseLLM):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
|
if type(self)._acall == LLM._acall:
|
||||||
|
# model doesn't implement async call, so use default implementation
|
||||||
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, partial(self._generate, prompts, stop, run_manager, **kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
generations = []
|
generations = []
|
||||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||||
|
@ -27,7 +27,10 @@ class FakeListLLM(LLM):
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""Return next response"""
|
"""Return next response"""
|
||||||
response = self.responses[self.i]
|
response = self.responses[self.i]
|
||||||
|
if self.i < len(self.responses) - 1:
|
||||||
self.i += 1
|
self.i += 1
|
||||||
|
else:
|
||||||
|
self.i = 0
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
@ -39,7 +42,10 @@ class FakeListLLM(LLM):
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""Return next response"""
|
"""Return next response"""
|
||||||
response = self.responses[self.i]
|
response = self.responses[self.i]
|
||||||
|
if self.i < len(self.responses) - 1:
|
||||||
self.i += 1
|
self.i += 1
|
||||||
|
else:
|
||||||
|
self.i = 0
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -12,10 +12,7 @@ from tenacity import (
|
|||||||
wait_exponential,
|
wait_exponential,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
AsyncCallbackManagerForLLMRun,
|
|
||||||
CallbackManagerForLLMRun,
|
|
||||||
)
|
|
||||||
from langchain.llms import BaseLLM
|
from langchain.llms import BaseLLM
|
||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
@ -161,15 +158,6 @@ class GooglePalm(BaseLLM, BaseModel):
|
|||||||
|
|
||||||
return LLMResult(generations=generations)
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
async def _agenerate(
|
|
||||||
self,
|
|
||||||
prompts: List[str],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> LLMResult:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from functools import partial
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from pydantic import Extra, Field, root_validator
|
from pydantic import Extra, Field, root_validator
|
||||||
|
|
||||||
@ -8,6 +7,7 @@ from langchain.callbacks.manager import (
|
|||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.schema.output import GenerationChunk
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceTextGenInference(LLM):
|
class HuggingFaceTextGenInference(LLM):
|
||||||
@ -69,7 +69,7 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
temperature = 0.01,
|
temperature = 0.01,
|
||||||
repetition_penalty = 1.03,
|
repetition_penalty = 1.03,
|
||||||
callbacks = callbacks,
|
callbacks = callbacks,
|
||||||
stream = True
|
streaming = True
|
||||||
)
|
)
|
||||||
print(llm("What is Deep Learning?"))
|
print(llm("What is Deep Learning?"))
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
inference_server_url: str = ""
|
inference_server_url: str = ""
|
||||||
timeout: int = 120
|
timeout: int = 120
|
||||||
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
stream: bool = False
|
streaming: bool = False
|
||||||
client: Any
|
client: Any
|
||||||
async_client: Any
|
async_client: Any
|
||||||
|
|
||||||
@ -154,8 +154,13 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
if self.streaming:
|
||||||
|
completion = ""
|
||||||
|
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
|
||||||
|
completion += chunk.text
|
||||||
|
return completion
|
||||||
|
|
||||||
invocation_params = self._invocation_params(stop, **kwargs)
|
invocation_params = self._invocation_params(stop, **kwargs)
|
||||||
if not self.stream:
|
|
||||||
res = self.client.generate(prompt, **invocation_params)
|
res = self.client.generate(prompt, **invocation_params)
|
||||||
# remove stop sequences from the end of the generated text
|
# remove stop sequences from the end of the generated text
|
||||||
for stop_seq in invocation_params["stop_sequences"]:
|
for stop_seq in invocation_params["stop_sequences"]:
|
||||||
@ -163,28 +168,7 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
res.generated_text = res.generated_text[
|
res.generated_text = res.generated_text[
|
||||||
: res.generated_text.index(stop_seq)
|
: res.generated_text.index(stop_seq)
|
||||||
]
|
]
|
||||||
text = res.generated_text
|
return res.generated_text
|
||||||
else:
|
|
||||||
text_callback = None
|
|
||||||
if run_manager:
|
|
||||||
text_callback = partial(
|
|
||||||
run_manager.on_llm_new_token, verbose=self.verbose
|
|
||||||
)
|
|
||||||
text = ""
|
|
||||||
for res in self.client.generate_stream(prompt, **invocation_params):
|
|
||||||
token = res.token
|
|
||||||
is_stop = False
|
|
||||||
for stop_seq in invocation_params["stop_sequences"]:
|
|
||||||
if stop_seq in token.text:
|
|
||||||
is_stop = True
|
|
||||||
break
|
|
||||||
if is_stop:
|
|
||||||
break
|
|
||||||
if not token.special:
|
|
||||||
if text_callback:
|
|
||||||
text_callback(token.text)
|
|
||||||
text += token.text
|
|
||||||
return text
|
|
||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
@ -193,39 +177,90 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
if self.streaming:
|
||||||
|
completion = ""
|
||||||
|
async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
|
||||||
|
completion += chunk.text
|
||||||
|
return completion
|
||||||
|
|
||||||
invocation_params = self._invocation_params(stop, **kwargs)
|
invocation_params = self._invocation_params(stop, **kwargs)
|
||||||
if not self.stream:
|
res = await self.async_client.generate(prompt, **invocation_params)
|
||||||
res = await self.async_client.generate(
|
|
||||||
prompt,
|
|
||||||
**invocation_params,
|
|
||||||
)
|
|
||||||
# remove stop sequences from the end of the generated text
|
# remove stop sequences from the end of the generated text
|
||||||
for stop_seq in invocation_params["stop_sequences"]:
|
for stop_seq in invocation_params["stop_sequences"]:
|
||||||
if stop_seq in res.generated_text:
|
if stop_seq in res.generated_text:
|
||||||
res.generated_text = res.generated_text[
|
res.generated_text = res.generated_text[
|
||||||
: res.generated_text.index(stop_seq)
|
: res.generated_text.index(stop_seq)
|
||||||
]
|
]
|
||||||
text: str = res.generated_text
|
return res.generated_text
|
||||||
else:
|
|
||||||
text_callback = None
|
def _stream(
|
||||||
if run_manager:
|
self,
|
||||||
text_callback = partial(
|
prompt: str,
|
||||||
run_manager.on_llm_new_token, verbose=self.verbose
|
stop: Optional[List[str]] = None,
|
||||||
)
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
text = ""
|
**kwargs: Any,
|
||||||
async for res in self.async_client.generate_stream(
|
) -> Iterator[GenerationChunk]:
|
||||||
prompt, **invocation_params
|
invocation_params = self._invocation_params(stop, **kwargs)
|
||||||
):
|
|
||||||
token = res.token
|
for res in self.client.generate_stream(prompt, **invocation_params):
|
||||||
is_stop = False
|
# identify stop sequence in generated text, if any
|
||||||
|
stop_seq_found: Optional[str] = None
|
||||||
for stop_seq in invocation_params["stop_sequences"]:
|
for stop_seq in invocation_params["stop_sequences"]:
|
||||||
if stop_seq in token.text:
|
if stop_seq in res.token.text:
|
||||||
is_stop = True
|
stop_seq_found = stop_seq
|
||||||
|
|
||||||
|
# identify text to yield
|
||||||
|
text: Optional[str] = None
|
||||||
|
if res.token.special:
|
||||||
|
text = None
|
||||||
|
elif stop_seq_found:
|
||||||
|
text = res.token.text[: res.token.text.index(stop_seq_found)]
|
||||||
|
else:
|
||||||
|
text = res.token.text
|
||||||
|
|
||||||
|
# yield text, if any
|
||||||
|
if text:
|
||||||
|
chunk = GenerationChunk(text=text)
|
||||||
|
yield chunk
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(chunk.text)
|
||||||
|
|
||||||
|
# break if stop sequence found
|
||||||
|
if stop_seq_found:
|
||||||
break
|
break
|
||||||
if is_stop:
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[GenerationChunk]:
|
||||||
|
invocation_params = self._invocation_params(stop, **kwargs)
|
||||||
|
|
||||||
|
async for res in self.async_client.generate_stream(prompt, **invocation_params):
|
||||||
|
# identify stop sequence in generated text, if any
|
||||||
|
stop_seq_found: Optional[str] = None
|
||||||
|
for stop_seq in invocation_params["stop_sequences"]:
|
||||||
|
if stop_seq in res.token.text:
|
||||||
|
stop_seq_found = stop_seq
|
||||||
|
|
||||||
|
# identify text to yield
|
||||||
|
text: Optional[str] = None
|
||||||
|
if res.token.special:
|
||||||
|
text = None
|
||||||
|
elif stop_seq_found:
|
||||||
|
text = res.token.text[: res.token.text.index(stop_seq_found)]
|
||||||
|
else:
|
||||||
|
text = res.token.text
|
||||||
|
|
||||||
|
# yield text, if any
|
||||||
|
if text:
|
||||||
|
chunk = GenerationChunk(text=text)
|
||||||
|
yield chunk
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(chunk.text)
|
||||||
|
|
||||||
|
# break if stop sequence found
|
||||||
|
if stop_seq_found:
|
||||||
break
|
break
|
||||||
if not token.special:
|
|
||||||
if text_callback:
|
|
||||||
await text_callback(token.text)
|
|
||||||
text += token.text
|
|
||||||
return text
|
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Generator, List, Optional
|
from typing import Any, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
from pydantic import Field, root_validator
|
from pydantic import Field, root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.schema.output import GenerationChunk
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -226,8 +227,10 @@ class LlamaCpp(LLM):
|
|||||||
# method that yields as they are generated
|
# method that yields as they are generated
|
||||||
# and return the combined strings from the first choices's text:
|
# and return the combined strings from the first choices's text:
|
||||||
combined_text_output = ""
|
combined_text_output = ""
|
||||||
for token in self.stream(prompt=prompt, stop=stop, run_manager=run_manager):
|
for chunk in self._stream(
|
||||||
combined_text_output += token["choices"][0]["text"]
|
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
combined_text_output += chunk.text
|
||||||
return combined_text_output
|
return combined_text_output
|
||||||
else:
|
else:
|
||||||
params = self._get_parameters(stop)
|
params = self._get_parameters(stop)
|
||||||
@ -235,17 +238,15 @@ class LlamaCpp(LLM):
|
|||||||
result = self.client(prompt=prompt, **params)
|
result = self.client(prompt=prompt, **params)
|
||||||
return result["choices"][0]["text"]
|
return result["choices"][0]["text"]
|
||||||
|
|
||||||
def stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
) -> Generator[Dict, None, None]:
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
"""Yields results objects as they are generated in real time.
|
"""Yields results objects as they are generated in real time.
|
||||||
|
|
||||||
BETA: this is a beta feature while we figure out the right abstraction.
|
|
||||||
Once that happens, this interface could change.
|
|
||||||
|
|
||||||
It also calls the callback manager's on_llm_new_token event with
|
It also calls the callback manager's on_llm_new_token event with
|
||||||
similar parameters to the OpenAI LLM class method of the same name.
|
similar parameters to the OpenAI LLM class method of the same name.
|
||||||
|
|
||||||
@ -274,16 +275,19 @@ class LlamaCpp(LLM):
|
|||||||
print(result["text"], end='', flush=True)
|
print(result["text"], end='', flush=True)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
params = self._get_parameters(stop)
|
params = {**self._get_parameters(stop), **kwargs}
|
||||||
result = self.client(prompt=prompt, stream=True, **params)
|
result = self.client(prompt=prompt, stream=True, **params)
|
||||||
for chunk in result:
|
for part in result:
|
||||||
token = chunk["choices"][0]["text"]
|
logprobs = part["choices"][0].get("logprobs", None)
|
||||||
log_probs = chunk["choices"][0].get("logprobs", None)
|
chunk = GenerationChunk(
|
||||||
if run_manager:
|
text=part["choices"][0]["text"],
|
||||||
run_manager.on_llm_new_token(
|
generation_info={"logprobs": logprobs},
|
||||||
token=token, verbose=self.verbose, log_probs=log_probs
|
|
||||||
)
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(
|
||||||
|
token=chunk.text, verbose=self.verbose, log_probs=logprobs
|
||||||
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
tokenized_text = self.client.tokenize(text.encode("utf-8"))
|
tokenized_text = self.client.tokenize(text.encode("utf-8"))
|
||||||
|
@ -6,10 +6,11 @@ import warnings
|
|||||||
from typing import (
|
from typing import (
|
||||||
AbstractSet,
|
AbstractSet,
|
||||||
Any,
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
Mapping,
|
Mapping,
|
||||||
@ -27,6 +28,7 @@ from langchain.callbacks.manager import (
|
|||||||
)
|
)
|
||||||
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
|
from langchain.schema.output import GenerationChunk
|
||||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -44,6 +46,19 @@ def update_token_usage(
|
|||||||
token_usage[_key] += response["usage"][_key]
|
token_usage[_key] += response["usage"][_key]
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_response_to_generation_chunk(
|
||||||
|
stream_response: Dict[str, Any],
|
||||||
|
) -> GenerationChunk:
|
||||||
|
"""Convert a stream response to a generation chunk."""
|
||||||
|
return GenerationChunk(
|
||||||
|
text=stream_response["choices"][0]["text"],
|
||||||
|
generation_info=dict(
|
||||||
|
finish_reason=stream_response["choices"][0].get("finish_reason", None),
|
||||||
|
logprobs=stream_response["choices"][0].get("logprobs", None),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
|
def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
|
||||||
"""Update response from the stream response."""
|
"""Update response from the stream response."""
|
||||||
response["choices"][0]["text"] += stream_response["choices"][0]["text"]
|
response["choices"][0]["text"] += stream_response["choices"][0]["text"]
|
||||||
@ -268,6 +283,50 @@ class BaseOpenAI(BaseLLM):
|
|||||||
|
|
||||||
return {**normal_params, **self.model_kwargs}
|
return {**normal_params, **self.model_kwargs}
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
params = {**self._invocation_params, **kwargs, "stream": True}
|
||||||
|
self.get_sub_prompts(params, [prompt], stop) # this mutate params
|
||||||
|
for stream_resp in completion_with_retry(self, prompt=prompt, **params):
|
||||||
|
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||||
|
yield chunk
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(
|
||||||
|
chunk.text,
|
||||||
|
verbose=self.verbose,
|
||||||
|
logprobs=chunk.generation_info["logprobs"]
|
||||||
|
if chunk.generation_info
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[GenerationChunk]:
|
||||||
|
params = {**self._invocation_params, **kwargs, "stream": True}
|
||||||
|
self.get_sub_prompts(params, [prompt], stop) # this mutate params
|
||||||
|
async for stream_resp in await acompletion_with_retry(
|
||||||
|
self, prompt=prompt, **params
|
||||||
|
):
|
||||||
|
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||||
|
yield chunk
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(
|
||||||
|
chunk.text,
|
||||||
|
verbose=self.verbose,
|
||||||
|
logprobs=chunk.generation_info["logprobs"]
|
||||||
|
if chunk.generation_info
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
@ -302,24 +361,28 @@ class BaseOpenAI(BaseLLM):
|
|||||||
if self.streaming:
|
if self.streaming:
|
||||||
if len(_prompts) > 1:
|
if len(_prompts) > 1:
|
||||||
raise ValueError("Cannot stream results with multiple prompts.")
|
raise ValueError("Cannot stream results with multiple prompts.")
|
||||||
params["stream"] = True
|
|
||||||
response = _streaming_response_template()
|
generation: Optional[GenerationChunk] = None
|
||||||
for stream_resp in completion_with_retry(
|
for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs):
|
||||||
self, prompt=_prompts, **params
|
if generation is None:
|
||||||
):
|
generation = chunk
|
||||||
if run_manager:
|
else:
|
||||||
run_manager.on_llm_new_token(
|
generation += chunk
|
||||||
stream_resp["choices"][0]["text"],
|
assert generation is not None
|
||||||
verbose=self.verbose,
|
choices.append(
|
||||||
logprobs=stream_resp["choices"][0]["logprobs"],
|
{
|
||||||
|
"text": generation.text,
|
||||||
|
"finish_reason": generation.generation_info.get("finish_reason")
|
||||||
|
if generation.generation_info
|
||||||
|
else None,
|
||||||
|
"logprobs": generation.generation_info.get("logprobs")
|
||||||
|
if generation.generation_info
|
||||||
|
else None,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
_update_response(response, stream_resp)
|
|
||||||
choices.extend(response["choices"])
|
|
||||||
else:
|
else:
|
||||||
response = completion_with_retry(self, prompt=_prompts, **params)
|
response = completion_with_retry(self, prompt=_prompts, **params)
|
||||||
choices.extend(response["choices"])
|
choices.extend(response["choices"])
|
||||||
if not self.streaming:
|
|
||||||
# Can't update token usage if streaming
|
|
||||||
update_token_usage(_keys, response, token_usage)
|
update_token_usage(_keys, response, token_usage)
|
||||||
return self.create_llm_result(choices, prompts, token_usage)
|
return self.create_llm_result(choices, prompts, token_usage)
|
||||||
|
|
||||||
@ -343,24 +406,30 @@ class BaseOpenAI(BaseLLM):
|
|||||||
if self.streaming:
|
if self.streaming:
|
||||||
if len(_prompts) > 1:
|
if len(_prompts) > 1:
|
||||||
raise ValueError("Cannot stream results with multiple prompts.")
|
raise ValueError("Cannot stream results with multiple prompts.")
|
||||||
params["stream"] = True
|
|
||||||
response = _streaming_response_template()
|
generation: Optional[GenerationChunk] = None
|
||||||
async for stream_resp in await acompletion_with_retry(
|
async for chunk in self._astream(
|
||||||
self, prompt=_prompts, **params
|
_prompts[0], stop, run_manager, **kwargs
|
||||||
):
|
):
|
||||||
if run_manager:
|
if generation is None:
|
||||||
await run_manager.on_llm_new_token(
|
generation = chunk
|
||||||
stream_resp["choices"][0]["text"],
|
else:
|
||||||
verbose=self.verbose,
|
generation += chunk
|
||||||
logprobs=stream_resp["choices"][0]["logprobs"],
|
assert generation is not None
|
||||||
|
choices.append(
|
||||||
|
{
|
||||||
|
"text": generation.text,
|
||||||
|
"finish_reason": generation.generation_info.get("finish_reason")
|
||||||
|
if generation.generation_info
|
||||||
|
else None,
|
||||||
|
"logprobs": generation.generation_info.get("logprobs")
|
||||||
|
if generation.generation_info
|
||||||
|
else None,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
_update_response(response, stream_resp)
|
|
||||||
choices.extend(response["choices"])
|
|
||||||
else:
|
else:
|
||||||
response = await acompletion_with_retry(self, prompt=_prompts, **params)
|
response = await acompletion_with_retry(self, prompt=_prompts, **params)
|
||||||
choices.extend(response["choices"])
|
choices.extend(response["choices"])
|
||||||
if not self.streaming:
|
|
||||||
# Can't update token usage if streaming
|
|
||||||
update_token_usage(_keys, response, token_usage)
|
update_token_usage(_keys, response, token_usage)
|
||||||
return self.create_llm_result(choices, prompts, token_usage)
|
return self.create_llm_result(choices, prompts, token_usage)
|
||||||
|
|
||||||
@ -409,43 +478,6 @@ class BaseOpenAI(BaseLLM):
|
|||||||
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
||||||
return LLMResult(generations=generations, llm_output=llm_output)
|
return LLMResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
|
|
||||||
"""Call OpenAI with streaming flag and return the resulting generator.
|
|
||||||
|
|
||||||
BETA: this is a beta feature while we figure out the right abstraction.
|
|
||||||
Once that happens, this interface could change.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: The prompts to pass into the model.
|
|
||||||
stop: Optional list of stop words to use when generating.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A generator representing the stream of tokens from OpenAI.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
generator = openai.stream("Tell me a joke.")
|
|
||||||
for token in generator:
|
|
||||||
yield token
|
|
||||||
"""
|
|
||||||
params = self.prep_streaming_params(stop)
|
|
||||||
generator = self.client.create(prompt=prompt, **params)
|
|
||||||
|
|
||||||
return generator
|
|
||||||
|
|
||||||
def prep_streaming_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
|
|
||||||
"""Prepare the params for streaming."""
|
|
||||||
params = self._invocation_params
|
|
||||||
if "best_of" in params and params["best_of"] != 1:
|
|
||||||
raise ValueError("OpenAI only supports best_of == 1 for streaming")
|
|
||||||
if stop is not None:
|
|
||||||
if "stop" in params:
|
|
||||||
raise ValueError("`stop` found in both the input and default params.")
|
|
||||||
params["stop"] = stop
|
|
||||||
params["stream"] = True
|
|
||||||
return params
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _invocation_params(self) -> Dict[str, Any]:
|
def _invocation_params(self) -> Dict[str, Any]:
|
||||||
"""Get the parameters used to invoke the model."""
|
"""Get the parameters used to invoke the model."""
|
||||||
@ -777,6 +809,38 @@ class OpenAIChat(BaseLLM):
|
|||||||
del params["max_tokens"]
|
del params["max_tokens"]
|
||||||
return messages, params
|
return messages, params
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
messages, params = self._get_chat_params([prompt], stop)
|
||||||
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
for stream_resp in completion_with_retry(self, messages=messages, **params):
|
||||||
|
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||||
|
yield GenerationChunk(text=token)
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(token)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[GenerationChunk]:
|
||||||
|
messages, params = self._get_chat_params([prompt], stop)
|
||||||
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
async for stream_resp in await acompletion_with_retry(
|
||||||
|
self, messages=messages, **params
|
||||||
|
):
|
||||||
|
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||||
|
yield GenerationChunk(text=token)
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(token)
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
@ -784,22 +848,18 @@ class OpenAIChat(BaseLLM):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
|
if self.streaming:
|
||||||
|
generation: Optional[GenerationChunk] = None
|
||||||
|
for chunk in self._stream(prompts[0], stop, run_manager, **kwargs):
|
||||||
|
if generation is None:
|
||||||
|
generation = chunk
|
||||||
|
else:
|
||||||
|
generation += chunk
|
||||||
|
assert generation is not None
|
||||||
|
return LLMResult(generations=[[generation]])
|
||||||
|
|
||||||
messages, params = self._get_chat_params(prompts, stop)
|
messages, params = self._get_chat_params(prompts, stop)
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
if self.streaming:
|
|
||||||
response = ""
|
|
||||||
params["stream"] = True
|
|
||||||
for stream_resp in completion_with_retry(self, messages=messages, **params):
|
|
||||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
|
||||||
response += token
|
|
||||||
if run_manager:
|
|
||||||
run_manager.on_llm_new_token(
|
|
||||||
token,
|
|
||||||
)
|
|
||||||
return LLMResult(
|
|
||||||
generations=[[Generation(text=response)]],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
full_response = completion_with_retry(self, messages=messages, **params)
|
full_response = completion_with_retry(self, messages=messages, **params)
|
||||||
llm_output = {
|
llm_output = {
|
||||||
"token_usage": full_response["usage"],
|
"token_usage": full_response["usage"],
|
||||||
@ -819,27 +879,19 @@ class OpenAIChat(BaseLLM):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
|
if self.streaming:
|
||||||
|
generation: Optional[GenerationChunk] = None
|
||||||
|
async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs):
|
||||||
|
if generation is None:
|
||||||
|
generation = chunk
|
||||||
|
else:
|
||||||
|
generation += chunk
|
||||||
|
assert generation is not None
|
||||||
|
return LLMResult(generations=[[generation]])
|
||||||
|
|
||||||
messages, params = self._get_chat_params(prompts, stop)
|
messages, params = self._get_chat_params(prompts, stop)
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
if self.streaming:
|
full_response = await acompletion_with_retry(self, messages=messages, **params)
|
||||||
response = ""
|
|
||||||
params["stream"] = True
|
|
||||||
async for stream_resp in await acompletion_with_retry(
|
|
||||||
self, messages=messages, **params
|
|
||||||
):
|
|
||||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
|
||||||
response += token
|
|
||||||
if run_manager:
|
|
||||||
await run_manager.on_llm_new_token(
|
|
||||||
token,
|
|
||||||
)
|
|
||||||
return LLMResult(
|
|
||||||
generations=[[Generation(text=response)]],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
full_response = await acompletion_with_retry(
|
|
||||||
self, messages=messages, **params
|
|
||||||
)
|
|
||||||
llm_output = {
|
llm_output = {
|
||||||
"token_usage": full_response["usage"],
|
"token_usage": full_response["usage"],
|
||||||
"model_name": self.model_name,
|
"model_name": self.model_name,
|
||||||
|
@ -13,10 +13,7 @@ from tenacity import (
|
|||||||
wait_exponential,
|
wait_exponential,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
AsyncCallbackManagerForLLMRun,
|
|
||||||
CallbackManagerForLLMRun,
|
|
||||||
)
|
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
@ -250,12 +247,3 @@ class Tongyi(LLM):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
return LLMResult(generations=generations)
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
async def _agenerate(
|
|
||||||
self,
|
|
||||||
prompts: List[str],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> LLMResult:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
@ -6,7 +6,7 @@ from typing import List
|
|||||||
from langchain.schema import BaseOutputParser
|
from langchain.schema import BaseOutputParser
|
||||||
|
|
||||||
|
|
||||||
class ListOutputParser(BaseOutputParser):
|
class ListOutputParser(BaseOutputParser[List[str]]):
|
||||||
"""Parse the output of an LLM call to a list."""
|
"""Parse the output of an LLM call to a list."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -4,14 +4,14 @@ from typing import Any, Dict, List, Type, Union
|
|||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
BaseLLMOutputParser,
|
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
Generation,
|
Generation,
|
||||||
OutputParserException,
|
OutputParserException,
|
||||||
)
|
)
|
||||||
|
from langchain.schema.output_parser import BaseGenerationOutputParser
|
||||||
|
|
||||||
|
|
||||||
class OutputFunctionsParser(BaseLLMOutputParser[Any]):
|
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||||
"""Parse an output that is one of sets of values."""
|
"""Parse an output that is one of sets of values."""
|
||||||
|
|
||||||
args_only: bool = True
|
args_only: bool = True
|
||||||
|
@ -5,10 +5,10 @@ import warnings
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Any, Callable, Dict, List, Set
|
from typing import Any, Callable, Dict, List, Set
|
||||||
|
|
||||||
from langchain.schema import BasePromptTemplate
|
from langchain.formatting import formatter
|
||||||
from langchain.schema.messages import BaseMessage, HumanMessage
|
from langchain.schema.messages import BaseMessage, HumanMessage
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
from langchain.utils import formatter
|
from langchain.schema.prompt_template import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||||
|
@ -446,7 +446,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
if isinstance(message, BaseMessagePromptTemplate):
|
if isinstance(message, BaseMessagePromptTemplate):
|
||||||
input_vars.update(message.input_variables)
|
input_vars.update(message.input_variables)
|
||||||
return cls(input_variables=list(input_vars), messages=messages)
|
return cls(input_variables=sorted(input_vars), messages=messages)
|
||||||
|
|
||||||
def format(self, **kwargs: Any) -> str:
|
def format(self, **kwargs: Any) -> str:
|
||||||
"""Format the chat template into a string.
|
"""Format the chat template into a string.
|
||||||
|
@ -1,9 +1,6 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
from langchain.utilities.arxiv import ArxivAPIWrapper
|
from langchain.utilities.arxiv import ArxivAPIWrapper
|
||||||
|
|
||||||
@ -20,8 +17,3 @@ class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
|
|||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
return self.load(query=query)
|
return self.load(query=query)
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -7,10 +7,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
|
|
||||||
@ -108,8 +105,3 @@ class BM25Retriever(BaseRetriever):
|
|||||||
processed_query = self.preprocess_func(query)
|
processed_query = self.preprocess_func(query)
|
||||||
return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k)
|
return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k)
|
||||||
return return_docs
|
return return_docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -3,10 +3,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||||
@ -208,11 +205,3 @@ class DocArrayRetriever(BaseRetriever):
|
|||||||
lc_doc.metadata[name] = value
|
lc_doc.metadata[name] = value
|
||||||
|
|
||||||
return lc_doc
|
return lc_doc
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -5,10 +5,7 @@ from __future__ import annotations
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Iterable, List
|
from typing import Any, Iterable, List
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.schema import BaseRetriever
|
from langchain.schema import BaseRetriever
|
||||||
|
|
||||||
@ -138,8 +135,3 @@ class ElasticSearchBM25Retriever(BaseRetriever):
|
|||||||
for r in res["hits"]["hits"]:
|
for r in res["hits"]["hits"]:
|
||||||
docs.append(Document(page_content=r["_source"]["content"]))
|
docs.append(Document(page_content=r["_source"]["content"]))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -5,10 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
|
|||||||
|
|
||||||
from pydantic import Extra, Field, root_validator
|
from pydantic import Extra, Field, root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
@ -184,8 +181,3 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever):
|
|||||||
documents = self._convert_search_response(response.results)
|
documents = self._convert_search_response(response.results)
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -4,10 +4,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.schema import BaseRetriever
|
from langchain.schema import BaseRetriever
|
||||||
|
|
||||||
@ -411,11 +408,3 @@ class AmazonKendraRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
docs = self._kendra_query(query, self.top_k, self.attribute_filter)
|
docs = self._kendra_query(query, self.top_k, self.attribute_filter)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError("Async version is not implemented for Kendra yet.")
|
|
||||||
|
@ -9,10 +9,7 @@ from typing import Any, List, Optional
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
@ -82,8 +79,3 @@ class KNNRetriever(BaseRetriever):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
return top_k_results
|
return top_k_results
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -2,10 +2,7 @@ from typing import Any, Dict, List, cast
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
|
|
||||||
@ -42,11 +39,6 @@ class LlamaIndexRetriever(BaseRetriever):
|
|||||||
)
|
)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError("LlamaIndexRetriever does not support async")
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaIndexGraphRetriever(BaseRetriever):
|
class LlamaIndexGraphRetriever(BaseRetriever):
|
||||||
"""Retriever for question-answering with sources over an LlamaIndex
|
"""Retriever for question-answering with sources over an LlamaIndex
|
||||||
@ -88,8 +80,3 @@ class LlamaIndexGraphRetriever(BaseRetriever):
|
|||||||
Document(page_content=source_node.source_text, metadata=metadata)
|
Document(page_content=source_node.source_text, metadata=metadata)
|
||||||
)
|
)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError("LlamaIndexGraphRetriever does not support async")
|
|
||||||
|
@ -2,10 +2,7 @@ from typing import Any, List, Optional
|
|||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
|
|
||||||
@ -43,8 +40,3 @@ class MetalRetriever(BaseRetriever):
|
|||||||
metadata = {k: v for k, v in r.items() if k != "text"}
|
metadata = {k: v for k, v in r.items() if k != "text"}
|
||||||
final_results.append(Document(page_content=r["text"], metadata=metadata))
|
final_results.append(Document(page_content=r["text"], metadata=metadata))
|
||||||
return final_results
|
return final_results
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -4,10 +4,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
from langchain.vectorstores.milvus import Milvus
|
from langchain.vectorstores.milvus import Milvus
|
||||||
@ -63,15 +60,6 @@ class MilvusRetriever(BaseRetriever):
|
|||||||
query, run_manager=run_manager.get_child(), **kwargs
|
query, run_manager=run_manager.get_child(), **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
def MilvusRetreiver(*args: Any, **kwargs: Any) -> MilvusRetriever:
|
def MilvusRetreiver(*args: Any, **kwargs: Any) -> MilvusRetriever:
|
||||||
"""Deprecated MilvusRetreiver. Please use MilvusRetriever ('i' before 'e') instead.
|
"""Deprecated MilvusRetreiver. Please use MilvusRetriever ('i' before 'e') instead.
|
||||||
|
@ -3,10 +3,7 @@ from typing import List
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.output_parsers.pydantic import PydanticOutputParser
|
from langchain.output_parsers.pydantic import PydanticOutputParser
|
||||||
@ -101,14 +98,6 @@ class MultiQueryRetriever(BaseRetriever):
|
|||||||
unique_documents = self.unique_union(documents)
|
unique_documents = self.unique_union(documents)
|
||||||
return unique_documents
|
return unique_documents
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def generate_queries(
|
def generate_queries(
|
||||||
self, question: str, run_manager: CallbackManagerForRetrieverRun
|
self, question: str, run_manager: CallbackManagerForRetrieverRun
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
@ -5,10 +5,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from pydantic import Extra, root_validator
|
from pydantic import Extra, root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
@ -175,8 +172,3 @@ class PineconeHybridSearchRetriever(BaseRetriever):
|
|||||||
)
|
)
|
||||||
# return search results as json
|
# return search results as json
|
||||||
return final_result
|
return final_result
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -1,9 +1,6 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
from langchain.utilities.pupmed import PubMedAPIWrapper
|
from langchain.utilities.pupmed import PubMedAPIWrapper
|
||||||
|
|
||||||
@ -19,8 +16,3 @@ class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
|
|||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
return self.load_docs(query=query)
|
return self.load_docs(query=query)
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -5,10 +5,7 @@ from typing import Any, Dict, List, Optional, Type, cast
|
|||||||
from pydantic import BaseModel, Field, root_validator
|
from pydantic import BaseModel, Field, root_validator
|
||||||
|
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.chains.query_constructor.base import load_query_constructor_chain
|
from langchain.chains.query_constructor.base import load_query_constructor_chain
|
||||||
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
|
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
|
||||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||||
@ -119,11 +116,6 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
|||||||
docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs)
|
docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
cls,
|
cls,
|
||||||
|
@ -5,10 +5,7 @@ from typing import Any, Iterable, List, Optional
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
@ -113,8 +110,3 @@ class SVMRetriever(BaseRetriever):
|
|||||||
):
|
):
|
||||||
top_k_results.append(Document(page_content=self.texts[row - 1]))
|
top_k_results.append(Document(page_content=self.texts[row - 1]))
|
||||||
return top_k_results
|
return top_k_results
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -2,10 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Dict, Iterable, List, Optional
|
from typing import Any, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
|
|
||||||
@ -79,8 +76,3 @@ class TFIDFRetriever(BaseRetriever):
|
|||||||
) # Op -- (n_docs,1) -- Cosine Sim with each doc
|
) # Op -- (n_docs,1) -- Cosine Sim with each doc
|
||||||
return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
||||||
return return_docs
|
return return_docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -4,10 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
|
||||||
@ -109,12 +106,6 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever):
|
|||||||
result.append(buffered_doc)
|
result.append(buffered_doc)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Return documents that are relevant to the query."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||||
"""Add documents to vectorstore."""
|
"""Add documents to vectorstore."""
|
||||||
current_time = kwargs.get("current_time")
|
current_time = kwargs.get("current_time")
|
||||||
|
@ -3,10 +3,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -57,11 +54,6 @@ class VespaRetriever(BaseRetriever):
|
|||||||
body["query"] = query
|
body["query"] = query
|
||||||
return self._query(body)
|
return self._query(body)
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_relevant_documents_with_filter(
|
def get_relevant_documents_with_filter(
|
||||||
self, query: str, *, _filter: Optional[str] = None
|
self, query: str, *, _filter: Optional[str] = None
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
|
@ -5,10 +5,7 @@ from uuid import uuid4
|
|||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.schema import BaseRetriever
|
from langchain.schema import BaseRetriever
|
||||||
|
|
||||||
@ -118,8 +115,3 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
|||||||
text = res.pop(self.text_key)
|
text = res.pop(self.text_key)
|
||||||
docs.append(Document(page_content=text, metadata=res))
|
docs.append(Document(page_content=text, metadata=res))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -1,9 +1,6 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
from langchain.utilities.wikipedia import WikipediaAPIWrapper
|
from langchain.utilities.wikipedia import WikipediaAPIWrapper
|
||||||
|
|
||||||
@ -19,8 +16,3 @@ class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
|
|||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
return self.load(query=query)
|
return self.load(query=query)
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -3,10 +3,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
from langchain.vectorstores.zilliz import Zilliz
|
from langchain.vectorstores.zilliz import Zilliz
|
||||||
@ -67,15 +64,6 @@ class ZillizRetriever(BaseRetriever):
|
|||||||
query, run_manager=run_manager.get_child(), **kwargs
|
query, run_manager=run_manager.get_child(), **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
def ZillizRetreiver(*args: Any, **kwargs: Any) -> ZillizRetriever:
|
def ZillizRetreiver(*args: Any, **kwargs: Any) -> ZillizRetriever:
|
||||||
"""Deprecated ZillizRetreiver.
|
"""Deprecated ZillizRetreiver.
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from langchain.schema.agent import AgentAction, AgentFinish
|
from langchain.schema.agent import AgentAction, AgentFinish
|
||||||
from langchain.schema.document import BaseDocumentTransformer, Document
|
from langchain.schema.document import BaseDocumentTransformer, Document
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
|
||||||
from langchain.schema.memory import BaseChatMessageHistory, BaseMemory
|
from langchain.schema.memory import BaseChatMessageHistory, BaseMemory
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
@ -67,6 +66,5 @@ __all__ = [
|
|||||||
"BaseOutputParser",
|
"BaseOutputParser",
|
||||||
"BaseLLMOutputParser",
|
"BaseLLMOutputParser",
|
||||||
"BasePromptTemplate",
|
"BasePromptTemplate",
|
||||||
"BaseLanguageModel",
|
|
||||||
"format_document",
|
"format_document",
|
||||||
]
|
]
|
||||||
|
@ -1,12 +1,22 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Set
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Set,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||||
from langchain.schema.output import LLMResult
|
from langchain.schema.output import LLMResult
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
|
from langchain.schema.runnable import Runnable
|
||||||
from langchain.utils import get_pydantic_field_names
|
from langchain.utils import get_pydantic_field_names
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -32,7 +42,13 @@ def _get_token_ids_default_method(text: str) -> List[int]:
|
|||||||
return tokenizer.encode(text)
|
return tokenizer.encode(text)
|
||||||
|
|
||||||
|
|
||||||
class BaseLanguageModel(Serializable, ABC):
|
LanguageModelInput = Union[PromptValue, str, List[BaseMessage]]
|
||||||
|
LanguageModelOutput = TypeVar("LanguageModelOutput")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLanguageModel(
|
||||||
|
Serializable, Runnable[LanguageModelInput, LanguageModelOutput], ABC
|
||||||
|
):
|
||||||
"""Abstract base class for interfacing with language models.
|
"""Abstract base class for interfacing with language models.
|
||||||
|
|
||||||
All language model wrappers inherit from BaseLanguageModel.
|
All language model wrappers inherit from BaseLanguageModel.
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import List, Sequence
|
from typing import Any, Dict, List, Sequence
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
@ -78,6 +78,49 @@ class BaseMessage(Serializable):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMessageChunk(BaseMessage):
|
||||||
|
def _merge_kwargs_dict(
|
||||||
|
self, left: Dict[str, Any], right: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Merge additional_kwargs from another BaseMessageChunk into this one."""
|
||||||
|
merged = left.copy()
|
||||||
|
for k, v in right.items():
|
||||||
|
if k not in merged:
|
||||||
|
merged[k] = v
|
||||||
|
elif type(merged[k]) != type(v):
|
||||||
|
raise ValueError(
|
||||||
|
f'additional_kwargs["{k}"] already exists in this message,'
|
||||||
|
" but with a different type."
|
||||||
|
)
|
||||||
|
elif isinstance(merged[k], str):
|
||||||
|
merged[k] += v
|
||||||
|
elif isinstance(merged[k], dict):
|
||||||
|
merged[k] = self._merge_kwargs_dict(merged[k], v)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Additional kwargs key {k} already exists in this message."
|
||||||
|
)
|
||||||
|
return merged
|
||||||
|
|
||||||
|
def __add__(self, other: Any) -> BaseMessageChunk:
|
||||||
|
if isinstance(other, BaseMessageChunk):
|
||||||
|
# If both are (subclasses of) BaseMessageChunk,
|
||||||
|
# concat into a single BaseMessageChunk
|
||||||
|
|
||||||
|
return self.__class__(
|
||||||
|
content=self.content + other.content,
|
||||||
|
additional_kwargs=self._merge_kwargs_dict(
|
||||||
|
self.additional_kwargs, other.additional_kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'unsupported operand type(s) for +: "'
|
||||||
|
f"{self.__class__.__name__}"
|
||||||
|
f'" and "{other.__class__.__name__}"'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HumanMessage(BaseMessage):
|
class HumanMessage(BaseMessage):
|
||||||
"""A Message from a human."""
|
"""A Message from a human."""
|
||||||
|
|
||||||
@ -92,6 +135,10 @@ class HumanMessage(BaseMessage):
|
|||||||
return "human"
|
return "human"
|
||||||
|
|
||||||
|
|
||||||
|
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AIMessage(BaseMessage):
|
class AIMessage(BaseMessage):
|
||||||
"""A Message from an AI."""
|
"""A Message from an AI."""
|
||||||
|
|
||||||
@ -106,6 +153,10 @@ class AIMessage(BaseMessage):
|
|||||||
return "ai"
|
return "ai"
|
||||||
|
|
||||||
|
|
||||||
|
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SystemMessage(BaseMessage):
|
class SystemMessage(BaseMessage):
|
||||||
"""A Message for priming AI behavior, usually passed in as the first of a sequence
|
"""A Message for priming AI behavior, usually passed in as the first of a sequence
|
||||||
of input messages.
|
of input messages.
|
||||||
@ -117,6 +168,10 @@ class SystemMessage(BaseMessage):
|
|||||||
return "system"
|
return "system"
|
||||||
|
|
||||||
|
|
||||||
|
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FunctionMessage(BaseMessage):
|
class FunctionMessage(BaseMessage):
|
||||||
"""A Message for passing the result of executing a function back to a model."""
|
"""A Message for passing the result of executing a function back to a model."""
|
||||||
|
|
||||||
@ -129,6 +184,10 @@ class FunctionMessage(BaseMessage):
|
|||||||
return "function"
|
return "function"
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseMessage):
|
class ChatMessage(BaseMessage):
|
||||||
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
|
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
|
||||||
|
|
||||||
@ -141,6 +200,10 @@ class ChatMessage(BaseMessage):
|
|||||||
return "chat"
|
return "chat"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _message_to_dict(message: BaseMessage) -> dict:
|
def _message_to_dict(message: BaseMessage) -> dict:
|
||||||
return {"type": message.type, "data": message.dict()}
|
return {"type": message.type, "data": message.dict()}
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from uuid import UUID
|
|||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema.messages import BaseMessage
|
from langchain.schema.messages import BaseMessage, BaseMessageChunk
|
||||||
|
|
||||||
|
|
||||||
class Generation(Serializable):
|
class Generation(Serializable):
|
||||||
@ -28,6 +28,24 @@ class Generation(Serializable):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationChunk(Generation):
|
||||||
|
def __add__(self, other: GenerationChunk) -> GenerationChunk:
|
||||||
|
if isinstance(other, GenerationChunk):
|
||||||
|
generation_info = (
|
||||||
|
{**(self.generation_info or {}), **(other.generation_info or {})}
|
||||||
|
if self.generation_info is not None or other.generation_info is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return GenerationChunk(
|
||||||
|
text=self.text + other.text,
|
||||||
|
generation_info=generation_info,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatGeneration(Generation):
|
class ChatGeneration(Generation):
|
||||||
"""A single chat generation output."""
|
"""A single chat generation output."""
|
||||||
|
|
||||||
@ -43,6 +61,26 @@ class ChatGeneration(Generation):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGenerationChunk(ChatGeneration):
|
||||||
|
message: BaseMessageChunk
|
||||||
|
|
||||||
|
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
|
||||||
|
if isinstance(other, ChatGenerationChunk):
|
||||||
|
generation_info = (
|
||||||
|
{**(self.generation_info or {}), **(other.generation_info or {})}
|
||||||
|
if self.generation_info is not None or other.generation_info is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return ChatGenerationChunk(
|
||||||
|
message=self.message + other.message,
|
||||||
|
generation_info=generation_info,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RunInfo(BaseModel):
|
class RunInfo(BaseModel):
|
||||||
"""Class that contains metadata for a single execution of a Chain or model."""
|
"""Class that contains metadata for a single execution of a Chain or model."""
|
||||||
|
|
||||||
|
@ -1,16 +1,18 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, Generic, List, Optional, TypeVar
|
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema.output import Generation
|
from langchain.schema.messages import BaseMessage
|
||||||
|
from langchain.schema.output import ChatGeneration, Generation
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
|
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
|
class BaseLLMOutputParser(Serializable, Generic[T], ABC):
|
||||||
"""Abstract base class for parsing the outputs of a model."""
|
"""Abstract base class for parsing the outputs of a model."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -26,7 +28,19 @@ class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
|
class BaseGenerationOutputParser(
|
||||||
|
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||||
|
):
|
||||||
|
def invoke(
|
||||||
|
self, input: str | BaseMessage, config: RunnableConfig | None = None
|
||||||
|
) -> T:
|
||||||
|
if isinstance(input, BaseMessage):
|
||||||
|
return self.parse_result([ChatGeneration(message=input)])
|
||||||
|
else:
|
||||||
|
return self.parse_result([Generation(text=input)])
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]):
|
||||||
"""Base class to parse the output of an LLM call.
|
"""Base class to parse the output of an LLM call.
|
||||||
|
|
||||||
Output parsers help structure language model responses.
|
Output parsers help structure language model responses.
|
||||||
@ -53,6 +67,14 @@ class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
|
|||||||
return "boolean_output_parser"
|
return "boolean_output_parser"
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: str | BaseMessage, config: RunnableConfig | None = None
|
||||||
|
) -> T:
|
||||||
|
if isinstance(input, BaseMessage):
|
||||||
|
return self.parse_result([ChatGeneration(message=input)])
|
||||||
|
else:
|
||||||
|
return self.parse_result([Generation(text=input)])
|
||||||
|
|
||||||
def parse_result(self, result: List[Generation]) -> T:
|
def parse_result(self, result: List[Generation]) -> T:
|
||||||
"""Parse a list of candidate model Generations into a specific format.
|
"""Parse a list of candidate model Generations into a specific format.
|
||||||
|
|
||||||
|
@ -12,9 +12,10 @@ from langchain.load.serializable import Serializable
|
|||||||
from langchain.schema.document import Document
|
from langchain.schema.document import Document
|
||||||
from langchain.schema.output_parser import BaseOutputParser
|
from langchain.schema.output_parser import BaseOutputParser
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
|
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||||
|
|
||||||
|
|
||||||
class BasePromptTemplate(Serializable, ABC):
|
class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
|
||||||
"""Base class for all prompt templates, returning a prompt."""
|
"""Base class for all prompt templates, returning a prompt."""
|
||||||
|
|
||||||
input_variables: List[str]
|
input_variables: List[str]
|
||||||
@ -34,6 +35,11 @@ class BasePromptTemplate(Serializable, ABC):
|
|||||||
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue:
|
||||||
|
return self._call_with_config(
|
||||||
|
lambda inner_input: self.format_prompt(**inner_input), input, config
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
"""Create Chat Messages."""
|
"""Create Chat Messages."""
|
||||||
|
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema.document import Document
|
from langchain.schema.document import Document
|
||||||
|
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
@ -17,7 +18,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseRetriever(Serializable, ABC):
|
class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||||
"""Abstract base class for a Document retrieval system.
|
"""Abstract base class for a Document retrieval system.
|
||||||
|
|
||||||
A retrieval system is defined as something that can take string queries and return
|
A retrieval system is defined as something that can take string queries and return
|
||||||
@ -43,9 +44,6 @@ class BaseRetriever(Serializable, ABC):
|
|||||||
# Op -- (n_docs,1) -- Cosine Sim with each doc
|
# Op -- (n_docs,1) -- Cosine Sim with each doc
|
||||||
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
|
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
|
||||||
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
||||||
|
|
||||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -106,6 +104,20 @@ class BaseRetriever(Serializable, ABC):
|
|||||||
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
|
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: str, config: Optional[RunnableConfig] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
return self.get_relevant_documents(input, **(config or {}))
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: str, config: Optional[RunnableConfig] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
if type(self).aget_relevant_documents == BaseRetriever.aget_relevant_documents:
|
||||||
|
# If the retriever doesn't implement async, use default implementation
|
||||||
|
return await super().ainvoke(input, config)
|
||||||
|
|
||||||
|
return await self.aget_relevant_documents(input, **(config or {}))
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
@ -118,7 +130,6 @@ class BaseRetriever(Serializable, ABC):
|
|||||||
List of relevant documents
|
List of relevant documents
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
@ -129,6 +140,7 @@ class BaseRetriever(Serializable, ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
List of relevant documents
|
List of relevant documents
|
||||||
"""
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_relevant_documents(
|
def get_relevant_documents(
|
||||||
self,
|
self,
|
||||||
|
705
libs/langchain/langchain/schema/runnable.py
Normal file
705
libs/langchain/langchain/schema/runnable.py
Normal file
@ -0,0 +1,705 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Callable,
|
||||||
|
Coroutine,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
TypedDict,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
||||||
|
from langchain.load.dump import dumpd
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
|
|
||||||
|
|
||||||
|
async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||||
|
async with semaphore:
|
||||||
|
return await coro
|
||||||
|
|
||||||
|
|
||||||
|
async def _gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
|
||||||
|
if n is None:
|
||||||
|
return await asyncio.gather(*coros)
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(n)
|
||||||
|
|
||||||
|
return await asyncio.gather(*(_gated_coro(semaphore, c) for c in coros))
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableConfig(TypedDict, total=False):
|
||||||
|
tags: List[str]
|
||||||
|
"""
|
||||||
|
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||||
|
You can use these to filter calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
"""
|
||||||
|
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||||
|
Keys should be strings, values should be JSON-serializable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
callbacks: Callbacks
|
||||||
|
"""
|
||||||
|
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||||
|
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
Input = TypeVar("Input")
|
||||||
|
# Output type should implement __concat__, as eg str, list, dict do
|
||||||
|
Output = TypeVar("Output")
|
||||||
|
Other = TypeVar("Other")
|
||||||
|
|
||||||
|
|
||||||
|
class Runnable(Generic[Input, Output], ABC):
|
||||||
|
def __or__(
|
||||||
|
self,
|
||||||
|
other: Union[
|
||||||
|
Runnable[Any, Other],
|
||||||
|
Dict[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||||
|
],
|
||||||
|
) -> RunnableSequence[Input, Other]:
|
||||||
|
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
|
||||||
|
|
||||||
|
def __ror__(
|
||||||
|
self,
|
||||||
|
other: Union[
|
||||||
|
Runnable[Other, Any],
|
||||||
|
Dict[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||||
|
],
|
||||||
|
) -> RunnableSequence[Other, Output]:
|
||||||
|
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Output:
|
||||||
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, self.invoke, input, config
|
||||||
|
)
|
||||||
|
|
||||||
|
def batch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
max_concurrency: Optional[int] = None,
|
||||||
|
) -> List[Output]:
|
||||||
|
configs = self._get_config_list(config, len(inputs))
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||||
|
return list(executor.map(self.invoke, inputs, configs))
|
||||||
|
|
||||||
|
async def abatch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
max_concurrency: Optional[int] = None,
|
||||||
|
) -> List[Output]:
|
||||||
|
configs = self._get_config_list(config, len(inputs))
|
||||||
|
coros = map(self.ainvoke, inputs, configs)
|
||||||
|
|
||||||
|
return await _gather_with_concurrency(max_concurrency, *coros)
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Iterator[Output]:
|
||||||
|
yield self.invoke(input, config)
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
|
) -> AsyncIterator[Output]:
|
||||||
|
yield await self.ainvoke(input, config)
|
||||||
|
|
||||||
|
def _get_config_list(
|
||||||
|
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
|
||||||
|
) -> List[RunnableConfig]:
|
||||||
|
if isinstance(config, list) and len(config) != length:
|
||||||
|
raise ValueError(
|
||||||
|
f"config must be a list of the same length as inputs, "
|
||||||
|
f"but got {len(config)} configs for {length} inputs"
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config
|
||||||
|
if isinstance(config, list)
|
||||||
|
else [config.copy() if config is not None else {} for _ in range(length)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _call_with_config(
|
||||||
|
self,
|
||||||
|
func: Callable[[Input], Output],
|
||||||
|
input: Input,
|
||||||
|
config: Optional[RunnableConfig],
|
||||||
|
) -> Output:
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
|
config = config or {}
|
||||||
|
callback_manager = CallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
)
|
||||||
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
output = func(input)
|
||||||
|
except Exception as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
run_manager.on_chain_end(
|
||||||
|
output if isinstance(output, dict) else {"output": output}
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||||
|
first: Runnable[Input, Any]
|
||||||
|
middle: List[Runnable[Any, Any]] = Field(default_factory=list)
|
||||||
|
last: Runnable[Any, Output]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def steps(self) -> List[Runnable[Any, Any]]:
|
||||||
|
return [self.first] + self.middle + [self.last]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def __or__(
|
||||||
|
self,
|
||||||
|
other: Union[
|
||||||
|
Runnable[Any, Other],
|
||||||
|
Dict[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||||
|
],
|
||||||
|
) -> RunnableSequence[Input, Other]:
|
||||||
|
if isinstance(other, RunnableSequence):
|
||||||
|
return RunnableSequence(
|
||||||
|
first=self.first,
|
||||||
|
middle=self.middle + [self.last] + other.middle,
|
||||||
|
last=other.last,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return RunnableSequence(
|
||||||
|
first=self.first,
|
||||||
|
middle=self.middle + [self.last],
|
||||||
|
last=_coerce_to_runnable(other),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __ror__(
|
||||||
|
self,
|
||||||
|
other: Union[
|
||||||
|
Runnable[Other, Any],
|
||||||
|
Dict[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||||
|
],
|
||||||
|
) -> RunnableSequence[Other, Output]:
|
||||||
|
if isinstance(other, RunnableSequence):
|
||||||
|
return RunnableSequence(
|
||||||
|
first=other.first,
|
||||||
|
middle=other.middle + [other.last] + self.middle,
|
||||||
|
last=self.last,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return RunnableSequence(
|
||||||
|
first=_coerce_to_runnable(other),
|
||||||
|
middle=[self.first] + self.middle,
|
||||||
|
last=self.last,
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
|
# setup callbacks
|
||||||
|
config = config or {}
|
||||||
|
callback_manager = CallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
local_callbacks=None,
|
||||||
|
verbose=False,
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
local_tags=None,
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
local_metadata=None,
|
||||||
|
)
|
||||||
|
# start the root run
|
||||||
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
|
)
|
||||||
|
|
||||||
|
# invoke all steps in sequence
|
||||||
|
try:
|
||||||
|
for step in self.steps:
|
||||||
|
input = step.invoke(
|
||||||
|
input,
|
||||||
|
# mark each step as a child run
|
||||||
|
_patch_config(config, run_manager.get_child()),
|
||||||
|
)
|
||||||
|
# finish the root run
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
run_manager.on_chain_end(
|
||||||
|
input if isinstance(input, dict) else {"output": input}
|
||||||
|
)
|
||||||
|
return cast(Output, input)
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Output:
|
||||||
|
from langchain.callbacks.manager import AsyncCallbackManager
|
||||||
|
|
||||||
|
# setup callbacks
|
||||||
|
config = config or {}
|
||||||
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
local_callbacks=None,
|
||||||
|
verbose=False,
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
local_tags=None,
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
local_metadata=None,
|
||||||
|
)
|
||||||
|
# start the root run
|
||||||
|
run_manager = await callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
|
)
|
||||||
|
|
||||||
|
# invoke all steps in sequence
|
||||||
|
try:
|
||||||
|
for step in self.steps:
|
||||||
|
input = await step.ainvoke(
|
||||||
|
input,
|
||||||
|
# mark each step as a child run
|
||||||
|
_patch_config(config, run_manager.get_child()),
|
||||||
|
)
|
||||||
|
# finish the root run
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
await run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
await run_manager.on_chain_end(
|
||||||
|
input if isinstance(input, dict) else {"output": input}
|
||||||
|
)
|
||||||
|
return cast(Output, input)
|
||||||
|
|
||||||
|
def batch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
max_concurrency: Optional[int] = None,
|
||||||
|
) -> List[Output]:
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
|
# setup callbacks
|
||||||
|
configs = self._get_config_list(config, len(inputs))
|
||||||
|
callback_managers = [
|
||||||
|
CallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
local_callbacks=None,
|
||||||
|
verbose=False,
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
local_tags=None,
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
local_metadata=None,
|
||||||
|
)
|
||||||
|
for config in configs
|
||||||
|
]
|
||||||
|
# start the root runs, one per input
|
||||||
|
run_managers = [
|
||||||
|
cm.on_chain_start(
|
||||||
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
|
)
|
||||||
|
for cm, input in zip(callback_managers, inputs)
|
||||||
|
]
|
||||||
|
|
||||||
|
# invoke
|
||||||
|
try:
|
||||||
|
for step in self.steps:
|
||||||
|
inputs = step.batch(
|
||||||
|
inputs,
|
||||||
|
[
|
||||||
|
# each step a child run of the corresponding root run
|
||||||
|
_patch_config(config, rm.get_child())
|
||||||
|
for rm, config in zip(run_managers, configs)
|
||||||
|
],
|
||||||
|
max_concurrency=max_concurrency,
|
||||||
|
)
|
||||||
|
# finish the root runs
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
for rm in run_managers:
|
||||||
|
rm.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
for rm, input in zip(run_managers, inputs):
|
||||||
|
rm.on_chain_end(input if isinstance(input, dict) else {"output": input})
|
||||||
|
return cast(List[Output], inputs)
|
||||||
|
|
||||||
|
async def abatch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
max_concurrency: Optional[int] = None,
|
||||||
|
) -> List[Output]:
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManager,
|
||||||
|
AsyncCallbackManagerForChainRun,
|
||||||
|
)
|
||||||
|
|
||||||
|
# setup callbacks
|
||||||
|
configs = self._get_config_list(config, len(inputs))
|
||||||
|
callback_managers = [
|
||||||
|
AsyncCallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
local_callbacks=None,
|
||||||
|
verbose=False,
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
local_tags=None,
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
local_metadata=None,
|
||||||
|
)
|
||||||
|
for config in configs
|
||||||
|
]
|
||||||
|
# start the root runs, one per input
|
||||||
|
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||||
|
*(
|
||||||
|
cm.on_chain_start(
|
||||||
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
|
)
|
||||||
|
for cm, input in zip(callback_managers, inputs)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# invoke .batch() on each step
|
||||||
|
# this uses batching optimizations in Runnable subclasses, like LLM
|
||||||
|
try:
|
||||||
|
for step in self.steps:
|
||||||
|
inputs = await step.abatch(
|
||||||
|
inputs,
|
||||||
|
[
|
||||||
|
# each step a child run of the corresponding root run
|
||||||
|
_patch_config(config, rm.get_child())
|
||||||
|
for rm, config in zip(run_managers, configs)
|
||||||
|
],
|
||||||
|
max_concurrency=max_concurrency,
|
||||||
|
)
|
||||||
|
# finish the root runs
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
await asyncio.gather(
|
||||||
|
*(
|
||||||
|
rm.on_chain_end(
|
||||||
|
input if isinstance(input, dict) else {"output": input}
|
||||||
|
)
|
||||||
|
for rm, input in zip(run_managers, inputs)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return cast(List[Output], inputs)
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Iterator[Output]:
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
|
# setup callbacks
|
||||||
|
config = config or {}
|
||||||
|
callback_manager = CallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
local_callbacks=None,
|
||||||
|
verbose=False,
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
local_tags=None,
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
local_metadata=None,
|
||||||
|
)
|
||||||
|
# start the root run
|
||||||
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
|
)
|
||||||
|
|
||||||
|
# invoke the first steps
|
||||||
|
try:
|
||||||
|
for step in [self.first] + self.middle:
|
||||||
|
input = step.invoke(
|
||||||
|
input,
|
||||||
|
# mark each step as a child run
|
||||||
|
_patch_config(config, run_manager.get_child()),
|
||||||
|
)
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# stream the last step
|
||||||
|
final: Union[Output, None] = None
|
||||||
|
final_supported = True
|
||||||
|
try:
|
||||||
|
for output in self.last.stream(
|
||||||
|
input,
|
||||||
|
# mark the last step as a child run
|
||||||
|
_patch_config(config, run_manager.get_child()),
|
||||||
|
):
|
||||||
|
yield output
|
||||||
|
# Accumulate output if possible, otherwise disable accumulation
|
||||||
|
if final_supported:
|
||||||
|
if final is None:
|
||||||
|
final = output
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
final += output # type: ignore[operator]
|
||||||
|
except TypeError:
|
||||||
|
final = None
|
||||||
|
final_supported = False
|
||||||
|
pass
|
||||||
|
# finish the root run
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
run_manager.on_chain_end(
|
||||||
|
final if isinstance(final, dict) else {"output": final}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
|
) -> AsyncIterator[Output]:
|
||||||
|
from langchain.callbacks.manager import AsyncCallbackManager
|
||||||
|
|
||||||
|
# setup callbacks
|
||||||
|
config = config or {}
|
||||||
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
local_callbacks=None,
|
||||||
|
verbose=False,
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
local_tags=None,
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
local_metadata=None,
|
||||||
|
)
|
||||||
|
# start the root run
|
||||||
|
run_manager = await callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
|
)
|
||||||
|
|
||||||
|
# invoke the first steps
|
||||||
|
try:
|
||||||
|
for step in [self.first] + self.middle:
|
||||||
|
input = await step.ainvoke(
|
||||||
|
input,
|
||||||
|
# mark each step as a child run
|
||||||
|
_patch_config(config, run_manager.get_child()),
|
||||||
|
)
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
await run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# stream the last step
|
||||||
|
final: Union[Output, None] = None
|
||||||
|
final_supported = True
|
||||||
|
try:
|
||||||
|
async for output in self.last.astream(
|
||||||
|
input,
|
||||||
|
# mark the last step as a child run
|
||||||
|
_patch_config(config, run_manager.get_child()),
|
||||||
|
):
|
||||||
|
yield output
|
||||||
|
# Accumulate output if possible, otherwise disable accumulation
|
||||||
|
if final_supported:
|
||||||
|
if final is None:
|
||||||
|
final = output
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
final += output # type: ignore[operator]
|
||||||
|
except TypeError:
|
||||||
|
final = None
|
||||||
|
final_supported = False
|
||||||
|
pass
|
||||||
|
# finish the root run
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
await run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
await run_manager.on_chain_end(
|
||||||
|
final if isinstance(final, dict) else {"output": final}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||||
|
steps: Dict[str, Runnable[Input, Any]]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
|
# setup callbacks
|
||||||
|
config = config or {}
|
||||||
|
callback_manager = CallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
local_callbacks=None,
|
||||||
|
verbose=False,
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
local_tags=None,
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
local_metadata=None,
|
||||||
|
)
|
||||||
|
# start the root run
|
||||||
|
run_manager = callback_manager.on_chain_start(dumpd(self), {"input": input})
|
||||||
|
|
||||||
|
# gather results from all steps
|
||||||
|
try:
|
||||||
|
# copy to avoid issues from the caller mutating the steps during invoke()
|
||||||
|
steps = self.steps.copy()
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(
|
||||||
|
step.invoke,
|
||||||
|
input,
|
||||||
|
# mark each step as a child run
|
||||||
|
_patch_config(config, run_manager.get_child()),
|
||||||
|
)
|
||||||
|
for step in steps.values()
|
||||||
|
]
|
||||||
|
output = {key: future.result() for key, future in zip(steps, futures)}
|
||||||
|
# finish the root run
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
run_manager.on_chain_end(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
from langchain.callbacks.manager import AsyncCallbackManager
|
||||||
|
|
||||||
|
# setup callbacks
|
||||||
|
config = config or {}
|
||||||
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
local_callbacks=None,
|
||||||
|
verbose=False,
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
local_tags=None,
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
local_metadata=None,
|
||||||
|
)
|
||||||
|
# start the root run
|
||||||
|
run_manager = await callback_manager.on_chain_start(
|
||||||
|
dumpd(self), {"input": input}
|
||||||
|
)
|
||||||
|
|
||||||
|
# gather results from all steps
|
||||||
|
try:
|
||||||
|
# copy to avoid issues from the caller mutating the steps during invoke()
|
||||||
|
steps = self.steps.copy()
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*(
|
||||||
|
step.ainvoke(
|
||||||
|
input,
|
||||||
|
# mark each step as a child run
|
||||||
|
_patch_config(config, run_manager.get_child()),
|
||||||
|
)
|
||||||
|
for step in steps.values()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
output = {key: value for key, value in zip(steps, results)}
|
||||||
|
# finish the root run
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
await run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
await run_manager.on_chain_end(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableLambda(Runnable[Input, Output]):
|
||||||
|
def __init__(self, func: Callable[[Input], Output]) -> None:
|
||||||
|
if callable(func):
|
||||||
|
self.func = func
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"Expected a callable type for `func`."
|
||||||
|
f"Instead got an unsupported type: {type(func)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
if isinstance(other, RunnableLambda):
|
||||||
|
return self.func == other.func
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||||
|
return self._call_with_config(self.func, input, config)
|
||||||
|
|
||||||
|
|
||||||
|
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||||
|
return self._call_with_config(lambda x: x, input, config)
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_config(
|
||||||
|
config: RunnableConfig, callback_manager: BaseCallbackManager
|
||||||
|
) -> RunnableConfig:
|
||||||
|
config = config.copy()
|
||||||
|
config["callbacks"] = callback_manager
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_to_runnable(
|
||||||
|
thing: Union[
|
||||||
|
Runnable[Input, Output],
|
||||||
|
Callable[[Input], Output],
|
||||||
|
Dict[str, Union[Runnable[Input, Output], Callable[[Input], Output]]],
|
||||||
|
]
|
||||||
|
) -> Runnable[Input, Output]:
|
||||||
|
if isinstance(thing, Runnable):
|
||||||
|
return thing
|
||||||
|
elif callable(thing):
|
||||||
|
return RunnableLambda(thing)
|
||||||
|
elif isinstance(thing, dict):
|
||||||
|
runnables = {key: _coerce_to_runnable(r) for key, r in thing.items()}
|
||||||
|
return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected a Runnable, callable or dict."
|
||||||
|
f"Instead got an unsupported type: {type(thing)}"
|
||||||
|
)
|
@ -177,3 +177,68 @@ def test_chat_openai_extra_kwargs() -> None:
|
|||||||
# Test that "model" cannot be specified in kwargs
|
# Test that "model" cannot be specified in kwargs
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
ChatOpenAI(model_kwargs={"model": "text-davinci-003"})
|
ChatOpenAI(model_kwargs={"model": "text-davinci-003"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_streaming() -> None:
|
||||||
|
"""Test streaming tokens from OpenAI."""
|
||||||
|
llm = ChatOpenAI(max_tokens=10)
|
||||||
|
|
||||||
|
for token in llm.stream("I'm Pickle Rick"):
|
||||||
|
assert isinstance(token.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_astream() -> None:
|
||||||
|
"""Test streaming tokens from OpenAI."""
|
||||||
|
llm = ChatOpenAI(max_tokens=10)
|
||||||
|
|
||||||
|
async for token in llm.astream("I'm Pickle Rick"):
|
||||||
|
assert isinstance(token.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_abatch() -> None:
|
||||||
|
"""Test streaming tokens from ChatOpenAI."""
|
||||||
|
llm = ChatOpenAI(max_tokens=10)
|
||||||
|
|
||||||
|
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||||
|
for token in result:
|
||||||
|
assert isinstance(token.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_abatch_tags() -> None:
|
||||||
|
"""Test batch tokens from ChatOpenAI."""
|
||||||
|
llm = ChatOpenAI(max_tokens=10)
|
||||||
|
|
||||||
|
result = await llm.abatch(
|
||||||
|
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||||
|
)
|
||||||
|
for token in result:
|
||||||
|
assert isinstance(token.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_batch() -> None:
|
||||||
|
"""Test batch tokens from ChatOpenAI."""
|
||||||
|
llm = ChatOpenAI(max_tokens=10)
|
||||||
|
|
||||||
|
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||||
|
for token in result:
|
||||||
|
assert isinstance(token.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_ainvoke() -> None:
|
||||||
|
"""Test invoke tokens from ChatOpenAI."""
|
||||||
|
llm = ChatOpenAI(max_tokens=10)
|
||||||
|
|
||||||
|
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||||
|
assert isinstance(result.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_invoke() -> None:
|
||||||
|
"""Test invoke tokens from ChatOpenAI."""
|
||||||
|
llm = ChatOpenAI(max_tokens=10)
|
||||||
|
|
||||||
|
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||||
|
assert isinstance(result.content, str)
|
||||||
|
@ -93,7 +93,64 @@ def test_openai_streaming() -> None:
|
|||||||
assert isinstance(generator, Generator)
|
assert isinstance(generator, Generator)
|
||||||
|
|
||||||
for token in generator:
|
for token in generator:
|
||||||
assert isinstance(token["choices"][0]["text"], str)
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_astream() -> None:
|
||||||
|
"""Test streaming tokens from OpenAI."""
|
||||||
|
llm = OpenAI(max_tokens=10)
|
||||||
|
|
||||||
|
async for token in llm.astream("I'm Pickle Rick"):
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_abatch() -> None:
|
||||||
|
"""Test streaming tokens from OpenAI."""
|
||||||
|
llm = OpenAI(max_tokens=10)
|
||||||
|
|
||||||
|
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||||
|
for token in result:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_abatch_tags() -> None:
|
||||||
|
"""Test streaming tokens from OpenAI."""
|
||||||
|
llm = OpenAI(max_tokens=10)
|
||||||
|
|
||||||
|
result = await llm.abatch(
|
||||||
|
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||||
|
)
|
||||||
|
for token in result:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_batch() -> None:
|
||||||
|
"""Test streaming tokens from OpenAI."""
|
||||||
|
llm = OpenAI(max_tokens=10)
|
||||||
|
|
||||||
|
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||||
|
for token in result:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_ainvoke() -> None:
|
||||||
|
"""Test streaming tokens from OpenAI."""
|
||||||
|
llm = OpenAI(max_tokens=10)
|
||||||
|
|
||||||
|
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_invoke() -> None:
|
||||||
|
"""Test streaming tokens from OpenAI."""
|
||||||
|
llm = OpenAI(max_tokens=10)
|
||||||
|
|
||||||
|
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
def test_openai_multiple_prompts() -> None:
|
def test_openai_multiple_prompts() -> None:
|
||||||
@ -105,13 +162,6 @@ def test_openai_multiple_prompts() -> None:
|
|||||||
assert len(output.generations) == 2
|
assert len(output.generations) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_openai_streaming_error() -> None:
|
|
||||||
"""Test error handling in stream."""
|
|
||||||
llm = OpenAI(best_of=2)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
llm.stream("I'm Pickle Rick")
|
|
||||||
|
|
||||||
|
|
||||||
def test_openai_streaming_best_of_error() -> None:
|
def test_openai_streaming_best_of_error() -> None:
|
||||||
"""Test validation for streaming fails if best_of is not 1."""
|
"""Test validation for streaming fails if best_of is not 1."""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
@ -67,10 +67,3 @@ def test_promptlayer_openai_streaming() -> None:
|
|||||||
|
|
||||||
for token in generator:
|
for token in generator:
|
||||||
assert isinstance(token["choices"][0]["text"], str)
|
assert isinstance(token["choices"][0]["text"], str)
|
||||||
|
|
||||||
|
|
||||||
def test_promptlayer_openai_streaming_error() -> None:
|
|
||||||
"""Test error handling in stream."""
|
|
||||||
llm = PromptLayerOpenAI(best_of=2)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
llm.stream("I'm Pickle Rick")
|
|
||||||
|
File diff suppressed because one or more lines are too long
38
libs/langchain/tests/unit_tests/schema/test_messages.py
Normal file
38
libs/langchain/tests/unit_tests/schema/test_messages.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from langchain.schema.messages import AIMessageChunk, HumanMessageChunk
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_chunks() -> None:
|
||||||
|
assert AIMessageChunk(content="I am") + AIMessageChunk(
|
||||||
|
content=" indeed."
|
||||||
|
) == AIMessageChunk(
|
||||||
|
content="I am indeed."
|
||||||
|
), "MessageChunk + MessageChunk should be a MessageChunk"
|
||||||
|
|
||||||
|
assert AIMessageChunk(content="I am") + HumanMessageChunk(
|
||||||
|
content=" indeed."
|
||||||
|
) == AIMessageChunk(
|
||||||
|
content="I am indeed."
|
||||||
|
), "MessageChunk + MessageChunk should be a MessageChunk of same class as the left side" # noqa: E501
|
||||||
|
|
||||||
|
assert AIMessageChunk(
|
||||||
|
content="", additional_kwargs={"foo": "bar"}
|
||||||
|
) + AIMessageChunk(content="", additional_kwargs={"baz": "foo"}) == AIMessageChunk(
|
||||||
|
content="", additional_kwargs={"foo": "bar", "baz": "foo"}
|
||||||
|
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
|
||||||
|
|
||||||
|
assert AIMessageChunk(
|
||||||
|
content="", additional_kwargs={"function_call": {"name": "web_search"}}
|
||||||
|
) + AIMessageChunk(
|
||||||
|
content="", additional_kwargs={"function_call": {"arguments": "{\n"}}
|
||||||
|
) + AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={"function_call": {"arguments": ' "query": "turtles"\n}'}},
|
||||||
|
) == AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"function_call": {
|
||||||
|
"name": "web_search",
|
||||||
|
"arguments": '{\n "query": "turtles"\n}',
|
||||||
|
}
|
||||||
|
},
|
||||||
|
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
|
52
libs/langchain/tests/unit_tests/schema/test_output.py
Normal file
52
libs/langchain/tests/unit_tests/schema/test_output.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from langchain.schema.messages import HumanMessageChunk
|
||||||
|
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
|
||||||
|
|
||||||
|
|
||||||
|
def test_generation_chunk() -> None:
|
||||||
|
assert GenerationChunk(text="Hello, ") + GenerationChunk(
|
||||||
|
text="world!"
|
||||||
|
) == GenerationChunk(
|
||||||
|
text="Hello, world!"
|
||||||
|
), "GenerationChunk + GenerationChunk should be a GenerationChunk"
|
||||||
|
|
||||||
|
assert GenerationChunk(text="Hello, ") + GenerationChunk(
|
||||||
|
text="world!", generation_info={"foo": "bar"}
|
||||||
|
) == GenerationChunk(
|
||||||
|
text="Hello, world!", generation_info={"foo": "bar"}
|
||||||
|
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
||||||
|
|
||||||
|
assert GenerationChunk(text="Hello, ") + GenerationChunk(
|
||||||
|
text="world!", generation_info={"foo": "bar"}
|
||||||
|
) + GenerationChunk(text="!", generation_info={"baz": "foo"}) == GenerationChunk(
|
||||||
|
text="Hello, world!!", generation_info={"foo": "bar", "baz": "foo"}
|
||||||
|
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_generation_chunk() -> None:
|
||||||
|
assert ChatGenerationChunk(
|
||||||
|
message=HumanMessageChunk(content="Hello, ")
|
||||||
|
) + ChatGenerationChunk(
|
||||||
|
message=HumanMessageChunk(content="world!")
|
||||||
|
) == ChatGenerationChunk(
|
||||||
|
message=HumanMessageChunk(content="Hello, world!")
|
||||||
|
), "ChatGenerationChunk + ChatGenerationChunk should be a ChatGenerationChunk"
|
||||||
|
|
||||||
|
assert ChatGenerationChunk(
|
||||||
|
message=HumanMessageChunk(content="Hello, ")
|
||||||
|
) + ChatGenerationChunk(
|
||||||
|
message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"}
|
||||||
|
) == ChatGenerationChunk(
|
||||||
|
message=HumanMessageChunk(content="Hello, world!"),
|
||||||
|
generation_info={"foo": "bar"},
|
||||||
|
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
||||||
|
|
||||||
|
assert ChatGenerationChunk(
|
||||||
|
message=HumanMessageChunk(content="Hello, ")
|
||||||
|
) + ChatGenerationChunk(
|
||||||
|
message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"}
|
||||||
|
) + ChatGenerationChunk(
|
||||||
|
message=HumanMessageChunk(content="!"), generation_info={"baz": "foo"}
|
||||||
|
) == ChatGenerationChunk(
|
||||||
|
message=HumanMessageChunk(content="Hello, world!!"),
|
||||||
|
generation_info={"foo": "bar", "baz": "foo"},
|
||||||
|
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
|
547
libs/langchain/tests/unit_tests/schema/test_runnable.py
Normal file
547
libs/langchain/tests/unit_tests/schema/test_runnable.py
Normal file
@ -0,0 +1,547 @@
|
|||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from freezegun import freeze_time
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
from syrupy import SnapshotAssertion
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
|
from langchain.chat_models.fake import FakeListChatModel
|
||||||
|
from langchain.llms.fake import FakeListLLM
|
||||||
|
from langchain.load.dump import dumps
|
||||||
|
from langchain.output_parsers.list import CommaSeparatedListOutputParser
|
||||||
|
from langchain.prompts.chat import (
|
||||||
|
ChatPromptTemplate,
|
||||||
|
ChatPromptValue,
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
SystemMessagePromptTemplate,
|
||||||
|
)
|
||||||
|
from langchain.schema.document import Document
|
||||||
|
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
from langchain.schema.retriever import BaseRetriever
|
||||||
|
from langchain.schema.runnable import (
|
||||||
|
Runnable,
|
||||||
|
RunnableConfig,
|
||||||
|
RunnableLambda,
|
||||||
|
RunnableMap,
|
||||||
|
RunnablePassthrough,
|
||||||
|
RunnableSequence,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTracer(BaseTracer):
|
||||||
|
"""Fake tracer that records LangChain execution."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the tracer."""
|
||||||
|
super().__init__()
|
||||||
|
self.runs: List[Run] = []
|
||||||
|
|
||||||
|
def _persist_run(self, run: Run) -> None:
|
||||||
|
"""Persist a run."""
|
||||||
|
self.runs.append(run)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeRunnable(Runnable[str, int]):
|
||||||
|
def invoke(
|
||||||
|
self,
|
||||||
|
input: str,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
) -> int:
|
||||||
|
return len(input)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeRetriever(BaseRetriever):
|
||||||
|
def _get_relevant_documents(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
return [Document(page_content="foo"), Document(page_content="bar")]
|
||||||
|
|
||||||
|
async def _aget_relevant_documents(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
return [Document(page_content="foo"), Document(page_content="bar")]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def fixed_uuids(mocker: MockerFixture) -> MockerFixture._Patcher:
|
||||||
|
"""Note this mock only works with `import uuid; uuid.uuid4()`,
|
||||||
|
it does not work with `from uuid import uuid4; uuid4()`."""
|
||||||
|
|
||||||
|
# Disable tracing to avoid fixed UUIDs causing tracing errors.
|
||||||
|
mocker.patch.dict("os.environ", {"LANGCHAIN_TRACING_V2": "false"})
|
||||||
|
|
||||||
|
side_effect = (
|
||||||
|
UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)
|
||||||
|
)
|
||||||
|
return mocker.patch("uuid.uuid4", side_effect=side_effect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||||
|
fake = FakeRunnable()
|
||||||
|
spy = mocker.spy(fake, "invoke")
|
||||||
|
|
||||||
|
assert fake.invoke("hello", dict(tags=["a-tag"])) == 5
|
||||||
|
assert spy.call_args_list == [
|
||||||
|
mocker.call("hello", dict(tags=["a-tag"])),
|
||||||
|
]
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert [*fake.stream("hello", dict(metadata={"key": "value"}))] == [5]
|
||||||
|
assert spy.call_args_list == [
|
||||||
|
mocker.call("hello", dict(metadata={"key": "value"})),
|
||||||
|
]
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert fake.batch(
|
||||||
|
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
|
||||||
|
) == [5, 7]
|
||||||
|
assert spy.call_args_list == [
|
||||||
|
mocker.call("hello", dict(tags=["a-tag"])),
|
||||||
|
mocker.call("wooorld", dict(metadata={"key": "value"})),
|
||||||
|
]
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
|
||||||
|
assert spy.call_args_list == [
|
||||||
|
mocker.call("hello", dict(tags=["a-tag"])),
|
||||||
|
mocker.call("wooorld", dict(tags=["a-tag"])),
|
||||||
|
]
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert await fake.ainvoke("hello", config={"callbacks": []}) == 5
|
||||||
|
assert spy.call_args_list == [
|
||||||
|
mocker.call("hello", dict(callbacks=[])),
|
||||||
|
]
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert [part async for part in fake.astream("hello")] == [5]
|
||||||
|
assert spy.call_args_list == [
|
||||||
|
mocker.call("hello", None),
|
||||||
|
]
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert await fake.abatch(["hello", "wooorld"], dict(metadata={"key": "value"})) == [
|
||||||
|
5,
|
||||||
|
7,
|
||||||
|
]
|
||||||
|
assert spy.call_args_list == [
|
||||||
|
mocker.call("hello", dict(metadata={"key": "value"})),
|
||||||
|
mocker.call("wooorld", dict(metadata={"key": "value"})),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt() -> None:
|
||||||
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessagePromptTemplate.from_template("{question}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
expected = ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your name?"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert prompt.invoke({"question": "What is your name?"}) == expected
|
||||||
|
|
||||||
|
assert prompt.batch(
|
||||||
|
[
|
||||||
|
{"question": "What is your name?"},
|
||||||
|
{"question": "What is your favorite color?"},
|
||||||
|
]
|
||||||
|
) == [
|
||||||
|
expected,
|
||||||
|
ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your favorite color?"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert [*prompt.stream({"question": "What is your name?"})] == [expected]
|
||||||
|
|
||||||
|
assert await prompt.ainvoke({"question": "What is your name?"}) == expected
|
||||||
|
|
||||||
|
assert await prompt.abatch(
|
||||||
|
[
|
||||||
|
{"question": "What is your name?"},
|
||||||
|
{"question": "What is your favorite color?"},
|
||||||
|
]
|
||||||
|
) == [
|
||||||
|
expected,
|
||||||
|
ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your favorite color?"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert [
|
||||||
|
part async for part in prompt.astream({"question": "What is your name?"})
|
||||||
|
] == [expected]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
async def test_prompt_with_chat_model(
|
||||||
|
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||||
|
) -> None:
|
||||||
|
prompt = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
|
+ "{question}"
|
||||||
|
)
|
||||||
|
chat = FakeListChatModel(responses=["foo", "bar"])
|
||||||
|
|
||||||
|
chain = prompt | chat
|
||||||
|
|
||||||
|
assert isinstance(chain, RunnableSequence)
|
||||||
|
assert chain.first == prompt
|
||||||
|
assert chain.middle == []
|
||||||
|
assert chain.last == chat
|
||||||
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
|
# Test invoke
|
||||||
|
prompt_spy = mocker.spy(prompt.__class__, "invoke")
|
||||||
|
chat_spy = mocker.spy(chat.__class__, "invoke")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert chain.invoke(
|
||||||
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
|
) == AIMessage(content="foo")
|
||||||
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
|
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your name?"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert tracer.runs == snapshot
|
||||||
|
mocker.stop(prompt_spy)
|
||||||
|
mocker.stop(chat_spy)
|
||||||
|
|
||||||
|
# Test batch
|
||||||
|
prompt_spy = mocker.spy(prompt.__class__, "batch")
|
||||||
|
chat_spy = mocker.spy(chat.__class__, "batch")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert chain.batch(
|
||||||
|
[
|
||||||
|
{"question": "What is your name?"},
|
||||||
|
{"question": "What is your favorite color?"},
|
||||||
|
],
|
||||||
|
dict(callbacks=[tracer]),
|
||||||
|
) == [
|
||||||
|
AIMessage(content="bar"),
|
||||||
|
AIMessage(content="foo"),
|
||||||
|
]
|
||||||
|
assert prompt_spy.call_args.args[1] == [
|
||||||
|
{"question": "What is your name?"},
|
||||||
|
{"question": "What is your favorite color?"},
|
||||||
|
]
|
||||||
|
assert chat_spy.call_args.args[1] == [
|
||||||
|
ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your name?"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your favorite color?"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
assert tracer.runs == snapshot
|
||||||
|
mocker.stop(prompt_spy)
|
||||||
|
mocker.stop(chat_spy)
|
||||||
|
|
||||||
|
# Test stream
|
||||||
|
prompt_spy = mocker.spy(prompt.__class__, "invoke")
|
||||||
|
chat_spy = mocker.spy(chat.__class__, "stream")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert [
|
||||||
|
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
|
||||||
|
] == [AIMessage(content="bar")]
|
||||||
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
|
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your name?"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
async def test_prompt_with_llm(
|
||||||
|
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||||
|
) -> None:
|
||||||
|
prompt = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
|
+ "{question}"
|
||||||
|
)
|
||||||
|
llm = FakeListLLM(responses=["foo", "bar"])
|
||||||
|
|
||||||
|
chain = prompt | llm
|
||||||
|
|
||||||
|
assert isinstance(chain, RunnableSequence)
|
||||||
|
assert chain.first == prompt
|
||||||
|
assert chain.middle == []
|
||||||
|
assert chain.last == llm
|
||||||
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
|
# Test invoke
|
||||||
|
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||||
|
llm_spy = mocker.spy(llm.__class__, "ainvoke")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert (
|
||||||
|
await chain.ainvoke(
|
||||||
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
|
)
|
||||||
|
== "foo"
|
||||||
|
)
|
||||||
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
|
assert llm_spy.call_args.args[1] == ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your name?"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert tracer.runs == snapshot
|
||||||
|
mocker.stop(prompt_spy)
|
||||||
|
mocker.stop(llm_spy)
|
||||||
|
|
||||||
|
# Test batch
|
||||||
|
prompt_spy = mocker.spy(prompt.__class__, "abatch")
|
||||||
|
llm_spy = mocker.spy(llm.__class__, "abatch")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert await chain.abatch(
|
||||||
|
[
|
||||||
|
{"question": "What is your name?"},
|
||||||
|
{"question": "What is your favorite color?"},
|
||||||
|
],
|
||||||
|
dict(callbacks=[tracer]),
|
||||||
|
) == ["bar", "foo"]
|
||||||
|
assert prompt_spy.call_args.args[1] == [
|
||||||
|
{"question": "What is your name?"},
|
||||||
|
{"question": "What is your favorite color?"},
|
||||||
|
]
|
||||||
|
assert llm_spy.call_args.args[1] == [
|
||||||
|
ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your name?"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your favorite color?"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
assert tracer.runs == snapshot
|
||||||
|
mocker.stop(prompt_spy)
|
||||||
|
mocker.stop(llm_spy)
|
||||||
|
|
||||||
|
# Test stream
|
||||||
|
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||||
|
llm_spy = mocker.spy(llm.__class__, "astream")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert [
|
||||||
|
token
|
||||||
|
async for token in chain.astream(
|
||||||
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
|
)
|
||||||
|
] == ["bar"]
|
||||||
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
|
assert llm_spy.call_args.args[1] == ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your name?"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_prompt_with_chat_model_and_parser(
|
||||||
|
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||||
|
) -> None:
|
||||||
|
prompt = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
|
+ "{question}"
|
||||||
|
)
|
||||||
|
chat = FakeListChatModel(responses=["foo, bar"])
|
||||||
|
parser = CommaSeparatedListOutputParser()
|
||||||
|
|
||||||
|
chain = prompt | chat | parser
|
||||||
|
|
||||||
|
assert isinstance(chain, RunnableSequence)
|
||||||
|
assert chain.first == prompt
|
||||||
|
assert chain.middle == [chat]
|
||||||
|
assert chain.last == parser
|
||||||
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
|
# Test invoke
|
||||||
|
prompt_spy = mocker.spy(prompt.__class__, "invoke")
|
||||||
|
chat_spy = mocker.spy(chat.__class__, "invoke")
|
||||||
|
parser_spy = mocker.spy(parser.__class__, "invoke")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert chain.invoke(
|
||||||
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
|
) == ["foo", "bar"]
|
||||||
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
|
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your name?"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar")
|
||||||
|
assert tracer.runs == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_seq_dict_prompt_llm(
|
||||||
|
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||||
|
) -> None:
|
||||||
|
passthrough = mocker.Mock(side_effect=lambda x: x)
|
||||||
|
|
||||||
|
retriever = FakeRetriever()
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
|
+ """Context:
|
||||||
|
{documents}
|
||||||
|
|
||||||
|
Question:
|
||||||
|
{question}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
chat = FakeListChatModel(responses=["foo, bar"])
|
||||||
|
|
||||||
|
parser = CommaSeparatedListOutputParser()
|
||||||
|
|
||||||
|
chain = (
|
||||||
|
{
|
||||||
|
"question": RunnablePassthrough[str]() | passthrough,
|
||||||
|
"documents": passthrough | retriever,
|
||||||
|
"just_to_test_lambda": passthrough,
|
||||||
|
}
|
||||||
|
| prompt
|
||||||
|
| chat
|
||||||
|
| parser
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(chain, RunnableSequence)
|
||||||
|
assert isinstance(chain.first, RunnableMap)
|
||||||
|
assert chain.middle == [prompt, chat]
|
||||||
|
assert chain.last == parser
|
||||||
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
|
# Test invoke
|
||||||
|
prompt_spy = mocker.spy(prompt.__class__, "invoke")
|
||||||
|
chat_spy = mocker.spy(chat.__class__, "invoke")
|
||||||
|
parser_spy = mocker.spy(parser.__class__, "invoke")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert chain.invoke("What is your name?", dict(callbacks=[tracer])) == [
|
||||||
|
"foo",
|
||||||
|
"bar",
|
||||||
|
]
|
||||||
|
assert prompt_spy.call_args.args[1] == {
|
||||||
|
"documents": [Document(page_content="foo"), Document(page_content="bar")],
|
||||||
|
"question": "What is your name?",
|
||||||
|
"just_to_test_lambda": "What is your name?",
|
||||||
|
}
|
||||||
|
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(
|
||||||
|
content="""Context:
|
||||||
|
[Document(page_content='foo', metadata={}), Document(page_content='bar', metadata={})]
|
||||||
|
|
||||||
|
Question:
|
||||||
|
What is your name?"""
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar")
|
||||||
|
assert tracer.runs == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_seq_prompt_dict(
|
||||||
|
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||||
|
) -> None:
|
||||||
|
passthrough = mocker.Mock(side_effect=lambda x: x)
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
|
+ "{question}"
|
||||||
|
)
|
||||||
|
|
||||||
|
chat = FakeListChatModel(responses=["i'm a chatbot"])
|
||||||
|
|
||||||
|
llm = FakeListLLM(responses=["i'm a textbot"])
|
||||||
|
|
||||||
|
chain = (
|
||||||
|
prompt
|
||||||
|
| passthrough
|
||||||
|
| { # type: ignore
|
||||||
|
"chat": chat,
|
||||||
|
"llm": llm,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(chain, RunnableSequence)
|
||||||
|
assert chain.first == prompt
|
||||||
|
assert chain.middle == [RunnableLambda(passthrough)]
|
||||||
|
assert isinstance(chain.last, RunnableMap)
|
||||||
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
|
# Test invoke
|
||||||
|
prompt_spy = mocker.spy(prompt.__class__, "invoke")
|
||||||
|
chat_spy = mocker.spy(chat.__class__, "invoke")
|
||||||
|
llm_spy = mocker.spy(llm.__class__, "invoke")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert chain.invoke(
|
||||||
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
|
) == {
|
||||||
|
"chat": AIMessage(content="i'm a chatbot"),
|
||||||
|
"llm": "i'm a textbot",
|
||||||
|
}
|
||||||
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
|
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your name?"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert llm_spy.call_args.args[1] == ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your name?"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert tracer.runs == snapshot
|
Loading…
Reference in New Issue
Block a user