mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
update
This commit is contained in:
@@ -28,7 +28,7 @@ from langchain.schema.messages import (
|
||||
)
|
||||
|
||||
|
||||
async def aenumerate(
|
||||
async def _aenumerate(
|
||||
iterable: AsyncIterator[Any], start: int = 0
|
||||
) -> AsyncIterator[tuple[int, Any]]:
|
||||
"""Async version of enumerate."""
|
||||
@@ -38,7 +38,7 @@ async def aenumerate(
|
||||
i += 1
|
||||
|
||||
|
||||
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
@@ -59,7 +59,7 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
@@ -87,7 +87,7 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
return message_dict
|
||||
|
||||
|
||||
def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]:
|
||||
def _convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]:
|
||||
"""Convert dictionaries representing OpenAI messages to LangChain format.
|
||||
|
||||
Args:
|
||||
@@ -96,7 +96,7 @@ def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMess
|
||||
Returns:
|
||||
List of LangChain BaseMessage objects.
|
||||
"""
|
||||
return [convert_dict_to_message(m) for m in messages]
|
||||
return [_convert_dict_to_message(m) for m in messages]
|
||||
|
||||
|
||||
def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]:
|
||||
@@ -155,10 +155,10 @@ class ChatCompletion:
|
||||
models = importlib.import_module("langchain.chat_models")
|
||||
model_cls = getattr(models, provider)
|
||||
model_config = model_cls(**kwargs)
|
||||
converted_messages = convert_openai_messages(messages)
|
||||
converted_messages = _convert_openai_messages(messages)
|
||||
if not stream:
|
||||
result = model_config.invoke(converted_messages)
|
||||
return {"choices": [{"message": convert_message_to_dict(result)}]}
|
||||
return {"choices": [{"message": _convert_message_to_dict(result)}]}
|
||||
else:
|
||||
return (
|
||||
_convert_message_chunk_to_delta(c, i)
|
||||
@@ -198,14 +198,14 @@ class ChatCompletion:
|
||||
models = importlib.import_module("langchain.chat_models")
|
||||
model_cls = getattr(models, provider)
|
||||
model_config = model_cls(**kwargs)
|
||||
converted_messages = convert_openai_messages(messages)
|
||||
converted_messages = _convert_openai_messages(messages)
|
||||
if not stream:
|
||||
result = await model_config.ainvoke(converted_messages)
|
||||
return {"choices": [{"message": convert_message_to_dict(result)}]}
|
||||
return {"choices": [{"message": _convert_message_to_dict(result)}]}
|
||||
else:
|
||||
return (
|
||||
_convert_message_chunk_to_delta(c, i)
|
||||
async for i, c in aenumerate(model_config.astream(converted_messages))
|
||||
async for i, c in _aenumerate(model_config.astream(converted_messages))
|
||||
)
|
||||
|
||||
|
||||
@@ -214,12 +214,12 @@ def _has_assistant_message(session: ChatSession) -> bool:
|
||||
return any([isinstance(m, AIMessage) for m in session["messages"]])
|
||||
|
||||
|
||||
def convert_messages_for_finetuning(
|
||||
def _convert_messages_for_finetuning(
|
||||
sessions: Iterable[ChatSession],
|
||||
) -> List[List[dict]]:
|
||||
"""Convert messages to a list of lists of dictionaries for fine-tuning."""
|
||||
return [
|
||||
[convert_message_to_dict(s) for s in session["messages"]]
|
||||
[_convert_message_to_dict(s) for s in session["messages"]]
|
||||
for session in sessions
|
||||
if _has_assistant_message(session)
|
||||
]
|
||||
|
||||
@@ -5,7 +5,7 @@ from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
def import_aim() -> Any:
|
||||
def _import_aim() -> Any:
|
||||
"""Import the aim python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import aim
|
||||
@@ -169,7 +169,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
super().__init__()
|
||||
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
self.repo = repo
|
||||
self.experiment_name = experiment_name
|
||||
self.system_tracking_interval = system_tracking_interval
|
||||
@@ -184,7 +184,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.action_records: list = []
|
||||
|
||||
def setup(self, **kwargs: Any) -> None:
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
|
||||
if not self._run:
|
||||
if self._run_hash:
|
||||
@@ -210,7 +210,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
|
||||
self.step += 1
|
||||
self.llm_starts += 1
|
||||
@@ -229,7 +229,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
self.step += 1
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
@@ -264,7 +264,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
self.step += 1
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
@@ -280,7 +280,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
self.step += 1
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
@@ -303,7 +303,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
@@ -315,7 +315,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
self.step += 1
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
@@ -339,7 +339,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
self.step += 1
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
@@ -356,7 +356,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
aim = import_aim()
|
||||
aim = _import_aim()
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import import_pandas
|
||||
from langchain.callbacks.utils import _import_pandas
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
pd = import_pandas()
|
||||
pd = _import_pandas()
|
||||
from arize.utils.types import (
|
||||
EmbeddingColumnNames,
|
||||
Environments,
|
||||
|
||||
@@ -8,11 +8,11 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
_import_pandas,
|
||||
_import_spacy,
|
||||
_import_textstat,
|
||||
flatten_dict,
|
||||
hash_string,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
load_json,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def import_clearml() -> Any:
|
||||
def _import_clearml() -> Any:
|
||||
"""Import the clearml python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import clearml # noqa: F401
|
||||
@@ -63,8 +63,8 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
|
||||
clearml = import_clearml()
|
||||
spacy = import_spacy()
|
||||
clearml = _import_clearml()
|
||||
spacy = _import_spacy()
|
||||
super().__init__()
|
||||
|
||||
self.task_type = task_type
|
||||
@@ -329,8 +329,8 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
(dict): A dictionary containing the complexity metrics.
|
||||
"""
|
||||
resp = {}
|
||||
textstat = import_textstat()
|
||||
spacy = import_spacy()
|
||||
textstat = _import_textstat()
|
||||
spacy = _import_spacy()
|
||||
if self.complexity_metrics:
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
@@ -399,7 +399,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def _create_session_analysis_df(self) -> Any:
|
||||
"""Create a dataframe with all the information from the session."""
|
||||
pd = import_pandas()
|
||||
pd = _import_pandas()
|
||||
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
|
||||
|
||||
llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df(
|
||||
@@ -465,8 +465,8 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pd = import_pandas()
|
||||
clearml = import_clearml()
|
||||
pd = _import_pandas()
|
||||
clearml = _import_clearml()
|
||||
|
||||
# Log the action records
|
||||
self.logger.report_table(
|
||||
|
||||
@@ -7,17 +7,17 @@ import langchain
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
_import_pandas,
|
||||
_import_spacy,
|
||||
_import_textstat,
|
||||
flatten_dict,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, Generation, LLMResult
|
||||
|
||||
LANGCHAIN_MODEL_NAME = "langchain-model"
|
||||
|
||||
|
||||
def import_comet_ml() -> Any:
|
||||
def _import_comet_ml() -> Any:
|
||||
"""Import comet_ml and raise an error if it is not installed."""
|
||||
try:
|
||||
import comet_ml # noqa: F401
|
||||
@@ -33,7 +33,7 @@ def import_comet_ml() -> Any:
|
||||
def _get_experiment(
|
||||
workspace: Optional[str] = None, project_name: Optional[str] = None
|
||||
) -> Any:
|
||||
comet_ml = import_comet_ml()
|
||||
comet_ml = _import_comet_ml()
|
||||
|
||||
experiment = comet_ml.Experiment( # type: ignore
|
||||
workspace=workspace,
|
||||
@@ -44,7 +44,7 @@ def _get_experiment(
|
||||
|
||||
|
||||
def _fetch_text_complexity_metrics(text: str) -> dict:
|
||||
textstat = import_textstat()
|
||||
textstat = _import_textstat()
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
||||
@@ -67,7 +67,7 @@ def _fetch_text_complexity_metrics(text: str) -> dict:
|
||||
|
||||
|
||||
def _summarize_metrics_for_generated_outputs(metrics: Sequence) -> dict:
|
||||
pd = import_pandas()
|
||||
pd = _import_pandas()
|
||||
metrics_df = pd.DataFrame(metrics)
|
||||
metrics_summary = metrics_df.describe()
|
||||
|
||||
@@ -107,7 +107,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
|
||||
self.comet_ml = import_comet_ml()
|
||||
self.comet_ml = _import_comet_ml()
|
||||
super().__init__()
|
||||
|
||||
self.task_type = task_type
|
||||
@@ -140,7 +140,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.action_records: list = []
|
||||
self.complexity_metrics = complexity_metrics
|
||||
if self.visualizations:
|
||||
spacy = import_spacy()
|
||||
spacy = _import_spacy()
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
else:
|
||||
self.nlp = None
|
||||
@@ -535,7 +535,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
if not (self.visualizations and self.nlp):
|
||||
return
|
||||
|
||||
spacy = import_spacy()
|
||||
spacy = _import_spacy()
|
||||
|
||||
prompts = session_df["prompts"].tolist()
|
||||
outputs = session_df["text"].tolist()
|
||||
@@ -603,7 +603,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
def _create_session_analysis_dataframe(self, langchain_asset: Any = None) -> dict:
|
||||
pd = import_pandas()
|
||||
pd = _import_pandas()
|
||||
|
||||
llm_parameters = self._get_llm_parameters(langchain_asset)
|
||||
num_generations_per_prompt = llm_parameters.get("n", 1)
|
||||
|
||||
@@ -10,7 +10,7 @@ from langchain.schema import (
|
||||
)
|
||||
|
||||
|
||||
def import_context() -> Any:
|
||||
def _import_context() -> Any:
|
||||
"""Import the `getcontext` package."""
|
||||
try:
|
||||
import getcontext # noqa: F401
|
||||
@@ -98,7 +98,7 @@ class ContextCallbackHandler(BaseCallbackHandler):
|
||||
self.message_model,
|
||||
self.message_role_model,
|
||||
self.rating_model,
|
||||
) = import_context()
|
||||
) = _import_context()
|
||||
|
||||
token = token or os.environ.get("CONTEXT_TOKEN") or ""
|
||||
|
||||
|
||||
@@ -8,10 +8,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
_import_pandas,
|
||||
_import_spacy,
|
||||
_import_textstat,
|
||||
flatten_dict,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_flytekit() -> Tuple[flytekit, renderer]:
|
||||
def _import_flytekit() -> Tuple[flytekit, renderer]:
|
||||
"""Import flytekit and flytekitplugins-deck-standard."""
|
||||
try:
|
||||
import flytekit # noqa: F401
|
||||
@@ -75,7 +75,7 @@ def analyze_text(
|
||||
resp.update(text_complexity_metrics)
|
||||
|
||||
if nlp is not None:
|
||||
spacy = import_spacy()
|
||||
spacy = _import_spacy()
|
||||
doc = nlp(text)
|
||||
dep_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="dep", jupyter=False, page=True
|
||||
@@ -97,12 +97,12 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize callback handler."""
|
||||
flytekit, renderer = import_flytekit()
|
||||
self.pandas = import_pandas()
|
||||
flytekit, renderer = _import_flytekit()
|
||||
self.pandas = _import_pandas()
|
||||
|
||||
self.textstat = None
|
||||
try:
|
||||
self.textstat = import_textstat()
|
||||
self.textstat = _import_textstat()
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Textstat library is not installed. \
|
||||
@@ -112,7 +112,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
spacy = None
|
||||
try:
|
||||
spacy = import_spacy()
|
||||
spacy = _import_spacy()
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Spacy library is not installed. \
|
||||
|
||||
@@ -6,7 +6,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema.messages import BaseMessage
|
||||
|
||||
|
||||
def import_infino() -> Any:
|
||||
def _import_infino() -> Any:
|
||||
"""Import the infino client."""
|
||||
try:
|
||||
from infinopy import InfinoClient
|
||||
@@ -19,7 +19,7 @@ def import_infino() -> Any:
|
||||
return InfinoClient()
|
||||
|
||||
|
||||
def import_tiktoken() -> Any:
|
||||
def _import_tiktoken() -> Any:
|
||||
"""Import tiktoken for counting tokens for OpenAI models."""
|
||||
try:
|
||||
import tiktoken
|
||||
@@ -38,7 +38,7 @@ def get_num_tokens(string: str, openai_model_name: str) -> int:
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/main
|
||||
/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
"""
|
||||
tiktoken = import_tiktoken()
|
||||
tiktoken = _import_tiktoken()
|
||||
|
||||
encoding = tiktoken.encoding_for_model(openai_model_name)
|
||||
num_tokens = len(encoding.encode(string))
|
||||
@@ -55,7 +55,7 @@ class InfinoCallbackHandler(BaseCallbackHandler):
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
# Set Infino client
|
||||
self.client = import_infino()
|
||||
self.client = _import_infino()
|
||||
self.model_id = model_id
|
||||
self.model_version = model_version
|
||||
self.verbose = verbose
|
||||
|
||||
@@ -23,7 +23,7 @@ class LabelStudioMode(Enum):
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
def get_default_label_configs(
|
||||
def _get_default_label_configs(
|
||||
mode: Union[str, LabelStudioMode]
|
||||
) -> Tuple[str, LabelStudioMode]:
|
||||
"""Get default Label Studio configs for the given mode.
|
||||
@@ -173,7 +173,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
|
||||
self.project_config = project_config
|
||||
self.mode = None
|
||||
else:
|
||||
self.project_config, self.mode = get_default_label_configs(mode)
|
||||
self.project_config, self.mode = _get_default_label_configs(mode)
|
||||
|
||||
self.project_id = project_id or os.getenv("LABEL_STUDIO_PROJECT_ID")
|
||||
if self.project_id is not None:
|
||||
|
||||
@@ -10,17 +10,17 @@ from typing import Any, Dict, List, Optional, Union
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
_import_pandas,
|
||||
_import_spacy,
|
||||
_import_textstat,
|
||||
flatten_dict,
|
||||
hash_string,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
def import_mlflow() -> Any:
|
||||
def _import_mlflow() -> Any:
|
||||
"""Import the mlflow python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import mlflow
|
||||
@@ -47,8 +47,8 @@ def analyze_text(
|
||||
files serialized to HTML string.
|
||||
"""
|
||||
resp: Dict[str, Any] = {}
|
||||
textstat = import_textstat()
|
||||
spacy = import_spacy()
|
||||
textstat = _import_textstat()
|
||||
spacy = _import_spacy()
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
||||
@@ -127,7 +127,7 @@ class MlflowLogger:
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
self.mlflow = import_mlflow()
|
||||
self.mlflow = _import_mlflow()
|
||||
if "DATABRICKS_RUNTIME_VERSION" in os.environ:
|
||||
self.mlflow.set_tracking_uri("databricks")
|
||||
self.mlf_expid = self.mlflow.tracking.fluent._get_experiment_id()
|
||||
@@ -246,10 +246,10 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
import_pandas()
|
||||
import_textstat()
|
||||
import_mlflow()
|
||||
spacy = import_spacy()
|
||||
_import_pandas()
|
||||
_import_textstat()
|
||||
_import_mlflow()
|
||||
spacy = _import_spacy()
|
||||
super().__init__()
|
||||
|
||||
self.name = name
|
||||
@@ -547,7 +547,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def _create_session_analysis_df(self) -> Any:
|
||||
"""Create a dataframe with all the information from the session."""
|
||||
pd = import_pandas()
|
||||
pd = _import_pandas()
|
||||
on_llm_start_records_df = pd.DataFrame(self.records["on_llm_start_records"])
|
||||
on_llm_end_records_df = pd.DataFrame(self.records["on_llm_end_records"])
|
||||
|
||||
@@ -617,7 +617,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
return session_analysis_df
|
||||
|
||||
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:
|
||||
pd = import_pandas()
|
||||
pd = _import_pandas()
|
||||
self.mlflg.table("action_records", pd.DataFrame(self.records["action_records"]))
|
||||
session_analysis_df = self._create_session_analysis_df()
|
||||
chat_html = session_analysis_df.pop("chat_html")
|
||||
|
||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, Tuple, Union
|
||||
|
||||
|
||||
def import_spacy() -> Any:
|
||||
def _import_spacy() -> Any:
|
||||
"""Import the spacy python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import spacy
|
||||
@@ -15,7 +15,7 @@ def import_spacy() -> Any:
|
||||
return spacy
|
||||
|
||||
|
||||
def import_pandas() -> Any:
|
||||
def _import_pandas() -> Any:
|
||||
"""Import the pandas python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import pandas
|
||||
@@ -27,7 +27,7 @@ def import_pandas() -> Any:
|
||||
return pandas
|
||||
|
||||
|
||||
def import_textstat() -> Any:
|
||||
def _import_textstat() -> Any:
|
||||
"""Import the textstat python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import textstat
|
||||
|
||||
@@ -7,16 +7,16 @@ from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
_import_pandas,
|
||||
_import_spacy,
|
||||
_import_textstat,
|
||||
flatten_dict,
|
||||
hash_string,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
def import_wandb() -> Any:
|
||||
def _import_wandb() -> Any:
|
||||
"""Import the wandb python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import wandb # noqa: F401
|
||||
@@ -28,7 +28,7 @@ def import_wandb() -> Any:
|
||||
return wandb
|
||||
|
||||
|
||||
def load_json_to_dict(json_path: Union[str, Path]) -> dict:
|
||||
def _load_json_to_dict(json_path: Union[str, Path]) -> dict:
|
||||
"""Load json file to a dictionary.
|
||||
|
||||
Parameters:
|
||||
@@ -42,7 +42,7 @@ def load_json_to_dict(json_path: Union[str, Path]) -> dict:
|
||||
return data
|
||||
|
||||
|
||||
def analyze_text(
|
||||
def _analyze_text(
|
||||
text: str,
|
||||
complexity_metrics: bool = True,
|
||||
visualize: bool = True,
|
||||
@@ -63,9 +63,9 @@ def analyze_text(
|
||||
files serialized in a wandb.Html element.
|
||||
"""
|
||||
resp = {}
|
||||
textstat = import_textstat()
|
||||
wandb = import_wandb()
|
||||
spacy = import_spacy()
|
||||
textstat = _import_textstat()
|
||||
wandb = _import_wandb()
|
||||
spacy = _import_spacy()
|
||||
if complexity_metrics:
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
@@ -120,7 +120,7 @@ def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> A
|
||||
|
||||
Returns:
|
||||
(wandb.Html): The html element."""
|
||||
wandb = import_wandb()
|
||||
wandb = _import_wandb()
|
||||
formatted_prompt = prompt.replace("\n", "<br>")
|
||||
formatted_generation = generation.replace("\n", "<br>")
|
||||
|
||||
@@ -173,10 +173,10 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
|
||||
wandb = import_wandb()
|
||||
import_pandas()
|
||||
import_textstat()
|
||||
spacy = import_spacy()
|
||||
wandb = _import_wandb()
|
||||
_import_pandas()
|
||||
_import_textstat()
|
||||
spacy = _import_spacy()
|
||||
super().__init__()
|
||||
|
||||
self.job_type = job_type
|
||||
@@ -269,7 +269,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
generation_resp = deepcopy(resp)
|
||||
generation_resp.update(flatten_dict(generation.dict()))
|
||||
generation_resp.update(
|
||||
analyze_text(
|
||||
_analyze_text(
|
||||
generation.text,
|
||||
complexity_metrics=self.complexity_metrics,
|
||||
visualize=self.visualize,
|
||||
@@ -438,7 +438,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def _create_session_analysis_df(self) -> Any:
|
||||
"""Create a dataframe with all the information from the session."""
|
||||
pd = import_pandas()
|
||||
pd = _import_pandas()
|
||||
on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
|
||||
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
|
||||
|
||||
@@ -533,8 +533,8 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pd = import_pandas()
|
||||
wandb = import_wandb()
|
||||
pd = _import_pandas()
|
||||
wandb = _import_wandb()
|
||||
action_records_table = wandb.Table(dataframe=pd.DataFrame(self.action_records))
|
||||
session_analysis_table = wandb.Table(
|
||||
dataframe=self._create_session_analysis_df()
|
||||
@@ -554,11 +554,11 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
try:
|
||||
langchain_asset.save(langchain_asset_path)
|
||||
model_artifact.add_file(str(langchain_asset_path))
|
||||
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
|
||||
model_artifact.metadata = _load_json_to_dict(langchain_asset_path)
|
||||
except ValueError:
|
||||
langchain_asset.save_agent(langchain_asset_path)
|
||||
model_artifact.add_file(str(langchain_asset_path))
|
||||
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
|
||||
model_artifact.metadata = _load_json_to_dict(langchain_asset_path)
|
||||
except NotImplementedError as e:
|
||||
print("Could not save model.")
|
||||
print(repr(e))
|
||||
|
||||
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
||||
diagnostic_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_langkit(
|
||||
def _import_langkit(
|
||||
sentiment: bool = False,
|
||||
toxicity: bool = False,
|
||||
themes: bool = False,
|
||||
@@ -159,7 +159,7 @@ class WhyLabsCallbackHandler(BaseCallbackHandler):
|
||||
WhyLabs writer.
|
||||
"""
|
||||
# langkit library will import necessary whylogs libraries
|
||||
import_langkit(sentiment=sentiment, toxicity=toxicity, themes=themes)
|
||||
_import_langkit(sentiment=sentiment, toxicity=toxicity, themes=themes)
|
||||
|
||||
import whylogs as why
|
||||
from langkit.callback_handler import get_callback_instance
|
||||
|
||||
@@ -149,7 +149,7 @@ class LangSmithDatasetChatLoader(BaseChatLoader):
|
||||
for data_point in data:
|
||||
yield ChatSession(
|
||||
messages=[
|
||||
oai_adapter.convert_dict_to_message(m)
|
||||
oai_adapter._convert_dict_to_message(m)
|
||||
for m in data_point.get("messages", [])
|
||||
],
|
||||
functions=data_point.get("functions"),
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Set
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.adapters.openai import convert_message_to_dict
|
||||
from langchain.adapters.openai import _convert_message_to_dict
|
||||
from langchain.chat_models.openai import (
|
||||
ChatOpenAI,
|
||||
_import_tiktoken,
|
||||
@@ -178,7 +178,7 @@ class ChatAnyscale(ChatOpenAI):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
num_tokens = 0
|
||||
messages_dict = [convert_message_to_dict(m) for m in messages]
|
||||
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set
|
||||
|
||||
from langchain.adapters.openai import convert_message_to_dict
|
||||
from langchain.adapters.openai import _convert_message_to_dict
|
||||
from langchain.chat_models.openai import (
|
||||
ChatOpenAI,
|
||||
_import_tiktoken,
|
||||
@@ -140,7 +140,7 @@ class ChatEverlyAI(ChatOpenAI):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
num_tokens = 0
|
||||
messages_dict = [convert_message_to_dict(m) for m in messages]
|
||||
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.adapters.openai import convert_message_to_dict
|
||||
from langchain.adapters.openai import _convert_message_to_dict
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@@ -168,7 +168,7 @@ class ChatFireworks(BaseChatModel):
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> List[Dict[str, Any]]:
|
||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts
|
||||
|
||||
def _stream(
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import (
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict
|
||||
from langchain.adapters.openai import _convert_dict_to_message, _convert_message_to_dict
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
@@ -246,13 +246,13 @@ class ChatKonko(ChatOpenAI):
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for res in response["choices"]:
|
||||
message = convert_dict_to_message(res["message"])
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(finish_reason=res.get("finish_reason")),
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict
|
||||
from langchain.adapters.openai import _convert_dict_to_message, _convert_message_to_dict
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@@ -370,13 +370,13 @@ class ChatOpenAI(BaseChatModel):
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for res in response["choices"]:
|
||||
message = convert_dict_to_message(res["message"])
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(finish_reason=res.get("finish_reason")),
|
||||
@@ -528,7 +528,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
messages_dict = [convert_message_to_dict(m) for m in messages]
|
||||
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.adapters.openai import convert_dict_to_message
|
||||
from langchain.adapters.openai import _convert_dict_to_message
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
@@ -26,7 +26,7 @@ def test_openai_model_param() -> None:
|
||||
def test_function_message_dict_to_function_message() -> None:
|
||||
content = json.dumps({"result": "Example #1"})
|
||||
name = "test_function"
|
||||
result = convert_dict_to_message(
|
||||
result = _convert_dict_to_message(
|
||||
{
|
||||
"role": "function",
|
||||
"name": name,
|
||||
@@ -40,21 +40,21 @@ def test_function_message_dict_to_function_message() -> None:
|
||||
|
||||
def test__convert_dict_to_message_human() -> None:
|
||||
message = {"role": "user", "content": "foo"}
|
||||
result = convert_dict_to_message(message)
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = HumanMessage(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test__convert_dict_to_message_ai() -> None:
|
||||
message = {"role": "assistant", "content": "foo"}
|
||||
result = convert_dict_to_message(message)
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = AIMessage(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test__convert_dict_to_message_system() -> None:
|
||||
message = {"role": "system", "content": "foo"}
|
||||
result = convert_dict_to_message(message)
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = SystemMessage(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user