Compare commits

...

14 Commits

Author SHA1 Message Date
Bagatur
5f88060a08 src 2023-10-13 18:35:10 -07:00
Bagatur
ccb0821dc7 conf 2023-10-13 18:30:18 -07:00
Bagatur
972a2735d1 fix 2023-10-13 18:23:54 -07:00
Bagatur
4f2c85a7ae conf 2023-10-13 18:19:57 -07:00
Bagatur
0d3ae97c11 cat 2023-10-13 18:15:44 -07:00
Bagatur
883ff006bb full 2023-10-13 18:11:05 -07:00
Bagatur
f7d7ed3a47 fix 2023-10-13 18:07:14 -07:00
Bagatur
86d346ce89 req 2023-10-13 18:03:34 -07:00
Bagatur
a9af504f9c cmd 2023-10-13 18:00:56 -07:00
Bagatur
520be95168 custom build 2023-10-13 17:50:28 -07:00
Bagatur
1a869d0ef2 more 2023-10-13 15:44:11 -07:00
Bagatur
7dda1bf45a more 2023-10-13 15:38:35 -07:00
Bagatur
26b66a59fa more 2023-10-13 15:34:28 -07:00
Bagatur
b17b87ae04 update 2023-10-13 15:25:07 -07:00
47 changed files with 298 additions and 291 deletions

View File

@@ -9,9 +9,14 @@ build:
os: ubuntu-22.04
tools:
python: "3.11"
jobs:
pre_build:
commands:
- python -mvirtualenv $READTHEDOCS_VIRTUALENV_PATH
- python -m pip install --upgrade --no-cache-dir pip setuptools
- python -m pip install --upgrade --no-cache-dir sphinx readthedocs-sphinx-ext
- python -m pip install --exists-action=w --no-cache-dir -r docs/api_reference/requirements.txt
- python docs/api_reference/create_api_rst.py
- cat docs/api_reference/conf.py
- python -m sphinx -T -E -b html -d _build/doctrees -c docs/api_reference docs/api_reference $READTHEDOCS_OUTPUT/html -j auto
# Build documentation in the docs/ directory with Sphinx
sphinx:

View File

