mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-12 04:01:05 +00:00
Compare commits
3 Commits
langchain=
...
vwp/retrie
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
774c405707 | ||
|
|
215ab8a62e | ||
|
|
198955575d |
@@ -29,7 +29,7 @@ from langchain.input import get_color_mapping
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseMessage,
|
||||
|
||||
@@ -20,7 +20,7 @@ from langchain.prompts.chat import (
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AgentAction
|
||||
from langchain.schema.base import AgentAction
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Union
|
||||
|
||||
from langchain.agents.agent import AgentOutputParser
|
||||
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.schema.base import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Union
|
||||
|
||||
from langchain.agents.agent import AgentOutputParser
|
||||
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.schema.base import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
|
||||
class ConvoOutputParser(AgentOutputParser):
|
||||
|
||||
@@ -23,7 +23,7 @@ from langchain.prompts.chat import (
|
||||
MessagesPlaceholder,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
AgentAction,
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Union
|
||||
from langchain.agents import AgentOutputParser
|
||||
from langchain.agents.conversational_chat.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.output_parsers.json import parse_json_markdown
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.schema.base import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
|
||||
class ConvoOutputParser(AgentOutputParser):
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Union
|
||||
|
||||
from langchain.agents.agent import AgentOutputParser
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.schema.base import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import re
|
||||
from typing import Union
|
||||
|
||||
from langchain.agents.agent import AgentOutputParser
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.schema.base import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
|
||||
class ReActOutputParser(AgentOutputParser):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.schema import AgentAction
|
||||
from langchain.schema.base import AgentAction
|
||||
|
||||
|
||||
class AgentScratchPadChatPromptTemplate(ChatPromptTemplate):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Sequence, Union
|
||||
|
||||
from langchain.agents.agent import AgentOutputParser
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.schema.base import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
|
||||
class SelfAskOutputParser(AgentOutputParser):
|
||||
|
||||
@@ -17,7 +17,7 @@ from langchain.prompts.chat import (
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AgentAction
|
||||
from langchain.schema.base import AgentAction
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
|
||||
|
||||
@@ -11,7 +11,7 @@ from langchain.agents.agent import AgentOutputParser
|
||||
from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.output_parsers import OutputFixingParser
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.schema.base import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import List, Optional, Sequence, Set
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
|
||||
from langchain.schema.base import BaseMessage, LLMResult, PromptValue, get_buffer_string
|
||||
|
||||
|
||||
def _get_token_ids_default_method(text: str) -> List[int]:
|
||||
|
||||
@@ -31,7 +31,7 @@ except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Generation
|
||||
from langchain.schema.base import Generation
|
||||
from langchain.vectorstores.redis import Redis as RedisVectorstore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -2,7 +2,7 @@ from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
def import_aim() -> Any:
|
||||
@@ -29,6 +29,7 @@ class BaseMetadataCallbackHandler:
|
||||
ignore_llm_ (bool): Whether to ignore llm callbacks.
|
||||
ignore_chain_ (bool): Whether to ignore chain callbacks.
|
||||
ignore_agent_ (bool): Whether to ignore agent callbacks.
|
||||
ignore_retriever_ (bool): Whether to ignore retriever callbacks.
|
||||
always_verbose_ (bool): Whether to always be verbose.
|
||||
chain_starts (int): The number of times the chain start method has been called.
|
||||
chain_ends (int): The number of times the chain end method has been called.
|
||||
@@ -51,6 +52,7 @@ class BaseMetadataCallbackHandler:
|
||||
self.ignore_llm_ = False
|
||||
self.ignore_chain_ = False
|
||||
self.ignore_agent_ = False
|
||||
self.ignore_retriever_ = False
|
||||
self.always_verbose_ = False
|
||||
|
||||
self.chain_starts = 0
|
||||
@@ -85,6 +87,11 @@ class BaseMetadataCallbackHandler:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
@property
|
||||
def ignore_retriever(self) -> bool:
|
||||
"""Whether to ignore retriever callbacks."""
|
||||
return self.ignore_retriever_
|
||||
|
||||
def get_custom_callback_meta(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"step": self.step,
|
||||
|
||||
@@ -3,7 +3,7 @@ import warnings
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
import pandas as pd
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class ArizeCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
@@ -1,15 +1,45 @@
|
||||
"""Base callback handler that can be used to handle callbacks in langchain."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseMessage,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain.schema.base import AgentAction, AgentFinish, BaseMessage, LLMResult
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
class RetrieverManagerMixin:
|
||||
"""Mixin for Retriever callbacks."""
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever starts running."""
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever errors."""
|
||||
|
||||
def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever ends running."""
|
||||
|
||||
|
||||
class LLMManagerMixin:
|
||||
@@ -183,6 +213,7 @@ class BaseCallbackHandler(
|
||||
LLMManagerMixin,
|
||||
ChainManagerMixin,
|
||||
ToolManagerMixin,
|
||||
RetrieverManagerMixin,
|
||||
CallbackManagerMixin,
|
||||
RunManagerMixin,
|
||||
):
|
||||
@@ -361,6 +392,36 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever start."""
|
||||
|
||||
async def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever end."""
|
||||
|
||||
async def on_retriever_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever error."""
|
||||
|
||||
|
||||
class BaseCallbackManager(CallbackManagerMixin):
|
||||
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
||||
|
||||
@@ -13,7 +13,7 @@ from langchain.callbacks.utils import (
|
||||
import_textstat,
|
||||
load_json,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
def import_clearml() -> Any:
|
||||
|
||||
@@ -12,7 +12,7 @@ from langchain.callbacks.utils import (
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, Generation, LLMResult
|
||||
from langchain.schema.base import AgentAction, AgentFinish, Generation, LLMResult
|
||||
|
||||
LANGCHAIN_MODEL_NAME = "langchain-model"
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, TextIO, cast
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
from langchain.schema import AgentAction, AgentFinish
|
||||
from langchain.schema.base import AgentAction, AgentFinish
|
||||
|
||||
|
||||
class FileCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import (
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
@@ -27,6 +28,7 @@ from langchain.callbacks.base import (
|
||||
BaseCallbackManager,
|
||||
ChainManagerMixin,
|
||||
LLMManagerMixin,
|
||||
RetrieverManagerMixin,
|
||||
RunManagerMixin,
|
||||
ToolManagerMixin,
|
||||
)
|
||||
@@ -36,13 +38,14 @@ from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
|
||||
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||
from langchain.callbacks.tracers.wandb import WandbTracer
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseMessage,
|
||||
LLMResult,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||
@@ -624,6 +627,91 @@ class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForRetrieverRun(RunManager, RetrieverManagerMixin):
|
||||
"""Callback manager for retriever run."""
|
||||
|
||||
def get_child(self) -> CallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = CallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
return manager
|
||||
|
||||
def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when retriever ends running."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_retriever_end",
|
||||
"ignore_retriever",
|
||||
documents,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when retriever errors."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_retriever_error",
|
||||
"ignore_retriever",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForRetrieverRun(
|
||||
AsyncRunManager,
|
||||
RetrieverManagerMixin,
|
||||
):
|
||||
"""Async callback manager for retriever run."""
|
||||
|
||||
def get_child(self) -> AsyncCallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
return manager
|
||||
|
||||
async def on_retriever_end(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when retriever ends running."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_retriever_end",
|
||||
"ignore_retriever",
|
||||
documents,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_retriever_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when retriever errors."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_retriever_error",
|
||||
"ignore_retriever",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManager(BaseCallbackManager):
|
||||
"""Callback manager that can be used to handle callbacks from langchain."""
|
||||
|
||||
@@ -733,6 +821,29 @@ class CallbackManager(BaseCallbackManager):
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
query: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
) -> CallbackManagerForRetrieverRun:
|
||||
"""Run when retriever starts running."""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_retriever_start",
|
||||
"ignore_retriever",
|
||||
query,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
)
|
||||
|
||||
return CallbackManagerForRetrieverRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def configure(
|
||||
cls,
|
||||
@@ -856,6 +967,29 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
query: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
) -> AsyncCallbackManagerForRetrieverRun:
|
||||
"""Run when retriever starts running."""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_retriever_start",
|
||||
"ignore_retriever",
|
||||
query,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
)
|
||||
|
||||
return AsyncCallbackManagerForRetrieverRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def configure(
|
||||
cls,
|
||||
|
||||
@@ -15,7 +15,7 @@ from langchain.callbacks.utils import (
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
MODEL_COST_PER_1K_TOKENS = {
|
||||
"gpt-4": 0.03,
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema.base import LLMResult
|
||||
|
||||
# TODO If used by two LLM runs in parallel this won't work as expected
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import sys
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
import streamlit as st
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
@@ -3,12 +3,13 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema.base import LLMResult
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
class TracerException(Exception):
|
||||
@@ -262,6 +263,65 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
self._end_trace(tool_run)
|
||||
self._on_tool_error(tool_run)
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when Retriever starts running."""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
retrieval_run = Run(
|
||||
id=run_id,
|
||||
name="Retriever",
|
||||
parent_run_id=parent_run_id,
|
||||
inputs={"query": query},
|
||||
extra=kwargs,
|
||||
start_time=datetime.utcnow(),
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
child_runs=[],
|
||||
run_type=RunTypeEnum.retriever,
|
||||
)
|
||||
self._start_trace(retrieval_run)
|
||||
self._on_retriever_start(retrieval_run)
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when Retriever errors."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retriever_error callback.")
|
||||
retrieval_run = self.run_map.get(str(run_id))
|
||||
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
|
||||
raise TracerException("No retriever Run found to be traced")
|
||||
|
||||
retrieval_run.error = repr(error)
|
||||
retrieval_run.end_time = datetime.utcnow()
|
||||
self._end_trace(retrieval_run)
|
||||
self._on_retriever_error(retrieval_run)
|
||||
|
||||
def on_retriever_end(
|
||||
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when Retriever ends running."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retriever_end callback.")
|
||||
retrieval_run = self.run_map.get(str(run_id))
|
||||
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
|
||||
raise TracerException("No retriever Run found to be traced")
|
||||
retrieval_run.outputs = {"documents": documents}
|
||||
retrieval_run.end_time = datetime.utcnow()
|
||||
self._end_trace(retrieval_run)
|
||||
self._on_retriever_end(retrieval_run)
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> BaseTracer:
|
||||
"""Deepcopy the tracer."""
|
||||
return self
|
||||
@@ -299,3 +359,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
def _on_chat_model_start(self, run: Run) -> None:
|
||||
"""Process the Chat Model Run upon start."""
|
||||
|
||||
def _on_retriever_start(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon start."""
|
||||
|
||||
def _on_retriever_end(self, run: Run) -> None:
|
||||
"""Process the Retriever Run."""
|
||||
|
||||
def _on_retriever_error(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon error."""
|
||||
|
||||
@@ -13,7 +13,7 @@ from langchainplus_sdk import LangChainPlusClient
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession
|
||||
from langchain.env import get_runtime_environment
|
||||
from langchain.schema import BaseMessage, messages_to_dict
|
||||
from langchain.schema.base import BaseMessage, messages_to_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -122,3 +122,15 @@ class LangChainTracer(BaseTracer):
|
||||
def _on_tool_error(self, run: Run) -> None:
|
||||
"""Process the Tool Run upon error."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_retriever_start(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon start."""
|
||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_retriever_end(self, run: Run) -> None:
|
||||
"""Process the Retriever Run."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_retriever_error(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon error."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
|
||||
@@ -16,7 +16,7 @@ from langchain.callbacks.tracers.schemas import (
|
||||
TracerSessionV1,
|
||||
TracerSessionV1Base,
|
||||
)
|
||||
from langchain.schema import get_buffer_string
|
||||
from langchain.schema.base import get_buffer_string
|
||||
from langchain.utils import raise_for_status_with_text
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from uuid import UUID
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
from langchain.env import get_runtime_environment
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema.base import LLMResult
|
||||
|
||||
|
||||
class TracerSessionV1Base(BaseModel):
|
||||
@@ -94,6 +94,7 @@ class RunTypeEnum(str, Enum):
|
||||
tool = "tool"
|
||||
chain = "chain"
|
||||
llm = "llm"
|
||||
retriever = "retriever"
|
||||
|
||||
|
||||
class RunBase(BaseModel):
|
||||
|
||||
@@ -13,7 +13,7 @@ from langchain.callbacks.utils import (
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
def import_wandb() -> Any:
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, Generation, LLMResult
|
||||
from langchain.schema.base import AgentAction, AgentFinish, Generation, LLMResult
|
||||
from langchain.utils import get_from_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -8,7 +8,7 @@ from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.api.openapi.prompts import REQUEST_TEMPLATE
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseOutputParser
|
||||
from langchain.schema.base import BaseOutputParser
|
||||
|
||||
|
||||
class APIRequesterOutputParser(BaseOutputParser):
|
||||
|
||||
@@ -8,7 +8,7 @@ from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.api.openapi.prompts import RESPONSE_TEMPLATE
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseOutputParser
|
||||
from langchain.schema.base import BaseOutputParser
|
||||
|
||||
|
||||
class APIResponderOutputParser(BaseOutputParser):
|
||||
|
||||
@@ -18,7 +18,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||
from langchain.schema.base import RUN_KEY, BaseMemory, RunInfo
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
|
||||
@@ -7,7 +7,7 @@ from langchain.chains.conversation.prompt import PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.buffer import ConversationBufferMemory
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseMemory
|
||||
from langchain.schema.base import BaseMemory
|
||||
|
||||
|
||||
class ConversationChain(LLMChain):
|
||||
|
||||
@@ -21,7 +21,9 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseMessage, BaseRetriever, Document
|
||||
from langchain.schema.base import BaseMessage
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
# Depending on the memory type and configuration, the chat history format may differ.
|
||||
@@ -87,7 +89,13 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
return _output_keys
|
||||
|
||||
@abstractmethod
|
||||
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
|
||||
def _call(
|
||||
@@ -107,7 +115,7 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
)
|
||||
else:
|
||||
new_question = question
|
||||
docs = self._get_docs(new_question, inputs)
|
||||
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
|
||||
new_inputs = inputs.copy()
|
||||
new_inputs["question"] = new_question
|
||||
new_inputs["chat_history"] = chat_history_str
|
||||
@@ -122,7 +130,13 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
return output
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
|
||||
async def _acall(
|
||||
@@ -141,7 +155,7 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
)
|
||||
else:
|
||||
new_question = question
|
||||
docs = await self._aget_docs(new_question, inputs)
|
||||
docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
|
||||
new_inputs = inputs.copy()
|
||||
new_inputs["question"] = new_question
|
||||
new_inputs["chat_history"] = chat_history_str
|
||||
@@ -187,12 +201,28 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
|
||||
return docs[:num_docs]
|
||||
|
||||
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
docs = self.retriever.get_relevant_documents(question)
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
run_manager_ = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
docs = self.retriever.retrieve(question, callbacks=run_manager_.get_child())
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
docs = await self.retriever.aget_relevant_documents(question)
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
run_manager_ = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
docs = await self.retriever.aget_relevant_documents(
|
||||
question, callbacks=run_manager_.get_child()
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
@classmethod
|
||||
@@ -253,14 +283,26 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
vectordbkwargs = inputs.get("vectordbkwargs", {})
|
||||
full_kwargs = {**self.search_kwargs, **vectordbkwargs}
|
||||
return self.vectorstore.similarity_search(
|
||||
question, k=self.top_k_docs_for_context, **full_kwargs
|
||||
)
|
||||
|
||||
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("ChatVectorDBChain does not support async")
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -8,9 +8,7 @@ import numpy as np
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.flare.prompts import (
|
||||
PROMPT,
|
||||
@@ -20,7 +18,8 @@ from langchain.chains.flare.prompts import (
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.schema import BaseRetriever, Generation
|
||||
from langchain.schema.base import Generation
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
class _ResponseChain(LLMChain):
|
||||
@@ -124,7 +123,7 @@ class FlareChain(Chain):
|
||||
callbacks = _run_manager.get_child()
|
||||
docs = []
|
||||
for question in questions:
|
||||
docs.extend(self.retriever.get_relevant_documents(question))
|
||||
docs.extend(self.retriever.retrieve(question, callbacks=callbacks))
|
||||
context = "\n\n".join(d.page_content for d in docs)
|
||||
result = self.response_chain.predict(
|
||||
user_input=user_input,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Tuple
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseOutputParser
|
||||
from langchain.schema.base import BaseOutputParser
|
||||
|
||||
|
||||
class FinishedOutputParser(BaseOutputParser[Tuple[str, bool]]):
|
||||
|
||||
@@ -17,7 +17,7 @@ from langchain.chains.base import Chain
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import LLMResult, PromptValue
|
||||
from langchain.schema.base import LLMResult, PromptValue
|
||||
|
||||
|
||||
class LLMChain(Chain):
|
||||
|
||||
@@ -13,7 +13,7 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_bash.prompt import PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import OutputParserException
|
||||
from langchain.schema.base import OutputParserException
|
||||
from langchain.utilities.bash import BashProcess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -5,7 +5,7 @@ import re
|
||||
from typing import List
|
||||
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
from langchain.schema.base import BaseOutputParser, OutputParserException
|
||||
|
||||
_PROMPT_TEMPLATE = """If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put "#!/bin/bash" in your answer. Make sure to reason step by step, using this format:
|
||||
|
||||
|
||||
@@ -115,7 +115,12 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
return values
|
||||
|
||||
@abstractmethod
|
||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
|
||||
def _call(
|
||||
@@ -124,7 +129,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
docs = self._get_docs(inputs)
|
||||
docs = self._get_docs(inputs, run_manager=_run_manager)
|
||||
answer = self.combine_documents_chain.run(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
||||
)
|
||||
@@ -141,7 +146,12 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
|
||||
async def _acall(
|
||||
@@ -150,7 +160,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
docs = await self._aget_docs(inputs)
|
||||
docs = await self._aget_docs(inputs, run_manager=_run_manager)
|
||||
answer = await self.combine_documents_chain.arun(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
||||
)
|
||||
@@ -180,10 +190,20 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
"""
|
||||
return [self.input_docs_key, self.question_key]
|
||||
|
||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
return inputs.pop(self.input_docs_key)
|
||||
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
return inputs.pop(self.input_docs_key)
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
"""Question-answering with sources over an index."""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import BaseRetriever
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
@@ -40,12 +44,26 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
|
||||
return docs[:num_docs]
|
||||
|
||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None
|
||||
) -> List[Document]:
|
||||
run_manager_ = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.question_key]
|
||||
docs = self.retriever.get_relevant_documents(question)
|
||||
docs = self.retriever.retrieve(question, callbacks=run_manager_.get_child())
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None
|
||||
) -> List[Document]:
|
||||
run_manager_ = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.question_key]
|
||||
docs = await self.retriever.aget_relevant_documents(question)
|
||||
docs = await self.retriever.aretrieve(
|
||||
question, callbacks=run_manager_.get_child()
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
"""Question-answering with sources over a vector database."""
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
||||
from langchain.docstore.document import Document
|
||||
@@ -45,14 +49,24 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
|
||||
return docs[:num_docs]
|
||||
|
||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None
|
||||
) -> List[Document]:
|
||||
question = inputs[self.question_key]
|
||||
docs = self.vectorstore.similarity_search(
|
||||
question, k=self.k, **self.search_kwargs
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("VectorDBQAWithSourcesChain does not support async")
|
||||
|
||||
@root_validator()
|
||||
|
||||
@@ -23,7 +23,7 @@ from langchain.chains.query_constructor.prompt import (
|
||||
)
|
||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||
from langchain.output_parsers.json import parse_and_check_json_markdown
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
from langchain.schema.base import BaseOutputParser, OutputParserException
|
||||
|
||||
|
||||
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||
|
||||
@@ -19,7 +19,8 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
@@ -94,7 +95,9 @@ class BaseRetrievalQA(Chain):
|
||||
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _get_docs(self, question: str) -> List[Document]:
|
||||
def _get_docs(
|
||||
self, question: str, *, run_manager: CallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
"""Get documents to do question answering over."""
|
||||
|
||||
def _call(
|
||||
@@ -116,7 +119,7 @@ class BaseRetrievalQA(Chain):
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
docs = self._get_docs(question)
|
||||
docs = self._get_docs(question, run_manager=_run_manager)
|
||||
answer = self.combine_documents_chain.run(
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||
)
|
||||
@@ -127,7 +130,9 @@ class BaseRetrievalQA(Chain):
|
||||
return {self.output_key: answer}
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_docs(self, question: str) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self, question: str, *, run_manager: AsyncCallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
"""Get documents to do question answering over."""
|
||||
|
||||
async def _acall(
|
||||
@@ -149,7 +154,7 @@ class BaseRetrievalQA(Chain):
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
docs = await self._aget_docs(question)
|
||||
docs = await self._aget_docs(question, run_manager=_run_manager)
|
||||
answer = await self.combine_documents_chain.arun(
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||
)
|
||||
@@ -177,11 +182,22 @@ class RetrievalQA(BaseRetrievalQA):
|
||||
|
||||
retriever: BaseRetriever = Field(exclude=True)
|
||||
|
||||
def _get_docs(self, question: str) -> List[Document]:
|
||||
return self.retriever.get_relevant_documents(question)
|
||||
def _get_docs(
|
||||
self, question: str, *, run_manager: Optional[CallbackManagerForChainRun] = None
|
||||
) -> List[Document]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
return self.retriever.retrieve(question, run_manager=_run_manager.get_child())
|
||||
|
||||
async def _aget_docs(self, question: str) -> List[Document]:
|
||||
return await self.retriever.aget_relevant_documents(question)
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
return await self.retriever.aretrieve(
|
||||
question, callbacks=_run_manager.get_child()
|
||||
)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
@@ -218,7 +234,9 @@ class VectorDBQA(BaseRetrievalQA):
|
||||
raise ValueError(f"search_type of {search_type} not allowed.")
|
||||
return values
|
||||
|
||||
def _get_docs(self, question: str) -> List[Document]:
|
||||
def _get_docs(
|
||||
self, question: str, *, run_manager: CallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(
|
||||
question, k=self.k, **self.search_kwargs
|
||||
@@ -231,7 +249,9 @@ class VectorDBQA(BaseRetrievalQA):
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def _aget_docs(self, question: str) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self, question: str, *, run_manager: AsyncCallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("VectorDBQA does not support async")
|
||||
|
||||
@property
|
||||
|
||||
@@ -14,7 +14,7 @@ from langchain.chains import LLMChain
|
||||
from langchain.chains.router.base import RouterChain
|
||||
from langchain.output_parsers.json import parse_and_check_json_markdown
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
from langchain.schema.base import BaseOutputParser, OutputParserException
|
||||
|
||||
|
||||
class LLMRouterChain(RouterChain):
|
||||
|
||||
@@ -15,7 +15,7 @@ from langchain.chains.router.multi_retrieval_prompt import (
|
||||
)
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseRetriever
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
class MultiRetrievalQAChain(MultiRouteChain):
|
||||
|
||||
@@ -8,7 +8,7 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.anthropic import _AnthropicCommon
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any, Dict, Mapping
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.schema import ChatResult
|
||||
from langchain.schema.base import ChatResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,7 +17,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
|
||||
@@ -18,7 +18,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
|
||||
@@ -29,7 +29,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
|
||||
@@ -7,7 +7,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import BaseMessage, ChatResult
|
||||
from langchain.schema.base import BaseMessage, ChatResult
|
||||
|
||||
|
||||
class PromptLayerChatOpenAI(ChatOpenAI):
|
||||
|
||||
@@ -10,7 +10,7 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.vertexai import _VertexAICommon
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
|
||||
@@ -17,7 +17,7 @@ from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
BaseMessage,
|
||||
ChatResult,
|
||||
HumanMessage,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Callable, Union
|
||||
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.schema import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
class DocstoreFn(Docstore):
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from langchain.schema import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
__all__ = ["Document"]
|
||||
|
||||
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Iterator, List, Optional
|
||||
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
from langchain.schema import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Iterator, List, Literal, Optional, Sequence, Union
|
||||
from langchain.document_loaders.base import BaseBlobParser, BaseLoader
|
||||
from langchain.document_loaders.blob_loaders import BlobLoader, FileSystemBlobLoader
|
||||
from langchain.document_loaders.parsers.registry import get_parser
|
||||
from langchain.schema import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
_PathLike = Union[str, Path]
|
||||
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
from typing import Iterator, List, Optional
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.utils import get_from_env
|
||||
|
||||
LINK_NOTE_TEMPLATE = "joplin://x-callback-url/openNote?id={id}"
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Iterator
|
||||
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
from langchain.schema import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
class OpenAIWhisperParser(BaseBlobParser):
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Iterator, Mapping, Optional
|
||||
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
from langchain.document_loaders.blob_loaders.schema import Blob
|
||||
from langchain.schema import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
class MimeTypeBasedParser(BaseBlobParser):
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Iterator, Mapping, Optional
|
||||
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
from langchain.schema import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
class PyPDFParser(BaseBlobParser):
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Iterator
|
||||
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
from langchain.schema import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
class TextParser(BaseBlobParser):
|
||||
|
||||
@@ -4,7 +4,7 @@ import re
|
||||
from typing import Any, Callable, Generator, Iterable, List, Optional
|
||||
|
||||
from langchain.document_loaders.web_base import WebBaseLoader
|
||||
from langchain.schema import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
def _default_parsing_function(content: Any) -> str:
|
||||
|
||||
@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.math_utils import cosine_similarity
|
||||
from langchain.schema import BaseDocumentTransformer, Document
|
||||
from langchain.schema.document import BaseDocumentTransformer, Document
|
||||
|
||||
|
||||
class _DocumentWithState(Document):
|
||||
|
||||
@@ -6,7 +6,7 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.evaluation.agents.trajectory_eval_prompt import EVAL_CHAT_PROMPT
|
||||
from langchain.schema import AgentAction, BaseOutputParser, OutputParserException
|
||||
from langchain.schema.base import AgentAction, BaseOutputParser, OutputParserException
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Prompt for trajectory evaluation chain."""
|
||||
# flake8: noqa
|
||||
from langchain.schema import AIMessage
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain.schema import SystemMessage
|
||||
from langchain.schema.base import AIMessage
|
||||
from langchain.schema.base import HumanMessage
|
||||
from langchain.schema.base import SystemMessage
|
||||
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
|
||||
@@ -12,7 +12,7 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.schema import RUN_KEY, BaseOutputParser
|
||||
from langchain.schema.base import RUN_KEY, BaseOutputParser
|
||||
|
||||
|
||||
class RunEvaluatorInputMapper:
|
||||
|
||||
@@ -14,13 +14,8 @@ from langchain.experimental.autonomous_agents.autogpt.prompt import AutoGPTPromp
|
||||
from langchain.experimental.autonomous_agents.autogpt.prompt_generator import (
|
||||
FINISH_NAME,
|
||||
)
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
Document,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.base import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain.schema.document import Document
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.human.tool import HumanInputRun
|
||||
from langchain.vectorstores.base import VectorStoreRetriever
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, NamedTuple
|
||||
|
||||
from langchain.schema import BaseOutputParser
|
||||
from langchain.schema.base import BaseOutputParser
|
||||
|
||||
|
||||
class AutoGPTAction(NamedTuple):
|
||||
|
||||
@@ -7,7 +7,7 @@ from langchain.experimental.autonomous_agents.autogpt.prompt_generator import ge
|
||||
from langchain.prompts.chat import (
|
||||
BaseChatPromptTemplate,
|
||||
)
|
||||
from langchain.schema import BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain.schema.base import BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.vectorstores.base import VectorStoreRetriever
|
||||
|
||||
|
||||
@@ -7,7 +7,8 @@ from langchain import LLMChain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
||||
from langchain.schema import BaseMemory, Document
|
||||
from langchain.schema.base import BaseMemory
|
||||
from langchain.schema.document import Document
|
||||
from langchain.utils import mock_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -9,7 +9,7 @@ from langchain.experimental.plan_and_execute.schema import (
|
||||
Step,
|
||||
)
|
||||
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain.schema import SystemMessage
|
||||
from langchain.schema.base import SystemMessage
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"Let's first understand the problem and devise a plan to solve the problem."
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import List, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.schema import BaseOutputParser
|
||||
from langchain.schema.base import BaseOutputParser
|
||||
|
||||
|
||||
class Step(BaseModel):
|
||||
|
||||
@@ -9,7 +9,7 @@ from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.schema import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.chroma import Chroma
|
||||
|
||||
@@ -19,7 +19,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
Generation,
|
||||
|
||||
@@ -18,7 +18,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms import BaseLLM
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.schema.base import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,7 +34,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.schema.base import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -7,7 +7,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms import OpenAI, OpenAIChat
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema.base import LLMResult
|
||||
|
||||
|
||||
class PromptLayerOpenAI(OpenAI):
|
||||
|
||||
@@ -4,7 +4,7 @@ from pydantic import root_validator
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
from langchain.schema import get_buffer_string
|
||||
from langchain.schema.base import get_buffer_string
|
||||
|
||||
|
||||
class ConversationBufferMemory(BaseChatMemory):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import BaseMessage, get_buffer_string
|
||||
from langchain.schema.base import BaseMessage, get_buffer_string
|
||||
|
||||
|
||||
class ConversationBufferWindowMemory(BaseChatMemory):
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import Field
|
||||
|
||||
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
from langchain.schema import BaseChatMessageHistory, BaseMemory
|
||||
from langchain.schema.base import BaseChatMessageHistory, BaseMemory
|
||||
|
||||
|
||||
class BaseChatMemory(BaseMemory, ABC):
|
||||
|
||||
@@ -2,10 +2,10 @@ import json
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
_message_to_dict,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
@@ -153,7 +153,7 @@ class CassandraChatMessageHistory(BaseChatMessageHistory):
|
||||
self.session.execute(
|
||||
"""INSERT INTO message_store
|
||||
(id, session_id, history) VALUES (%s, %s, %s);""",
|
||||
(uuid.uuid4(), self.session_id, json.dumps(_message_to_dict(message))),
|
||||
(uuid.uuid4(), self.session_id, json.dumps(message_to_dict(message))),
|
||||
)
|
||||
except (Unavailable, WriteTimeout, WriteFailure) as error:
|
||||
logger.error("Unable to write chat history messages to cassandra")
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Type
|
||||
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
messages_from_dict,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
_message_to_dict,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
@@ -65,7 +65,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
messages = messages_to_dict(self.messages)
|
||||
_message = _message_to_dict(message)
|
||||
_message = message_to_dict(message)
|
||||
messages.append(_message)
|
||||
|
||||
try:
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
messages_from_dict,
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
messages_from_dict,
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
)
|
||||
|
||||
@@ -4,10 +4,10 @@ import json
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
_message_to_dict,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
from langchain.utils import get_from_env
|
||||
@@ -153,7 +153,7 @@ class MomentoChatMessageHistory(BaseChatMessageHistory):
|
||||
"""
|
||||
from momento.responses import CacheListPushBack
|
||||
|
||||
item = json.dumps(_message_to_dict(message))
|
||||
item = json.dumps(message_to_dict(message))
|
||||
push_response = self.cache_client.list_push_back(
|
||||
self.cache_name, self.key, item, ttl=self.ttl
|
||||
)
|
||||
|
||||
@@ -2,10 +2,10 @@ import json
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
_message_to_dict,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
@@ -75,7 +75,7 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
self.collection.insert_one(
|
||||
{
|
||||
"SessionId": self.session_id,
|
||||
"History": json.dumps(_message_to_dict(message)),
|
||||
"History": json.dumps(message_to_dict(message)),
|
||||
}
|
||||
)
|
||||
except errors.WriteError as err:
|
||||
|
||||
@@ -2,10 +2,10 @@ import json
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
_message_to_dict,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
@@ -61,7 +61,7 @@ class PostgresChatMessageHistory(BaseChatMessageHistory):
|
||||
sql.Identifier(self.table_name)
|
||||
)
|
||||
self.cursor.execute(
|
||||
query, (self.session_id, json.dumps(_message_to_dict(message)))
|
||||
query, (self.session_id, json.dumps(message_to_dict(message)))
|
||||
)
|
||||
self.connection.commit()
|
||||
|
||||
|
||||
@@ -2,10 +2,10 @@ import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
_message_to_dict,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
@@ -52,7 +52,7 @@ class RedisChatMessageHistory(BaseChatMessageHistory):
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in Redis"""
|
||||
self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message)))
|
||||
self.redis_client.lpush(self.key, json.dumps(message_to_dict(message)))
|
||||
if self.ttl:
|
||||
self.redis_client.expire(self.key, self.ttl)
|
||||
|
||||
|
||||
@@ -10,10 +10,10 @@ except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
_message_to_dict,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
@@ -66,7 +66,7 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in db"""
|
||||
with self.Session() as session:
|
||||
jsonstr = json.dumps(_message_to_dict(message))
|
||||
jsonstr = json.dumps(message_to_dict(message))
|
||||
session.add(self.Message(session_id=self.session_id, message=jsonstr))
|
||||
session.commit()
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from langchain.schema import (
|
||||
from langchain.schema.base import (
|
||||
AIMessage,
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Set
|
||||
from pydantic import validator
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import BaseMemory
|
||||
from langchain.schema.base import BaseMemory
|
||||
|
||||
|
||||
class CombinedMemory(BaseMemory):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user