Arcee.ai LLM & Retriever integration (#11579)

- **Description:** This PR introduces a new LLM and Retriever API to
https://arcee.ai for the python client
  - **Issue:** implements the integrations as requested in #11578 ,
  - **Dependencies:** no dependencies are required,
  - **Tag maintainer:** @hwchase17
  - **Twitter handle:** shwooobham 


** `make format`, `make lint` and `make test` runs locally.**
```shell
=========== 1245 passed, 277 skipped, 20 warnings in 16.26s ===========
./scripts/check_pydantic.sh .
./scripts/check_imports.sh
poetry run ruff .
[ "." = "" ] || poetry run black . --check
All done!  🍰 
1818 files would be left unchanged.
[ "." = "" ] || poetry run mypy .
Success: no issues found in 1815 source files
[ "." = "" ] || poetry run black .
All done!  🍰 
1818 files left unchanged.
[ "." = "" ] || poetry run ruff --select I --fix .
poetry run codespell --toml pyproject.toml
poetry run codespell --toml pyproject.toml -w
```


**Contributions**
1. Arcee (langchain/llms), ArceeRetriever (langchain/retrievers),
ArceeWrapper (langchain/utilities)
2. docs for Arcee (llms/arcee.py) and
ArceeRetriever(retrievers/arcee.py)
3.

cc: @jacobsolawetz @ben-epstein

---------

Co-authored-by: Shubham <shubham@sORo.local>
This commit is contained in:
Shubham Kushwaha 2023-10-10 22:50:45 +05:30 committed by GitHub
parent b6a2507794
commit 49de862076
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 773 additions and 0 deletions

View File

@ -0,0 +1,146 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Arcee\n",
"This notebook demonstrates how to use the `Arcee` class for generating text using Arcee's Domain Adapted Language Models (DALMs)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setup\n",
"\n",
"Before using Arcee, make sure the Arcee API key is set as `ARCEE_API_KEY` environment variable. You can also pass the api key as a named parameter."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import Arcee\n",
"\n",
"# Create an instance of the Arcee class\n",
"arcee = Arcee(\n",
" model=\"DALM-PubMed\",\n",
" # arcee_api_key=\"ARCEE-API-KEY\" # if not already set in the environment\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Additional Configuration\n",
"\n",
"You can also configure Arcee's parameters such as `arcee_api_url`, `arcee_app_url`, and `model_kwargs` as needed.\n",
"Setting the `model_kwargs` at the object initialization uses the parameters as default for all the subsequent calls to the generate response."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"arcee = Arcee(\n",
" model=\"DALM-Patent\",\n",
" # arcee_api_key=\"ARCEE-API-KEY\", # if not already set in the environment\n",
" arcee_api_url=\"https://custom-api.arcee.ai\", # default is https://api.arcee.ai\n",
" arcee_app_url=\"https://custom-app.arcee.ai\", # default is https://app.arcee.ai\n",
" model_kwargs={\n",
" \"size\": 5,\n",
" \"filters\": [\n",
" {\n",
" \"field_name\": \"document\",\n",
" \"filter_type\": \"fuzzy_search\",\n",
" \"value\": \"Einstein\"\n",
" }\n",
" ]\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generating Text\n",
"\n",
"You can generate text from Arcee by providing a prompt. Here's an example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Generate text\n",
"prompt = \"Can AI-driven music therapy contribute to the rehabilitation of patients with disorders of consciousness?\"\n",
"response = arcee(prompt)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Additional parameters\n",
"\n",
"Arcee allows you to apply `filters` and set the `size` (in terms of count) of retrieved document(s) to aid text generation. Filters help narrow down the results. Here's how to use these parameters:\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define filters\n",
"filters = [\n",
" {\n",
" \"field_name\": \"document\",\n",
" \"filter_type\": \"fuzzy_search\",\n",
" \"value\": \"Einstein\"\n",
" },\n",
" {\n",
" \"field_name\": \"year\",\n",
" \"filter_type\": \"strict_search\",\n",
" \"value\": \"1905\"\n",
" }\n",
"]\n",
"\n",
"# Generate text with filters and size params\n",
"response = arcee(prompt, size=5, filters=filters)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,141 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Arcee Retriever\n",
"This notebook demonstrates how to use the `ArceeRetriever` class to retrieve relevant document(s) for Arcee's Domain Adapted Language Models (DALMs)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setup\n",
"\n",
"Before using `ArceeRetriever`, make sure the Arcee API key is set as `ARCEE_API_KEY` environment variable. You can also pass the api key as a named parameter."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.retrievers import ArceeRetriever\n",
"\n",
"retriever = ArceeRetriever(\n",
" model=\"DALM-PubMed\",\n",
" # arcee_api_key=\"ARCEE-API-KEY\" # if not already set in the environment\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Additional Configuration\n",
"\n",
"You can also configure `ArceeRetriever`'s parameters such as `arcee_api_url`, `arcee_app_url`, and `model_kwargs` as needed.\n",
"Setting the `model_kwargs` at the object initialization uses the filters and size as default for all the subsequent retrievals."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"retriever = ArceeRetriever(\n",
" model=\"DALM-PubMed\",\n",
" # arcee_api_key=\"ARCEE-API-KEY\", # if not already set in the environment\n",
" arcee_api_url=\"https://custom-api.arcee.ai\", # default is https://api.arcee.ai\n",
" arcee_app_url=\"https://custom-app.arcee.ai\", # default is https://app.arcee.ai\n",
" model_kwargs={\n",
" \"size\": 5,\n",
" \"filters\": [\n",
" {\n",
" \"field_name\": \"document\",\n",
" \"filter_type\": \"fuzzy_search\",\n",
" \"value\": \"Einstein\"\n",
" }\n",
" ]\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Retrieving documents\n",
"You can retrieve relevant documents from uploaded contexts by providing a query. Here's an example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query = \"Can AI-driven music therapy contribute to the rehabilitation of patients with disorders of consciousness?\"\n",
"documents = retriever.get_relevant_documents(query=query)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Additional parameters\n",
"\n",
"Arcee allows you to apply `filters` and set the `size` (in terms of count) of retrieved document(s). Filters help narrow down the results. Here's how to use these parameters:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define filters\n",
"filters = [\n",
" {\n",
" \"field_name\": \"document\",\n",
" \"filter_type\": \"fuzzy_search\",\n",
" \"value\": \"Music\"\n",
" },\n",
" {\n",
" \"field_name\": \"year\",\n",
" \"filter_type\": \"strict_search\",\n",
" \"value\": \"1905\"\n",
" }\n",
"]\n",
"\n",
"# Retrieve documents with filters and size params\n",
"documents = retriever.get_relevant_documents(query=query, size=5, filters=filters)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -52,6 +52,12 @@ def _import_anyscale() -> Any:
return Anyscale return Anyscale
def _import_arcee() -> Any:
from langchain.llms.arcee import Arcee
return Arcee
def _import_aviary() -> Any: def _import_aviary() -> Any:
from langchain.llms.aviary import Aviary from langchain.llms.aviary import Aviary
@ -479,6 +485,8 @@ def __getattr__(name: str) -> Any:
return _import_anthropic() return _import_anthropic()
elif name == "Anyscale": elif name == "Anyscale":
return _import_anyscale() return _import_anyscale()
elif name == "Arcee":
return _import_arcee()
elif name == "Aviary": elif name == "Aviary":
return _import_aviary() return _import_aviary()
elif name == "AzureMLOnlineEndpoint": elif name == "AzureMLOnlineEndpoint":
@ -633,6 +641,7 @@ __all__ = [
"AmazonAPIGateway", "AmazonAPIGateway",
"Anthropic", "Anthropic",
"Anyscale", "Anyscale",
"Arcee",
"Aviary", "Aviary",
"AzureMLOnlineEndpoint", "AzureMLOnlineEndpoint",
"AzureOpenAI", "AzureOpenAI",
@ -713,6 +722,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"amazon_bedrock": _import_bedrock, "amazon_bedrock": _import_bedrock,
"anthropic": _import_anthropic, "anthropic": _import_anthropic,
"anyscale": _import_anyscale, "anyscale": _import_anyscale,
"arcee": _import_arcee,
"aviary": _import_aviary, "aviary": _import_aviary,
"azure": _import_azure_openai, "azure": _import_azure_openai,
"azureml_endpoint": _import_azureml_endpoint, "azureml_endpoint": _import_azureml_endpoint,

View File

@ -0,0 +1,147 @@
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.pydantic_v1 import Extra, root_validator
from langchain.utilities.arcee import ArceeWrapper, DALMFilter
from langchain.utils import get_from_dict_or_env
class Arcee(LLM):
"""Arcee's Domain Adapted Language Models (DALMs).
To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key,
or pass ``arcee_api_key`` as a named parameter.
Example:
.. code-block:: python
from langchain.llms import Arcee
arcee = Arcee(
model="DALM-PubMed",
arcee_api_key="ARCEE-API-KEY"
)
response = arcee("AI-driven music therapy")
"""
_client: Optional[ArceeWrapper] = None #: :meta private:
"""Arcee _client."""
arcee_api_key: str = ""
"""Arcee API Key"""
model: str
"""Arcee DALM name"""
arcee_api_url: str = "https://api.arcee.ai"
"""Arcee API URL"""
arcee_api_version: str = "v2"
"""Arcee API Version"""
arcee_app_url: str = "https://app.arcee.ai"
"""Arcee App URL"""
model_id: str = ""
"""Arcee Model ID"""
model_kwargs: Optional[Dict[str, Any]] = None
"""Keyword arguments to pass to the model."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
underscore_attrs_are_private = True
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "arcee"
def __init__(self, **data: Any) -> None:
"""Initializes private fields."""
super().__init__(**data)
self._client = None
self._client = ArceeWrapper(
arcee_api_key=self.arcee_api_key,
arcee_api_url=self.arcee_api_url,
arcee_api_version=self.arcee_api_version,
model_kwargs=self.model_kwargs,
model_name=self.model,
)
self._client.validate_model_training_status()
@root_validator()
def validate_environments(cls, values: Dict) -> Dict:
"""Validate Arcee environment variables."""
# validate env vars
values["arcee_api_key"] = get_from_dict_or_env(
values,
"arcee_api_key",
"ARCEE_API_KEY",
)
values["arcee_api_url"] = get_from_dict_or_env(
values,
"arcee_api_url",
"ARCEE_API_URL",
)
values["arcee_app_url"] = get_from_dict_or_env(
values,
"arcee_app_url",
"ARCEE_APP_URL",
)
values["arcee_api_version"] = get_from_dict_or_env(
values,
"arcee_api_version",
"ARCEE_API_VERSION",
)
# validate model kwargs
if values["model_kwargs"]:
kw = values["model_kwargs"]
# validate size
if kw.get("size") is not None:
if not kw.get("size") >= 0:
raise ValueError("`size` must be positive")
# validate filters
if kw.get("filters") is not None:
if not isinstance(kw.get("filters"), List):
raise ValueError("`filters` must be a list")
for f in kw.get("filters"):
DALMFilter(**f)
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Generate text from Arcee DALM.
Args:
prompt: Prompt to generate text from.
size: The max number of context results to retrieve.
Defaults to 3. (Can be less if filters are provided).
filters: Filters to apply to the context dataset.
"""
try:
if not self._client:
raise ValueError("Client is not initialized.")
return self._client.generate(prompt=prompt, **kwargs)
except Exception as e:
raise Exception(f"Failed to generate text: {e}") from e

View File

@ -18,6 +18,7 @@ the backbone of a retriever, but there are other types of retrievers as well.
CallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun CallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun
""" """
from langchain.retrievers.arcee import ArceeRetriever
from langchain.retrievers.arxiv import ArxivRetriever from langchain.retrievers.arxiv import ArxivRetriever
from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever
from langchain.retrievers.bm25 import BM25Retriever from langchain.retrievers.bm25 import BM25Retriever
@ -66,6 +67,7 @@ from langchain.retrievers.zilliz import ZillizRetriever
__all__ = [ __all__ = [
"AmazonKendraRetriever", "AmazonKendraRetriever",
"ArceeRetriever",
"ArxivRetriever", "ArxivRetriever",
"AzureCognitiveSearchRetriever", "AzureCognitiveSearchRetriever",
"ChatGPTPluginRetriever", "ChatGPTPluginRetriever",

View File

@ -0,0 +1,136 @@
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.docstore.document import Document
from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema import BaseRetriever
from langchain.utilities.arcee import ArceeWrapper, DALMFilter
from langchain.utils import get_from_dict_or_env
class ArceeRetriever(BaseRetriever):
"""Document retriever for Arcee's Domain Adapted Language Models (DALMs).
To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key,
or pass ``arcee_api_key`` as a named parameter.
Example:
.. code-block:: python
from langchain.retrievers import ArceeRetriever
retriever = ArceeRetriever(
model="DALM-PubMed",
arcee_api_key="ARCEE-API-KEY"
)
documents = retriever.get_relevant_documents("AI-driven music therapy")
"""
_client: Optional[ArceeWrapper] = None #: :meta private:
"""Arcee client."""
arcee_api_key: str = ""
"""Arcee API Key"""
model: str
"""Arcee DALM name"""
arcee_api_url: str = "https://api.arcee.ai"
"""Arcee API URL"""
arcee_api_version: str = "v2"
"""Arcee API Version"""
arcee_app_url: str = "https://app.arcee.ai"
"""Arcee App URL"""
model_kwargs: Optional[Dict[str, Any]] = None
"""Keyword arguments to pass to the model."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
underscore_attrs_are_private = True
def __init__(self, **data: Any) -> None:
"""Initializes private fields."""
super().__init__(**data)
self._client = ArceeWrapper(
arcee_api_key=self.arcee_api_key,
arcee_api_url=self.arcee_api_url,
arcee_api_version=self.arcee_api_version,
model_kwargs=self.model_kwargs,
model_name=self.model,
)
self._client.validate_model_training_status()
@root_validator()
def validate_environments(cls, values: Dict) -> Dict:
"""Validate Arcee environment variables."""
# validate env vars
values["arcee_api_key"] = get_from_dict_or_env(
values,
"arcee_api_key",
"ARCEE_API_KEY",
)
values["arcee_api_url"] = get_from_dict_or_env(
values,
"arcee_api_url",
"ARCEE_API_URL",
)
values["arcee_app_url"] = get_from_dict_or_env(
values,
"arcee_app_url",
"ARCEE_APP_URL",
)
values["arcee_api_version"] = get_from_dict_or_env(
values,
"arcee_api_version",
"ARCEE_API_VERSION",
)
# validate model kwargs
if values["model_kwargs"]:
kw = values["model_kwargs"]
# validate size
if kw.get("size") is not None:
if not kw.get("size") >= 0:
raise ValueError("`size` must not be negative.")
# validate filters
if kw.get("filters") is not None:
if not isinstance(kw.get("filters"), List):
raise ValueError("`filters` must be a list.")
for f in kw.get("filters"):
DALMFilter(**f)
return values
def _get_relevant_documents(
self, query: str, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
"""Retrieve {size} contexts with your retriever for a given query
Args:
query: Query to submit to the model
size: The max number of context results to retrieve.
Defaults to 3. (Can be less if filters are provided).
filters: Filters to apply to the context dataset.
"""
try:
if not self._client:
raise ValueError("Client is not initialized.")
return self._client.retrieve(query=query, **kwargs)
except Exception as e:
raise ValueError(f"Error while retrieving documents: {e}") from e

View File

@ -5,6 +5,7 @@ and packages.
""" """
from langchain.utilities.alpha_vantage import AlphaVantageAPIWrapper from langchain.utilities.alpha_vantage import AlphaVantageAPIWrapper
from langchain.utilities.apify import ApifyWrapper from langchain.utilities.apify import ApifyWrapper
from langchain.utilities.arcee import ArceeWrapper
from langchain.utilities.arxiv import ArxivAPIWrapper from langchain.utilities.arxiv import ArxivAPIWrapper
from langchain.utilities.awslambda import LambdaWrapper from langchain.utilities.awslambda import LambdaWrapper
from langchain.utilities.bash import BashProcess from langchain.utilities.bash import BashProcess
@ -41,6 +42,7 @@ from langchain.utilities.zapier import ZapierNLAWrapper
__all__ = [ __all__ = [
"AlphaVantageAPIWrapper", "AlphaVantageAPIWrapper",
"ApifyWrapper", "ApifyWrapper",
"ArceeWrapper",
"ArxivAPIWrapper", "ArxivAPIWrapper",
"BashProcess", "BashProcess",
"BibtexparserWrapper", "BibtexparserWrapper",

View File

@ -0,0 +1,189 @@
# This module contains utility classes and functions for interacting with Arcee API.
# For more information and updates, refer to the Arcee utils page:
# [https://github.com/arcee-ai/arcee-python/blob/main/arcee/dalm.py]
from enum import Enum
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
import requests
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema.retriever import Document
class ArceeRoute(str, Enum):
generate = "models/generate"
retrieve = "models/retrieve"
model_training_status = "models/status/{id_or_name}"
class DALMFilterType(str, Enum):
fuzzy_search = "fuzzy_search"
strict_search = "strict_search"
class DALMFilter(BaseModel):
"""Filters available for a dalm retrieval and generation
Arguments:
field_name: The field to filter on. Can be 'document' or 'name' to filter
on your document's raw text or title. Any other field will be presumed
to be a metadata field you included when uploading your context data
filter_type: Currently 'fuzzy_search' and 'strict_search' are supported.
'fuzzy_search' means a fuzzy search on the provided field is performed.
The exact strict doesn't need to exist in the document
for this to find a match.
Very useful for scanning a document for some keyword terms.
'strict_search' means that the exact string must appear
in the provided field.
This is NOT an exact eq filter. ie a document with content
"the happy dog crossed the street" will match on a strict_search of
"dog" but won't match on "the dog".
Python equivalent of `return search_string in full_string`.
value: The actual value to search for in the context data/metadata
"""
field_name: str
filter_type: DALMFilterType
value: str
_is_metadata: bool = False
@root_validator()
def set_meta(cls, values: Dict) -> Dict:
"""document and name are reserved arcee keys. Anything else is metadata"""
values["_is_meta"] = values.get("field_name") not in ["document", "name"]
return values
class ArceeWrapper:
def __init__(
self,
arcee_api_key: str,
arcee_api_url: str,
arcee_api_version: str,
model_kwargs: Optional[Dict[str, Any]],
model_name: str,
):
self.arcee_api_key = arcee_api_key
self.model_kwargs = model_kwargs
self.arcee_api_url = arcee_api_url
self.arcee_api_version = arcee_api_version
try:
route = ArceeRoute.model_training_status.value.format(id_or_name=model_name)
response = self._make_request("get", route)
self.model_id = response.get("model_id")
self.model_training_status = response.get("status")
except Exception as e:
raise ValueError(
f"Error while validating model training status for '{model_name}': {e}"
) from e
def validate_model_training_status(self) -> None:
if self.model_training_status != "training_complete":
raise Exception(
f"Model {self.model_id} is not ready. "
"Please wait for training to complete."
)
def _make_request(
self,
method: Literal["post", "get"],
route: Union[ArceeRoute, str],
body: Optional[Mapping[str, Any]] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
) -> dict:
"""Make a request to the Arcee API
Args:
method: The HTTP method to use
route: The route to call
body: The body of the request
params: The query params of the request
headers: The headers of the request
"""
headers = self._make_request_headers(headers=headers)
url = self._make_request_url(route=route)
req_type = getattr(requests, method)
response = req_type(url, json=body, params=params, headers=headers)
if response.status_code not in (200, 201):
raise Exception(f"Failed to make request. Response: {response.text}")
return response.json()
def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
headers = headers or {}
internal_headers = {
"X-Token": self.arcee_api_key,
"Content-Type": "application/json",
}
headers.update(internal_headers)
return headers
def _make_request_url(self, route: Union[ArceeRoute, str]) -> str:
return f"{self.arcee_api_url}/{self.arcee_api_version}/{route}"
def _make_request_body_for_models(
self, prompt: str, **kwargs: Mapping[str, Any]
) -> Mapping[str, Any]:
"""Make the request body for generate/retrieve models endpoint"""
_model_kwargs = self.model_kwargs or {}
_params = {**_model_kwargs, **kwargs}
filters = [DALMFilter(**f) for f in _params.get("filters", [])]
return dict(
model_id=self.model_id,
query=prompt,
size=_params.get("size", 3),
filters=filters,
id=self.model_id,
)
def generate(
self,
prompt: str,
**kwargs: Any,
) -> str:
"""Generate text from Arcee DALM.
Args:
prompt: Prompt to generate text from.
size: The max number of context results to retrieve. Defaults to 3.
(Can be less if filters are provided).
filters: Filters to apply to the context dataset.
"""
response = self._make_request(
method="post",
route=ArceeRoute.generate,
body=self._make_request_body_for_models(
prompt=prompt,
**kwargs,
),
)
return response["text"]
def retrieve(
self,
query: str,
**kwargs: Any,
) -> List[Document]:
"""Retrieve {size} contexts with your retriever for a given query
Args:
query: Query to submit to the model
size: The max number of context results to retrieve. Defaults to 3.
(Can be less if filters are provided).
filters: Filters to apply to the context dataset.
"""
response = self._make_request(
method="post",
route=ArceeRoute.retrieve,
body=self._make_request_body_for_models(
prompt=query,
**kwargs,
),
)
return [Document(**doc) for doc in response["documents"]]