Compare commits

...

3 Commits

Author SHA1 Message Date
vowelparrot
774c405707 Another crazy 2023-06-09 12:48:14 -07:00
vowelparrot
215ab8a62e Update internal imports 2023-06-09 12:24:19 -07:00
vowelparrot
198955575d Move schema to directory 2023-06-09 12:12:26 -07:00
234 changed files with 1506 additions and 424 deletions

View File

@@ -29,7 +29,7 @@ from langchain.input import get_color_mapping
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
from langchain.schema.base import (
AgentAction,
AgentFinish,
BaseMessage,

View File

@@ -20,7 +20,7 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import AgentAction
from langchain.schema.base import AgentAction
from langchain.tools.base import BaseTool

View File

@@ -3,7 +3,7 @@ from typing import Union
from langchain.agents.agent import AgentOutputParser
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.schema.base import AgentAction, AgentFinish, OutputParserException
FINAL_ANSWER_ACTION = "Final Answer:"

View File

@@ -3,7 +3,7 @@ from typing import Union
from langchain.agents.agent import AgentOutputParser
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.schema.base import AgentAction, AgentFinish, OutputParserException
class ConvoOutputParser(AgentOutputParser):

View File

@@ -23,7 +23,7 @@ from langchain.prompts.chat import (
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from langchain.schema import (
from langchain.schema.base import (
AgentAction,
AIMessage,
BaseMessage,

View File

@@ -5,7 +5,7 @@ from typing import Union
from langchain.agents import AgentOutputParser
from langchain.agents.conversational_chat.prompt import FORMAT_INSTRUCTIONS
from langchain.output_parsers.json import parse_json_markdown
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.schema.base import AgentAction, AgentFinish, OutputParserException
class ConvoOutputParser(AgentOutputParser):

View File

@@ -3,7 +3,7 @@ from typing import Union
from langchain.agents.agent import AgentOutputParser
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.schema.base import AgentAction, AgentFinish, OutputParserException
FINAL_ANSWER_ACTION = "Final Answer:"

View File

@@ -2,7 +2,7 @@ import re
from typing import Union
from langchain.agents.agent import AgentOutputParser
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.schema.base import AgentAction, AgentFinish, OutputParserException
class ReActOutputParser(AgentOutputParser):

View File

@@ -1,7 +1,7 @@
from typing import Any, Dict, List, Tuple
from langchain.prompts.chat import ChatPromptTemplate
from langchain.schema import AgentAction
from langchain.schema.base import AgentAction
class AgentScratchPadChatPromptTemplate(ChatPromptTemplate):

View File

@@ -1,7 +1,7 @@
from typing import Sequence, Union
from langchain.agents.agent import AgentOutputParser
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.schema.base import AgentAction, AgentFinish, OutputParserException
class SelfAskOutputParser(AgentOutputParser):

View File

@@ -17,7 +17,7 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import AgentAction
from langchain.schema.base import AgentAction
from langchain.tools import BaseTool
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"

View File

@@ -11,7 +11,7 @@ from langchain.agents.agent import AgentOutputParser
from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS
from langchain.base_language import BaseLanguageModel
from langchain.output_parsers import OutputFixingParser
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.schema.base import AgentAction, AgentFinish, OutputParserException
logger = logging.getLogger(__name__)

View File

@@ -7,7 +7,7 @@ from typing import List, Optional, Sequence, Set
from pydantic import BaseModel
from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
from langchain.schema.base import BaseMessage, LLMResult, PromptValue, get_buffer_string
def _get_token_ids_default_method(text: str) -> List[int]:

View File

@@ -31,7 +31,7 @@ except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain.embeddings.base import Embeddings
from langchain.schema import Generation
from langchain.schema.base import Generation
from langchain.vectorstores.redis import Redis as RedisVectorstore
if TYPE_CHECKING:

View File

@@ -2,7 +2,7 @@ from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
def import_aim() -> Any:
@@ -29,6 +29,7 @@ class BaseMetadataCallbackHandler:
ignore_llm_ (bool): Whether to ignore llm callbacks.
ignore_chain_ (bool): Whether to ignore chain callbacks.
ignore_agent_ (bool): Whether to ignore agent callbacks.
ignore_retriever_ (bool): Whether to ignore retriever callbacks.
always_verbose_ (bool): Whether to always be verbose.
chain_starts (int): The number of times the chain start method has been called.
chain_ends (int): The number of times the chain end method has been called.
@@ -51,6 +52,7 @@ class BaseMetadataCallbackHandler:
self.ignore_llm_ = False
self.ignore_chain_ = False
self.ignore_agent_ = False
self.ignore_retriever_ = False
self.always_verbose_ = False
self.chain_starts = 0
@@ -85,6 +87,11 @@ class BaseMetadataCallbackHandler:
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
@property
def ignore_retriever(self) -> bool:
"""Whether to ignore retriever callbacks."""
return self.ignore_retriever_
def get_custom_callback_meta(self) -> Dict[str, Any]:
return {
"step": self.step,

View File

@@ -3,7 +3,7 @@ import warnings
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
class ArgillaCallbackHandler(BaseCallbackHandler):

View File

@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Union
import pandas as pd
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
class ArizeCallbackHandler(BaseCallbackHandler):

View File

@@ -1,15 +1,45 @@
"""Base callback handler that can be used to handle callbacks in langchain."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union
from uuid import UUID
from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
LLMResult,
)
from langchain.schema.base import AgentAction, AgentFinish, BaseMessage, LLMResult
from langchain.schema.document import Document
class RetrieverManagerMixin:
"""Mixin for Retriever callbacks."""
def on_retriever_start(
self,
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever starts running."""
def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever errors."""
def on_retriever_end(
self,
documents: Sequence[Document],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever ends running."""
class LLMManagerMixin:
@@ -183,6 +213,7 @@ class BaseCallbackHandler(
LLMManagerMixin,
ChainManagerMixin,
ToolManagerMixin,
RetrieverManagerMixin,
CallbackManagerMixin,
RunManagerMixin,
):
@@ -361,6 +392,36 @@ class AsyncCallbackHandler(BaseCallbackHandler):
) -> None:
"""Run on agent end."""
async def on_retriever_start(
self,
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run on retriever start."""
async def on_retriever_end(
self,
documents: Sequence[Document],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run on retriever end."""
async def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run on retriever error."""
class BaseCallbackManager(CallbackManagerMixin):
"""Base callback manager that can be used to handle callbacks from LangChain."""

View File

@@ -13,7 +13,7 @@ from langchain.callbacks.utils import (
import_textstat,
load_json,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
def import_clearml() -> Any:

View File

@@ -12,7 +12,7 @@ from langchain.callbacks.utils import (
import_spacy,
import_textstat,
)
from langchain.schema import AgentAction, AgentFinish, Generation, LLMResult
from langchain.schema.base import AgentAction, AgentFinish, Generation, LLMResult
LANGCHAIN_MODEL_NAME = "langchain-model"

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, TextIO, cast
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish
from langchain.schema.base import AgentAction, AgentFinish
class FileCallbackHandler(BaseCallbackHandler):

View File

@@ -14,6 +14,7 @@ from typing import (
Generator,
List,
Optional,
Sequence,
Type,
TypeVar,
Union,
@@ -27,6 +28,7 @@ from langchain.callbacks.base import (
BaseCallbackManager,
ChainManagerMixin,
LLMManagerMixin,
RetrieverManagerMixin,
RunManagerMixin,
ToolManagerMixin,
)
@@ -36,13 +38,14 @@ from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
from langchain.callbacks.tracers.wandb import WandbTracer
from langchain.schema import (
from langchain.schema.base import (
AgentAction,
AgentFinish,
BaseMessage,
LLMResult,
get_buffer_string,
)
from langchain.schema.document import Document
logger = logging.getLogger(__name__)
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
@@ -624,6 +627,91 @@ class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
)
class CallbackManagerForRetrieverRun(RunManager, RetrieverManagerMixin):
"""Callback manager for retriever run."""
def get_child(self) -> CallbackManager:
"""Get a child callback manager."""
manager = CallbackManager([], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
return manager
def on_retriever_end(
self,
documents: Sequence[Document],
**kwargs: Any,
) -> None:
"""Run when retriever ends running."""
_handle_event(
self.handlers,
"on_retriever_end",
"ignore_retriever",
documents,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when retriever errors."""
_handle_event(
self.handlers,
"on_retriever_error",
"ignore_retriever",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class AsyncCallbackManagerForRetrieverRun(
AsyncRunManager,
RetrieverManagerMixin,
):
"""Async callback manager for retriever run."""
def get_child(self) -> AsyncCallbackManager:
"""Get a child callback manager."""
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
return manager
async def on_retriever_end(
self, documents: Sequence[Document], **kwargs: Any
) -> None:
"""Run when retriever ends running."""
await _ahandle_event(
self.handlers,
"on_retriever_end",
"ignore_retriever",
documents,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
async def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when retriever errors."""
await _ahandle_event(
self.handlers,
"on_retriever_error",
"ignore_retriever",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class CallbackManager(BaseCallbackManager):
"""Callback manager that can be used to handle callbacks from langchain."""
@@ -733,6 +821,29 @@ class CallbackManager(BaseCallbackManager):
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
def on_retriever_start(
self,
query: str,
run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None,
) -> CallbackManagerForRetrieverRun:
"""Run when retriever starts running."""
if run_id is None:
run_id = uuid4()
_handle_event(
self.handlers,
"on_retriever_start",
"ignore_retriever",
query,
run_id=run_id,
parent_run_id=self.parent_run_id,
)
return CallbackManagerForRetrieverRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
@classmethod
def configure(
cls,
@@ -856,6 +967,29 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
async def on_retriever_start(
self,
query: str,
run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None,
) -> AsyncCallbackManagerForRetrieverRun:
"""Run when retriever starts running."""
if run_id is None:
run_id = uuid4()
await _ahandle_event(
self.handlers,
"on_retriever_start",
"ignore_retriever",
query,
run_id=run_id,
parent_run_id=self.parent_run_id,
)
return AsyncCallbackManagerForRetrieverRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
@classmethod
def configure(
cls,

View File

@@ -15,7 +15,7 @@ from langchain.callbacks.utils import (
import_spacy,
import_textstat,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
from langchain.utils import get_from_dict_or_env

View File

@@ -2,7 +2,7 @@
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
MODEL_COST_PER_1K_TOKENS = {
"gpt-4": 0.03,

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
class StdOutCallbackHandler(BaseCallbackHandler):

View File

@@ -4,7 +4,7 @@ import asyncio
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.schema import LLMResult
from langchain.schema.base import LLMResult
# TODO If used by two LLM runs in parallel this won't work as expected

View File

@@ -3,7 +3,7 @@ import sys
from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
class StreamingStdOutCallbackHandler(BaseCallbackHandler):

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Union
import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
class StreamlitCallbackHandler(BaseCallbackHandler):

View File

@@ -3,12 +3,13 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union
from uuid import UUID
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
from langchain.schema import LLMResult
from langchain.schema.base import LLMResult
from langchain.schema.document import Document
class TracerException(Exception):
@@ -262,6 +263,65 @@ class BaseTracer(BaseCallbackHandler, ABC):
self._end_trace(tool_run)
self._on_tool_error(tool_run)
def on_retriever_start(
self,
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run when Retriever starts running."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
retrieval_run = Run(
id=run_id,
name="Retriever",
parent_run_id=parent_run_id,
inputs={"query": query},
extra=kwargs,
start_time=datetime.utcnow(),
execution_order=execution_order,
child_execution_order=execution_order,
child_runs=[],
run_type=RunTypeEnum.retriever,
)
self._start_trace(retrieval_run)
self._on_retriever_start(retrieval_run)
def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
**kwargs: Any,
) -> None:
"""Run when Retriever errors."""
if not run_id:
raise TracerException("No run_id provided for on_retriever_error callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
raise TracerException("No retriever Run found to be traced")
retrieval_run.error = repr(error)
retrieval_run.end_time = datetime.utcnow()
self._end_trace(retrieval_run)
self._on_retriever_error(retrieval_run)
def on_retriever_end(
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
) -> None:
"""Run when Retriever ends running."""
if not run_id:
raise TracerException("No run_id provided for on_retriever_end callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
raise TracerException("No retriever Run found to be traced")
retrieval_run.outputs = {"documents": documents}
retrieval_run.end_time = datetime.utcnow()
self._end_trace(retrieval_run)
self._on_retriever_end(retrieval_run)
def __deepcopy__(self, memo: dict) -> BaseTracer:
"""Deepcopy the tracer."""
return self
@@ -299,3 +359,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
def _on_chat_model_start(self, run: Run) -> None:
"""Process the Chat Model Run upon start."""
def _on_retriever_start(self, run: Run) -> None:
"""Process the Retriever Run upon start."""
def _on_retriever_end(self, run: Run) -> None:
"""Process the Retriever Run."""
def _on_retriever_error(self, run: Run) -> None:
"""Process the Retriever Run upon error."""

View File

@@ -13,7 +13,7 @@ from langchainplus_sdk import LangChainPlusClient
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession
from langchain.env import get_runtime_environment
from langchain.schema import BaseMessage, messages_to_dict
from langchain.schema.base import BaseMessage, messages_to_dict
logger = logging.getLogger(__name__)
@@ -122,3 +122,15 @@ class LangChainTracer(BaseTracer):
def _on_tool_error(self, run: Run) -> None:
"""Process the Tool Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
def _on_retriever_start(self, run: Run) -> None:
"""Process the Retriever Run upon start."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
def _on_retriever_end(self, run: Run) -> None:
"""Process the Retriever Run."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
def _on_retriever_error(self, run: Run) -> None:
"""Process the Retriever Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))

View File

@@ -16,7 +16,7 @@ from langchain.callbacks.tracers.schemas import (
TracerSessionV1,
TracerSessionV1Base,
)
from langchain.schema import get_buffer_string
from langchain.schema.base import get_buffer_string
from langchain.utils import raise_for_status_with_text

View File

@@ -9,7 +9,7 @@ from uuid import UUID
from pydantic import BaseModel, Field, root_validator
from langchain.env import get_runtime_environment
from langchain.schema import LLMResult
from langchain.schema.base import LLMResult
class TracerSessionV1Base(BaseModel):
@@ -94,6 +94,7 @@ class RunTypeEnum(str, Enum):
tool = "tool"
chain = "chain"
llm = "llm"
retriever = "retriever"
class RunBase(BaseModel):

View File

@@ -13,7 +13,7 @@ from langchain.callbacks.utils import (
import_spacy,
import_textstat,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.base import AgentAction, AgentFinish, LLMResult
def import_wandb() -> Any:

View File

@@ -4,7 +4,7 @@ import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, Generation, LLMResult
from langchain.schema.base import AgentAction, AgentFinish, Generation, LLMResult
from langchain.utils import get_from_env
if TYPE_CHECKING:

View File

@@ -8,7 +8,7 @@ from langchain.base_language import BaseLanguageModel
from langchain.chains.api.openapi.prompts import REQUEST_TEMPLATE
from langchain.chains.llm import LLMChain
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseOutputParser
from langchain.schema.base import BaseOutputParser
class APIRequesterOutputParser(BaseOutputParser):

View File

@@ -8,7 +8,7 @@ from langchain.base_language import BaseLanguageModel
from langchain.chains.api.openapi.prompts import RESPONSE_TEMPLATE
from langchain.chains.llm import LLMChain
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseOutputParser
from langchain.schema.base import BaseOutputParser
class APIResponderOutputParser(BaseOutputParser):

View File

@@ -18,7 +18,7 @@ from langchain.callbacks.manager import (
CallbackManagerForChainRun,
Callbacks,
)
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
from langchain.schema.base import RUN_KEY, BaseMemory, RunInfo
def _get_verbosity() -> bool:

View File

@@ -7,7 +7,7 @@ from langchain.chains.conversation.prompt import PROMPT
from langchain.chains.llm import LLMChain
from langchain.memory.buffer import ConversationBufferMemory
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseMemory
from langchain.schema.base import BaseMemory
class ConversationChain(LLMChain):

View File

@@ -21,7 +21,9 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseMessage, BaseRetriever, Document
from langchain.schema.base import BaseMessage
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
from langchain.vectorstores.base import VectorStore
# Depending on the memory type and configuration, the chat history format may differ.
@@ -87,7 +89,13 @@ class BaseConversationalRetrievalChain(Chain):
return _output_keys
@abstractmethod
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> List[Document]:
"""Get docs."""
def _call(
@@ -107,7 +115,7 @@ class BaseConversationalRetrievalChain(Chain):
)
else:
new_question = question
docs = self._get_docs(new_question, inputs)
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
new_inputs = inputs.copy()
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
@@ -122,7 +130,13 @@ class BaseConversationalRetrievalChain(Chain):
return output
@abstractmethod
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> List[Document]:
"""Get docs."""
async def _acall(
@@ -141,7 +155,7 @@ class BaseConversationalRetrievalChain(Chain):
)
else:
new_question = question
docs = await self._aget_docs(new_question, inputs)
docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
new_inputs = inputs.copy()
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
@@ -187,12 +201,28 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
return docs[:num_docs]
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
docs = self.retriever.get_relevant_documents(question)
def _get_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> List[Document]:
run_manager_ = run_manager or CallbackManagerForChainRun.get_noop_manager()
docs = self.retriever.retrieve(question, callbacks=run_manager_.get_child())
return self._reduce_tokens_below_limit(docs)
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
docs = await self.retriever.aget_relevant_documents(question)
async def _aget_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> List[Document]:
run_manager_ = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
docs = await self.retriever.aget_relevant_documents(
question, callbacks=run_manager_.get_child()
)
return self._reduce_tokens_below_limit(docs)
@classmethod
@@ -253,14 +283,26 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
)
return values
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> List[Document]:
vectordbkwargs = inputs.get("vectordbkwargs", {})
full_kwargs = {**self.search_kwargs, **vectordbkwargs}
return self.vectorstore.similarity_search(
question, k=self.top_k_docs_for_context, **full_kwargs
)
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> List[Document]:
raise NotImplementedError("ChatVectorDBChain does not support async")
@classmethod

View File

@@ -8,9 +8,7 @@ import numpy as np
from pydantic import Field
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
CallbackManagerForChainRun,
)
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.flare.prompts import (
PROMPT,
@@ -20,7 +18,8 @@ from langchain.chains.flare.prompts import (
from langchain.chains.llm import LLMChain
from langchain.llms import OpenAI
from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseRetriever, Generation
from langchain.schema.base import Generation
from langchain.schema.retriever import BaseRetriever
class _ResponseChain(LLMChain):
@@ -124,7 +123,7 @@ class FlareChain(Chain):
callbacks = _run_manager.get_child()
docs = []
for question in questions:
docs.extend(self.retriever.get_relevant_documents(question))
docs.extend(self.retriever.retrieve(question, callbacks=callbacks))
context = "\n\n".join(d.page_content for d in docs)
result = self.response_chain.predict(
user_input=user_input,

View File

@@ -1,7 +1,7 @@
from typing import Tuple
from langchain.prompts import PromptTemplate
from langchain.schema import BaseOutputParser
from langchain.schema.base import BaseOutputParser
class FinishedOutputParser(BaseOutputParser[Tuple[str, bool]]):

View File

@@ -17,7 +17,7 @@ from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import LLMResult, PromptValue
from langchain.schema.base import LLMResult, PromptValue
class LLMChain(Chain):

View File

@@ -13,7 +13,7 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.prompt import PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import OutputParserException
from langchain.schema.base import OutputParserException
from langchain.utilities.bash import BashProcess
logger = logging.getLogger(__name__)

View File

@@ -5,7 +5,7 @@ import re
from typing import List
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException
from langchain.schema.base import BaseOutputParser, OutputParserException
_PROMPT_TEMPLATE = """If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put "#!/bin/bash" in your answer. Make sure to reason step by step, using this format:

View File

@@ -115,7 +115,12 @@ class BaseQAWithSourcesChain(Chain, ABC):
return values
@abstractmethod
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> List[Document]:
"""Get docs to run questioning over."""
def _call(
@@ -124,7 +129,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
docs = self._get_docs(inputs)
docs = self._get_docs(inputs, run_manager=_run_manager)
answer = self.combine_documents_chain.run(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
@@ -141,7 +146,12 @@ class BaseQAWithSourcesChain(Chain, ABC):
return result
@abstractmethod
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> List[Document]:
"""Get docs to run questioning over."""
async def _acall(
@@ -150,7 +160,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
docs = await self._aget_docs(inputs)
docs = await self._aget_docs(inputs, run_manager=_run_manager)
answer = await self.combine_documents_chain.arun(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
@@ -180,10 +190,20 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
"""
return [self.input_docs_key, self.question_key]
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> List[Document]:
return inputs.pop(self.input_docs_key)
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> List[Document]:
return inputs.pop(self.input_docs_key)
@property

View File

@@ -1,13 +1,17 @@
"""Question-answering with sources over an index."""
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from pydantic import Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
from langchain.docstore.document import Document
from langchain.schema import BaseRetriever
from langchain.schema.retriever import BaseRetriever
class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
@@ -40,12 +44,26 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
return docs[:num_docs]
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None
) -> List[Document]:
run_manager_ = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.question_key]
docs = self.retriever.get_relevant_documents(question)
docs = self.retriever.retrieve(question, callbacks=run_manager_.get_child())
return self._reduce_tokens_below_limit(docs)
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None
) -> List[Document]:
run_manager_ = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.question_key]
docs = await self.retriever.aget_relevant_documents(question)
docs = await self.retriever.aretrieve(
question, callbacks=run_manager_.get_child()
)
return self._reduce_tokens_below_limit(docs)

View File

@@ -1,10 +1,14 @@
"""Question-answering with sources over a vector database."""
import warnings
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from pydantic import Field, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
from langchain.docstore.document import Document
@@ -45,14 +49,24 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
return docs[:num_docs]
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None
) -> List[Document]:
question = inputs[self.question_key]
docs = self.vectorstore.similarity_search(
question, k=self.k, **self.search_kwargs
)
return self._reduce_tokens_below_limit(docs)
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None
) -> List[Document]:
raise NotImplementedError("VectorDBQAWithSourcesChain does not support async")
@root_validator()

View File

@@ -23,7 +23,7 @@ from langchain.chains.query_constructor.prompt import (
)
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.output_parsers.json import parse_and_check_json_markdown
from langchain.schema import BaseOutputParser, OutputParserException
from langchain.schema.base import BaseOutputParser, OutputParserException
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):

View File

@@ -19,7 +19,8 @@ from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
from langchain.prompts import PromptTemplate
from langchain.schema import BaseRetriever, Document
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
from langchain.vectorstores.base import VectorStore
@@ -94,7 +95,9 @@ class BaseRetrievalQA(Chain):
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
@abstractmethod
def _get_docs(self, question: str) -> List[Document]:
def _get_docs(
self, question: str, *, run_manager: CallbackManagerForChainRun
) -> List[Document]:
"""Get documents to do question answering over."""
def _call(
@@ -116,7 +119,7 @@ class BaseRetrievalQA(Chain):
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
docs = self._get_docs(question)
docs = self._get_docs(question, run_manager=_run_manager)
answer = self.combine_documents_chain.run(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
@@ -127,7 +130,9 @@ class BaseRetrievalQA(Chain):
return {self.output_key: answer}
@abstractmethod
async def _aget_docs(self, question: str) -> List[Document]:
async def _aget_docs(
self, question: str, *, run_manager: AsyncCallbackManagerForChainRun
) -> List[Document]:
"""Get documents to do question answering over."""
async def _acall(
@@ -149,7 +154,7 @@ class BaseRetrievalQA(Chain):
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
docs = await self._aget_docs(question)
docs = await self._aget_docs(question, run_manager=_run_manager)
answer = await self.combine_documents_chain.arun(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
@@ -177,11 +182,22 @@ class RetrievalQA(BaseRetrievalQA):
retriever: BaseRetriever = Field(exclude=True)
def _get_docs(self, question: str) -> List[Document]:
return self.retriever.get_relevant_documents(question)
def _get_docs(
self, question: str, *, run_manager: Optional[CallbackManagerForChainRun] = None
) -> List[Document]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
return self.retriever.retrieve(question, run_manager=_run_manager.get_child())
async def _aget_docs(self, question: str) -> List[Document]:
return await self.retriever.aget_relevant_documents(question)
async def _aget_docs(
self,
question: str,
*,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> List[Document]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
return await self.retriever.aretrieve(
question, callbacks=_run_manager.get_child()
)
@property
def _chain_type(self) -> str:
@@ -218,7 +234,9 @@ class VectorDBQA(BaseRetrievalQA):
raise ValueError(f"search_type of {search_type} not allowed.")
return values
def _get_docs(self, question: str) -> List[Document]:
def _get_docs(
self, question: str, *, run_manager: CallbackManagerForChainRun
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(
question, k=self.k, **self.search_kwargs
@@ -231,7 +249,9 @@ class VectorDBQA(BaseRetrievalQA):
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def _aget_docs(self, question: str) -> List[Document]:
async def _aget_docs(
self, question: str, *, run_manager: AsyncCallbackManagerForChainRun
) -> List[Document]:
raise NotImplementedError("VectorDBQA does not support async")
@property

View File

@@ -14,7 +14,7 @@ from langchain.chains import LLMChain
from langchain.chains.router.base import RouterChain
from langchain.output_parsers.json import parse_and_check_json_markdown
from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException
from langchain.schema.base import BaseOutputParser, OutputParserException
class LLMRouterChain(RouterChain):

View File

@@ -15,7 +15,7 @@ from langchain.chains.router.multi_retrieval_prompt import (
)
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.schema import BaseRetriever
from langchain.schema.retriever import BaseRetriever
class MultiRetrievalQAChain(MultiRouteChain):

View File

@@ -8,7 +8,7 @@ from langchain.callbacks.manager import (
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.anthropic import _AnthropicCommon
from langchain.schema import (
from langchain.schema.base import (
AIMessage,
BaseMessage,
ChatGeneration,

View File

@@ -7,7 +7,7 @@ from typing import Any, Dict, Mapping
from pydantic import root_validator
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import ChatResult
from langchain.schema.base import ChatResult
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)

View File

@@ -17,7 +17,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.schema import (
from langchain.schema.base import (
AIMessage,
BaseMessage,
ChatGeneration,

View File

@@ -18,7 +18,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
from langchain.schema.base import (
AIMessage,
BaseMessage,
ChatGeneration,

View File

@@ -29,7 +29,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
from langchain.schema.base import (
AIMessage,
BaseMessage,
ChatGeneration,

View File

@@ -7,7 +7,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.chat_models import ChatOpenAI
from langchain.schema import BaseMessage, ChatResult
from langchain.schema.base import BaseMessage, ChatResult
class PromptLayerChatOpenAI(ChatOpenAI):

View File

@@ -10,7 +10,7 @@ from langchain.callbacks.manager import (
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.vertexai import _VertexAICommon
from langchain.schema import (
from langchain.schema.base import (
AIMessage,
BaseMessage,
ChatGeneration,

View File

@@ -17,7 +17,7 @@ from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.chains.base import Chain
from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
from langchain.schema import (
from langchain.schema.base import (
BaseMessage,
ChatResult,
HumanMessage,

View File

@@ -1,7 +1,7 @@
from typing import Callable, Union
from langchain.docstore.base import Docstore
from langchain.schema import Document
from langchain.schema.document import Document
class DocstoreFn(Docstore):

View File

@@ -1,3 +1,3 @@
from langchain.schema import Document
from langchain.schema.document import Document
__all__ = ["Document"]

View File

@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
from typing import Iterator, List, Optional
from langchain.document_loaders.blob_loaders import Blob
from langchain.schema import Document
from langchain.schema.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter

View File

@@ -6,7 +6,7 @@ from typing import Iterator, List, Literal, Optional, Sequence, Union
from langchain.document_loaders.base import BaseBlobParser, BaseLoader
from langchain.document_loaders.blob_loaders import BlobLoader, FileSystemBlobLoader
from langchain.document_loaders.parsers.registry import get_parser
from langchain.schema import Document
from langchain.schema.document import Document
from langchain.text_splitter import TextSplitter
_PathLike = Union[str, Path]

View File

@@ -4,7 +4,7 @@ from datetime import datetime
from typing import Iterator, List, Optional
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from langchain.schema.document import Document
from langchain.utils import get_from_env
LINK_NOTE_TEMPLATE = "joplin://x-callback-url/openNote?id={id}"

View File

@@ -2,7 +2,7 @@ from typing import Iterator
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders import Blob
from langchain.schema import Document
from langchain.schema.document import Document
class OpenAIWhisperParser(BaseBlobParser):

View File

@@ -6,7 +6,7 @@ from typing import Iterator, Mapping, Optional
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders.schema import Blob
from langchain.schema import Document
from langchain.schema.document import Document
class MimeTypeBasedParser(BaseBlobParser):

View File

@@ -3,7 +3,7 @@ from typing import Any, Iterator, Mapping, Optional
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders import Blob
from langchain.schema import Document
from langchain.schema.document import Document
class PyPDFParser(BaseBlobParser):

View File

@@ -3,7 +3,7 @@ from typing import Iterator
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders import Blob
from langchain.schema import Document
from langchain.schema.document import Document
class TextParser(BaseBlobParser):

View File

@@ -4,7 +4,7 @@ import re
from typing import Any, Callable, Generator, Iterable, List, Optional
from langchain.document_loaders.web_base import WebBaseLoader
from langchain.schema import Document
from langchain.schema.document import Document
def _default_parsing_function(content: Any) -> str:

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
from langchain.embeddings.base import Embeddings
from langchain.math_utils import cosine_similarity
from langchain.schema import BaseDocumentTransformer, Document
from langchain.schema.document import BaseDocumentTransformer, Document
class _DocumentWithState(Document):

View File

@@ -6,7 +6,7 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.evaluation.agents.trajectory_eval_prompt import EVAL_CHAT_PROMPT
from langchain.schema import AgentAction, BaseOutputParser, OutputParserException
from langchain.schema.base import AgentAction, BaseOutputParser, OutputParserException
from langchain.tools.base import BaseTool

View File

@@ -1,8 +1,8 @@
"""Prompt for trajectory evaluation chain."""
# flake8: noqa
from langchain.schema import AIMessage
from langchain.schema import HumanMessage
from langchain.schema import SystemMessage
from langchain.schema.base import AIMessage
from langchain.schema.base import HumanMessage
from langchain.schema.base import SystemMessage
from langchain.prompts.chat import (
ChatPromptTemplate,

View File

@@ -12,7 +12,7 @@ from langchain.callbacks.manager import (
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.schema import RUN_KEY, BaseOutputParser
from langchain.schema.base import RUN_KEY, BaseOutputParser
class RunEvaluatorInputMapper:

View File

@@ -14,13 +14,8 @@ from langchain.experimental.autonomous_agents.autogpt.prompt import AutoGPTPromp
from langchain.experimental.autonomous_agents.autogpt.prompt_generator import (
FINISH_NAME,
)
from langchain.schema import (
AIMessage,
BaseMessage,
Document,
HumanMessage,
SystemMessage,
)
from langchain.schema.base import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain.schema.document import Document
from langchain.tools.base import BaseTool
from langchain.tools.human.tool import HumanInputRun
from langchain.vectorstores.base import VectorStoreRetriever

View File

@@ -3,7 +3,7 @@ import re
from abc import abstractmethod
from typing import Dict, NamedTuple
from langchain.schema import BaseOutputParser
from langchain.schema.base import BaseOutputParser
class AutoGPTAction(NamedTuple):

View File

@@ -7,7 +7,7 @@ from langchain.experimental.autonomous_agents.autogpt.prompt_generator import ge
from langchain.prompts.chat import (
BaseChatPromptTemplate,
)
from langchain.schema import BaseMessage, HumanMessage, SystemMessage
from langchain.schema.base import BaseMessage, HumanMessage, SystemMessage
from langchain.tools.base import BaseTool
from langchain.vectorstores.base import VectorStoreRetriever

View File

@@ -7,7 +7,8 @@ from langchain import LLMChain
from langchain.base_language import BaseLanguageModel
from langchain.prompts import PromptTemplate
from langchain.retrievers import TimeWeightedVectorStoreRetriever
from langchain.schema import BaseMemory, Document
from langchain.schema.base import BaseMemory
from langchain.schema.document import Document
from langchain.utils import mock_now
logger = logging.getLogger(__name__)

View File

@@ -9,7 +9,7 @@ from langchain.experimental.plan_and_execute.schema import (
Step,
)
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.schema import SystemMessage
from langchain.schema.base import SystemMessage
SYSTEM_PROMPT = (
"Let's first understand the problem and devise a plan to solve the problem."

View File

@@ -3,7 +3,7 @@ from typing import List, Tuple
from pydantic import BaseModel, Field
from langchain.schema import BaseOutputParser
from langchain.schema.base import BaseOutputParser
class Step(BaseModel):

View File

@@ -9,7 +9,7 @@ from langchain.document_loaders.base import BaseLoader
from langchain.embeddings.base import Embeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.openai import OpenAI
from langchain.schema import Document
from langchain.schema.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.chroma import Chroma

View File

@@ -19,7 +19,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.schema import (
from langchain.schema.base import (
AIMessage,
BaseMessage,
Generation,

View File

@@ -18,7 +18,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.llms import BaseLLM
from langchain.schema import Generation, LLMResult
from langchain.schema.base import Generation, LLMResult
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)

View File

@@ -34,7 +34,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.llms.base import BaseLLM
from langchain.schema import Generation, LLMResult
from langchain.schema.base import Generation, LLMResult
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)

View File

@@ -7,7 +7,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.llms import OpenAI, OpenAIChat
from langchain.schema import LLMResult
from langchain.schema.base import LLMResult
class PromptLayerOpenAI(OpenAI):

View File

@@ -4,7 +4,7 @@ from pydantic import root_validator
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
from langchain.memory.utils import get_prompt_input_key
from langchain.schema import get_buffer_string
from langchain.schema.base import get_buffer_string
class ConversationBufferMemory(BaseChatMemory):

View File

@@ -1,7 +1,7 @@
from typing import Any, Dict, List
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import BaseMessage, get_buffer_string
from langchain.schema.base import BaseMessage, get_buffer_string
class ConversationBufferWindowMemory(BaseChatMemory):

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain.memory.utils import get_prompt_input_key
from langchain.schema import BaseChatMessageHistory, BaseMemory
from langchain.schema.base import BaseChatMessageHistory, BaseMemory
class BaseChatMemory(BaseMemory, ABC):

View File

@@ -2,10 +2,10 @@ import json
import logging
from typing import List
from langchain.schema import (
from langchain.schema.base import (
BaseChatMessageHistory,
BaseMessage,
_message_to_dict,
message_to_dict,
messages_from_dict,
)
@@ -153,7 +153,7 @@ class CassandraChatMessageHistory(BaseChatMessageHistory):
self.session.execute(
"""INSERT INTO message_store
(id, session_id, history) VALUES (%s, %s, %s);""",
(uuid.uuid4(), self.session_id, json.dumps(_message_to_dict(message))),
(uuid.uuid4(), self.session_id, json.dumps(message_to_dict(message))),
)
except (Unavailable, WriteTimeout, WriteFailure) as error:
logger.error("Unable to write chat history messages to cassandra")

View File

@@ -5,7 +5,7 @@ import logging
from types import TracebackType
from typing import TYPE_CHECKING, Any, List, Optional, Type
from langchain.schema import (
from langchain.schema.base import (
BaseChatMessageHistory,
BaseMessage,
messages_from_dict,

View File

@@ -1,10 +1,10 @@
import logging
from typing import List, Optional
from langchain.schema import (
from langchain.schema.base import (
BaseChatMessageHistory,
BaseMessage,
_message_to_dict,
message_to_dict,
messages_from_dict,
messages_to_dict,
)
@@ -65,7 +65,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
from botocore.exceptions import ClientError
messages = messages_to_dict(self.messages)
_message = _message_to_dict(message)
_message = message_to_dict(message)
messages.append(_message)
try:

View File

@@ -3,7 +3,7 @@ import logging
from pathlib import Path
from typing import List
from langchain.schema import (
from langchain.schema.base import (
BaseChatMessageHistory,
BaseMessage,
messages_from_dict,

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING, List, Optional
from langchain.schema import (
from langchain.schema.base import (
BaseChatMessageHistory,
BaseMessage,
messages_from_dict,

View File

@@ -2,7 +2,7 @@ from typing import List
from pydantic import BaseModel
from langchain.schema import (
from langchain.schema.base import (
BaseChatMessageHistory,
BaseMessage,
)

View File

@@ -4,10 +4,10 @@ import json
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional
from langchain.schema import (
from langchain.schema.base import (
BaseChatMessageHistory,
BaseMessage,
_message_to_dict,
message_to_dict,
messages_from_dict,
)
from langchain.utils import get_from_env
@@ -153,7 +153,7 @@ class MomentoChatMessageHistory(BaseChatMessageHistory):
"""
from momento.responses import CacheListPushBack
item = json.dumps(_message_to_dict(message))
item = json.dumps(message_to_dict(message))
push_response = self.cache_client.list_push_back(
self.cache_name, self.key, item, ttl=self.ttl
)

View File

@@ -2,10 +2,10 @@ import json
import logging
from typing import List
from langchain.schema import (
from langchain.schema.base import (
BaseChatMessageHistory,
BaseMessage,
_message_to_dict,
message_to_dict,
messages_from_dict,
)
@@ -75,7 +75,7 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
self.collection.insert_one(
{
"SessionId": self.session_id,
"History": json.dumps(_message_to_dict(message)),
"History": json.dumps(message_to_dict(message)),
}
)
except errors.WriteError as err:

View File

@@ -2,10 +2,10 @@ import json
import logging
from typing import List
from langchain.schema import (
from langchain.schema.base import (
BaseChatMessageHistory,
BaseMessage,
_message_to_dict,
message_to_dict,
messages_from_dict,
)
@@ -61,7 +61,7 @@ class PostgresChatMessageHistory(BaseChatMessageHistory):
sql.Identifier(self.table_name)
)
self.cursor.execute(
query, (self.session_id, json.dumps(_message_to_dict(message)))
query, (self.session_id, json.dumps(message_to_dict(message)))
)
self.connection.commit()

View File

@@ -2,10 +2,10 @@ import json
import logging
from typing import List, Optional
from langchain.schema import (
from langchain.schema.base import (
BaseChatMessageHistory,
BaseMessage,
_message_to_dict,
message_to_dict,
messages_from_dict,
)
@@ -52,7 +52,7 @@ class RedisChatMessageHistory(BaseChatMessageHistory):
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in Redis"""
self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message)))
self.redis_client.lpush(self.key, json.dumps(message_to_dict(message)))
if self.ttl:
self.redis_client.expire(self.key, self.ttl)

View File

@@ -10,10 +10,10 @@ except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from langchain.schema import (
from langchain.schema.base import (
BaseChatMessageHistory,
BaseMessage,
_message_to_dict,
message_to_dict,
messages_from_dict,
)
@@ -66,7 +66,7 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in db"""
with self.Session() as session:
jsonstr = json.dumps(_message_to_dict(message))
jsonstr = json.dumps(message_to_dict(message))
session.add(self.Message(session_id=self.session_id, message=jsonstr))
session.commit()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Dict, List, Optional
from langchain.schema import (
from langchain.schema.base import (
AIMessage,
BaseChatMessageHistory,
BaseMessage,

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Set
from pydantic import validator
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import BaseMemory
from langchain.schema.base import BaseMemory
class CombinedMemory(BaseMemory):

Some files were not shown because too many files have changed in this diff Show More