mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 01:19:31 +00:00
community: OCI GenAI embedding batch size (#22986)
Thank you for contributing to LangChain! - [x] **PR title**: "community: OCI GenAI embedding batch size" - [x] **PR message**: - **Issue:** #22985 - [ ] **Add tests and docs**: N/A - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Signed-off-by: Anders Swanson <anders.swanson@oracle.com> Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
8235bae48e
commit
aacc6198b9
@ -1,5 +1,5 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
@ -80,6 +80,10 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
truncate: Optional[str] = "END"
|
truncate: Optional[str] = "END"
|
||||||
"""Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")"""
|
"""Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")"""
|
||||||
|
|
||||||
|
batch_size: int = 96
|
||||||
|
"""Batch size of OCI GenAI embedding requests. OCI GenAI may handle up to 96 texts
|
||||||
|
per request"""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
@ -182,16 +186,23 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
else:
|
else:
|
||||||
serving_mode = models.OnDemandServingMode(model_id=self.model_id)
|
serving_mode = models.OnDemandServingMode(model_id=self.model_id)
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
|
def split_texts() -> Iterator[List[str]]:
|
||||||
|
for i in range(0, len(texts), self.batch_size):
|
||||||
|
yield texts[i : i + self.batch_size]
|
||||||
|
|
||||||
|
for chunk in split_texts():
|
||||||
invocation_obj = models.EmbedTextDetails(
|
invocation_obj = models.EmbedTextDetails(
|
||||||
serving_mode=serving_mode,
|
serving_mode=serving_mode,
|
||||||
compartment_id=self.compartment_id,
|
compartment_id=self.compartment_id,
|
||||||
truncate=self.truncate,
|
truncate=self.truncate,
|
||||||
inputs=texts,
|
inputs=chunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self.client.embed_text(invocation_obj)
|
response = self.client.embed_text(invocation_obj)
|
||||||
|
embeddings.extend(response.data.embeddings)
|
||||||
|
|
||||||
return response.data.embeddings
|
return embeddings
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Call out to OCIGenAI's embedding endpoint.
|
"""Call out to OCIGenAI's embedding endpoint.
|
||||||
|
Loading…
Reference in New Issue
Block a user