This commit is contained in:
Bagatur
2023-10-13 15:25:07 -07:00
parent 0d37b4c27d
commit b17b87ae04
20 changed files with 121 additions and 121 deletions

View File

@@ -28,7 +28,7 @@ from langchain.schema.messages import (
)
async def aenumerate(
async def _aenumerate(
iterable: AsyncIterator[Any], start: int = 0
) -> AsyncIterator[tuple[int, Any]]:
"""Async version of enumerate."""
@@ -38,7 +38,7 @@ async def aenumerate(
i += 1
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
@@ -59,7 +59,7 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
return ChatMessage(content=_dict["content"], role=role)
def convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
@@ -87,7 +87,7 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
return message_dict
def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]:
def _convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]:
"""Convert dictionaries representing OpenAI messages to LangChain format.
Args:
@@ -96,7 +96,7 @@ def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMess
Returns:
List of LangChain BaseMessage objects.
"""
return [convert_dict_to_message(m) for m in messages]
return [_convert_dict_to_message(m) for m in messages]
def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]:
@@ -155,10 +155,10 @@ class ChatCompletion:
models = importlib.import_module("langchain.chat_models")
model_cls = getattr(models, provider)
model_config = model_cls(**kwargs)
converted_messages = convert_openai_messages(messages)
converted_messages = _convert_openai_messages(messages)
if not stream:
result = model_config.invoke(converted_messages)
return {"choices": [{"message": convert_message_to_dict(result)}]}
return {"choices": [{"message": _convert_message_to_dict(result)}]}
else:
return (
_convert_message_chunk_to_delta(c, i)
@@ -198,14 +198,14 @@ class ChatCompletion:
models = importlib.import_module("langchain.chat_models")
model_cls = getattr(models, provider)
model_config = model_cls(**kwargs)
converted_messages = convert_openai_messages(messages)
converted_messages = _convert_openai_messages(messages)
if not stream:
result = await model_config.ainvoke(converted_messages)
return {"choices": [{"message": convert_message_to_dict(result)}]}
return {"choices": [{"message": _convert_message_to_dict(result)}]}
else:
return (
_convert_message_chunk_to_delta(c, i)
async for i, c in aenumerate(model_config.astream(converted_messages))
async for i, c in _aenumerate(model_config.astream(converted_messages))
)
@@ -214,12 +214,12 @@ def _has_assistant_message(session: ChatSession) -> bool:
return any([isinstance(m, AIMessage) for m in session["messages"]])
def convert_messages_for_finetuning(
def _convert_messages_for_finetuning(
sessions: Iterable[ChatSession],
) -> List[List[dict]]:
"""Convert messages to a list of lists of dictionaries for fine-tuning."""
return [
[convert_message_to_dict(s) for s in session["messages"]]
[_convert_message_to_dict(s) for s in session["messages"]]
for session in sessions
if _has_assistant_message(session)
]

View File

@@ -5,7 +5,7 @@ from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
def import_aim() -> Any:
def _import_aim() -> Any:
"""Import the aim python package and raise an error if it is not installed."""
try:
import aim
@@ -169,7 +169,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
super().__init__()
aim = import_aim()
aim = _import_aim()
self.repo = repo
self.experiment_name = experiment_name
self.system_tracking_interval = system_tracking_interval
@@ -184,7 +184,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.action_records: list = []
def setup(self, **kwargs: Any) -> None:
aim = import_aim()
aim = _import_aim()
if not self._run:
if self._run_hash:
@@ -210,7 +210,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
aim = import_aim()
aim = _import_aim()
self.step += 1
self.llm_starts += 1
@@ -229,7 +229,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
aim = import_aim()
aim = _import_aim()
self.step += 1
self.llm_ends += 1
self.ends += 1
@@ -264,7 +264,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
aim = import_aim()
aim = _import_aim()
self.step += 1
self.chain_starts += 1
self.starts += 1
@@ -280,7 +280,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
aim = import_aim()
aim = _import_aim()
self.step += 1
self.chain_ends += 1
self.ends += 1
@@ -303,7 +303,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
aim = import_aim()
aim = _import_aim()
self.step += 1
self.tool_starts += 1
self.starts += 1
@@ -315,7 +315,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
aim = import_aim()
aim = _import_aim()
self.step += 1
self.tool_ends += 1
self.ends += 1
@@ -339,7 +339,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
aim = import_aim()
aim = _import_aim()
self.step += 1
self.agent_ends += 1
self.ends += 1
@@ -356,7 +356,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
aim = import_aim()
aim = _import_aim()
self.step += 1
self.tool_starts += 1
self.starts += 1

View File

@@ -2,7 +2,7 @@ from datetime import datetime
from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import import_pandas
from langchain.callbacks.utils import _import_pandas
from langchain.schema import AgentAction, AgentFinish, LLMResult
@@ -60,7 +60,7 @@ class ArizeCallbackHandler(BaseCallbackHandler):
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pd = import_pandas()
pd = _import_pandas()
from arize.utils.types import (
EmbeddingColumnNames,
Environments,

View File

@@ -8,11 +8,11 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
_import_pandas,
_import_spacy,
_import_textstat,
flatten_dict,
hash_string,
import_pandas,
import_spacy,
import_textstat,
load_json,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
import pandas as pd
def import_clearml() -> Any:
def _import_clearml() -> Any:
"""Import the clearml python package and raise an error if it is not installed."""
try:
import clearml # noqa: F401
@@ -63,8 +63,8 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) -> None:
"""Initialize callback handler."""
clearml = import_clearml()
spacy = import_spacy()
clearml = _import_clearml()
spacy = _import_spacy()
super().__init__()
self.task_type = task_type
@@ -329,8 +329,8 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
(dict): A dictionary containing the complexity metrics.
"""
resp = {}
textstat = import_textstat()
spacy = import_spacy()
textstat = _import_textstat()
spacy = _import_spacy()
if self.complexity_metrics:
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
@@ -399,7 +399,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def _create_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session."""
pd = import_pandas()
pd = _import_pandas()
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df(
@@ -465,8 +465,8 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
Returns:
None
"""
pd = import_pandas()
clearml = import_clearml()
pd = _import_pandas()
clearml = _import_clearml()
# Log the action records
self.logger.report_table(

View File

@@ -7,17 +7,17 @@ import langchain
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
_import_pandas,
_import_spacy,
_import_textstat,
flatten_dict,
import_pandas,
import_spacy,
import_textstat,
)
from langchain.schema import AgentAction, AgentFinish, Generation, LLMResult
LANGCHAIN_MODEL_NAME = "langchain-model"
def import_comet_ml() -> Any:
def _import_comet_ml() -> Any:
"""Import comet_ml and raise an error if it is not installed."""
try:
import comet_ml # noqa: F401
@@ -33,7 +33,7 @@ def import_comet_ml() -> Any:
def _get_experiment(
workspace: Optional[str] = None, project_name: Optional[str] = None
) -> Any:
comet_ml = import_comet_ml()
comet_ml = _import_comet_ml()
experiment = comet_ml.Experiment( # type: ignore
workspace=workspace,
@@ -44,7 +44,7 @@ def _get_experiment(
def _fetch_text_complexity_metrics(text: str) -> dict:
textstat = import_textstat()
textstat = _import_textstat()
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
@@ -67,7 +67,7 @@ def _fetch_text_complexity_metrics(text: str) -> dict:
def _summarize_metrics_for_generated_outputs(metrics: Sequence) -> dict:
pd = import_pandas()
pd = _import_pandas()
metrics_df = pd.DataFrame(metrics)
metrics_summary = metrics_df.describe()
@@ -107,7 +107,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) -> None:
"""Initialize callback handler."""
self.comet_ml = import_comet_ml()
self.comet_ml = _import_comet_ml()
super().__init__()
self.task_type = task_type
@@ -140,7 +140,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.action_records: list = []
self.complexity_metrics = complexity_metrics
if self.visualizations:
spacy = import_spacy()
spacy = _import_spacy()
self.nlp = spacy.load("en_core_web_sm")
else:
self.nlp = None
@@ -535,7 +535,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
if not (self.visualizations and self.nlp):
return
spacy = import_spacy()
spacy = _import_spacy()
prompts = session_df["prompts"].tolist()
outputs = session_df["text"].tolist()
@@ -603,7 +603,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.temp_dir = tempfile.TemporaryDirectory()
def _create_session_analysis_dataframe(self, langchain_asset: Any = None) -> dict:
pd = import_pandas()
pd = _import_pandas()
llm_parameters = self._get_llm_parameters(langchain_asset)
num_generations_per_prompt = llm_parameters.get("n", 1)

View File

@@ -10,7 +10,7 @@ from langchain.schema import (
)
def import_context() -> Any:
def _import_context() -> Any:
"""Import the `getcontext` package."""
try:
import getcontext # noqa: F401
@@ -98,7 +98,7 @@ class ContextCallbackHandler(BaseCallbackHandler):
self.message_model,
self.message_role_model,
self.rating_model,
) = import_context()
) = _import_context()
token = token or os.environ.get("CONTEXT_TOKEN") or ""

View File

@@ -8,10 +8,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
_import_pandas,
_import_spacy,
_import_textstat,
flatten_dict,
import_pandas,
import_spacy,
import_textstat,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def import_flytekit() -> Tuple[flytekit, renderer]:
def _import_flytekit() -> Tuple[flytekit, renderer]:
"""Import flytekit and flytekitplugins-deck-standard."""
try:
import flytekit # noqa: F401
@@ -75,7 +75,7 @@ def analyze_text(
resp.update(text_complexity_metrics)
if nlp is not None:
spacy = import_spacy()
spacy = _import_spacy()
doc = nlp(text)
dep_out = spacy.displacy.render( # type: ignore
doc, style="dep", jupyter=False, page=True
@@ -97,12 +97,12 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def __init__(self) -> None:
"""Initialize callback handler."""
flytekit, renderer = import_flytekit()
self.pandas = import_pandas()
flytekit, renderer = _import_flytekit()
self.pandas = _import_pandas()
self.textstat = None
try:
self.textstat = import_textstat()
self.textstat = _import_textstat()
except ImportError:
logger.warning(
"Textstat library is not installed. \
@@ -112,7 +112,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
spacy = None
try:
spacy = import_spacy()
spacy = _import_spacy()
except ImportError:
logger.warning(
"Spacy library is not installed. \

View File

@@ -6,7 +6,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.messages import BaseMessage
def import_infino() -> Any:
def _import_infino() -> Any:
"""Import the infino client."""
try:
from infinopy import InfinoClient
@@ -19,7 +19,7 @@ def import_infino() -> Any:
return InfinoClient()
def import_tiktoken() -> Any:
def _import_tiktoken() -> Any:
"""Import tiktoken for counting tokens for OpenAI models."""
try:
import tiktoken
@@ -38,7 +38,7 @@ def get_num_tokens(string: str, openai_model_name: str) -> int:
Official documentation: https://github.com/openai/openai-cookbook/blob/main
/examples/How_to_count_tokens_with_tiktoken.ipynb
"""
tiktoken = import_tiktoken()
tiktoken = _import_tiktoken()
encoding = tiktoken.encoding_for_model(openai_model_name)
num_tokens = len(encoding.encode(string))
@@ -55,7 +55,7 @@ class InfinoCallbackHandler(BaseCallbackHandler):
verbose: bool = False,
) -> None:
# Set Infino client
self.client = import_infino()
self.client = _import_infino()
self.model_id = model_id
self.model_version = model_version
self.verbose = verbose

View File

@@ -23,7 +23,7 @@ class LabelStudioMode(Enum):
CHAT = "chat"
def get_default_label_configs(
def _get_default_label_configs(
mode: Union[str, LabelStudioMode]
) -> Tuple[str, LabelStudioMode]:
"""Get default Label Studio configs for the given mode.
@@ -173,7 +173,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
self.project_config = project_config
self.mode = None
else:
self.project_config, self.mode = get_default_label_configs(mode)
self.project_config, self.mode = _get_default_label_configs(mode)
self.project_id = project_id or os.getenv("LABEL_STUDIO_PROJECT_ID")
if self.project_id is not None:

View File

@@ -10,17 +10,17 @@ from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
_import_pandas,
_import_spacy,
_import_textstat,
flatten_dict,
hash_string,
import_pandas,
import_spacy,
import_textstat,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.utils import get_from_dict_or_env
def import_mlflow() -> Any:
def _import_mlflow() -> Any:
"""Import the mlflow python package and raise an error if it is not installed."""
try:
import mlflow
@@ -47,8 +47,8 @@ def analyze_text(
files serialized to HTML string.
"""
resp: Dict[str, Any] = {}
textstat = import_textstat()
spacy = import_spacy()
textstat = _import_textstat()
spacy = _import_spacy()
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
@@ -127,7 +127,7 @@ class MlflowLogger:
"""
def __init__(self, **kwargs: Any):
self.mlflow = import_mlflow()
self.mlflow = _import_mlflow()
if "DATABRICKS_RUNTIME_VERSION" in os.environ:
self.mlflow.set_tracking_uri("databricks")
self.mlf_expid = self.mlflow.tracking.fluent._get_experiment_id()
@@ -246,10 +246,10 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
tracking_uri: Optional[str] = None,
) -> None:
"""Initialize callback handler."""
import_pandas()
import_textstat()
import_mlflow()
spacy = import_spacy()
_import_pandas()
_import_textstat()
_import_mlflow()
spacy = _import_spacy()
super().__init__()
self.name = name
@@ -547,7 +547,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def _create_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session."""
pd = import_pandas()
pd = _import_pandas()
on_llm_start_records_df = pd.DataFrame(self.records["on_llm_start_records"])
on_llm_end_records_df = pd.DataFrame(self.records["on_llm_end_records"])
@@ -617,7 +617,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
return session_analysis_df
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:
pd = import_pandas()
pd = _import_pandas()
self.mlflg.table("action_records", pd.DataFrame(self.records["action_records"]))
session_analysis_df = self._create_session_analysis_df()
chat_html = session_analysis_df.pop("chat_html")

View File

@@ -3,7 +3,7 @@ from pathlib import Path
from typing import Any, Dict, Iterable, Tuple, Union
def import_spacy() -> Any:
def _import_spacy() -> Any:
"""Import the spacy python package and raise an error if it is not installed."""
try:
import spacy
@@ -15,7 +15,7 @@ def import_spacy() -> Any:
return spacy
def import_pandas() -> Any:
def _import_pandas() -> Any:
"""Import the pandas python package and raise an error if it is not installed."""
try:
import pandas
@@ -27,7 +27,7 @@ def import_pandas() -> Any:
return pandas
def import_textstat() -> Any:
def _import_textstat() -> Any:
"""Import the textstat python package and raise an error if it is not installed."""
try:
import textstat

View File

@@ -7,16 +7,16 @@ from typing import Any, Dict, List, Optional, Sequence, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
BaseMetadataCallbackHandler,
_import_pandas,
_import_spacy,
_import_textstat,
flatten_dict,
hash_string,
import_pandas,
import_spacy,
import_textstat,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
def import_wandb() -> Any:
def _import_wandb() -> Any:
"""Import the wandb python package and raise an error if it is not installed."""
try:
import wandb # noqa: F401
@@ -28,7 +28,7 @@ def import_wandb() -> Any:
return wandb
def load_json_to_dict(json_path: Union[str, Path]) -> dict:
def _load_json_to_dict(json_path: Union[str, Path]) -> dict:
"""Load json file to a dictionary.
Parameters:
@@ -42,7 +42,7 @@ def load_json_to_dict(json_path: Union[str, Path]) -> dict:
return data
def analyze_text(
def _analyze_text(
text: str,
complexity_metrics: bool = True,
visualize: bool = True,
@@ -63,9 +63,9 @@ def analyze_text(
files serialized in a wandb.Html element.
"""
resp = {}
textstat = import_textstat()
wandb = import_wandb()
spacy = import_spacy()
textstat = _import_textstat()
wandb = _import_wandb()
spacy = _import_spacy()
if complexity_metrics:
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
@@ -120,7 +120,7 @@ def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> A
Returns:
(wandb.Html): The html element."""
wandb = import_wandb()
wandb = _import_wandb()
formatted_prompt = prompt.replace("\n", "<br>")
formatted_generation = generation.replace("\n", "<br>")
@@ -173,10 +173,10 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) -> None:
"""Initialize callback handler."""
wandb = import_wandb()
import_pandas()
import_textstat()
spacy = import_spacy()
wandb = _import_wandb()
_import_pandas()
_import_textstat()
spacy = _import_spacy()
super().__init__()
self.job_type = job_type
@@ -269,7 +269,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
generation_resp.update(
analyze_text(
_analyze_text(
generation.text,
complexity_metrics=self.complexity_metrics,
visualize=self.visualize,
@@ -438,7 +438,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def _create_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session."""
pd = import_pandas()
pd = _import_pandas()
on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
@@ -533,8 +533,8 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
Returns:
None
"""
pd = import_pandas()
wandb = import_wandb()
pd = _import_pandas()
wandb = _import_wandb()
action_records_table = wandb.Table(dataframe=pd.DataFrame(self.action_records))
session_analysis_table = wandb.Table(
dataframe=self._create_session_analysis_df()
@@ -554,11 +554,11 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
try:
langchain_asset.save(langchain_asset_path)
model_artifact.add_file(str(langchain_asset_path))
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
model_artifact.metadata = _load_json_to_dict(langchain_asset_path)
except ValueError:
langchain_asset.save_agent(langchain_asset_path)
model_artifact.add_file(str(langchain_asset_path))
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
model_artifact.metadata = _load_json_to_dict(langchain_asset_path)
except NotImplementedError as e:
print("Could not save model.")
print(repr(e))

View File

@@ -12,7 +12,7 @@ if TYPE_CHECKING:
diagnostic_logger = logging.getLogger(__name__)
def import_langkit(
def _import_langkit(
sentiment: bool = False,
toxicity: bool = False,
themes: bool = False,
@@ -159,7 +159,7 @@ class WhyLabsCallbackHandler(BaseCallbackHandler):
WhyLabs writer.
"""
# langkit library will import necessary whylogs libraries
import_langkit(sentiment=sentiment, toxicity=toxicity, themes=themes)
_import_langkit(sentiment=sentiment, toxicity=toxicity, themes=themes)
import whylogs as why
from langkit.callback_handler import get_callback_instance

View File

@@ -149,7 +149,7 @@ class LangSmithDatasetChatLoader(BaseChatLoader):
for data_point in data:
yield ChatSession(
messages=[
oai_adapter.convert_dict_to_message(m)
oai_adapter._convert_dict_to_message(m)
for m in data_point.get("messages", [])
],
functions=data_point.get("functions"),

View File

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Set
import requests
from langchain.adapters.openai import convert_message_to_dict
from langchain.adapters.openai import _convert_message_to_dict
from langchain.chat_models.openai import (
ChatOpenAI,
_import_tiktoken,
@@ -178,7 +178,7 @@ class ChatAnyscale(ChatOpenAI):
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [convert_message_to_dict(m) for m in messages]
messages_dict = [_convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():

View File

@@ -5,7 +5,7 @@ import logging
import sys
from typing import TYPE_CHECKING, Dict, Optional, Set
from langchain.adapters.openai import convert_message_to_dict
from langchain.adapters.openai import _convert_message_to_dict
from langchain.chat_models.openai import (
ChatOpenAI,
_import_tiktoken,
@@ -140,7 +140,7 @@ class ChatEverlyAI(ChatOpenAI):
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [convert_message_to_dict(m) for m in messages]
messages_dict = [_convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():

View File

@@ -10,7 +10,7 @@ from typing import (
Union,
)
from langchain.adapters.openai import convert_message_to_dict
from langchain.adapters.openai import _convert_message_to_dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@@ -168,7 +168,7 @@ class ChatFireworks(BaseChatModel):
def _create_message_dicts(
self, messages: List[BaseMessage]
) -> List[Dict[str, Any]]:
message_dicts = [convert_message_to_dict(m) for m in messages]
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts
def _stream(

View File

@@ -17,7 +17,7 @@ from typing import (
import requests
from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict
from langchain.adapters.openai import _convert_dict_to_message, _convert_message_to_dict
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
@@ -246,13 +246,13 @@ class ChatKonko(ChatOpenAI):
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [convert_message_to_dict(m) for m in messages]
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response["choices"]:
message = convert_dict_to_message(res["message"])
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.get("finish_reason")),

View File

@@ -18,7 +18,7 @@ from typing import (
Union,
)
from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict
from langchain.adapters.openai import _convert_dict_to_message, _convert_message_to_dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@@ -370,13 +370,13 @@ class ChatOpenAI(BaseChatModel):
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [convert_message_to_dict(m) for m in messages]
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response["choices"]:
message = convert_dict_to_message(res["message"])
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.get("finish_reason")),
@@ -528,7 +528,7 @@ class ChatOpenAI(BaseChatModel):
"information on how messages are converted to tokens."
)
num_tokens = 0
messages_dict = [convert_message_to_dict(m) for m in messages]
messages_dict = [_convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():

View File

@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
import pytest
from langchain.adapters.openai import convert_dict_to_message
from langchain.adapters.openai import _convert_dict_to_message
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema.messages import (
AIMessage,
@@ -26,7 +26,7 @@ def test_openai_model_param() -> None:
def test_function_message_dict_to_function_message() -> None:
content = json.dumps({"result": "Example #1"})
name = "test_function"
result = convert_dict_to_message(
result = _convert_dict_to_message(
{
"role": "function",
"name": name,
@@ -40,21 +40,21 @@ def test_function_message_dict_to_function_message() -> None:
def test__convert_dict_to_message_human() -> None:
message = {"role": "user", "content": "foo"}
result = convert_dict_to_message(message)
result = _convert_dict_to_message(message)
expected_output = HumanMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_ai() -> None:
message = {"role": "assistant", "content": "foo"}
result = convert_dict_to_message(message)
result = _convert_dict_to_message(message)
expected_output = AIMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_system() -> None:
message = {"role": "system", "content": "foo"}
result = convert_dict_to_message(message)
result = _convert_dict_to_message(message)
expected_output = SystemMessage(content="foo")
assert result == expected_output