mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-26 13:21:40 +00:00 
			
		
		
		
	Moved the following modules to new package langchain-community in a backwards compatible fashion: ``` mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community ``` Moved the following to core ``` mv langchain/langchain/utils/json_schema.py core/langchain_core/utils mv langchain/langchain/utils/html.py core/langchain_core/utils mv langchain/langchain/utils/strings.py core/langchain_core/utils cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py rm langchain/langchain/utils/env.py ``` See .scripts/community_split/script_integrations.sh for all changes
		
			
				
	
	
		
			256 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			256 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # This module contains utility classes and functions for interacting with Arcee API.
 | |
| # For more information and updates, refer to the Arcee utils page:
 | |
| # [https://github.com/arcee-ai/arcee-python/blob/main/arcee/dalm.py]
 | |
| 
 | |
| from enum import Enum
 | |
| from typing import Any, Dict, List, Literal, Mapping, Optional, Union
 | |
| 
 | |
| import requests
 | |
| from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
 | |
| from langchain_core.retrievers import Document
 | |
| 
 | |
| 
 | |
| class ArceeRoute(str, Enum):
 | |
|     """Routes available for the Arcee API as enumerator."""
 | |
| 
 | |
|     generate = "models/generate"
 | |
|     retrieve = "models/retrieve"
 | |
|     model_training_status = "models/status/{id_or_name}"
 | |
| 
 | |
| 
 | |
| class DALMFilterType(str, Enum):
 | |
|     """Filter types available for a DALM retrieval as enumerator."""
 | |
| 
 | |
|     fuzzy_search = "fuzzy_search"
 | |
|     strict_search = "strict_search"
 | |
| 
 | |
| 
 | |
| class DALMFilter(BaseModel):
 | |
|     """Filters available for a DALM retrieval and generation.
 | |
| 
 | |
|     Arguments:
 | |
|         field_name: The field to filter on. Can be 'document' or 'name' to filter
 | |
|             on your document's raw text or title. Any other field will be presumed
 | |
|             to be a metadata field you included when uploading your context data
 | |
|         filter_type: Currently 'fuzzy_search' and 'strict_search' are supported.
 | |
|             'fuzzy_search' means a fuzzy search on the provided field is performed.
 | |
|             The exact strict doesn't need to exist in the document
 | |
|             for this to find a match.
 | |
|             Very useful for scanning a document for some keyword terms.
 | |
|             'strict_search' means that the exact string must appear
 | |
|             in the provided field.
 | |
|             This is NOT an exact eq filter. ie a document with content
 | |
|             "the happy dog crossed the street" will match on a strict_search of
 | |
|             "dog" but won't match on "the dog".
 | |
|             Python equivalent of `return search_string in full_string`.
 | |
|         value: The actual value to search for in the context data/metadata
 | |
|     """
 | |
| 
 | |
|     field_name: str
 | |
|     filter_type: DALMFilterType
 | |
|     value: str
 | |
|     _is_metadata: bool = False
 | |
| 
 | |
|     @root_validator()
 | |
|     def set_meta(cls, values: Dict) -> Dict:
 | |
|         """document and name are reserved arcee keys. Anything else is metadata"""
 | |
|         values["_is_meta"] = values.get("field_name") not in ["document", "name"]
 | |
|         return values
 | |
| 
 | |
| 
 | |
| class ArceeDocumentSource(BaseModel):
 | |
|     """Source of an Arcee document."""
 | |
| 
 | |
|     document: str
 | |
|     name: str
 | |
|     id: str
 | |
| 
 | |
| 
 | |
| class ArceeDocument(BaseModel):
 | |
|     """Arcee document."""
 | |
| 
 | |
|     index: str
 | |
|     id: str
 | |
|     score: float
 | |
|     source: ArceeDocumentSource
 | |
| 
 | |
| 
 | |
| class ArceeDocumentAdapter:
 | |
|     """Adapter for Arcee documents"""
 | |
| 
 | |
|     @classmethod
 | |
|     def adapt(cls, arcee_document: ArceeDocument) -> Document:
 | |
|         """Adapts an `ArceeDocument` to a langchain's `Document` object."""
 | |
|         return Document(
 | |
|             page_content=arcee_document.source.document,
 | |
|             metadata={
 | |
|                 # arcee document; source metadata
 | |
|                 "name": arcee_document.source.name,
 | |
|                 "source_id": arcee_document.source.id,
 | |
|                 # arcee document metadata
 | |
|                 "index": arcee_document.index,
 | |
|                 "id": arcee_document.id,
 | |
|                 "score": arcee_document.score,
 | |
|             },
 | |
|         )
 | |
| 
 | |
| 
 | |
| class ArceeWrapper:
 | |
|     """Wrapper for Arcee API.
 | |
| 
 | |
|     For more details, see: https://www.arcee.ai/
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         arcee_api_key: Union[str, SecretStr],
 | |
|         arcee_api_url: str,
 | |
|         arcee_api_version: str,
 | |
|         model_kwargs: Optional[Dict[str, Any]],
 | |
|         model_name: str,
 | |
|     ):
 | |
|         """Initialize ArceeWrapper.
 | |
| 
 | |
|         Arguments:
 | |
|             arcee_api_key: API key for Arcee API.
 | |
|             arcee_api_url: URL for Arcee API.
 | |
|             arcee_api_version: Version of Arcee API.
 | |
|             model_kwargs: Keyword arguments for Arcee API.
 | |
|             model_name: Name of an Arcee model.
 | |
