mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 01:19:31 +00:00
community: OCI GenAI embedding batch size (#22986)
Thank you for contributing to LangChain! - [x] **PR title**: "community: OCI GenAI embedding batch size" - [x] **PR message**: - **Issue:** #22985 - [ ] **Add tests and docs**: N/A - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Signed-off-by: Anders Swanson <anders.swanson@oracle.com> Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
8235bae48e
commit
aacc6198b9
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
@ -80,6 +80,10 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings):
|
||||
truncate: Optional[str] = "END"
|
||||
"""Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")"""
|
||||
|
||||
batch_size: int = 96
|
||||
"""Batch size of OCI GenAI embedding requests. OCI GenAI may handle up to 96 texts
|
||||
per request"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@ -182,16 +186,23 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings):
|
||||
else:
|
||||
serving_mode = models.OnDemandServingMode(model_id=self.model_id)
|
||||
|
||||
embeddings = []
|
||||
|
||||
def split_texts() -> Iterator[List[str]]:
|
||||
for i in range(0, len(texts), self.batch_size):
|
||||
yield texts[i : i + self.batch_size]
|
||||
|
||||
for chunk in split_texts():
|
||||
invocation_obj = models.EmbedTextDetails(
|
||||
serving_mode=serving_mode,
|
||||
compartment_id=self.compartment_id,
|
||||
truncate=self.truncate,
|
||||
inputs=texts,
|
||||
inputs=chunk,
|
||||
)
|
||||
|
||||
response = self.client.embed_text(invocation_obj)
|
||||
embeddings.extend(response.data.embeddings)
|
||||
|
||||
return response.data.embeddings
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to OCIGenAI's embedding endpoint.
|
||||
|
Loading…
Reference in New Issue
Block a user