mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-09 10:41:52 +00:00
Compare commits
14 Commits
wfh/log_er
...
bagatur/pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f88060a08 | ||
|
|
ccb0821dc7 | ||
|
|
972a2735d1 | ||
|
|
4f2c85a7ae | ||
|
|
0d3ae97c11 | ||
|
|
883ff006bb | ||
|
|
f7d7ed3a47 | ||
|
|
86d346ce89 | ||
|
|
a9af504f9c | ||
|
|
520be95168 | ||
|
|
1a869d0ef2 | ||
|
|
7dda1bf45a | ||
|
|
26b66a59fa | ||
|
|
b17b87ae04 |
@@ -9,9 +9,14 @@ build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.11"
|
||||
jobs:
|
||||
pre_build:
|
||||
commands:
|
||||
- python -mvirtualenv $READTHEDOCS_VIRTUALENV_PATH
|
||||
- python -m pip install --upgrade --no-cache-dir pip setuptools
|
||||
- python -m pip install --upgrade --no-cache-dir sphinx readthedocs-sphinx-ext
|
||||
- python -m pip install --exists-action=w --no-cache-dir -r docs/api_reference/requirements.txt
|
||||
- python docs/api_reference/create_api_rst.py
|
||||
- cat docs/api_reference/conf.py
|
||||
- python -m sphinx -T -E -b html -d _build/doctrees -c docs/api_reference docs/api_reference $READTHEDOCS_OUTPUT/html -j auto
|
||||
|
||||
# Build documentation in the docs/ directory with Sphinx
|
||||
sphinx:
|
||||
|
||||
@@ -28,7 +28,7 @@ from langchain.schema.messages import (
|
||||
)
|
||||
|
||||
|
||||
async def aenumerate(
|
||||
async def _aenumerate(
|
||||
iterable: AsyncIterator[Any], start: int = 0
|
||||
) -> AsyncIterator[tuple[int, Any]]:
|
||||
"""Async version of enumerate."""
|
||||
@@ -38,7 +38,7 @@ async def aenumerate(
|
||||
i += 1
|
||||
|
||||
|
||||
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
@@ -59,7 +59,7 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
@@ -87,7 +87,7 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
return message_dict
|
||||
|
||||
|
||||
def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]:
|
||||
def _convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]:
|
||||
"""Convert dictionaries representing OpenAI messages to LangChain format.
|
||||
|
||||
Args:
|
||||
@@ -96,7 +96,7 @@ def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMess
|
||||
Returns:
|
||||
List of LangChain BaseMessage objects.
|
||||
"""
|
||||
return [convert_dict_to_message(m) for m in messages]
|
||||
return [_convert_dict_to_message(m) for m in messages]
|
||||
|
||||
|
||||
def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]:
|
||||
@@ -155,10 +155,10 @@ class ChatCompletion:
|
||||
models = importlib.import_module("langchain.chat_models")
|
||||
model_cls = getattr(models, provider)
|
||||
model_config = model_cls(**kwargs)
|
||||
converted_messages = convert_openai_messages(messages)
|
||||
converted_messages = _convert_openai_messages(messages)
|
||||
if not stream:
|
||||
result = model_config.invoke(converted_messages)
|
||||
return {"choices": [{"message": convert_message_to_dict(result)}]}
|
||||
return {"choices": [{"message": _convert_message_to_dict(result)}]}
|
||||
else:
|
||||
return (
|
||||
_convert_message_chunk_to_delta(c, i)
|
||||
@@ -198,14 +198,14 @@ class ChatCompletion:
|
||||
models = importlib.import_module("langchain.chat_models")
|
||||
model_cls = getattr(models, provider)
|
||||
model_config = model_cls(**kwargs)
|
||||
converted_messages = convert_openai_messages(messages)
|
||||
converted_messages = _convert_openai_messages(messages)
|
||||
if not stream:
|
||||
result = await model_config.ainvoke(converted_messages)
|
||||
return {"choices": [{"message": convert_message_to_dict(result)}]}
|
||||
return {"choices": [{"message": _convert_message_to_dict(result)}]}
|
||||
else:
|
||||
return (
|
||||
_convert_message_chunk_to_delta(c, i)
|
||||
async for i, c in aenumerate(model_config.astream(converted_messages))
|
||||
async for i, c in _aenumerate(model_config.astream(converted_messages))
|
||||
)
|
||||
|
||||
|
||||
@@ -214,12 +214,12 @@ def _has_assistant_message(session: ChatSession) -> bool:
|
||||
return any([isinstance(m, AIMessage) for m in session["messages"]])
|
||||
|
||||
|
||||
def convert_messages_for_finetuning(
|
||||
def _convert_messages_for_finetuning(
|
||||
sessions: Iterable[ChatSession],
|
||||
) -> List[List[dict]]:
|
||||
"""Convert messages to a list of lists of dictionaries for fine-tuning."""
|
||||
return [
|
||||
[convert_message_to_dict(s) for s in session["messages"]]
|
||||
[_convert_message_to_dict(s) for s in session["messages"]]
|
||||
for session in sessions
|
||||
if _has_assistant_message(session)
|
||||
]
|
||||
|
||||
@@ -5,7 +5,7 @@ from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
def import_aim() -> Any:
|
||||
def _import_aim() -> Any:
|
||||
"""Import the aim python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import aim
|
||||
@@ -169,7 +169,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
super().__init__()
|
||||
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
self.repo = repo
|
||||
self.experiment_name = experiment_name
|
||||
self.system_tracking_interval = system_tracking_interval
|
||||
@@ -184,7 +184,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.action_records: list = []
|
||||
|
||||
def setup(self, **kwargs: Any) -> None:
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
|
||||
if not self._run:
|
||||
if self._run_hash:
|
||||
@@ -210,7 +210,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
|
||||
self.step += 1
|
||||
self.llm_starts += 1
|
||||
@@ -229,7 +229,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
self.step += 1
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
@@ -264,7 +264,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
self.step += 1
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
@@ -280,7 +280,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
self.step += 1
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
@@ -303,7 +303,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
@@ -315,7 +315,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
self.step += 1
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
@@ -339,7 +339,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
self.step += 1
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
@@ -356,7 +356,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import import_pandas
|
||||
from langchain.callbacks.utils import _import_pandas
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
pd = import_pandas()
|
||||
pd = _import_pandas()
|
||||
from arize.utils.types import (
|
||||
EmbeddingColumnNames,
|
||||
Environments,
|
||||
|
||||
@@ -8,11 +8,11 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
_import_pandas,
|
||||
_import_spacy,
|
||||
_import_textstat,
|
||||
flatten_dict,
|
||||
hash_string,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
load_json,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def import_clearml() -> Any:
|
||||
def _import_clearml() -> Any:
|
||||
"""Import the clearml python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import clearml # noqa: F401
|
||||
@@ -63,8 +63,8 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
|
||||
clearml = import_clearml()
|
||||
spacy = import_spacy()
|
||||
clearml = _import_clearml()
|
||||
spacy = _import_spacy()
|
||||
super().__init__()
|
||||
|
||||
self.task_type = task_type
|
||||
@@ -329,8 +329,8 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
(dict): A dictionary containing the complexity metrics.
|
||||
"""
|
||||
resp = {}
|
||||
textstat = import_textstat()
|
||||
spacy = import_spacy()
|
||||
textstat = _import_textstat()
|
||||
spacy = _import_spacy()
|
||||
if self.complexity_metrics:
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
@@ -399,7 +399,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def _create_session_analysis_df(self) -> Any:
|
||||
"""Create a dataframe with all the information from the session."""
|
||||
pd = import_pandas()
|
||||
pd = _import_pandas()
|
||||
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
|
||||
|
||||
llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df(
|
||||
@@ -465,8 +465,8 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pd = import_pandas()
|
||||
clearml = import_clearml()
|
||||
pd = _import_pandas()
|
||||
clearml = _import_clearml()
|
||||
|
||||
# Log the action records
|
||||
self.logger.report_table(
|
||||
|
||||
@@ -7,17 +7,17 @@ import langchain
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
_import_pandas,
|
||||
_import_spacy,
|
||||
_import_textstat,
|
||||
flatten_dict,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, Generation, LLMResult
|
||||
|
||||
LANGCHAIN_MODEL_NAME = "langchain-model"
|
||||
|
||||
|
||||
def import_comet_ml() -> Any:
|
||||
def _import_comet_ml() -> Any:
|
||||
"""Import comet_ml and raise an error if it is not installed."""
|
||||
try:
|
||||
import comet_ml # noqa: F401
|
||||
@@ -33,7 +33,7 @@ def import_comet_ml() -> Any:
|
||||
def _get_experiment(
|
||||
workspace: Optional[str] = None, project_name: Optional[str] = None
|
||||
) -> Any:
|
||||
comet_ml = import_comet_ml()
|
||||
comet_ml = _import_comet_ml()
|
||||
|
||||
experiment = comet_ml.Experiment( # type: ignore
|
||||
workspace=workspace,
|
||||
@@ -44,7 +44,7 @@ def _get_experiment(
|
||||
|
||||
|
||||
def _fetch_text_complexity_metrics(text: str) -> dict:
|
||||
textstat = import_textstat()
|
||||
textstat = _import_textstat()
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
||||
@@ -67,7 +67,7 @@ def _fetch_text_complexity_metrics(text: str) -> dict:
|
||||
|
||||
|
||||
def _summarize_metrics_for_generated_outputs(metrics: Sequence) -> dict:
|
||||
pd = import_pandas()
|
||||
pd = _import_pandas()
|
||||
metrics_df = pd.DataFrame(metrics)
|
||||
metrics_summary = metrics_df.describe()
|
||||
|
||||
@@ -107,7 +107,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
|
||||
self.comet_ml = import_comet_ml()
|
||||
self.comet_ml = _import_comet_ml()
|
||||
super().__init__()
|
||||
|
||||
self.task_type = task_type
|
||||
@@ -140,7 +140,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.action_records: list = []
|
||||
self.complexity_metrics = complexity_metrics
|
||||
if self.visualizations:
|
||||
spacy = import_spacy()
|
||||
spacy = _import_spacy()
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
else:
|
||||
self.nlp = None
|
||||
@@ -535,7 +535,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
if not (self.visualizations and self.nlp):
|
||||
return
|
||||
|
||||
spacy = import_spacy()
|
||||
spacy = _import_spacy()
|
||||
|
||||
prompts = session_df["prompts"].tolist()
|
||||
outputs = session_df["text"].tolist()
|
||||
@@ -603,7 +603,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
def _create_session_analysis_dataframe(self, langchain_asset: Any = None) -> dict:
|
||||
pd = import_pandas()
|
||||
pd = _import_pandas()
|
||||
|
||||
llm_parameters = self._get_llm_parameters(langchain_asset)
|
||||
num_generations_per_prompt = llm_parameters.get("n", 1)
|
||||
|
||||
@@ -10,7 +10,7 @@ from langchain.schema import (
|
||||
)
|
||||
|
||||
|
||||
def import_context() -> Any:
|
||||
def _import_context() -> Any:
|
||||
"""Import the `getcontext` package."""
|
||||
try:
|
||||
import getcontext # noqa: F401
|
||||
@@ -98,7 +98,7 @@ class ContextCallbackHandler(BaseCallbackHandler):
|
||||
self.message_model,
|
||||
self.message_role_model,
|
||||
self.rating_model,
|
||||
) = import_context()
|
||||
) = _import_context()
|
||||
|
||||
token = token or os.environ.get("CONTEXT_TOKEN") or ""
|
||||
|
||||
|
||||
@@ -8,10 +8,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
_import_pandas,
|
||||
_import_spacy,
|
||||
_import_textstat,
|
||||
flatten_dict,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_flytekit() -> Tuple[flytekit, renderer]:
|
||||
def _import_flytekit() -> Tuple[flytekit, renderer]:
|
||||
"""Import flytekit and flytekitplugins-deck-standard."""
|
||||
try:
|
||||
import flytekit # noqa: F401
|
||||
@@ -75,7 +75,7 @@ def analyze_text(
|
||||
resp.update(text_complexity_metrics)
|
||||
|
||||
if nlp is not None:
|
||||
spacy = import_spacy()
|
||||
spacy = _import_spacy()
|
||||
doc = nlp(text)
|
||||
dep_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="dep", jupyter=False, page=True
|
||||
@@ -97,12 +97,12 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize callback handler."""
|
||||
flytekit, renderer = import_flytekit()
|
||||
self.pandas = import_pandas()
|
||||
flytekit, renderer = _import_flytekit()
|
||||
self.pandas = _import_pandas()
|
||||
|
||||
self.textstat = None
|
||||
try:
|
||||
self.textstat = import_textstat()
|
||||
self.textstat = _import_textstat()
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Textstat library is not installed. \
|
||||
@@ -112,7 +112,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
spacy = None
|
||||
try:
|
||||
spacy = import_spacy()
|
||||
spacy = _import_spacy()
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Spacy library is not installed. \
|
||||
|
||||
@@ -6,7 +6,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.messages import BaseMessage
|
||||
|
||||
|
||||
def import_infino() -> Any:
|
||||
def _import_infino() -> Any:
|
||||
"""Import the infino client."""
|
||||
try:
|
||||
from infinopy import InfinoClient
|
||||
@@ -19,7 +19,7 @@ def import_infino() -> Any:
|
||||
return InfinoClient()
|
||||
|
||||
|
||||
def import_tiktoken() -> Any:
|
||||
def _import_tiktoken() -> Any:
|
||||
"""Import tiktoken for counting tokens for OpenAI models."""
|
||||
try:
|
||||
import tiktoken
|
||||
@@ -38,7 +38,7 @@ def get_num_tokens(string: str, openai_model_name: str) -> int:
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/main
|
||||
/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
"""
|
||||
tiktoken = import_tiktoken()
|
||||
tiktoken = _import_tiktoken()
|
||||
|
||||
encoding = tiktoken.encoding_for_model(openai_model_name)
|
||||
num_tokens = len(encoding.encode(string))
|
||||
@@ -55,7 +55,7 @@ class InfinoCallbackHandler(BaseCallbackHandler):
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
# Set Infino client
|
||||
self.client = import_infino()
|
||||
self.client = _import_infino()
|
||||
self.model_id = model_id
|
||||
self.model_version = model_version
|
||||
self.verbose = verbose
|
||||
|
||||
@@ -23,7 +23,7 @@ class LabelStudioMode(Enum):
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
def get_default_label_configs(
|
||||
def _get_default_label_configs(
|
||||
mode: Union[str, LabelStudioMode]
|
||||
) -> Tuple[str, LabelStudioMode]:
|
||||
"""Get default Label Studio configs for the given mode.
|
||||
@@ -173,7 +173,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
|
||||
self.project_config = project_config
|
||||
self.mode = None
|
||||
else:
|
||||
self.project_config, self.mode = get_default_label_configs(mode)
|
||||
self.project_config, self.mode = _get_default_label_configs(mode)
|
||||
|
||||
self.project_id = project_id or os.getenv("LABEL_STUDIO_PROJECT_ID")
|
||||
if self.project_id is not None:
|
||||
|
||||
@@ -10,17 +10,17 @@ from typing import Any, Dict, List, Optional, Union
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
_import_pandas,
|
||||
_import_spacy,
|
||||
_import_textstat,
|
||||
flatten_dict,
|
||||
hash_string,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
def import_mlflow() -> Any:
|
||||
def _import_mlflow() -> Any:
|
||||
"""Import the mlflow python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import mlflow
|
||||
@@ -47,8 +47,8 @@ def analyze_text(
|
||||
files serialized to HTML string.
|
||||
"""
|
||||
resp: Dict[str, Any] = {}
|
||||
textstat = import_textstat()
|
||||
spacy = import_spacy()
|
||||
textstat = _import_textstat()
|
||||
spacy = _import_spacy()
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
||||
@@ -127,7 +127,7 @@ class MlflowLogger:
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
self.mlflow = import_mlflow()
|
||||
self.mlflow = _import_mlflow()
|
||||
if "DATABRICKS_RUNTIME_VERSION" in os.environ:
|
||||
self.mlflow.set_tracking_uri("databricks")
|
||||
self.mlf_expid = self.mlflow.tracking.fluent._get_experiment_id()
|
||||
@@ -246,10 +246,10 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
import_pandas()
|
||||
import_textstat()
|
||||
import_mlflow()
|
||||
spacy = import_spacy()
|
||||
_import_pandas()
|
||||
_import_textstat()
|
||||
_import_mlflow()
|
||||
spacy = _import_spacy()
|
||||
super().__init__()
|
||||
|
||||
self.name = name
|
||||
@@ -547,7 +547,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def _create_session_analysis_df(self) -> Any:
|
||||
"""Create a dataframe with all the information from the session."""
|
||||
pd = import_pandas()
|
||||
pd = _import_pandas()
|
||||
on_llm_start_records_df = pd.DataFrame(self.records["on_llm_start_records"])
|
||||
on_llm_end_records_df = pd.DataFrame(self.records["on_llm_end_records"])
|
||||
|
||||
@@ -617,7 +617,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
return session_analysis_df
|
||||
|
||||
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:
|
||||
pd = import_pandas()
|
||||
pd = _import_pandas()
|
||||
self.mlflg.table("action_records", pd.DataFrame(self.records["action_records"]))
|
||||
session_analysis_df = self._create_session_analysis_df()
|
||||
chat_html = session_analysis_df.pop("chat_html")
|
||||
|
||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, Tuple, Union
|
||||
|
||||
|
||||
def import_spacy() -> Any:
|
||||
def _import_spacy() -> Any:
|
||||
"""Import the spacy python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import spacy
|
||||
@@ -15,7 +15,7 @@ def import_spacy() -> Any:
|
||||
return spacy
|
||||
|
||||
|
||||
def import_pandas() -> Any:
|
||||
def _import_pandas() -> Any:
|
||||
"""Import the pandas python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import pandas
|
||||
@@ -27,7 +27,7 @@ def import_pandas() -> Any:
|
||||
return pandas
|
||||
|
||||
|
||||
def import_textstat() -> Any:
|
||||
def _import_textstat() -> Any:
|
||||
"""Import the textstat python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import textstat
|
||||
|
||||
@@ -7,16 +7,16 @@ from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
_import_pandas,
|
||||
_import_spacy,
|
||||
_import_textstat,
|
||||
flatten_dict,
|
||||
hash_string,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
def import_wandb() -> Any:
|
||||
def _import_wandb() -> Any:
|
||||
"""Import the wandb python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import wandb # noqa: F401
|
||||
@@ -28,7 +28,7 @@ def import_wandb() -> Any:
|
||||
return wandb
|
||||
|
||||
|
||||
def load_json_to_dict(json_path: Union[str, Path]) -> dict:
|
||||
def _load_json_to_dict(json_path: Union[str, Path]) -> dict:
|
||||
"""Load json file to a dictionary.
|
||||
|
||||
Parameters:
|
||||
@@ -42,7 +42,7 @@ def load_json_to_dict(json_path: Union[str, Path]) -> dict:
|
||||
return data
|
||||
|
||||
|
||||
def analyze_text(
|
||||
def _analyze_text(
|
||||
text: str,
|
||||
complexity_metrics: bool = True,
|
||||
visualize: bool = True,
|
||||
@@ -63,9 +63,9 @@ def analyze_text(
|
||||
files serialized in a wandb.Html element.
|
||||
"""
|
||||
resp = {}
|
||||
textstat = import_textstat()
|
||||
wandb = import_wandb()
|
||||
spacy = import_spacy()
|
||||
textstat = _import_textstat()
|
||||
wandb = _import_wandb()
|
||||
spacy = _import_spacy()
|
||||
if complexity_metrics:
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
@@ -120,7 +120,7 @@ def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> A
|
||||
|
||||
Returns:
|
||||
(wandb.Html): The html element."""
|
||||
wandb = import_wandb()
|
||||
wandb = _import_wandb()
|
||||
formatted_prompt = prompt.replace("\n", "<br>")
|
||||
formatted_generation = generation.replace("\n", "<br>")
|
||||
|
||||
@@ -173,10 +173,10 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
|
||||
wandb = import_wandb()
|
||||
import_pandas()
|
||||
import_textstat()
|
||||
spacy = import_spacy()
|
||||
wandb = _import_wandb()
|
||||
_import_pandas()
|
||||
_import_textstat()
|
||||
spacy = _import_spacy()
|
||||
super().__init__()
|
||||
|
||||
self.job_type = job_type
|
||||
@@ -269,7 +269,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
generation_resp = deepcopy(resp)
|
||||
generation_resp.update(flatten_dict(generation.dict()))
|
||||
generation_resp.update(
|
||||
analyze_text(
|
||||
_analyze_text(
|
||||
generation.text,
|
||||
complexity_metrics=self.complexity_metrics,
|
||||
visualize=self.visualize,
|
||||
@@ -438,7 +438,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def _create_session_analysis_df(self) -> Any:
|
||||
"""Create a dataframe with all the information from the session."""
|
||||
pd = import_pandas()
|
||||
pd = _import_pandas()
|
||||
on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
|
||||
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
|
||||
|
||||
@@ -533,8 +533,8 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pd = import_pandas()
|
||||
wandb = import_wandb()
|
||||
pd = _import_pandas()
|
||||
wandb = _import_wandb()
|
||||
action_records_table = wandb.Table(dataframe=pd.DataFrame(self.action_records))
|
||||
session_analysis_table = wandb.Table(
|
||||
dataframe=self._create_session_analysis_df()
|
||||
@@ -554,11 +554,11 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
try:
|
||||
langchain_asset.save(langchain_asset_path)
|
||||
model_artifact.add_file(str(langchain_asset_path))
|
||||
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
|
||||
model_artifact.metadata = _load_json_to_dict(langchain_asset_path)
|
||||
except ValueError:
|
||||
langchain_asset.save_agent(langchain_asset_path)
|
||||
model_artifact.add_file(str(langchain_asset_path))
|
||||
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
|
||||
model_artifact.metadata = _load_json_to_dict(langchain_asset_path)
|
||||
except NotImplementedError as e:
|
||||
print("Could not save model.")
|
||||
print(repr(e))
|
||||
|
||||
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
||||
diagnostic_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_langkit(
|
||||
def _import_langkit(
|
||||
sentiment: bool = False,
|
||||
toxicity: bool = False,
|
||||
themes: bool = False,
|
||||
@@ -159,7 +159,7 @@ class WhyLabsCallbackHandler(BaseCallbackHandler):
|
||||
WhyLabs writer.
|
||||
"""
|
||||
# langkit library will import necessary whylogs libraries
|
||||
import_langkit(sentiment=sentiment, toxicity=toxicity, themes=themes)
|
||||
_import_langkit(sentiment=sentiment, toxicity=toxicity, themes=themes)
|
||||
|
||||
import whylogs as why
|
||||
from langkit.callback_handler import get_callback_instance
|
||||
|
||||
@@ -149,7 +149,7 @@ class LangSmithDatasetChatLoader(BaseChatLoader):
|
||||
for data_point in data:
|
||||
yield ChatSession(
|
||||
messages=[
|
||||
oai_adapter.convert_dict_to_message(m)
|
||||
oai_adapter._convert_dict_to_message(m)
|
||||
for m in data_point.get("messages", [])
|
||||
],
|
||||
functions=data_point.get("functions"),
|
||||
|
||||
@@ -40,7 +40,7 @@ def _convert_one_message_to_text(
|
||||
return message_text
|
||||
|
||||
|
||||
def convert_messages_to_prompt_anthropic(
|
||||
def _convert_messages_to_prompt_anthropic(
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
human_prompt: str = "\n\nHuman:",
|
||||
@@ -115,7 +115,7 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
prompt_params["human_prompt"] = self.HUMAN_PROMPT
|
||||
if self.AI_PROMPT:
|
||||
prompt_params["ai_prompt"] = self.AI_PROMPT
|
||||
return convert_messages_to_prompt_anthropic(messages=messages, **prompt_params)
|
||||
return _convert_messages_to_prompt_anthropic(messages=messages, **prompt_params)
|
||||
|
||||
def convert_prompt(self, prompt: PromptValue) -> str:
|
||||
return self._convert_messages_to_prompt(prompt.to_messages())
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Set
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.adapters.openai import convert_message_to_dict
|
||||
from langchain.adapters.openai import _convert_message_to_dict
|
||||
from langchain.chat_models.openai import (
|
||||
ChatOpenAI,
|
||||
_import_tiktoken,
|
||||
@@ -178,7 +178,7 @@ class ChatAnyscale(ChatOpenAI):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
num_tokens = 0
|
||||
messages_dict = [convert_message_to_dict(m) for m in messages]
|
||||
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
|
||||
@@ -41,7 +41,7 @@ def _convert_resp_to_message_chunk(resp: Mapping[str, Any]) -> BaseMessageChunk:
|
||||
)
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"""Convert a message to a dictionary that can be passed to the API."""
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
@@ -194,7 +194,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
"""
|
||||
messages_dict: Dict[str, Any] = {
|
||||
"messages": [
|
||||
convert_message_to_dict(m)
|
||||
_convert_message_to_dict(m)
|
||||
for m in messages
|
||||
if not isinstance(m, SystemMessage)
|
||||
]
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, Iterator, List, Optional
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
|
||||
from langchain.chat_models.anthropic import _convert_messages_to_prompt_anthropic
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.bedrock import BedrockBase
|
||||
from langchain.pydantic_v1 import Extra
|
||||
@@ -25,7 +25,7 @@ class ChatPromptAdapter:
|
||||
cls, provider: str, messages: List[BaseMessage]
|
||||
) -> str:
|
||||
if provider == "anthropic":
|
||||
prompt = convert_messages_to_prompt_anthropic(messages=messages)
|
||||
prompt = _convert_messages_to_prompt_anthropic(messages=messages)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Provider {provider} model does not support chat."
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set
|
||||
|
||||
from langchain.adapters.openai import convert_message_to_dict
|
||||
from langchain.adapters.openai import _convert_message_to_dict
|
||||
from langchain.chat_models.openai import (
|
||||
ChatOpenAI,
|
||||
_import_tiktoken,
|
||||
@@ -140,7 +140,7 @@ class ChatEverlyAI(ChatOpenAI):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
num_tokens = 0
|
||||
messages_dict = [convert_message_to_dict(m) for m in messages]
|
||||
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.adapters.openai import convert_message_to_dict
|
||||
from langchain.adapters.openai import _convert_message_to_dict
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@@ -58,7 +58,7 @@ def _convert_delta_to_message_chunk(
|
||||
return default_class(content=content)
|
||||
|
||||
|
||||
def convert_dict_to_message(_dict: Any) -> BaseMessage:
|
||||
def _convert_dict_to_message(_dict: Any) -> BaseMessage:
|
||||
"""Convert a dict response to a message."""
|
||||
role = _dict.role
|
||||
content = _dict.content or ""
|
||||
@@ -125,7 +125,7 @@ class ChatFireworks(BaseChatModel):
|
||||
"messages": message_dicts,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
response = completion_with_retry(
|
||||
response = _completion_with_retry(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
@@ -143,7 +143,7 @@ class ChatFireworks(BaseChatModel):
|
||||
"messages": message_dicts,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
response = await acompletion_with_retry(
|
||||
response = await _acompletion_with_retry(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
@@ -156,7 +156,7 @@ class ChatFireworks(BaseChatModel):
|
||||
def _create_chat_result(self, response: Any) -> ChatResult:
|
||||
generations = []
|
||||
for res in response.choices:
|
||||
message = convert_dict_to_message(res.message)
|
||||
message = _convert_dict_to_message(res.message)
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(finish_reason=res.finish_reason),
|
||||
@@ -168,7 +168,7 @@ class ChatFireworks(BaseChatModel):
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> List[Dict[str, Any]]:
|
||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts
|
||||
|
||||
def _stream(
|
||||
@@ -186,7 +186,7 @@ class ChatFireworks(BaseChatModel):
|
||||
"stream": True,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
for chunk in completion_with_retry(
|
||||
for chunk in _completion_with_retry(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
):
|
||||
choice = chunk.choices[0]
|
||||
@@ -215,7 +215,7 @@ class ChatFireworks(BaseChatModel):
|
||||
"stream": True,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
async for chunk in await acompletion_with_retry_streaming(
|
||||
async for chunk in await _acompletion_with_retry_streaming(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
):
|
||||
choice = chunk.choices[0]
|
||||
@@ -230,7 +230,7 @@ class ChatFireworks(BaseChatModel):
|
||||
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
def _completion_with_retry(
|
||||
llm: ChatFireworks,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
@@ -242,15 +242,15 @@ def completion_with_retry(
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
def __completion_with_retry(**kwargs: Any) -> Any:
|
||||
return fireworks.client.ChatCompletion.create(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
return __completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
async def _acompletion_with_retry(
|
||||
llm: ChatFireworks,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
@@ -270,7 +270,7 @@ async def acompletion_with_retry(
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry_streaming(
|
||||
async def _acompletion_with_retry_streaming(
|
||||
llm: ChatFireworks,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
|
||||
@@ -190,27 +190,27 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||
)
|
||||
|
||||
|
||||
def chat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
|
||||
def _chat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
def _chat_with_retry(**kwargs: Any) -> Any:
|
||||
def __chat_with_retry(**kwargs: Any) -> Any:
|
||||
return llm.client.chat(**kwargs)
|
||||
|
||||
return _chat_with_retry(**kwargs)
|
||||
return __chat_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def achat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
|
||||
async def _achat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
retry_decorator = _create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
async def _achat_with_retry(**kwargs: Any) -> Any:
|
||||
async def __achat_with_retry(**kwargs: Any) -> Any:
|
||||
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
|
||||
return await llm.client.chat_async(**kwargs)
|
||||
|
||||
return await _achat_with_retry(**kwargs)
|
||||
return await __achat_with_retry(**kwargs)
|
||||
|
||||
|
||||
class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
@@ -294,7 +294,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
) -> ChatResult:
|
||||
prompt = _messages_to_prompt_dict(messages)
|
||||
|
||||
response: genai.types.ChatResponse = chat_with_retry(
|
||||
response: genai.types.ChatResponse = _chat_with_retry(
|
||||
self,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
@@ -316,7 +316,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
) -> ChatResult:
|
||||
prompt = _messages_to_prompt_dict(messages)
|
||||
|
||||
response: genai.types.ChatResponse = await achat_with_retry(
|
||||
response: genai.types.ChatResponse = await _achat_with_retry(
|
||||
self,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
|
||||
@@ -79,7 +79,7 @@ def _create_retry_decorator(llm: JinaChat) -> Callable[[Any], Any]:
|
||||
)
|
||||
|
||||
|
||||
async def acompletion_with_retry(llm: JinaChat, **kwargs: Any) -> Any:
|
||||
async def _acompletion_with_retry(llm: JinaChat, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
|
||||
@@ -274,15 +274,15 @@ class JinaChat(BaseChatModel):
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
def completion_with_retry(self, **kwargs: Any) -> Any:
|
||||
def _completion_with_retry(self, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = self._create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
def __completion_with_retry(**kwargs: Any) -> Any:
|
||||
return self.client.create(**kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
return __completion_with_retry(**kwargs)
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
@@ -309,7 +309,7 @@ class JinaChat(BaseChatModel):
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
for chunk in self.completion_with_retry(messages=message_dicts, **params):
|
||||
for chunk in self._completion_with_retry(messages=message_dicts, **params):
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
@@ -332,7 +332,7 @@ class JinaChat(BaseChatModel):
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||
response = self._completion_with_retry(messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_message_dicts(
|
||||
@@ -366,7 +366,7 @@ class JinaChat(BaseChatModel):
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
async for chunk in await acompletion_with_retry(
|
||||
async for chunk in await _acompletion_with_retry(
|
||||
self, messages=message_dicts, **params
|
||||
):
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
@@ -391,7 +391,7 @@ class JinaChat(BaseChatModel):
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = await acompletion_with_retry(self, messages=message_dicts, **params)
|
||||
response = await _acompletion_with_retry(self, messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@property
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import (
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict
|
||||
from langchain.adapters.openai import _convert_dict_to_message, _convert_message_to_dict
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
@@ -165,13 +165,13 @@ class ChatKonko(ChatOpenAI):
|
||||
|
||||
return {model["id"] for model in models_response.json()["data"]}
|
||||
|
||||
def completion_with_retry(
|
||||
def _completion_with_retry(
|
||||
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
def __completion_with_retry(**kwargs: Any) -> Any:
|
||||
return self.client.create(**kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
return __completion_with_retry(**kwargs)
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
@@ -198,7 +198,7 @@ class ChatKonko(ChatOpenAI):
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
for chunk in self.completion_with_retry(
|
||||
for chunk in self._completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk["choices"]) == 0:
|
||||
@@ -233,7 +233,7 @@ class ChatKonko(ChatOpenAI):
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
response = self._completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
@@ -246,13 +246,13 @@ class ChatKonko(ChatOpenAI):
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for res in response["choices"]:
|
||||
message = convert_dict_to_message(res["message"])
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(finish_reason=res.get("finish_reason")),
|
||||
|
||||
@@ -97,7 +97,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
async def _acompletion_with_retry(
|
||||
llm: ChatLiteLLM,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
@@ -225,17 +225,17 @@ class ChatLiteLLM(BaseChatModel):
|
||||
}
|
||||
return {**self._default_params, **creds}
|
||||
|
||||
def completion_with_retry(
|
||||
def _completion_with_retry(
|
||||
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
def __completion_with_retry(**kwargs: Any) -> Any:
|
||||
return self.client.completion(**kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
return __completion_with_retry(**kwargs)
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@@ -302,7 +302,7 @@ class ChatLiteLLM(BaseChatModel):
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
response = self._completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
@@ -345,7 +345,7 @@ class ChatLiteLLM(BaseChatModel):
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
for chunk in self.completion_with_retry(
|
||||
for chunk in self._completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk["choices"]) == 0:
|
||||
@@ -368,7 +368,7 @@ class ChatLiteLLM(BaseChatModel):
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
async for chunk in await acompletion_with_retry(
|
||||
async for chunk in await _acompletion_with_retry(
|
||||
self, messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk["choices"]) == 0:
|
||||
@@ -397,7 +397,7 @@ class ChatLiteLLM(BaseChatModel):
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = await acompletion_with_retry(
|
||||
response = await _acompletion_with_retry(
|
||||
self, messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict
|
||||
from langchain.adapters.openai import _convert_dict_to_message, _convert_message_to_dict
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@@ -81,7 +81,7 @@ def _create_retry_decorator(
|
||||
)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
async def _acompletion_with_retry(
|
||||
llm: ChatOpenAI,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
@@ -286,17 +286,17 @@ class ChatOpenAI(BaseChatModel):
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
def completion_with_retry(
|
||||
def _completion_with_retry(
|
||||
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
def __completion_with_retry(**kwargs: Any) -> Any:
|
||||
return self.client.create(**kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
return __completion_with_retry(**kwargs)
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
@@ -323,7 +323,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
for chunk in self.completion_with_retry(
|
||||
for chunk in self._completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk["choices"]) == 0:
|
||||
@@ -357,7 +357,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
return _generate_from_stream(stream_iter)
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
response = self._completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
@@ -370,13 +370,13 @@ class ChatOpenAI(BaseChatModel):
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for res in response["choices"]:
|
||||
message = convert_dict_to_message(res["message"])
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(finish_reason=res.get("finish_reason")),
|
||||
@@ -397,7 +397,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
async for chunk in await acompletion_with_retry(
|
||||
async for chunk in await _acompletion_with_retry(
|
||||
self, messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk["choices"]) == 0:
|
||||
@@ -432,7 +432,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = await acompletion_with_retry(
|
||||
response = await _acompletion_with_retry(
|
||||
self, messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
@@ -528,7 +528,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
messages_dict = [convert_message_to_dict(m) for m in messages]
|
||||
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
|
||||
@@ -67,7 +67,7 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
@@ -266,14 +266,14 @@ class ChatTongyi(BaseChatModel):
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
def completion_with_retry(
|
||||
def _completion_with_retry(
|
||||
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**_kwargs: Any) -> Any:
|
||||
def __completion_with_retry(**_kwargs: Any) -> Any:
|
||||
resp = self.client.call(**_kwargs)
|
||||
if resp.status_code == 200:
|
||||
return resp
|
||||
@@ -289,19 +289,19 @@ class ChatTongyi(BaseChatModel):
|
||||
response=resp,
|
||||
)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
return __completion_with_retry(**kwargs)
|
||||
|
||||
def stream_completion_with_retry(
|
||||
def _stream_completion_with_retry(
|
||||
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _stream_completion_with_retry(**_kwargs: Any) -> Any:
|
||||
def __stream_completion_with_retry(**_kwargs: Any) -> Any:
|
||||
return self.client.call(**_kwargs)
|
||||
|
||||
return _stream_completion_with_retry(**kwargs)
|
||||
return __stream_completion_with_retry(**kwargs)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@@ -320,7 +320,7 @@ class ChatTongyi(BaseChatModel):
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
response = self._completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
@@ -337,7 +337,7 @@ class ChatTongyi(BaseChatModel):
|
||||
# Mark current chunk total length
|
||||
length = 0
|
||||
default_chunk_class = AIMessageChunk
|
||||
for chunk in self.stream_completion_with_retry(
|
||||
for chunk in self._stream_completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk["output"]["choices"]) == 0:
|
||||
@@ -368,7 +368,7 @@ class ChatTongyi(BaseChatModel):
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
|
||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _client_params(self) -> Dict[str, Any]:
|
||||
|
||||
@@ -43,7 +43,7 @@ class _FileType(str, Enum):
|
||||
PDF = "pdf"
|
||||
|
||||
|
||||
def fetch_mime_types(file_types: Sequence[_FileType]) -> Dict[str, str]:
|
||||
def _fetch_mime_types(file_types: Sequence[_FileType]) -> Dict[str, str]:
|
||||
mime_types_mapping = {}
|
||||
for file_type in file_types:
|
||||
if file_type.value == "doc":
|
||||
@@ -73,7 +73,7 @@ class O365BaseLoader(BaseLoader, BaseModel):
|
||||
@property
|
||||
def _fetch_mime_types(self) -> Dict[str, str]:
|
||||
"""Return a dict of supported file types to corresponding mime types."""
|
||||
return fetch_mime_types(self._file_types)
|
||||
return _fetch_mime_types(self._file_types)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
|
||||
@@ -6,7 +6,7 @@ from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
def concatenate_rows(message: dict, title: str) -> str:
|
||||
def _concatenate_rows(message: dict, title: str) -> str:
|
||||
"""
|
||||
Combine message information in a readable format ready to be used.
|
||||
Args:
|
||||
@@ -50,7 +50,7 @@ class ChatGPTLoader(BaseLoader):
|
||||
messages = d["mapping"]
|
||||
text = "".join(
|
||||
[
|
||||
concatenate_rows(messages[key]["message"], title)
|
||||
_concatenate_rows(messages[key]["message"], title)
|
||||
for idx, key in enumerate(messages)
|
||||
if not (
|
||||
idx == 0
|
||||
|
||||
@@ -7,7 +7,7 @@ from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
def concatenate_rows(row: dict) -> str:
|
||||
def _concatenate_rows(row: dict) -> str:
|
||||
"""Combine message information in a readable format ready to be used.
|
||||
|
||||
Args:
|
||||
@@ -36,7 +36,7 @@ class FacebookChatLoader(BaseLoader):
|
||||
d = json.load(f)
|
||||
|
||||
text = "".join(
|
||||
concatenate_rows(message)
|
||||
_concatenate_rows(message)
|
||||
for message in d["messages"]
|
||||
if message.get("content") and isinstance(message["content"], str)
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
def concatenate_cells(
|
||||
def _concatenate_cells(
|
||||
cell: dict, include_outputs: bool, max_output_length: int, traceback: bool
|
||||
) -> str:
|
||||
"""Combine cells information in a readable format ready to be used.
|
||||
@@ -55,16 +55,16 @@ def concatenate_cells(
|
||||
return ""
|
||||
|
||||
|
||||
def remove_newlines(x: Any) -> Any:
|
||||
def _remove_newlines(x: Any) -> Any:
|
||||
"""Recursively remove newlines, no matter the data structure they are stored in."""
|
||||
import pandas as pd
|
||||
|
||||
if isinstance(x, str):
|
||||
return x.replace("\n", "")
|
||||
elif isinstance(x, list):
|
||||
return [remove_newlines(elem) for elem in x]
|
||||
return [_remove_newlines(elem) for elem in x]
|
||||
elif isinstance(x, pd.DataFrame):
|
||||
return x.applymap(remove_newlines)
|
||||
return x.applymap(_remove_newlines)
|
||||
else:
|
||||
return x
|
||||
|
||||
@@ -118,10 +118,10 @@ class NotebookLoader(BaseLoader):
|
||||
data = pd.json_normalize(d["cells"])
|
||||
filtered_data = data[["cell_type", "source", "outputs"]]
|
||||
if self.remove_newline:
|
||||
filtered_data = filtered_data.applymap(remove_newlines)
|
||||
filtered_data = filtered_data.applymap(_remove_newlines)
|
||||
|
||||
text = filtered_data.apply(
|
||||
lambda x: concatenate_cells(
|
||||
lambda x: _concatenate_cells(
|
||||
x, self.include_outputs, self.max_output_length, self.traceback
|
||||
),
|
||||
axis=1,
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
def default_joiner(docs: List[Tuple[str, Any]]) -> str:
|
||||
def _default_joiner(docs: List[Tuple[str, Any]]) -> str:
|
||||
"""Default joiner for content columns."""
|
||||
return "\n".join([doc[1] for doc in docs])
|
||||
|
||||
@@ -47,7 +47,9 @@ class RocksetLoader(BaseLoader):
|
||||
query: Any,
|
||||
content_keys: List[str],
|
||||
metadata_keys: Optional[List[str]] = None,
|
||||
content_columns_joiner: Callable[[List[Tuple[str, Any]]], str] = default_joiner,
|
||||
content_columns_joiner: Callable[
|
||||
[List[Tuple[str, Any]]], str
|
||||
] = _default_joiner,
|
||||
):
|
||||
"""Initialize with Rockset client.
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
from telethon.hints import EntityLike
|
||||
|
||||
|
||||
def concatenate_rows(row: dict) -> str:
|
||||
def _concatenate_rows(row: dict) -> str:
|
||||
"""Combine message information in a readable format ready to be used."""
|
||||
date = row["date"]
|
||||
sender = row["from"]
|
||||
@@ -37,7 +37,7 @@ class TelegramChatFileLoader(BaseLoader):
|
||||
d = json.load(f)
|
||||
|
||||
text = "".join(
|
||||
concatenate_rows(message)
|
||||
_concatenate_rows(message)
|
||||
for message in d["messages"]
|
||||
if message["type"] == "message" and isinstance(message["text"], str)
|
||||
)
|
||||
@@ -46,7 +46,7 @@ class TelegramChatFileLoader(BaseLoader):
|
||||
return [Document(page_content=text, metadata=metadata)]
|
||||
|
||||
|
||||
def text_to_docs(text: Union[str, List[str]]) -> List[Document]:
|
||||
def _text_to_docs(text: Union[str, List[str]]) -> List[Document]:
|
||||
"""Convert a string or list of strings to a list of Documents with metadata."""
|
||||
if isinstance(text, str):
|
||||
# Take a single string as one page
|
||||
@@ -258,4 +258,4 @@ class TelegramChatApiLoader(BaseLoader):
|
||||
message_threads = self._get_message_threads(df)
|
||||
combined_texts = self._combine_message_texts(message_threads, df)
|
||||
|
||||
return text_to_docs(combined_texts)
|
||||
return _text_to_docs(combined_texts)
|
||||
|
||||
@@ -6,7 +6,7 @@ from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
def concatenate_rows(date: str, sender: str, text: str) -> str:
|
||||
def _concatenate_rows(date: str, sender: str, text: str) -> str:
|
||||
"""Combine message information in a readable format ready to be used."""
|
||||
return f"{sender} on {date}: {text}\n\n"
|
||||
|
||||
@@ -57,7 +57,7 @@ class WhatsAppChatLoader(BaseLoader):
|
||||
if result:
|
||||
date, sender, text = result.groups()
|
||||
if text not in ignore_lines:
|
||||
text_content += concatenate_rows(date, sender, text)
|
||||
text_content += _concatenate_rows(date, sender, text)
|
||||
|
||||
metadata = {"source": str(p)}
|
||||
|
||||
|
||||
@@ -40,12 +40,12 @@ def _create_retry_decorator(embeddings: DashScopeEmbeddings) -> Callable[[Any],
|
||||
)
|
||||
|
||||
|
||||
def embed_with_retry(embeddings: DashScopeEmbeddings, **kwargs: Any) -> Any:
|
||||
def _embed_with_retry(embeddings: DashScopeEmbeddings, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the embedding call."""
|
||||
retry_decorator = _create_retry_decorator(embeddings)
|
||||
|
||||
@retry_decorator
|
||||
def _embed_with_retry(**kwargs: Any) -> Any:
|
||||
def __embed_with_retry(**kwargs: Any) -> Any:
|
||||
resp = embeddings.client.call(**kwargs)
|
||||
if resp.status_code == 200:
|
||||
return resp.output["embeddings"]
|
||||
@@ -61,7 +61,7 @@ def embed_with_retry(embeddings: DashScopeEmbeddings, **kwargs: Any) -> Any:
|
||||
response=resp,
|
||||
)
|
||||
|
||||
return _embed_with_retry(**kwargs)
|
||||
return __embed_with_retry(**kwargs)
|
||||
|
||||
|
||||
class DashScopeEmbeddings(BaseModel, Embeddings):
|
||||
@@ -135,7 +135,7 @@ class DashScopeEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = embed_with_retry(
|
||||
embeddings = _embed_with_retry(
|
||||
self, input=texts, text_type="document", model=self.model
|
||||
)
|
||||
embedding_list = [item["embedding"] for item in embeddings]
|
||||
@@ -150,7 +150,7 @@ class DashScopeEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
embedding = embed_with_retry(
|
||||
embedding = _embed_with_retry(
|
||||
self, input=text, text_type="query", model=self.model
|
||||
)[0]["embedding"]
|
||||
return embedding
|
||||
|
||||
@@ -40,17 +40,17 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||
)
|
||||
|
||||
|
||||
def embed_with_retry(
|
||||
def _embed_with_retry(
|
||||
embeddings: GooglePalmEmbeddings, *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
def _embed_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
def __embed_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
return embeddings.client.generate_embeddings(*args, **kwargs)
|
||||
|
||||
return _embed_with_retry(*args, **kwargs)
|
||||
return __embed_with_retry(*args, **kwargs)
|
||||
|
||||
|
||||
class GooglePalmEmbeddings(BaseModel, Embeddings):
|
||||
@@ -83,5 +83,5 @@ class GooglePalmEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
embedding = embed_with_retry(self, self.model_name, text)
|
||||
embedding = _embed_with_retry(self, self.model_name, text)
|
||||
return embedding["embedding"]
|
||||
|
||||
@@ -94,27 +94,27 @@ def _check_response(response: dict) -> dict:
|
||||
return response
|
||||
|
||||
|
||||
def embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
|
||||
def _embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the embedding call."""
|
||||
retry_decorator = _create_retry_decorator(embeddings)
|
||||
|
||||
@retry_decorator
|
||||
def _embed_with_retry(**kwargs: Any) -> Any:
|
||||
def __embed_with_retry(**kwargs: Any) -> Any:
|
||||
response = embeddings.client.create(**kwargs)
|
||||
return _check_response(response)
|
||||
|
||||
return _embed_with_retry(**kwargs)
|
||||
return __embed_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
|
||||
async def _async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the embedding call."""
|
||||
|
||||
@_async_retry_decorator(embeddings)
|
||||
async def _async_embed_with_retry(**kwargs: Any) -> Any:
|
||||
async def __async_embed_with_retry(**kwargs: Any) -> Any:
|
||||
response = await embeddings.client.acreate(**kwargs)
|
||||
return _check_response(response)
|
||||
|
||||
return await _async_embed_with_retry(**kwargs)
|
||||
return await __async_embed_with_retry(**kwargs)
|
||||
|
||||
|
||||
class LocalAIEmbeddings(BaseModel, Embeddings):
|
||||
@@ -265,13 +265,13 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
|
||||
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
return embed_with_retry(
|
||||
return _embed_with_retry(
|
||||
self,
|
||||
input=[text],
|
||||
**self._invocation_params,
|
||||
)["data"][
|
||||
0
|
||||
]["embedding"]
|
||||
)[
|
||||
"data"
|
||||
][0]["embedding"]
|
||||
|
||||
async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||
"""Call out to LocalAI's embedding endpoint."""
|
||||
@@ -281,7 +281,7 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
return (
|
||||
await async_embed_with_retry(
|
||||
await _async_embed_with_retry(
|
||||
self,
|
||||
input=[text],
|
||||
**self._invocation_params,
|
||||
|
||||
@@ -34,15 +34,15 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||
)
|
||||
|
||||
|
||||
def embed_with_retry(embeddings: MiniMaxEmbeddings, *args: Any, **kwargs: Any) -> Any:
|
||||
def _embed_with_retry(embeddings: MiniMaxEmbeddings, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
def _embed_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
def __embed_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
return embeddings.embed(*args, **kwargs)
|
||||
|
||||
return _embed_with_retry(*args, **kwargs)
|
||||
return __embed_with_retry(*args, **kwargs)
|
||||
|
||||
|
||||
class MiniMaxEmbeddings(BaseModel, Embeddings):
|
||||
@@ -144,7 +144,7 @@ class MiniMaxEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = embed_with_retry(self, texts=texts, embed_type=self.embed_type_db)
|
||||
embeddings = _embed_with_retry(self, texts=texts, embed_type=self.embed_type_db)
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
@@ -156,7 +156,7 @@ class MiniMaxEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
embeddings = embed_with_retry(
|
||||
embeddings = _embed_with_retry(
|
||||
self, texts=[text], embed_type=self.embed_type_query
|
||||
)
|
||||
return embeddings[0]
|
||||
|
||||
@@ -95,27 +95,27 @@ def _check_response(response: dict, skip_empty: bool = False) -> dict:
|
||||
return response
|
||||
|
||||
|
||||
def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
|
||||
def _embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the embedding call."""
|
||||
retry_decorator = _create_retry_decorator(embeddings)
|
||||
|
||||
@retry_decorator
|
||||
def _embed_with_retry(**kwargs: Any) -> Any:
|
||||
def __embed_with_retry(**kwargs: Any) -> Any:
|
||||
response = embeddings.client.create(**kwargs)
|
||||
return _check_response(response, skip_empty=embeddings.skip_empty)
|
||||
|
||||
return _embed_with_retry(**kwargs)
|
||||
return __embed_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
|
||||
async def _async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the embedding call."""
|
||||
|
||||
@_async_retry_decorator(embeddings)
|
||||
async def _async_embed_with_retry(**kwargs: Any) -> Any:
|
||||
async def __async_embed_with_retry(**kwargs: Any) -> Any:
|
||||
response = await embeddings.client.acreate(**kwargs)
|
||||
return _check_response(response, skip_empty=embeddings.skip_empty)
|
||||
|
||||
return await _async_embed_with_retry(**kwargs)
|
||||
return await __async_embed_with_retry(**kwargs)
|
||||
|
||||
|
||||
class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
@@ -371,7 +371,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
_iter = range(0, len(tokens), _chunk_size)
|
||||
|
||||
for i in _iter:
|
||||
response = embed_with_retry(
|
||||
response = _embed_with_retry(
|
||||
self,
|
||||
input=tokens[i : i + _chunk_size],
|
||||
**self._invocation_params,
|
||||
@@ -389,7 +389,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
for i in range(len(texts)):
|
||||
_result = results[i]
|
||||
if len(_result) == 0:
|
||||
average = embed_with_retry(
|
||||
average = _embed_with_retry(
|
||||
self,
|
||||
input="",
|
||||
**self._invocation_params,
|
||||
@@ -443,7 +443,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
batched_embeddings: List[List[float]] = []
|
||||
_chunk_size = chunk_size or self.chunk_size
|
||||
for i in range(0, len(tokens), _chunk_size):
|
||||
response = await async_embed_with_retry(
|
||||
response = await _async_embed_with_retry(
|
||||
self,
|
||||
input=tokens[i : i + _chunk_size],
|
||||
**self._invocation_params,
|
||||
@@ -460,7 +460,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
_result = results[i]
|
||||
if len(_result) == 0:
|
||||
average = (
|
||||
await async_embed_with_retry(
|
||||
await _async_embed_with_retry(
|
||||
self,
|
||||
input="",
|
||||
**self._invocation_params,
|
||||
|
||||
@@ -17,8 +17,8 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.llms.openai import (
|
||||
BaseOpenAI,
|
||||
acompletion_with_retry,
|
||||
completion_with_retry,
|
||||
_acompletion_with_retry,
|
||||
_completion_with_retry,
|
||||
)
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema import Generation, LLMResult
|
||||
@@ -162,7 +162,7 @@ class Anyscale(BaseOpenAI):
|
||||
) -> Iterator[GenerationChunk]:
|
||||
messages, params = self._get_chat_messages([prompt], stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
for stream_resp in completion_with_retry(
|
||||
for stream_resp in _completion_with_retry(
|
||||
self, messages=messages, run_manager=run_manager, **params
|
||||
):
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
@@ -180,7 +180,7 @@ class Anyscale(BaseOpenAI):
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
messages, params = self._get_chat_messages([prompt], stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
async for stream_resp in await acompletion_with_retry(
|
||||
async for stream_resp in await _acompletion_with_retry(
|
||||
self, messages=messages, run_manager=run_manager, **params
|
||||
):
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
@@ -223,7 +223,7 @@ class Anyscale(BaseOpenAI):
|
||||
else:
|
||||
messages, params = self._get_chat_messages([prompt], stop)
|
||||
params = {**params, **kwargs}
|
||||
response = completion_with_retry(
|
||||
response = _completion_with_retry(
|
||||
self, messages=messages, run_manager=run_manager, **params
|
||||
)
|
||||
choices.extend(response["choices"])
|
||||
@@ -264,7 +264,7 @@ class Anyscale(BaseOpenAI):
|
||||
else:
|
||||
messages, params = self._get_chat_messages([prompt], stop)
|
||||
params = {**params, **kwargs}
|
||||
response = await acompletion_with_retry(
|
||||
response = await _acompletion_with_retry(
|
||||
self, messages=messages, run_manager=run_manager, **params
|
||||
)
|
||||
choices.extend(response["choices"])
|
||||
|
||||
@@ -40,18 +40,18 @@ def _create_retry_decorator(llm: Cohere) -> Callable[[Any], Any]:
|
||||
)
|
||||
|
||||
|
||||
def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
|
||||
def _completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
def __completion_with_retry(**kwargs: Any) -> Any:
|
||||
return llm.client.generate(**kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
return __completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
|
||||
def _acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
|
||||
@@ -206,7 +206,7 @@ class Cohere(LLM, BaseCohere):
|
||||
response = cohere("Tell me a joke.")
|
||||
"""
|
||||
params = self._invocation_params(stop, **kwargs)
|
||||
response = completion_with_retry(
|
||||
response = _completion_with_retry(
|
||||
self, model=self.model, prompt=prompt, **params
|
||||
)
|
||||
_stop = params.get("stop_sequences")
|
||||
@@ -234,7 +234,7 @@ class Cohere(LLM, BaseCohere):
|
||||
response = await cohere("Tell me a joke.")
|
||||
"""
|
||||
params = self._invocation_params(stop, **kwargs)
|
||||
response = await acompletion_with_retry(
|
||||
response = await _acompletion_with_retry(
|
||||
self, model=self.model, prompt=prompt, **params
|
||||
)
|
||||
_stop = params.get("stop_sequences")
|
||||
|
||||
@@ -73,7 +73,7 @@ class Fireworks(LLM):
|
||||
"prompt": prompt,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
response = completion_with_retry(
|
||||
response = _completion_with_retry(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
)
|
||||
|
||||
@@ -92,7 +92,7 @@ class Fireworks(LLM):
|
||||
"prompt": prompt,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
response = await acompletion_with_retry(
|
||||
response = await _acompletion_with_retry(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
)
|
||||
|
||||
@@ -111,7 +111,7 @@ class Fireworks(LLM):
|
||||
"stream": True,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
for stream_resp in completion_with_retry(
|
||||
for stream_resp in _completion_with_retry(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
):
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
@@ -132,7 +132,7 @@ class Fireworks(LLM):
|
||||
"stream": True,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
async for stream_resp in await acompletion_with_retry_streaming(
|
||||
async for stream_resp in await _acompletion_with_retry_streaming(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
):
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
@@ -177,7 +177,7 @@ class Fireworks(LLM):
|
||||
assert generation is not None
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
def _completion_with_retry(
|
||||
llm: Fireworks,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
@@ -189,15 +189,15 @@ def completion_with_retry(
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
def __completion_with_retry(**kwargs: Any) -> Any:
|
||||
return fireworks.client.Completion.create(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
return __completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
async def _acompletion_with_retry(
|
||||
llm: Fireworks,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
@@ -217,7 +217,7 @@ async def acompletion_with_retry(
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry_streaming(
|
||||
async def _acompletion_with_retry_streaming(
|
||||
llm: Fireworks,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
|
||||
@@ -100,7 +100,7 @@ def _create_retry_decorator(
|
||||
)
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
def _completion_with_retry(
|
||||
llm: Union[BaseOpenAI, OpenAIChat],
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
@@ -109,13 +109,13 @@ def completion_with_retry(
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
def __completion_with_retry(**kwargs: Any) -> Any:
|
||||
return llm.client.create(**kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
return __completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
async def _acompletion_with_retry(
|
||||
llm: Union[BaseOpenAI, OpenAIChat],
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
@@ -305,7 +305,7 @@ class BaseOpenAI(BaseLLM):
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = {**self._invocation_params, **kwargs, "stream": True}
|
||||
self.get_sub_prompts(params, [prompt], stop) # this mutates params
|
||||
for stream_resp in completion_with_retry(
|
||||
for stream_resp in _completion_with_retry(
|
||||
self, prompt=prompt, run_manager=run_manager, **params
|
||||
):
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
@@ -329,7 +329,7 @@ class BaseOpenAI(BaseLLM):
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
params = {**self._invocation_params, **kwargs, "stream": True}
|
||||
self.get_sub_prompts(params, [prompt], stop) # this mutate params
|
||||
async for stream_resp in await acompletion_with_retry(
|
||||
async for stream_resp in await _acompletion_with_retry(
|
||||
self, prompt=prompt, run_manager=run_manager, **params
|
||||
):
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
@@ -398,7 +398,7 @@ class BaseOpenAI(BaseLLM):
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = completion_with_retry(
|
||||
response = _completion_with_retry(
|
||||
self, prompt=_prompts, run_manager=run_manager, **params
|
||||
)
|
||||
choices.extend(response["choices"])
|
||||
@@ -447,7 +447,7 @@ class BaseOpenAI(BaseLLM):
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await acompletion_with_retry(
|
||||
response = await _acompletion_with_retry(
|
||||
self, prompt=_prompts, run_manager=run_manager, **params
|
||||
)
|
||||
choices.extend(response["choices"])
|
||||
@@ -847,7 +847,7 @@ class OpenAIChat(BaseLLM):
|
||||
) -> Iterator[GenerationChunk]:
|
||||
messages, params = self._get_chat_params([prompt], stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
for stream_resp in completion_with_retry(
|
||||
for stream_resp in _completion_with_retry(
|
||||
self, messages=messages, run_manager=run_manager, **params
|
||||
):
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
@@ -865,7 +865,7 @@ class OpenAIChat(BaseLLM):
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
messages, params = self._get_chat_params([prompt], stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
async for stream_resp in await acompletion_with_retry(
|
||||
async for stream_resp in await _acompletion_with_retry(
|
||||
self, messages=messages, run_manager=run_manager, **params
|
||||
):
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
@@ -893,7 +893,7 @@ class OpenAIChat(BaseLLM):
|
||||
|
||||
messages, params = self._get_chat_params(prompts, stop)
|
||||
params = {**params, **kwargs}
|
||||
full_response = completion_with_retry(
|
||||
full_response = _completion_with_retry(
|
||||
self, messages=messages, run_manager=run_manager, **params
|
||||
)
|
||||
llm_output = {
|
||||
@@ -926,7 +926,7 @@ class OpenAIChat(BaseLLM):
|
||||
|
||||
messages, params = self._get_chat_params(prompts, stop)
|
||||
params = {**params, **kwargs}
|
||||
full_response = await acompletion_with_retry(
|
||||
full_response = await _acompletion_with_retry(
|
||||
self, messages=messages, run_manager=run_manager, **params
|
||||
)
|
||||
llm_output = {
|
||||
|
||||
@@ -167,7 +167,7 @@ class Nebula(LLM):
|
||||
else:
|
||||
raise ValueError("Prompt must contain instruction and conversation.")
|
||||
|
||||
response = completion_with_retry(
|
||||
response = _completion_with_retry(
|
||||
self,
|
||||
instruction=instruction,
|
||||
conversation=conversation,
|
||||
@@ -232,12 +232,12 @@ def _create_retry_decorator(llm: Nebula) -> Callable[[Any], Any]:
|
||||
)
|
||||
|
||||
|
||||
def completion_with_retry(llm: Nebula, **kwargs: Any) -> Any:
|
||||
def _completion_with_retry(llm: Nebula, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**_kwargs: Any) -> Any:
|
||||
def __completion_with_retry(**_kwargs: Any) -> Any:
|
||||
return make_request(llm, **_kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
return __completion_with_retry(**kwargs)
|
||||
|
||||
@@ -86,7 +86,7 @@ def _create_retry_decorator(
|
||||
return decorator
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
def _completion_with_retry(
|
||||
llm: VertexAI,
|
||||
*args: Any,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
@@ -96,13 +96,13 @@ def completion_with_retry(
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
def __completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
return llm.client.predict(*args, **kwargs)
|
||||
|
||||
return _completion_with_retry(*args, **kwargs)
|
||||
return __completion_with_retry(*args, **kwargs)
|
||||
|
||||
|
||||
def stream_completion_with_retry(
|
||||
def _stream_completion_with_retry(
|
||||
llm: VertexAI,
|
||||
*args: Any,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
@@ -118,7 +118,7 @@ def stream_completion_with_retry(
|
||||
return _completion_with_retry(*args, **kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
async def _acompletion_with_retry(
|
||||
llm: VertexAI,
|
||||
*args: Any,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
@@ -128,10 +128,10 @@ async def acompletion_with_retry(
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
async def __acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
return await llm.client.predict_async(*args, **kwargs)
|
||||
|
||||
return await _acompletion_with_retry(*args, **kwargs)
|
||||
return await __acompletion_with_retry(*args, **kwargs)
|
||||
|
||||
|
||||
class _VertexAIBase(BaseModel):
|
||||
@@ -295,7 +295,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
generation += chunk
|
||||
generations.append([generation])
|
||||
else:
|
||||
res = completion_with_retry(
|
||||
res = _completion_with_retry(
|
||||
self, prompt, run_manager=run_manager, **params
|
||||
)
|
||||
generations.append([_response_to_generation(r) for r in res.candidates])
|
||||
@@ -311,7 +311,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
res = await acompletion_with_retry(
|
||||
res = await _acompletion_with_retry(
|
||||
self, prompt, run_manager=run_manager, **params
|
||||
)
|
||||
generations.append([_response_to_generation(r) for r in res.candidates])
|
||||
@@ -325,7 +325,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
||||
for stream_resp in stream_completion_with_retry(
|
||||
for stream_resp in _stream_completion_with_retry(
|
||||
self, prompt, run_manager=run_manager, **params
|
||||
):
|
||||
chunk = _response_to_generation(stream_resp)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List
|
||||
import pytest
|
||||
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
|
||||
from langchain.chat_models.anthropic import _convert_messages_to_prompt_anthropic
|
||||
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
os.environ["ANTHROPIC_API_KEY"] = "foo"
|
||||
@@ -69,5 +69,5 @@ def test_anthropic_initialization() -> None:
|
||||
],
|
||||
)
|
||||
def test_formatting(messages: List[BaseMessage], expected: str) -> None:
|
||||
result = convert_messages_to_prompt_anthropic(messages)
|
||||
result = _convert_messages_to_prompt_anthropic(messages)
|
||||
assert result == expected
|
||||
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.adapters.openai import convert_dict_to_message
|
||||
from langchain.adapters.openai import _convert_dict_to_message
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
@@ -26,7 +26,7 @@ def test_openai_model_param() -> None:
|
||||
def test_function_message_dict_to_function_message() -> None:
|
||||
content = json.dumps({"result": "Example #1"})
|
||||
name = "test_function"
|
||||
result = convert_dict_to_message(
|
||||
result = _convert_dict_to_message(
|
||||
{
|
||||
"role": "function",
|
||||
"name": name,
|
||||
@@ -40,21 +40,21 @@ def test_function_message_dict_to_function_message() -> None:
|
||||
|
||||
def test__convert_dict_to_message_human() -> None:
|
||||
message = {"role": "user", "content": "foo"}
|
||||
result = convert_dict_to_message(message)
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = HumanMessage(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test__convert_dict_to_message_ai() -> None:
|
||||
message = {"role": "assistant", "content": "foo"}
|
||||
result = convert_dict_to_message(message)
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = AIMessage(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test__convert_dict_to_message_system() -> None:
|
||||
message = {"role": "system", "content": "foo"}
|
||||
result = convert_dict_to_message(message)
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = SystemMessage(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user