mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 14:31:55 +00:00
community[major], core[patch], langchain[patch], experimental[patch]: Create langchain-community (#14463)
Moved the following modules to new package langchain-community in a backwards compatible fashion: ``` mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community ``` Moved the following to core ``` mv langchain/langchain/utils/json_schema.py core/langchain_core/utils mv langchain/langchain/utils/html.py core/langchain_core/utils mv langchain/langchain/utils/strings.py core/langchain_core/utils cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py rm langchain/langchain/utils/env.py ``` See .scripts/community_split/script_integrations.sh for all changes
This commit is contained in:
255
libs/community/langchain_community/utilities/arcee.py
Normal file
255
libs/community/langchain_community/utilities/arcee.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# This module contains utility classes and functions for interacting with Arcee API.
|
||||
# For more information and updates, refer to the Arcee utils page:
|
||||
# [https://github.com/arcee-ai/arcee-python/blob/main/arcee/dalm.py]
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
|
||||
|
||||
import requests
|
||||
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
|
||||
from langchain_core.retrievers import Document
|
||||
|
||||
|
||||
class ArceeRoute(str, Enum):
|
||||
"""Routes available for the Arcee API as enumerator."""
|
||||
|
||||
generate = "models/generate"
|
||||
retrieve = "models/retrieve"
|
||||
model_training_status = "models/status/{id_or_name}"
|
||||
|
||||
|
||||
class DALMFilterType(str, Enum):
|
||||
"""Filter types available for a DALM retrieval as enumerator."""
|
||||
|
||||
fuzzy_search = "fuzzy_search"
|
||||
strict_search = "strict_search"
|
||||
|
||||
|
||||
class DALMFilter(BaseModel):
|
||||
"""Filters available for a DALM retrieval and generation.
|
||||
|
||||
Arguments:
|
||||
field_name: The field to filter on. Can be 'document' or 'name' to filter
|
||||
on your document's raw text or title. Any other field will be presumed
|
||||
to be a metadata field you included when uploading your context data
|
||||
filter_type: Currently 'fuzzy_search' and 'strict_search' are supported.
|
||||
'fuzzy_search' means a fuzzy search on the provided field is performed.
|
||||
The exact strict doesn't need to exist in the document
|
||||
for this to find a match.
|
||||
Very useful for scanning a document for some keyword terms.
|
||||
'strict_search' means that the exact string must appear
|
||||
in the provided field.
|
||||
This is NOT an exact eq filter. ie a document with content
|
||||
"the happy dog crossed the street" will match on a strict_search of
|
||||
"dog" but won't match on "the dog".
|
||||
Python equivalent of `return search_string in full_string`.
|
||||
value: The actual value to search for in the context data/metadata
|
||||
"""
|
||||
|
||||
field_name: str
|
||||
filter_type: DALMFilterType
|
||||
value: str
|
||||
_is_metadata: bool = False
|
||||
|
||||
@root_validator()
|
||||
def set_meta(cls, values: Dict) -> Dict:
|
||||
"""document and name are reserved arcee keys. Anything else is metadata"""
|
||||
values["_is_meta"] = values.get("field_name") not in ["document", "name"]
|
||||
return values
|
||||
|
||||
|
||||
class ArceeDocumentSource(BaseModel):
|
||||
"""Source of an Arcee document."""
|
||||
|
||||
document: str
|
||||
name: str
|
||||
id: str
|
||||
|
||||
|
||||
class ArceeDocument(BaseModel):
|
||||
"""Arcee document."""
|
||||
|
||||
index: str
|
||||
id: str
|
||||
score: float
|
||||
source: ArceeDocumentSource
|
||||
|
||||
|
||||
class ArceeDocumentAdapter:
|
||||
"""Adapter for Arcee documents"""
|
||||
|
||||
@classmethod
|
||||
def adapt(cls, arcee_document: ArceeDocument) -> Document:
|
||||
"""Adapts an `ArceeDocument` to a langchain's `Document` object."""
|
||||
return Document(
|
||||
page_content=arcee_document.source.document,
|
||||
metadata={
|
||||
# arcee document; source metadata
|
||||
"name": arcee_document.source.name,
|
||||
"source_id": arcee_document.source.id,
|
||||
# arcee document metadata
|
||||
"index": arcee_document.index,
|
||||
"id": arcee_document.id,
|
||||
"score": arcee_document.score,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class ArceeWrapper:
|
||||
"""Wrapper for Arcee API.
|
||||
|
||||
For more details, see: https://www.arcee.ai/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arcee_api_key: Union[str, SecretStr],
|
||||
arcee_api_url: str,
|
||||
arcee_api_version: str,
|
||||
model_kwargs: Optional[Dict[str, Any]],
|
||||
model_name: str,
|
||||
):
|
||||
"""Initialize ArceeWrapper.
|
||||
|
||||
Arguments:
|
||||
arcee_api_key: API key for Arcee API.
|
||||
arcee_api_url: URL for Arcee API.
|
||||
arcee_api_version: Version of Arcee API.
|
||||
model_kwargs: Keyword arguments for Arcee API.
|
||||
model_name: Name of an Arcee model.
|
||||
"""
|
||||
if isinstance(arcee_api_key, str):
|
||||
arcee_api_key_ = SecretStr(arcee_api_key)
|
||||
else:
|
||||
arcee_api_key_ = arcee_api_key
|
||||
self.arcee_api_key: SecretStr = arcee_api_key_
|
||||
self.model_kwargs = model_kwargs
|
||||
self.arcee_api_url = arcee_api_url
|
||||
self.arcee_api_version = arcee_api_version
|
||||
|
||||
try:
|
||||
route = ArceeRoute.model_training_status.value.format(id_or_name=model_name)
|
||||
response = self._make_request("get", route)
|
||||
self.model_id = response.get("model_id")
|
||||
self.model_training_status = response.get("status")
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error while validating model training status for '{model_name}': {e}"
|
||||
) from e
|
||||
|
||||
def validate_model_training_status(self) -> None:
|
||||
if self.model_training_status != "training_complete":
|
||||
raise Exception(
|
||||
f"Model {self.model_id} is not ready. "
|
||||
"Please wait for training to complete."
|
||||
)
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
method: Literal["post", "get"],
|
||||
route: Union[ArceeRoute, str],
|
||||
body: Optional[Mapping[str, Any]] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""Make a request to the Arcee API
|
||||
Args:
|
||||
method: The HTTP method to use
|
||||
route: The route to call
|
||||
body: The body of the request
|
||||
params: The query params of the request
|
||||
headers: The headers of the request
|
||||
"""
|
||||
headers = self._make_request_headers(headers=headers)
|
||||
url = self._make_request_url(route=route)
|
||||
|
||||
req_type = getattr(requests, method)
|
||||
|
||||
response = req_type(url, json=body, params=params, headers=headers)
|
||||
if response.status_code not in (200, 201):
|
||||
raise Exception(f"Failed to make request. Response: {response.text}")
|
||||
return response.json()
|
||||
|
||||
def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
|
||||
headers = headers or {}
|
||||
if not isinstance(self.arcee_api_key, SecretStr):
|
||||
raise TypeError(
|
||||
f"arcee_api_key must be a SecretStr. Got {type(self.arcee_api_key)}"
|
||||
)
|
||||
api_key = self.arcee_api_key.get_secret_value()
|
||||
internal_headers = {
|
||||
"X-Token": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
headers.update(internal_headers)
|
||||
return headers
|
||||
|
||||
def _make_request_url(self, route: Union[ArceeRoute, str]) -> str:
|
||||
return f"{self.arcee_api_url}/{self.arcee_api_version}/{route}"
|
||||
|
||||
def _make_request_body_for_models(
|
||||
self, prompt: str, **kwargs: Mapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
"""Make the request body for generate/retrieve models endpoint"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_params = {**_model_kwargs, **kwargs}
|
||||
|
||||
filters = [DALMFilter(**f) for f in _params.get("filters", [])]
|
||||
return dict(
|
||||
model_id=self.model_id,
|
||||
query=prompt,
|
||||
size=_params.get("size", 3),
|
||||
filters=filters,
|
||||
id=self.model_id,
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Generate text from Arcee DALM.
|
||||
|
||||
Args:
|
||||
prompt: Prompt to generate text from.
|
||||
size: The max number of context results to retrieve. Defaults to 3.
|
||||
(Can be less if filters are provided).
|
||||
filters: Filters to apply to the context dataset.
|
||||
"""
|
||||
|
||||
response = self._make_request(
|
||||
method="post",
|
||||
route=ArceeRoute.generate.value,
|
||||
body=self._make_request_body_for_models(
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
return response["text"]
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Retrieve {size} contexts with your retriever for a given query
|
||||
|
||||
Args:
|
||||
query: Query to submit to the model
|
||||
size: The max number of context results to retrieve. Defaults to 3.
|
||||
(Can be less if filters are provided).
|
||||
filters: Filters to apply to the context dataset.
|
||||
"""
|
||||
|
||||
response = self._make_request(
|
||||
method="post",
|
||||
route=ArceeRoute.retrieve.value,
|
||||
body=self._make_request_body_for_models(
|
||||
prompt=query,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
return [
|
||||
ArceeDocumentAdapter.adapt(ArceeDocument(**doc))
|
||||
for doc in response["results"]
|
||||
]
|
Reference in New Issue
Block a user