|         """
 | |
|         if isinstance(arcee_api_key, str):
 | |
|             arcee_api_key_ = SecretStr(arcee_api_key)
 | |
|         else:
 | |
|             arcee_api_key_ = arcee_api_key
 | |
|         self.arcee_api_key: SecretStr = arcee_api_key_
 | |
|         self.model_kwargs = model_kwargs
 | |
|         self.arcee_api_url = arcee_api_url
 | |
|         self.arcee_api_version = arcee_api_version
 | |
| 
 | |
|         try:
 | |
|             route = ArceeRoute.model_training_status.value.format(id_or_name=model_name)
 | |
|             response = self._make_request("get", route)
 | |
|             self.model_id = response.get("model_id")
 | |
|             self.model_training_status = response.get("status")
 | |
|         except Exception as e:
 | |
|             raise ValueError(
 | |
|                 f"Error while validating model training status for '{model_name}': {e}"
 | |
|             ) from e
 | |
| 
 | |
|     def validate_model_training_status(self) -> None:
 | |
|         if self.model_training_status != "training_complete":
 | |
|             raise Exception(
 | |
|                 f"Model {self.model_id} is not ready. "
 | |
|                 "Please wait for training to complete."
 | |
|             )
 | |
| 
 | |
|     def _make_request(
 | |
|         self,
 | |
|         method: Literal["post", "get"],
 | |
|         route: Union[ArceeRoute, str],
 | |
|         body: Optional[Mapping[str, Any]] = None,
 | |
|         params: Optional[dict] = None,
 | |
|         headers: Optional[dict] = None,
 | |
|     ) -> dict:
 | |
|         """Make a request to the Arcee API
 | |
|         Args:
 | |
|             method: The HTTP method to use
 | |
|             route: The route to call
 | |
|             body: The body of the request
 | |
|             params: The query params of the request
 | |
|             headers: The headers of the request
 | |
|         """
 | |
|         headers = self._make_request_headers(headers=headers)
 | |
|         url = self._make_request_url(route=route)
 | |
| 
 | |
|         req_type = getattr(requests, method)
 | |
| 
 | |
|         response = req_type(url, json=body, params=params, headers=headers)
 | |
|         if response.status_code not in (200, 201):
 | |
|             raise Exception(f"Failed to make request. Response: {response.text}")
 | |
|         return response.json()
 | |
| 
 | |
|     def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
 | |
|         headers = headers or {}
 | |
|         if not isinstance(self.arcee_api_key, SecretStr):
 | |
|             raise TypeError(
 | |
|                 f"arcee_api_key must be a SecretStr. Got {type(self.arcee_api_key)}"
 | |
|             )
 | |
|         api_key = self.arcee_api_key.get_secret_value()
 | |
|         internal_headers = {
 | |
|             "X-Token": api_key,
 | |
|             "Content-Type": "application/json",
 | |
|         }
 | |
|         headers.update(internal_headers)
 | |
|         return headers
 | |
| 
 | |
|     def _make_request_url(self, route: Union[ArceeRoute, str]) -> str:
 | |
|         return f"{self.arcee_api_url}/{self.arcee_api_version}/{route}"
 | |
| 
 | |
|     def _make_request_body_for_models(
 | |
|         self, prompt: str, **kwargs: Mapping[str, Any]
 | |
|     ) -> Mapping[str, Any]:
 | |
|         """Make the request body for generate/retrieve models endpoint"""
 | |
|         _model_kwargs = self.model_kwargs or {}
 | |
|         _params = {**_model_kwargs, **kwargs}
 | |
| 
 | |
|         filters = [DALMFilter(**f) for f in _params.get("filters", [])]
 | |
|         return dict(
 | |
|             model_id=self.model_id,
 | |
|             query=prompt,
 | |
|             size=_params.get("size", 3),
 | |
|             filters=filters,
 | |
|             id=self.model_id,
 | |
|         )
 | |
| 
 | |
|     def generate(
 | |
|         self,
 | |
|         prompt: str,
 | |
|         **kwargs: Any,
 | |
|     ) -> str:
 | |
|         """Generate text from Arcee DALM.
 | |
| 
 | |
|         Args:
 | |
|             prompt: Prompt to generate text from.
 | |
|             size: The max number of context results to retrieve. Defaults to 3.
 | |
|               (Can be less if filters are provided).
 | |
|             filters: Filters to apply to the context dataset.
 | |
|         """
 | |
| 
 | |
|         response = self._make_request(
 | |
|             method="post",
 | |
|             route=ArceeRoute.generate.value,
 | |
|             body=self._make_request_body_for_models(
 | |
|                 prompt=prompt,
 | |
|                 **kwargs,
 | |
|             ),
 | |
|         )
 | |
|         return response["text"]
 | |
| 
 | |
|     def retrieve(
 | |
|         self,
 | |
|         query: str,
 | |
|         **kwargs: Any,
 | |
|     ) -> List[Document]:
 | |
|         """Retrieve {size} contexts with your retriever for a given query
 | |
| 
 | |
|         Args:
 | |
|             query: Query to submit to the model
 | |
|             size: The max number of context results to retrieve. Defaults to 3.
 | |
|               (Can be less if filters are provided).
 | |
|             filters: Filters to apply to the context dataset.
 | |
|         """
 | |
| 
 | |
|         response = self._make_request(
 | |
|             method="post",
 | |
|             route=ArceeRoute.retrieve.value,
 | |
|             body=self._make_request_body_for_models(
 | |
|                 prompt=query,
 | |
|                 **kwargs,
 | |
|             ),
 | |
|         )
 | |
|         return [
 | |
|             ArceeDocumentAdapter.adapt(ArceeDocument(**doc))
 | |
|             for doc in response["results"]
 | |
|         ]
 |