mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
Update ArgillaCallbackHandler
as of latest argilla
release (#9043)
Hi @agola11, or whoever is reviewing this PR 😄 ## What's in this PR? As of the latest Argilla release, we'll change and refactor some things to make some workflows easier, one of those is how everything's pushed to Argilla, so that now there's no need to call `push_to_argilla` over a `FeedbackDataset` when either `push_to_argilla` is called for the first time, or `from_argilla` is called; among others. We also add some class variables to make sure those are easy to update in case we update those internally in the future, also to make the `warnings.warn` message lighter from the code view. P.S. Regarding the Twitter/X mention feel free to do so at either https://twitter.com/argilla_io or https://twitter.com/alvarobartt, or both if applicable, otherwise, just the first Twitter/X handle.
This commit is contained in:
parent
8d351bfc20
commit
08a0741d82
@ -2,6 +2,8 @@ import os
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from packaging.version import parse
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
@ -51,6 +53,13 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
|||||||
"Argilla, no doubt about it."
|
"Argilla, no doubt about it."
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
REPO_URL = "https://github.com/argilla-io/argilla"
|
||||||
|
ISSUES_URL = f"{REPO_URL}/issues"
|
||||||
|
BLOG_URL = "https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html" # noqa: E501
|
||||||
|
|
||||||
|
DEFAULT_API_URL = "http://localhost:6900"
|
||||||
|
DEFAULT_API_KEY = "argilla.apikey"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_name: str,
|
dataset_name: str,
|
||||||
@ -58,23 +67,22 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
|||||||
api_url: Optional[str] = None,
|
api_url: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initializes the `ArgillaCallbackHandler`.
|
f"""Initializes the `ArgillaCallbackHandler`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
|
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
|
||||||
exist in advance. If you need help on how to create a `FeedbackDataset`
|
exist in advance. If you need help on how to create a `FeedbackDataset`
|
||||||
in Argilla, please visit
|
in Argilla, please visit {self.BLOG_URL}.
|
||||||
https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html.
|
|
||||||
workspace_name: name of the workspace in Argilla where the specified
|
workspace_name: name of the workspace in Argilla where the specified
|
||||||
`FeedbackDataset` lives in. Defaults to `None`, which means that the
|
`FeedbackDataset` lives in. Defaults to `None`, which means that the
|
||||||
default workspace will be used.
|
default workspace will be used.
|
||||||
api_url: URL of the Argilla Server that we want to use, and where the
|
api_url: URL of the Argilla Server that we want to use, and where the
|
||||||
`FeedbackDataset` lives in. Defaults to `None`, which means that either
|
`FeedbackDataset` lives in. Defaults to `None`, which means that either
|
||||||
`ARGILLA_API_URL` environment variable or the default
|
`ARGILLA_API_URL` environment variable or `{self.DEFAULT_API_URL}` will
|
||||||
http://localhost:6900 will be used.
|
be used.
|
||||||
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
|
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
|
||||||
means that either `ARGILLA_API_KEY` environment variable or the default
|
means that either `ARGILLA_API_KEY` environment variable or the default
|
||||||
`argilla.apikey` will be used.
|
`{self.DEFAULT_API_KEY}` will be used.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ImportError: if the `argilla` package is not installed.
|
ImportError: if the `argilla` package is not installed.
|
||||||
@ -87,41 +95,47 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
|||||||
# Import Argilla (not via `import_argilla` to keep hints in IDEs)
|
# Import Argilla (not via `import_argilla` to keep hints in IDEs)
|
||||||
try:
|
try:
|
||||||
import argilla as rg # noqa: F401
|
import argilla as rg # noqa: F401
|
||||||
|
|
||||||
|
self.ARGILLA_VERSION = rg.__version__
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"To use the Argilla callback manager you need to have the `argilla` "
|
"To use the Argilla callback manager you need to have the `argilla` "
|
||||||
"Python package installed. Please install it with `pip install argilla`"
|
"Python package installed. Please install it with `pip install argilla`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check whether the Argilla version is compatible
|
||||||
|
if parse(self.ARGILLA_VERSION) < parse("1.8.0"):
|
||||||
|
raise ImportError(
|
||||||
|
f"The installed `argilla` version is {self.ARGILLA_VERSION} but "
|
||||||
|
"`ArgillaCallbackHandler` requires at least version 1.8.0. Please "
|
||||||
|
"upgrade `argilla` with `pip install --upgrade argilla`."
|
||||||
|
)
|
||||||
|
|
||||||
# Show a warning message if Argilla will assume the default values will be used
|
# Show a warning message if Argilla will assume the default values will be used
|
||||||
if api_url is None and os.getenv("ARGILLA_API_URL") is None:
|
if api_url is None and os.getenv("ARGILLA_API_URL") is None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
(
|
(
|
||||||
"Since `api_url` is None, and the env var `ARGILLA_API_URL` is not"
|
"Since `api_url` is None, and the env var `ARGILLA_API_URL` is not"
|
||||||
" set, it will default to `http://localhost:6900`."
|
f" set, it will default to `{self.DEFAULT_API_URL}`."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if api_key is None and os.getenv("ARGILLA_API_KEY") is None:
|
if api_key is None and os.getenv("ARGILLA_API_KEY") is None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
(
|
(
|
||||||
"Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not"
|
"Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not"
|
||||||
" set, it will default to `argilla.apikey`."
|
f" set, it will default to `{self.DEFAULT_API_KEY}`."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Connect to Argilla with the provided credentials, if applicable
|
# Connect to Argilla with the provided credentials, if applicable
|
||||||
try:
|
try:
|
||||||
rg.init(
|
rg.init(api_key=api_key, api_url=api_url)
|
||||||
api_key=api_key,
|
|
||||||
api_url=api_url,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ConnectionError(
|
raise ConnectionError(
|
||||||
f"Could not connect to Argilla with exception: '{e}'.\n"
|
f"Could not connect to Argilla with exception: '{e}'.\n"
|
||||||
"Please check your `api_key` and `api_url`, and make sure that "
|
"Please check your `api_key` and `api_url`, and make sure that "
|
||||||
"the Argilla server is up and running. If the problem persists "
|
"the Argilla server is up and running. If the problem persists "
|
||||||
"please report it to https://github.com/argilla-io/argilla/issues "
|
f"please report it to {self.ISSUES_URL} as an `integration` issue."
|
||||||
"with the label `langchain`."
|
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
# Set the Argilla variables
|
# Set the Argilla variables
|
||||||
@ -130,46 +144,47 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
# Retrieve the `FeedbackDataset` from Argilla (without existing records)
|
# Retrieve the `FeedbackDataset` from Argilla (without existing records)
|
||||||
try:
|
try:
|
||||||
|
extra_args = {}
|
||||||
|
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
|
||||||
|
warnings.warn(
|
||||||
|
f"You have Argilla {self.ARGILLA_VERSION}, but Argilla 1.14.0 or"
|
||||||
|
" higher is recommended.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
extra_args = {"with_records": False}
|
||||||
self.dataset = rg.FeedbackDataset.from_argilla(
|
self.dataset = rg.FeedbackDataset.from_argilla(
|
||||||
name=self.dataset_name,
|
name=self.dataset_name,
|
||||||
workspace=self.workspace_name,
|
workspace=self.workspace_name,
|
||||||
with_records=False,
|
**extra_args,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
"`FeedbackDataset` retrieval from Argilla failed with exception:"
|
f"`FeedbackDataset` retrieval from Argilla failed with exception `{e}`."
|
||||||
f" '{e}'.\nPlease check that the dataset with"
|
f"\nPlease check that the dataset with name={self.dataset_name} in the"
|
||||||
f" name={self.dataset_name} in the"
|
|
||||||
f" workspace={self.workspace_name} exists in advance. If you need help"
|
f" workspace={self.workspace_name} exists in advance. If you need help"
|
||||||
" on how to create a `langchain`-compatible `FeedbackDataset` in"
|
" on how to create a `langchain`-compatible `FeedbackDataset` in"
|
||||||
" Argilla, please visit"
|
f" Argilla, please visit {self.BLOG_URL}. If the problem persists"
|
||||||
" https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html." # noqa: E501
|
f" please report it to {self.ISSUES_URL} as an `integration` issue."
|
||||||
" If the problem persists please report it to"
|
|
||||||
" https://github.com/argilla-io/argilla/issues with the label"
|
|
||||||
" `langchain`."
|
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
supported_fields = ["prompt", "response"]
|
supported_fields = ["prompt", "response"]
|
||||||
if supported_fields != [field.name for field in self.dataset.fields]:
|
if supported_fields != [field.name for field in self.dataset.fields]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`FeedbackDataset` with name={self.dataset_name} in the"
|
f"`FeedbackDataset` with name={self.dataset_name} in the workspace="
|
||||||
f" workspace={self.workspace_name} "
|
f"{self.workspace_name} had fields that are not supported yet for the"
|
||||||
"had fields that are not supported yet for the `langchain` integration."
|
f"`langchain` integration. Supported fields are: {supported_fields},"
|
||||||
" Supported fields are: "
|
f" and the current `FeedbackDataset` fields are {[field.name for field in self.dataset.fields]}." # noqa: E501
|
||||||
f"{supported_fields}, and the current `FeedbackDataset` fields are"
|
" For more information on how to create a `langchain`-compatible"
|
||||||
f" {[field.name for field in self.dataset.fields]}. "
|
f" `FeedbackDataset` in Argilla, please visit {self.BLOG_URL}."
|
||||||
"For more information on how to create a `langchain`-compatible"
|
|
||||||
" `FeedbackDataset` in Argilla, please visit"
|
|
||||||
" https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html." # noqa: E501
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.prompts: Dict[str, List[str]] = {}
|
self.prompts: Dict[str, List[str]] = {}
|
||||||
|
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
(
|
(
|
||||||
"The `ArgillaCallbackHandler` is currently in beta and is subject to "
|
"The `ArgillaCallbackHandler` is currently in beta and is subject to"
|
||||||
"change based on updates to `langchain`. Please report any issues to "
|
" change based on updates to `langchain`. Please report any issues to"
|
||||||
"https://github.com/argilla-io/argilla/issues with the tag `langchain`."
|
f" {self.ISSUES_URL} as an `integration` issue."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -205,12 +220,13 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Push the records to Argilla
|
|
||||||
self.dataset.push_to_argilla()
|
|
||||||
|
|
||||||
# Pop current run from `self.runs`
|
# Pop current run from `self.runs`
|
||||||
self.prompts.pop(str(kwargs["run_id"]))
|
self.prompts.pop(str(kwargs["run_id"]))
|
||||||
|
|
||||||
|
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
|
||||||
|
# Push the records to Argilla
|
||||||
|
self.dataset.push_to_argilla()
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -278,15 +294,16 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Push the records to Argilla
|
|
||||||
self.dataset.push_to_argilla()
|
|
||||||
|
|
||||||
# Pop current run from `self.runs`
|
# Pop current run from `self.runs`
|
||||||
if str(kwargs["parent_run_id"]) in self.prompts:
|
if str(kwargs["parent_run_id"]) in self.prompts:
|
||||||
self.prompts.pop(str(kwargs["parent_run_id"]))
|
self.prompts.pop(str(kwargs["parent_run_id"]))
|
||||||
if str(kwargs["run_id"]) in self.prompts:
|
if str(kwargs["run_id"]) in self.prompts:
|
||||||
self.prompts.pop(str(kwargs["run_id"]))
|
self.prompts.pop(str(kwargs["run_id"]))
|
||||||
|
|
||||||
|
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
|
||||||
|
# Push the records to Argilla
|
||||||
|
self.dataset.push_to_argilla()
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user