@@ -28,7 +28,7 @@ from langchain.schema.messages import (
)
async def aenumerate(
async def _aenumerate(
iterable: AsyncIterator[Any], start: int = 0
) -> AsyncIterator[tuple[int, Any]]:
"""Async version of enumerate."""
@@ -38,7 +38,7 @@ async def aenumerate(
i += 1
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
@@ -59,7 +59,7 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
return ChatMessage(content=_dict["content"], role=role)
def convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
@@ -87,7 +87,7 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
return message_dict
def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]:
def _convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]:
"""Convert dictionaries representing OpenAI messages to LangChain format.
Args:
@@ -96,7 +96,7 @@ def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMess
Returns:
List of LangChain BaseMessage objects.
"""
return [convert_dict_to_message(m) for m in messages]
return [_convert_dict_to_message(m) for m in messages]
def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]:
@@ -155,10 +155,10 @@ class ChatCompletion:
models = importlib.import_module("langchain.chat_models")
model_cls = getattr(models, provider)
model_config = model_cls(**kwargs)
converted_messages = convert_openai_messages(messages)
converted_messages = _convert_openai_messages(messages)
if not stream:
result = model_config.invoke(converted_messages)
return {"choices": [{"message": convert_message_to_dict(result)}]}
return {"choices": [{"message": _convert_message_to_dict(result)}]}
else:
return (
_convert_message_chunk_to_delta(c, i)
@@ -198,14 +198,14 @@ class ChatCompletion:
models = importlib.import_module("langchain.chat_models")
model_cls = getattr(models, provider)
model_config = model_cls(**kwargs)
converted_messages = convert_openai_messages(messages)
converted_messages = _convert_openai_messages(messages)
if not stream:
result = await model_config.ainvoke(converted_messages)
return {"choices": [{"message": convert_message_to_dict(result)}]}
return {"choices": [{"message": _convert_message_to_dict(result)}]}
else:
return (
_convert_message_chunk_to_delta(c, i)
async for i, c in aenumerate(model_config.astream(converted_messages))
async for i, c in _aenumerate(model_config.astream(converted_messages))
)
@@ -214,12 +214,12 @@ def _has_assistant_message(session: ChatSession) -> bool:
return any([isinstance(m, AIMessage) for m in session["messages"]])
def convert_messages_for_finetuning(
def _convert_messages_for_finetuning(
sessions: Iterable[ChatSession],
) -> List[List[dict]]:
"""Convert messages to a list of lists of dictionaries for fine-tuning."""
return [
[convert_message_to_dict(s) for s in session["messages"]]
[_convert_message_to_dict(s) for s in session["messages"]]
for session in sessions
if _has_assistant_message(session)
]

View File

@@ -5,7 +5,7 @@ from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
def import_aim() -> Any:
def _import_aim() -> Any:
"""Import the aim python package and raise an error if it is not installed."""
try:
import aim
@@ -169,7 +169,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
super().__init__()
aim = import_aim()
aim = _import_aim()
self.repo = repo
self.experiment_name = experiment_name
self.system_tracking_interval = system_tracking_interval
@@ -184,7 +184,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.action_records: list = []
def setup(self, **kwargs: Any) -> None:
aim = import_aim()
aim = _import_aim()
if not self._run:
if self._run_hash:
@@ -210,7 +210,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
aim = import_aim()
aim = _import_aim()
self.step += 1
self.llm_starts += 1
@@ -229,7 +229,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
aim = import_aim()
aim = _import_aim()
self.step += 1
self.llm_ends += 1
self.ends += 1
@@ -264,7 +264,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
aim = import_aim()
aim = _import_aim()
self.step += 1
self.chain_starts += 1
self.starts += 1
@@ -280,7 +280,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
aim = import_aim()
aim = _import_aim()
self.step += 1
self.chain_ends += 1
self.ends += 1
@@ -303,7 +303,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
aim = import_aim()
aim = _import_aim()
self.step += 1
self.tool_starts += 1
self.starts += 1
@@ -315,7 +315,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
aim = import_aim()
aim = _import_aim()
self.step += 1
self.tool_ends += 1
self.ends += 1
@@ -339,7 +339,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
aim = import_aim()
aim = _import_aim()
self.step += 1
self.agent_ends += 1
self.ends += 1
@@ -356,7 +356,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
aim = import_aim()
aim = _import_aim()
self.step += 1
self.tool_starts += 1
self.starts += 1

View File

@@ -2,7 +2,7 @@ from datetime import datetime
from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import import_pandas
from langchain.callbacks.utils import _import_pandas
from langchain.schema import AgentAction, AgentFinish, LLMResult
@@ -60,7 +60,7 @@ class ArizeCallbackHandler(BaseCallbackHandler):
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pd = import_pandas()
pd = _import_pandas()
from arize.utils.types import (
EmbeddingColumnNames,
Environments,

View File

@@ -8,11 +8,11 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
_import_pandas,
_import_spacy,
_import_textstat,
flatten_dict,
hash_string,
import_pandas,
import_spacy,
import_textstat,
load_json,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
import pandas as pd
def import_clearml() -> Any:
def _import_clearml() -> Any:
"""Import the clearml python package and raise an error if it is not installed."""
try:
import clearml # noqa: F401
@@ -63,8 +63,8 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) -> None:
"""Initialize callback handler."""
clearml = import_clearml()
spacy = import_spacy()
clearml = _import_clearml()
spacy = _import_spacy()
super().__init__()
self.task_type = task_type
@@ -329,8 +329,8 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
(dict): A dictionary containing the complexity metrics.
"""
resp = {}
textstat = import_textstat()
spacy = import_spacy()
textstat = _import_textstat()
spacy = _import_spacy()
if self.complexity_metrics:
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
@@ -399,7 +399,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def _create_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session."""
pd = import_pandas()
pd = _import_pandas()
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df(
@@ -465,8 +465,8 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
Returns:
None
"""
pd = import_pandas()
clearml = import_clearml()
pd = _import_pandas()
clearml = _import_clearml()
# Log the action records
self.logger.report_table(

View File

@@ -7,17 +7,17 @@ import langchain
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
_import_pandas,
_import_spacy,
_import_textstat,
flatten_dict,
import_pandas,
import_spacy,
import_textstat,
)
from langchain.schema import AgentAction, AgentFinish, Generation, LLMResult
LANGCHAIN_MODEL_NAME = "langchain-model"
def import_comet_ml() -> Any:
def _import_comet_ml() -> Any:
"""Import comet_ml and raise an error if it is not installed."""
try:
import comet_ml # noqa: F401
@@ -33,7 +33,7 @@ def import_comet_ml() -> Any:
def _get_experiment(
workspace: Optional[str] = None, project_name: Optional[str] = None
) -> Any:
comet_ml = import_comet_ml()
comet_ml = _import_comet_ml()
experiment = comet_ml.Experiment( # type: ignore
workspace=workspace,
@@ -44,7 +44,7 @@ def _get_experiment(
def _fetch_text_complexity_metrics(text: str) -> dict:
textstat = import_textstat()
textstat = _import_textstat()
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
@@ -67,7 +67,7 @@ def _fetch_text_complexity_metrics(text: str) -> dict:
def _summarize_metrics_for_generated_outputs(metrics: Sequence) -> dict:
pd = import_pandas()
pd = _import_pandas()
metrics_df = pd.DataFrame(metrics)
metrics_summary = metrics_df.describe()
@@ -107,7 +107,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) -> None:
"""Initialize callback handler."""
self.comet_ml = import_comet_ml()
self.comet_ml = _import_comet_ml()
super().__init__()
self.task_type = task_type
@@ -140,7 +140,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.action_records: list = []
self.complexity_metrics = complexity_metrics
if self.visualizations:
spacy = import_spacy()
spacy = _import_spacy()
self.nlp = spacy.load("en_core_web_sm")
else:
self.nlp = None
@@ -535,7 +535,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
if not (self.visualizations and self.nlp):
return
spacy = import_spacy()
spacy = _import_spacy()
prompts = session_df["prompts"].tolist()
outputs = session_df["text"].tolist()
@@ -603,7 +603,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.temp_dir = tempfile.TemporaryDirectory()
def _create_session_analysis_dataframe(self, langchain_asset: Any = None) -> dict:
pd = import_pandas()
pd = _import_pandas()
llm_parameters = self._get_llm_parameters(langchain_asset)
num_generations_per_prompt = llm_parameters.get("n", 1)

View File

@@ -10,7 +10,7 @@ from langchain.schema import (
)
def import_context() -> Any:
def _import_context() -> Any:
"""Import the `getcontext` package."""
try:
import getcontext # noqa: F401
@@ -98,7 +98,7 @@ class ContextCallbackHandler(BaseCallbackHandler):
self.message_model,
self.message_role_model,
self.rating_model,
) = import_context()
) = _import_context()
token = token or os.environ.get("CONTEXT_TOKEN") or ""

View File

@@ -8,10 +8,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
_import_pandas,
_import_spacy,
_import_textstat,
flatten_dict,
import_pandas,
import_spacy,
import_textstat,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def import_flytekit() -> Tuple[flytekit, renderer]:
def _import_flytekit() -> Tuple[flytekit, renderer]:
"""Import flytekit and flytekitplugins-deck-standard."""
try:
import flytekit # noqa: F401
@@ -75,7 +75,7 @@ def analyze_text(
resp.update(text_complexity_metrics)
if nlp is not None:
spacy = import_spacy()
spacy = _import_spacy()
doc = nlp(text)
dep_out = spacy.displacy.render( # type: ignore
doc, style="dep", jupyter=False, page=True
@@ -97,12 +97,12 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def __init__(self) -> None:
"""Initialize callback handler."""
flytekit, renderer = import_flytekit()
self.pandas = import_pandas()
flytekit, renderer = _import_flytekit()
self.pandas = _import_pandas()
self.textstat = None
try:
self.textstat = import_textstat()
self.textstat = _import_textstat()
except ImportError:
logger.warning(
"Textstat library is not installed. \
@@ -112,7 +112,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
spacy = None
try:
spacy = import_spacy()
spacy = _import_spacy()
except ImportError:
logger.warning(
"Spacy library is not installed. \

View File

@@ -6,7 +6,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.messages import BaseMessage
def import_infino() -> Any:
def _import_infino() -> Any:
"""Import the infino client."""
try:
from infinopy import InfinoClient
@@ -19,7 +19,7 @@ def import_infino() -> Any:
return InfinoClient()
def import_tiktoken() -> Any:
def _import_tiktoken() -> Any:
"""Import tiktoken for counting tokens for OpenAI models."""
try:
import tiktoken
@@ -38,7 +38,7 @@ def get_num_tokens(string: str, openai_model_name: str) -> int:
Official documentation: https://github.com/openai/openai-cookbook/blob/main
/examples/How_to_count_tokens_with_tiktoken.ipynb
"""
tiktoken = import_tiktoken()
tiktoken = _import_tiktoken()
encoding = tiktoken.encoding_for_model(openai_model_name)
num_tokens = len(encoding.encode(string))
@@ -55,7 +55,7 @@ class InfinoCallbackHandler(BaseCallbackHandler):
verbose: bool = False,
) -> None:
# Set Infino client
self.client = import_infino()
self.client = _import_infino()
self.model_id = model_id
self.model_version = model_version
self.verbose = verbose

View File

@@ -23,7 +23,7 @@ class LabelStudioMode(Enum):
CHAT = "chat"
def get_default_label_configs(
def _get_default_label_configs(
mode: Union[str, LabelStudioMode]
) -> Tuple[str, LabelStudioMode]:
"""Get default Label Studio configs for the given mode.
@@ -173,7 +173,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
self.project_config = project_config
self.mode = None
else:
self.project_config, self.mode = get_default_label_configs(mode)
self.project_config, self.mode = _get_default_label_configs(mode)
self.project_id = project_id or os.getenv("LABEL_STUDIO_PROJECT_ID")
if self.project_id is not None:

View File

@@ -10,17 +10,17 @@ from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
_import_pandas,
_import_spacy,
_import_textstat,
flatten_dict,
hash_string,
import_pandas,
import_spacy,
import_textstat,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.utils import get_from_dict_or_env
def import_mlflow() -> Any:
def _import_mlflow() -> Any:
"""Import the mlflow python package and raise an error if it is not installed."""
try:
import mlflow
@@ -47,8 +47,8 @@ def analyze_text(
files serialized to HTML string.
"""
resp: Dict[str, Any] = {}
textstat = import_textstat()
spacy = import_spacy()
textstat = _import_textstat()
spacy = _import_spacy()
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
@@ -127,7 +127,7 @@ class MlflowLogger:
"""
def __init__(self, **kwargs: Any):
self.mlflow = import_mlflow()
self.mlflow = _import_mlflow()
if "DATABRICKS_RUNTIME_VERSION" in os.environ:
self.mlflow.set_tracking_uri("databricks")
self.mlf_expid = self.mlflow.tracking.fluent._get_experiment_id()
@@ -246,10 +246,10 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
tracking_uri: Optional[str] = None,
) -> None:
"""Initialize callback handler."""
import_pandas()
import_textstat()
import_mlflow()
spacy = import_spacy()
_import_pandas()
_import_textstat()
_import_mlflow()
spacy = _import_spacy()
super().__init__()
self.name = name
@@ -547,7 +547,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def _create_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session."""
pd = import_pandas()
pd = _import_pandas()
on_llm_start_records_df = pd.DataFrame(self.records["on_llm_start_records"])
on_llm_end_records_df = pd.DataFrame(self.records["on_llm_end_records"])
@@ -617,7 +617,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
return session_analysis_df
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:
pd = import_pandas()
pd = _import_pandas()
self.mlflg.table("action_records", pd.DataFrame(self.records["action_records"]))
session_analysis_df = self._create_session_analysis_df()
chat_html = session_analysis_df.pop("chat_html")

View File

@@ -3,7 +3,7 @@ from pathlib import Path
from typing import Any, Dict, Iterable, Tuple, Union
def import_spacy() -> Any:
def _import_spacy() -> Any:
"""Import the spacy python package and raise an error if it is not installed."""
try:
import spacy
@@ -15,7 +15,7 @@ def import_spacy() -> Any:
return spacy
def import_pandas() -> Any:
def _import_pandas() -> Any:
"""Import the pandas python package and raise an error if it is not installed."""
try:
import pandas
@@ -27,7 +27,7 @@ def import_pandas() -> Any:
return pandas
def import_textstat() -> Any:
def _import_textstat() -> Any:
"""Import the textstat python package and raise an error if it is not installed."""
try:
import textstat

View File

@@ -7,16 +7,16 @@ from typing import Any, Dict, List, Optional, Sequence, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
_import_pandas,
_import_spacy,
_import_textstat,
flatten_dict,
hash_string,
import_pandas,
import_spacy,
import_textstat,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
def import_wandb() -> Any:
def _import_wandb() -> Any:
"""Import the wandb python package and raise an error if it is not installed."""
try:
import wandb # noqa: F401
@@ -28,7 +28,7 @@ def import_wandb() -> Any:
return wandb
def load_json_to_dict(json_path: Union[str, Path]) -> dict:
def _load_json_to_dict(json_path: Union[str, Path]) -> dict:
"""Load json file to a dictionary.
Parameters:
@@ -42,7 +42,7 @@ def load_json_to_dict(json_path: Union[str, Path]) -> dict:
return data
def analyze_text(
def _analyze_text(
text: str,
complexity_metrics: bool = True,
visualize: bool = True,
@@ -63,9 +63,9 @@ def analyze_text(
files serialized in a wandb.Html element.
"""
resp = {}
textstat = import_textstat()
wandb = import_wandb()
spacy = import_spacy()
textstat = _import_textstat()
wandb = _import_wandb()
spacy = _import_spacy()
if complexity_metrics:
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
@@ -120,7 +120,7 @@ def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> A
Returns:
(wandb.Html): The html element."""
wandb = import_wandb()
wandb = _import_wandb()
formatted_prompt = prompt.replace("\n", "<br>")
formatted_generation = generation.replace("\n", "<br>")
@@ -173,10 +173,10 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) -> None:
"""Initialize callback handler."""
wandb = import_wandb()
import_pandas()
import_textstat()
spacy = import_spacy()
wandb = _import_wandb()
_import_pandas()
_import_textstat()
spacy = _import_spacy()
super().__init__()
self.job_type = job_type
@@ -269,7 +269,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
generation_resp.update(
analyze_text(
_analyze_text(
generation.text,
complexity_metrics=self.complexity_metrics,
visualize=self.visualize,
@@ -438,7 +438,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def _create_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session."""
pd = import_pandas()
pd = _import_pandas()
on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
@@ -533,8 +533,8 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
Returns:
None
"""
pd = import_pandas()
wandb = import_wandb()
pd = _import_pandas()
wandb = _import_wandb()
action_records_table = wandb.Table(dataframe=pd.DataFrame(self.action_records))
session_analysis_table = wandb.Table(
dataframe=self._create_session_analysis_df()
@@ -554,11 +554,11 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
try:
langchain_asset.save(langchain_asset_path)
model_artifact.add_file(str(langchain_asset_path))
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
model_artifact.metadata = _load_json_to_dict(langchain_asset_path)
except ValueError:
langchain_asset.save_agent(langchain_asset_path)
model_artifact.add_file(str(langchain_asset_path))
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
model_artifact.metadata = _load_json_to_dict(langchain_asset_path)
except NotImplementedError as e:
print("Could not save model.")
print(repr(e))

View File

@@ -12,7 +12,7 @@ if TYPE_CHECKING:
diagnostic_logger = logging.getLogger(__name__)
def import_langkit(
def _import_langkit(
sentiment: bool = False,
toxicity: bool = False,
themes: bool = False,
@@ -159,7 +159,7 @@ class WhyLabsCallbackHandler(BaseCallbackHandler):
WhyLabs writer.
"""
# langkit library will import necessary whylogs libraries
import_langkit(sentiment=sentiment, toxicity=toxicity, themes=themes)
_import_langkit(sentiment=sentiment, toxicity=toxicity, themes=themes)
import whylogs as why
from langkit.callback_handler import get_callback_instance

View File

@@ -149,7 +149,7 @@ class LangSmithDatasetChatLoader(BaseChatLoader):
for data_point in data:
yield ChatSession(
messages=[
oai_adapter.convert_dict_to_message(m)
oai_adapter._convert_dict_to_message(m)
for m in data_point.get("messages", [])
],
functions=data_point.get("functions"),

View File

@@ -40,7 +40,7 @@ def _convert_one_message_to_text(
return message_text
def convert_messages_to_prompt_anthropic(
def _convert_messages_to_prompt_anthropic(
messages: List[BaseMessage],
*,
human_prompt: str = "\n\nHuman:",
@@ -115,7 +115,7 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
prompt_params["human_prompt"] = self.HUMAN_PROMPT
if self.AI_PROMPT:
prompt_params["ai_prompt"] = self.AI_PROMPT
return convert_messages_to_prompt_anthropic(messages=messages, **prompt_params)
return _convert_messages_to_prompt_anthropic(messages=messages, **prompt_params)
def convert_prompt(self, prompt: PromptValue) -> str:
return self._convert_messages_to_prompt(prompt.to_messages())

View File

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Set
import requests
from langchain.adapters.openai import convert_message_to_dict
from langchain.adapters.openai import _convert_message_to_dict
from langchain.chat_models.openai import (
ChatOpenAI,
_import_tiktoken,
@@ -178,7 +178,7 @@ class ChatAnyscale(ChatOpenAI):
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [convert_message_to_dict(m) for m in messages]
messages_dict = [_convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():

View File

@@ -41,7 +41,7 @@ def _convert_resp_to_message_chunk(resp: Mapping[str, Any]) -> BaseMessageChunk:
)
def convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a message to a dictionary that can be passed to the API."""
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
@@ -194,7 +194,7 @@ class QianfanChatEndpoint(BaseChatModel):
"""
messages_dict: Dict[str, Any] = {
"messages": [
convert_message_to_dict(m)
_convert_message_to_dict(m)
for m in messages
if not isinstance(m, SystemMessage)
]

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, Iterator, List, Optional
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
from langchain.chat_models.anthropic import _convert_messages_to_prompt_anthropic
from langchain.chat_models.base import BaseChatModel
from langchain.llms.bedrock import BedrockBase
from langchain.pydantic_v1 import Extra
@@ -25,7 +25,7 @@ class ChatPromptAdapter:
cls, provider: str, messages: List[BaseMessage]
) -> str:
if provider == "anthropic":
prompt = convert_messages_to_prompt_anthropic(messages=messages)
prompt = _convert_messages_to_prompt_anthropic(messages=messages)
else:
raise NotImplementedError(
f"Provider {provider} model does not support chat."

View File

@@ -5,7 +5,7 @@ import logging
import sys
from typing import TYPE_CHECKING, Dict, Optional, Set
from langchain.adapters.openai import convert_message_to_dict
from langchain.adapters.openai import _convert_message_to_dict
from langchain.chat_models.openai import (
ChatOpenAI,
_import_tiktoken,
@@ -140,7 +140,7 @@ class ChatEverlyAI(ChatOpenAI):
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [convert_message_to_dict(m) for m in messages]
messages_dict = [_convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():

View File

@@ -10,7 +10,7 @@ from typing import (
Union,
)
from langchain.adapters.openai import convert_message_to_dict
from langchain.adapters.openai import _convert_message_to_dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@@ -58,7 +58,7 @@ def _convert_delta_to_message_chunk(
return default_class(content=content)
def convert_dict_to_message(_dict: Any) -> BaseMessage:
def _convert_dict_to_message(_dict: Any) -> BaseMessage:
"""Convert a dict response to a message."""
role = _dict.role
content = _dict.content or ""
@@ -125,7 +125,7 @@ class ChatFireworks(BaseChatModel):
"messages": message_dicts,
**self.model_kwargs,
}
response = completion_with_retry(
response = _completion_with_retry(
self, run_manager=run_manager, stop=stop, **params
)
return self._create_chat_result(response)
@@ -143,7 +143,7 @@ class ChatFireworks(BaseChatModel):
"messages": message_dicts,
**self.model_kwargs,
}
response = await acompletion_with_retry(
response = await _acompletion_with_retry(
self, run_manager=run_manager, stop=stop, **params
)
return self._create_chat_result(response)
@@ -156,7 +156,7 @@ class ChatFireworks(BaseChatModel):
def _create_chat_result(self, response: Any) -> ChatResult:
generations = []
for res in response.choices:
message = convert_dict_to_message(res.message)
message = _convert_dict_to_message(res.message)
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.finish_reason),
@@ -168,7 +168,7 @@ class ChatFireworks(BaseChatModel):
def _create_message_dicts(
self, messages: List[BaseMessage]
) -> List[Dict[str, Any]]:
message_dicts = [convert_message_to_dict(m) for m in messages]
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts
def _stream(
@@ -186,7 +186,7 @@ class ChatFireworks(BaseChatModel):
"stream": True,
**self.model_kwargs,
}
for chunk in completion_with_retry(
for chunk in _completion_with_retry(
self, run_manager=run_manager, stop=stop, **params
):
choice = chunk.choices[0]
@@ -215,7 +215,7 @@ class ChatFireworks(BaseChatModel):
"stream": True,
**self.model_kwargs,
}
async for chunk in await acompletion_with_retry_streaming(
async for chunk in await _acompletion_with_retry_streaming(
self, run_manager=run_manager, stop=stop, **params
):
choice = chunk.choices[0]
@@ -230,7 +230,7 @@ class ChatFireworks(BaseChatModel):
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
def completion_with_retry(
def _completion_with_retry(
llm: ChatFireworks,
*,
run_manager: Optional[CallbackManagerForLLMRun] = None,
@@ -242,15 +242,15 @@ def completion_with_retry(
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
def __completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.ChatCompletion.create(
**kwargs,
)
return _completion_with_retry(**kwargs)
return __completion_with_retry(**kwargs)
async def acompletion_with_retry(
async def _acompletion_with_retry(
llm: ChatFireworks,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
@@ -270,7 +270,7 @@ async def acompletion_with_retry(
return await _completion_with_retry(**kwargs)
async def acompletion_with_retry_streaming(
async def _acompletion_with_retry_streaming(
llm: ChatFireworks,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,

View File

@@ -190,27 +190,27 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
)
def chat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
def _chat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator()
@retry_decorator
def _chat_with_retry(**kwargs: Any) -> Any:
def __chat_with_retry(**kwargs: Any) -> Any:
return llm.client.chat(**kwargs)
return _chat_with_retry(**kwargs)
return __chat_with_retry(**kwargs)
async def achat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
async def _achat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator()
@retry_decorator
async def _achat_with_retry(**kwargs: Any) -> Any:
async def __achat_with_retry(**kwargs: Any) -> Any:
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
return await llm.client.chat_async(**kwargs)
return await _achat_with_retry(**kwargs)
return await __achat_with_retry(**kwargs)
class ChatGooglePalm(BaseChatModel, BaseModel):
@@ -294,7 +294,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)
response: genai.types.ChatResponse = chat_with_retry(
response: genai.types.ChatResponse = _chat_with_retry(
self,
model=self.model_name,
prompt=prompt,
@@ -316,7 +316,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)
response: genai.types.ChatResponse = await achat_with_retry(
response: genai.types.ChatResponse = await _achat_with_retry(
self,
model=self.model_name,
prompt=prompt,

View File

@@ -79,7 +79,7 @@ def _create_retry_decorator(llm: JinaChat) -> Callable[[Any], Any]:
)
async def acompletion_with_retry(llm: JinaChat, **kwargs: Any) -> Any:
async def _acompletion_with_retry(llm: JinaChat, **kwargs: Any) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm)
@@ -274,15 +274,15 @@ class JinaChat(BaseChatModel):
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(self, **kwargs: Any) -> Any:
def _completion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = self._create_retry_decorator()
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
def __completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
return __completion_with_retry(**kwargs)
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
@@ -309,7 +309,7 @@ class JinaChat(BaseChatModel):
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(messages=message_dicts, **params):
for chunk in self._completion_with_retry(messages=message_dicts, **params):
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
@@ -332,7 +332,7 @@ class JinaChat(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(messages=message_dicts, **params)
response = self._completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)
def _create_message_dicts(
@@ -366,7 +366,7 @@ class JinaChat(BaseChatModel):
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
async for chunk in await acompletion_with_retry(
async for chunk in await _acompletion_with_retry(
self, messages=message_dicts, **params
):
delta = chunk["choices"][0]["delta"]
@@ -391,7 +391,7 @@ class JinaChat(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(self, messages=message_dicts, **params)
response = await _acompletion_with_retry(self, messages=message_dicts, **params)
return self._create_chat_result(response)
@property

View File

@@ -17,7 +17,7 @@ from typing import (
import requests
from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict
from langchain.adapters.openai import _convert_dict_to_message, _convert_message_to_dict
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
@@ -165,13 +165,13 @@ class ChatKonko(ChatOpenAI):
return {model["id"] for model in models_response.json()["data"]}
def completion_with_retry(
def _completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
def _completion_with_retry(**kwargs: Any) -> Any:
def __completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
return __completion_with_retry(**kwargs)
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
@@ -198,7 +198,7 @@ class ChatKonko(ChatOpenAI):
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(
for chunk in self._completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk["choices"]) == 0:
@@ -233,7 +233,7 @@ class ChatKonko(ChatOpenAI):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
response = self._completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
@@ -246,13 +246,13 @@ class ChatKonko(ChatOpenAI):
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [convert_message_to_dict(m) for m in messages]
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response["choices"]:
message = convert_dict_to_message(res["message"])
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.get("finish_reason")),

View File

@@ -97,7 +97,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
return ChatMessage(content=_dict["content"], role=role)
async def acompletion_with_retry(
async def _acompletion_with_retry(
llm: ChatLiteLLM,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
@@ -225,17 +225,17 @@ class ChatLiteLLM(BaseChatModel):
}
return {**self._default_params, **creds}
def completion_with_retry(
def _completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
def __completion_with_retry(**kwargs: Any) -> Any:
return self.client.completion(**kwargs)
return _completion_with_retry(**kwargs)
return __completion_with_retry(**kwargs)
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@@ -302,7 +302,7 @@ class ChatLiteLLM(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
response = self._completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
@@ -345,7 +345,7 @@ class ChatLiteLLM(BaseChatModel):
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(
for chunk in self._completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk["choices"]) == 0:
@@ -368,7 +368,7 @@ class ChatLiteLLM(BaseChatModel):
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
async for chunk in await acompletion_with_retry(
async for chunk in await _acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk["choices"]) == 0:
@@ -397,7 +397,7 @@ class ChatLiteLLM(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(
response = await _acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)

View File

@@ -18,7 +18,7 @@ from typing import (
Union,
)
from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict
from langchain.adapters.openai import _convert_dict_to_message, _convert_message_to_dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@@ -81,7 +81,7 @@ def _create_retry_decorator(
)
async def acompletion_with_retry(
async def _acompletion_with_retry(
llm: ChatOpenAI,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
@@ -286,17 +286,17 @@ class ChatOpenAI(BaseChatModel):
**self.model_kwargs,
}
def completion_with_retry(
def _completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
def __completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
return __completion_with_retry(**kwargs)
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
@@ -323,7 +323,7 @@ class ChatOpenAI(BaseChatModel):
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(
for chunk in self._completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk["choices"]) == 0:
@@ -357,7 +357,7 @@ class ChatOpenAI(BaseChatModel):
return _generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
response = self._completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
@@ -370,13 +370,13 @@ class ChatOpenAI(BaseChatModel):
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [convert_message_to_dict(m) for m in messages]
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response["choices"]:
message = convert_dict_to_message(res["message"])
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.get("finish_reason")),
@@ -397,7 +397,7 @@ class ChatOpenAI(BaseChatModel):
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
async for chunk in await acompletion_with_retry(
async for chunk in await _acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk["choices"]) == 0:
@@ -432,7 +432,7 @@ class ChatOpenAI(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(
response = await _acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
@@ -528,7 +528,7 @@ class ChatOpenAI(BaseChatModel):
"information on how messages are converted to tokens."
)
num_tokens = 0
messages_dict = [convert_message_to_dict(m) for m in messages]
messages_dict = [_convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():

View File

@@ -67,7 +67,7 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
return ChatMessage(content=_dict["content"], role=role)
def convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
@@ -266,14 +266,14 @@ class ChatTongyi(BaseChatModel):
**self.model_kwargs,
}
def completion_with_retry(
def _completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**_kwargs: Any) -> Any:
def __completion_with_retry(**_kwargs: Any) -> Any:
resp = self.client.call(**_kwargs)
if resp.status_code == 200:
return resp
@@ -289,19 +289,19 @@ class ChatTongyi(BaseChatModel):
response=resp,
)
return _completion_with_retry(**kwargs)
return __completion_with_retry(**kwargs)
def stream_completion_with_retry(
def _stream_completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator
def _stream_completion_with_retry(**_kwargs: Any) -> Any:
def __stream_completion_with_retry(**_kwargs: Any) -> Any:
return self.client.call(**_kwargs)
return _stream_completion_with_retry(**kwargs)
return __stream_completion_with_retry(**kwargs)
def _generate(
self,
@@ -320,7 +320,7 @@ class ChatTongyi(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
response = self._completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
@@ -337,7 +337,7 @@ class ChatTongyi(BaseChatModel):
# Mark current chunk total length
length = 0
default_chunk_class = AIMessageChunk
for chunk in self.stream_completion_with_retry(
for chunk in self._stream_completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk["output"]["choices"]) == 0:
@@ -368,7 +368,7 @@ class ChatTongyi(BaseChatModel):
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [convert_message_to_dict(m) for m in messages]
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _client_params(self) -> Dict[str, Any]:

View File

@@ -43,7 +43,7 @@ class _FileType(str, Enum):
PDF = "pdf"
def fetch_mime_types(file_types: Sequence[_FileType]) -> Dict[str, str]:
def _fetch_mime_types(file_types: Sequence[_FileType]) -> Dict[str, str]:
mime_types_mapping = {}
for file_type in file_types:
if file_type.value == "doc":
@@ -73,7 +73,7 @@ class O365BaseLoader(BaseLoader, BaseModel):
@property
def _fetch_mime_types(self) -> Dict[str, str]:
"""Return a dict of supported file types to corresponding mime types."""
return fetch_mime_types(self._file_types)
return _fetch_mime_types(self._file_types)
@property
@abstractmethod

View File

@@ -6,7 +6,7 @@ from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
def concatenate_rows(message: dict, title: str) -> str:
def _concatenate_rows(message: dict, title: str) -> str:
"""
Combine message information in a readable format ready to be used.
Args:
@@ -50,7 +50,7 @@ class ChatGPTLoader(BaseLoader):
messages = d["mapping"]
text = "".join(
[
concatenate_rows(messages[key]["message"], title)
_concatenate_rows(messages[key]["message"], title)
for idx, key in enumerate(messages)
if not (
idx == 0

View File

@@ -7,7 +7,7 @@ from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
def concatenate_rows(row: dict) -> str:
def _concatenate_rows(row: dict) -> str:
"""Combine message information in a readable format ready to be used.
Args:
@@ -36,7 +36,7 @@ class FacebookChatLoader(BaseLoader):
d = json.load(f)
text = "".join(
concatenate_rows(message)
_concatenate_rows(message)
for message in d["messages"]
if message.get("content") and isinstance(message["content"], str)
)

View File

@@ -7,7 +7,7 @@ from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
def concatenate_cells(
def _concatenate_cells(
cell: dict, include_outputs: bool, max_output_length: int, traceback: bool
) -> str:
"""Combine cells information in a readable format ready to be used.
@@ -55,16 +55,16 @@ def concatenate_cells(
return ""
def remove_newlines(x: Any) -> Any:
def _remove_newlines(x: Any) -> Any:
"""Recursively remove newlines, no matter the data structure they are stored in."""
import pandas as pd
if isinstance(x, str):
return x.replace("\n", "")
elif isinstance(x, list):
return [remove_newlines(elem) for elem in x]
return [_remove_newlines(elem) for elem in x]
elif isinstance(x, pd.DataFrame):
return x.applymap(remove_newlines)
return x.applymap(_remove_newlines)
else:
return x
@@ -118,10 +118,10 @@ class NotebookLoader(BaseLoader):
data = pd.json_normalize(d["cells"])
filtered_data = data[["cell_type", "source", "outputs"]]
if self.remove_newline:
filtered_data = filtered_data.applymap(remove_newlines)
filtered_data = filtered_data.applymap(_remove_newlines)
text = filtered_data.apply(
lambda x: concatenate_cells(
lambda x: _concatenate_cells(
x, self.include_outputs, self.max_output_length, self.traceback
),
axis=1,

View File

@@ -4,7 +4,7 @@ from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
def default_joiner(docs: List[Tuple[str, Any]]) -> str:
def _default_joiner(docs: List[Tuple[str, Any]]) -> str:
"""Default joiner for content columns."""
return "\n".join([doc[1] for doc in docs])
@@ -47,7 +47,9 @@ class RocksetLoader(BaseLoader):
query: Any,
content_keys: List[str],
metadata_keys: Optional[List[str]] = None,
content_columns_joiner: Callable[[List[Tuple[str, Any]]], str] = default_joiner,
content_columns_joiner: Callable[
[List[Tuple[str, Any]]], str
] = _default_joiner,
):
"""Initialize with Rockset client.

View File

@@ -14,7 +14,7 @@ if TYPE_CHECKING:
from telethon.hints import EntityLike
def concatenate_rows(row: dict) -> str:
def _concatenate_rows(row: dict) -> str:
"""Combine message information in a readable format ready to be used."""
date = row["date"]
sender = row["from"]
@@ -37,7 +37,7 @@ class TelegramChatFileLoader(BaseLoader):
d = json.load(f)
text = "".join(
concatenate_rows(message)
_concatenate_rows(message)
for message in d["messages"]
if message["type"] == "message" and isinstance(message["text"], str)
)
@@ -46,7 +46,7 @@ class TelegramChatFileLoader(BaseLoader):
return [Document(page_content=text, metadata=metadata)]
def text_to_docs(text: Union[str, List[str]]) -> List[Document]:
def _text_to_docs(text: Union[str, List[str]]) -> List[Document]:
"""Convert a string or list of strings to a list of Documents with metadata."""
if isinstance(text, str):
# Take a single string as one page
@@ -258,4 +258,4 @@ class TelegramChatApiLoader(BaseLoader):
message_threads = self._get_message_threads(df)
combined_texts = self._combine_message_texts(message_threads, df)
return text_to_docs(combined_texts)
return _text_to_docs(combined_texts)

View File

@@ -6,7 +6,7 @@ from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
def concatenate_rows(date: str, sender: str, text: str) -> str:
def _concatenate_rows(date: str, sender: str, text: str) -> str:
"""Combine message information in a readable format ready to be used."""
return f"{sender} on {date}: {text}\n\n"
@@ -57,7 +57,7 @@ class WhatsAppChatLoader(BaseLoader):
if result:
date, sender, text = result.groups()
if text not in ignore_lines:
text_content += concatenate_rows(date, sender, text)
text_content += _concatenate_rows(date, sender, text)
metadata = {"source": str(p)}

View File

@@ -40,12 +40,12 @@ def _create_retry_decorator(embeddings: DashScopeEmbeddings) -> Callable[[Any],
)
def embed_with_retry(embeddings: DashScopeEmbeddings, **kwargs: Any) -> Any:
def _embed_with_retry(embeddings: DashScopeEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""
retry_decorator = _create_retry_decorator(embeddings)
@retry_decorator
def _embed_with_retry(**kwargs: Any) -> Any:
def __embed_with_retry(**kwargs: Any) -> Any:
resp = embeddings.client.call(**kwargs)
if resp.status_code == 200:
return resp.output["embeddings"]
@@ -61,7 +61,7 @@ def embed_with_retry(embeddings: DashScopeEmbeddings, **kwargs: Any) -> Any:
response=resp,
)
return _embed_with_retry(**kwargs)
return __embed_with_retry(**kwargs)
class DashScopeEmbeddings(BaseModel, Embeddings):
@@ -135,7 +135,7 @@ class DashScopeEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
embeddings = embed_with_retry(
embeddings = _embed_with_retry(
self, input=texts, text_type="document", model=self.model
)
embedding_list = [item["embedding"] for item in embeddings]
@@ -150,7 +150,7 @@ class DashScopeEmbeddings(BaseModel, Embeddings):
Returns:
Embedding for the text.
"""
embedding = embed_with_retry(
embedding = _embed_with_retry(
self, input=text, text_type="query", model=self.model
)[0]["embedding"]
return embedding

View File

@@ -40,17 +40,17 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
)
def embed_with_retry(
def _embed_with_retry(
embeddings: GooglePalmEmbeddings, *args: Any, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator()
@retry_decorator
def _embed_with_retry(*args: Any, **kwargs: Any) -> Any:
def __embed_with_retry(*args: Any, **kwargs: Any) -> Any:
return embeddings.client.generate_embeddings(*args, **kwargs)
return _embed_with_retry(*args, **kwargs)
return __embed_with_retry(*args, **kwargs)
class GooglePalmEmbeddings(BaseModel, Embeddings):
@@ -83,5 +83,5 @@ class GooglePalmEmbeddings(BaseModel, Embeddings):
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
embedding = embed_with_retry(self, self.model_name, text)
embedding = _embed_with_retry(self, self.model_name, text)
return embedding["embedding"]

View File

@@ -94,27 +94,27 @@ def _check_response(response: dict) -> dict:
return response
def embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
def _embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""
retry_decorator = _create_retry_decorator(embeddings)
@retry_decorator
def _embed_with_retry(**kwargs: Any) -> Any:
def __embed_with_retry(**kwargs: Any) -> Any:
response = embeddings.client.create(**kwargs)
return _check_response(response)
return _embed_with_retry(**kwargs)
return __embed_with_retry(**kwargs)
async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
async def _async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""
@_async_retry_decorator(embeddings)
async def _async_embed_with_retry(**kwargs: Any) -> Any:
async def __async_embed_with_retry(**kwargs: Any) -> Any:
response = await embeddings.client.acreate(**kwargs)
return _check_response(response)
return await _async_embed_with_retry(**kwargs)
return await __async_embed_with_retry(**kwargs)
class LocalAIEmbeddings(BaseModel, Embeddings):
@@ -265,13 +265,13 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return embed_with_retry(
return _embed_with_retry(
self,
input=[text],
**self._invocation_params,
)["data"][
0
]["embedding"]
)[
"data"
][0]["embedding"]
async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to LocalAI's embedding endpoint."""
@@ -281,7 +281,7 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (
await async_embed_with_retry(
await _async_embed_with_retry(
self,
input=[text],
**self._invocation_params,

View File

@@ -34,15 +34,15 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
)
def embed_with_retry(embeddings: MiniMaxEmbeddings, *args: Any, **kwargs: Any) -> Any:
def _embed_with_retry(embeddings: MiniMaxEmbeddings, *args: Any, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator()
@retry_decorator
def _embed_with_retry(*args: Any, **kwargs: Any) -> Any:
def __embed_with_retry(*args: Any, **kwargs: Any) -> Any:
return embeddings.embed(*args, **kwargs)
return _embed_with_retry(*args, **kwargs)
return __embed_with_retry(*args, **kwargs)
class MiniMaxEmbeddings(BaseModel, Embeddings):
@@ -144,7 +144,7 @@ class MiniMaxEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
embeddings = embed_with_retry(self, texts=texts, embed_type=self.embed_type_db)
embeddings = _embed_with_retry(self, texts=texts, embed_type=self.embed_type_db)
return embeddings
def embed_query(self, text: str) -> List[float]:
@@ -156,7 +156,7 @@ class MiniMaxEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text.
"""
embeddings = embed_with_retry(
embeddings = _embed_with_retry(
self, texts=[text], embed_type=self.embed_type_query
)
return embeddings[0]

View File

@@ -95,27 +95,27 @@ def _check_response(response: dict, skip_empty: bool = False) -> dict:
return response
def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
def _embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""
retry_decorator = _create_retry_decorator(embeddings)
@retry_decorator
def _embed_with_retry(**kwargs: Any) -> Any:
def __embed_with_retry(**kwargs: Any) -> Any:
response = embeddings.client.create(**kwargs)
return _check_response(response, skip_empty=embeddings.skip_empty)
return _embed_with_retry(**kwargs)
return __embed_with_retry(**kwargs)
async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
async def _async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""
@_async_retry_decorator(embeddings)
async def _async_embed_with_retry(**kwargs: Any) -> Any:
async def __async_embed_with_retry(**kwargs: Any) -> Any:
response = await embeddings.client.acreate(**kwargs)
return _check_response(response, skip_empty=embeddings.skip_empty)
return await _async_embed_with_retry(**kwargs)
return await __async_embed_with_retry(**kwargs)
class OpenAIEmbeddings(BaseModel, Embeddings):
@@ -371,7 +371,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
_iter = range(0, len(tokens), _chunk_size)
for i in _iter:
response = embed_with_retry(
response = _embed_with_retry(
self,
input=tokens[i : i + _chunk_size],
**self._invocation_params,
@@ -389,7 +389,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
for i in range(len(texts)):
_result = results[i]
if len(_result) == 0:
average = embed_with_retry(
average = _embed_with_retry(
self,
input="",
**self._invocation_params,
@@ -443,7 +443,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
batched_embeddings: List[List[float]] = []
_chunk_size = chunk_size or self.chunk_size
for i in range(0, len(tokens), _chunk_size):
response = await async_embed_with_retry(
response = await _async_embed_with_retry(
self,
input=tokens[i : i + _chunk_size],
**self._invocation_params,
@@ -460,7 +460,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
_result = results[i]
if len(_result) == 0:
average = (
await async_embed_with_retry(
await _async_embed_with_retry(
self,
input="",
**self._invocation_params,

View File

@@ -17,8 +17,8 @@ from langchain.callbacks.manager import (
)
from langchain.llms.openai import (
BaseOpenAI,
acompletion_with_retry,
completion_with_retry,
_acompletion_with_retry,
_completion_with_retry,
)
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema import Generation, LLMResult
@@ -162,7 +162,7 @@ class Anyscale(BaseOpenAI):
) -> Iterator[GenerationChunk]:
messages, params = self._get_chat_messages([prompt], stop)
params = {**params, **kwargs, "stream": True}
for stream_resp in completion_with_retry(
for stream_resp in _completion_with_retry(
self, messages=messages, run_manager=run_manager, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "")
@@ -180,7 +180,7 @@ class Anyscale(BaseOpenAI):
) -> AsyncIterator[GenerationChunk]:
messages, params = self._get_chat_messages([prompt], stop)
params = {**params, **kwargs, "stream": True}
async for stream_resp in await acompletion_with_retry(
async for stream_resp in await _acompletion_with_retry(
self, messages=messages, run_manager=run_manager, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "")
@@ -223,7 +223,7 @@ class Anyscale(BaseOpenAI):
else:
messages, params = self._get_chat_messages([prompt], stop)
params = {**params, **kwargs}
response = completion_with_retry(
response = _completion_with_retry(
self, messages=messages, run_manager=run_manager, **params
)
choices.extend(response["choices"])
@@ -264,7 +264,7 @@ class Anyscale(BaseOpenAI):
else:
messages, params = self._get_chat_messages([prompt], stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(
response = await _acompletion_with_retry(
self, messages=messages, run_manager=run_manager, **params
)
choices.extend(response["choices"])

View File

@@ -40,18 +40,18 @@ def _create_retry_decorator(llm: Cohere) -> Callable[[Any], Any]:
)
def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
def _completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
def __completion_with_retry(**kwargs: Any) -> Any:
return llm.client.generate(**kwargs)
return _completion_with_retry(**kwargs)
return __completion_with_retry(**kwargs)
def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
def _acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
@@ -206,7 +206,7 @@ class Cohere(LLM, BaseCohere):
response = cohere("Tell me a joke.")
"""
params = self._invocation_params(stop, **kwargs)
response = completion_with_retry(
response = _completion_with_retry(
self, model=self.model, prompt=prompt, **params
)
_stop = params.get("stop_sequences")
@@ -234,7 +234,7 @@ class Cohere(LLM, BaseCohere):
response = await cohere("Tell me a joke.")
"""
params = self._invocation_params(stop, **kwargs)
response = await acompletion_with_retry(
response = await _acompletion_with_retry(
self, model=self.model, prompt=prompt, **params
)
_stop = params.get("stop_sequences")

View File

@@ -73,7 +73,7 @@ class Fireworks(LLM):
"prompt": prompt,
**self.model_kwargs,
}
response = completion_with_retry(
response = _completion_with_retry(
self, run_manager=run_manager, stop=stop, **params
)
@@ -92,7 +92,7 @@ class Fireworks(LLM):
"prompt": prompt,
**self.model_kwargs,
}
response = await acompletion_with_retry(
response = await _acompletion_with_retry(
self, run_manager=run_manager, stop=stop, **params
)
@@ -111,7 +111,7 @@ class Fireworks(LLM):
"stream": True,
**self.model_kwargs,
}
for stream_resp in completion_with_retry(
for stream_resp in _completion_with_retry(
self, run_manager=run_manager, stop=stop, **params
):
chunk = _stream_response_to_generation_chunk(stream_resp)
@@ -132,7 +132,7 @@ class Fireworks(LLM):
"stream": True,
**self.model_kwargs,
}
async for stream_resp in await acompletion_with_retry_streaming(
async for stream_resp in await _acompletion_with_retry_streaming(
self, run_manager=run_manager, stop=stop, **params
):
chunk = _stream_response_to_generation_chunk(stream_resp)
@@ -177,7 +177,7 @@ class Fireworks(LLM):
assert generation is not None
def completion_with_retry(
def _completion_with_retry(
llm: Fireworks,
*,
run_manager: Optional[CallbackManagerForLLMRun] = None,
@@ -189,15 +189,15 @@ def completion_with_retry(
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
def __completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.Completion.create(
**kwargs,
)
return _completion_with_retry(**kwargs)
return __completion_with_retry(**kwargs)
async def acompletion_with_retry(
async def _acompletion_with_retry(
llm: Fireworks,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
@@ -217,7 +217,7 @@ async def acompletion_with_retry(
return await _completion_with_retry(**kwargs)
async def acompletion_with_retry_streaming(
async def _acompletion_with_retry_streaming(
llm: Fireworks,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,

View File

@@ -100,7 +100,7 @@ def _create_retry_decorator(
)
def completion_with_retry(
def _completion_with_retry(
llm: Union[BaseOpenAI, OpenAIChat],
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
@@ -109,13 +109,13 @@ def completion_with_retry(
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
def __completion_with_retry(**kwargs: Any) -> Any:
return llm.client.create(**kwargs)
return _completion_with_retry(**kwargs)
return __completion_with_retry(**kwargs)
async def acompletion_with_retry(
async def _acompletion_with_retry(
llm: Union[BaseOpenAI, OpenAIChat],
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
@@ -305,7 +305,7 @@ class BaseOpenAI(BaseLLM):
) -> Iterator[GenerationChunk]:
params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutates params
for stream_resp in completion_with_retry(
for stream_resp in _completion_with_retry(
self, prompt=prompt, run_manager=run_manager, **params
):
chunk = _stream_response_to_generation_chunk(stream_resp)
@@ -329,7 +329,7 @@ class BaseOpenAI(BaseLLM):
) -> AsyncIterator[GenerationChunk]:
params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutate params
async for stream_resp in await acompletion_with_retry(
async for stream_resp in await _acompletion_with_retry(
self, prompt=prompt, run_manager=run_manager, **params
):
chunk = _stream_response_to_generation_chunk(stream_resp)
@@ -398,7 +398,7 @@ class BaseOpenAI(BaseLLM):
}
)
else:
response = completion_with_retry(
response = _completion_with_retry(
self, prompt=_prompts, run_manager=run_manager, **params
)
choices.extend(response["choices"])
@@ -447,7 +447,7 @@ class BaseOpenAI(BaseLLM):
}
)
else:
response = await acompletion_with_retry(
response = await _acompletion_with_retry(
self, prompt=_prompts, run_manager=run_manager, **params
)
choices.extend(response["choices"])
@@ -847,7 +847,7 @@ class OpenAIChat(BaseLLM):
) -> Iterator[GenerationChunk]:
messages, params = self._get_chat_params([prompt], stop)
params = {**params, **kwargs, "stream": True}
for stream_resp in completion_with_retry(
for stream_resp in _completion_with_retry(
self, messages=messages, run_manager=run_manager, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "")
@@ -865,7 +865,7 @@ class OpenAIChat(BaseLLM):
) -> AsyncIterator[GenerationChunk]:
messages, params = self._get_chat_params([prompt], stop)
params = {**params, **kwargs, "stream": True}
async for stream_resp in await acompletion_with_retry(
async for stream_resp in await _acompletion_with_retry(
self, messages=messages, run_manager=run_manager, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "")
@@ -893,7 +893,7 @@ class OpenAIChat(BaseLLM):
messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
full_response = completion_with_retry(
full_response = _completion_with_retry(
self, messages=messages, run_manager=run_manager, **params
)
llm_output = {
@@ -926,7 +926,7 @@ class OpenAIChat(BaseLLM):
messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
full_response = await acompletion_with_retry(
full_response = await _acompletion_with_retry(
self, messages=messages, run_manager=run_manager, **params
)
llm_output = {

View File

@@ -167,7 +167,7 @@ class Nebula(LLM):
else:
raise ValueError("Prompt must contain instruction and conversation.")
response = completion_with_retry(
response = _completion_with_retry(
self,
instruction=instruction,
conversation=conversation,
@@ -232,12 +232,12 @@ def _create_retry_decorator(llm: Nebula) -> Callable[[Any], Any]:
)
def completion_with_retry(llm: Nebula, **kwargs: Any) -> Any:
def _completion_with_retry(llm: Nebula, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
@retry_decorator
def _completion_with_retry(**_kwargs: Any) -> Any:
def __completion_with_retry(**_kwargs: Any) -> Any:
return make_request(llm, **_kwargs)
return _completion_with_retry(**kwargs)
return __completion_with_retry(**kwargs)

View File

@@ -86,7 +86,7 @@ def _create_retry_decorator(
return decorator
def completion_with_retry(
def _completion_with_retry(
llm: VertexAI,
*args: Any,
run_manager: Optional[CallbackManagerForLLMRun] = None,
@@ -96,13 +96,13 @@ def completion_with_retry(
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
def __completion_with_retry(*args: Any, **kwargs: Any) -> Any:
return llm.client.predict(*args, **kwargs)
return _completion_with_retry(*args, **kwargs)
return __completion_with_retry(*args, **kwargs)
def stream_completion_with_retry(
def _stream_completion_with_retry(
llm: VertexAI,
*args: Any,
run_manager: Optional[CallbackManagerForLLMRun] = None,
@@ -118,7 +118,7 @@ def stream_completion_with_retry(
return _completion_with_retry(*args, **kwargs)
async def acompletion_with_retry(
async def _acompletion_with_retry(
llm: VertexAI,
*args: Any,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
@@ -128,10 +128,10 @@ async def acompletion_with_retry(
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
async def __acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
return await llm.client.predict_async(*args, **kwargs)
return await _acompletion_with_retry(*args, **kwargs)
return await __acompletion_with_retry(*args, **kwargs)
class _VertexAIBase(BaseModel):
@@ -295,7 +295,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
generation += chunk
generations.append([generation])
else:
res = completion_with_retry(
res = _completion_with_retry(
self, prompt, run_manager=run_manager, **params
)
generations.append([_response_to_generation(r) for r in res.candidates])
@@ -311,7 +311,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
params = self._prepare_params(stop=stop, **kwargs)
generations = []
for prompt in prompts:
res = await acompletion_with_retry(
res = await _acompletion_with_retry(
self, prompt, run_manager=run_manager, **params
)
generations.append([_response_to_generation(r) for r in res.candidates])
@@ -325,7 +325,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = self._prepare_params(stop=stop, stream=True, **kwargs)
for stream_resp in stream_completion_with_retry(
for stream_resp in _stream_completion_with_retry(
self, prompt, run_manager=run_manager, **params
):
chunk = _response_to_generation(stream_resp)

View File

@@ -5,7 +5,7 @@ from typing import List
import pytest
from langchain.chat_models import ChatAnthropic
from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
from langchain.chat_models.anthropic import _convert_messages_to_prompt_anthropic
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
os.environ["ANTHROPIC_API_KEY"] = "foo"
@@ -69,5 +69,5 @@ def test_anthropic_initialization() -> None:
],
)
def test_formatting(messages: List[BaseMessage], expected: str) -> None:
result = convert_messages_to_prompt_anthropic(messages)
result = _convert_messages_to_prompt_anthropic(messages)
assert result == expected

View File

@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
import pytest
from langchain.adapters.openai import convert_dict_to_message
from langchain.adapters.openai import _convert_dict_to_message
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema.messages import (
AIMessage,
@@ -26,7 +26,7 @@ def test_openai_model_param() -> None:
def test_function_message_dict_to_function_message() -> None:
content = json.dumps({"result": "Example #1"})
name = "test_function"
result = convert_dict_to_message(
result = _convert_dict_to_message(
{
"role": "function",
"name": name,
@@ -40,21 +40,21 @@ def test_function_message_dict_to_function_message() -> None:
def test__convert_dict_to_message_human() -> None:
message = {"role": "user", "content": "foo"}
result = convert_dict_to_message(message)
result = _convert_dict_to_message(message)
expected_output = HumanMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_ai() -> None:
message = {"role": "assistant", "content": "foo"}
result = convert_dict_to_message(message)
result = _convert_dict_to_message(message)
expected_output = AIMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_system() -> None:
message = {"role": "system", "content": "foo"}
result = convert_dict_to_message(message)
result = _convert_dict_to_message(message)
expected_output = SystemMessage(content="foo")
assert result == expected_output