community: Add Baichuan Embeddings batch size (#22942)

- **Support batch size** 
Baichuan updates the document, indicating that up to 16 documents can be
imported at a time

- **Standardized model init arg names**
    - baichuan_api_key -> api_key
    - model_name  -> model
This commit is contained in:
maang-h
2024-06-18 02:11:04 +08:00
committed by GitHub
parent 722c8f50ea
commit c6b7db6587
3 changed files with 66 additions and 22 deletions

View File

@@ -17,3 +17,13 @@ def test_baichuan_embedding_query() -> None:
embedding = BaichuanTextEmbeddings() # type: ignore[call-arg]
output = embedding.embed_query(document)
assert len(output) == 1024 # type: ignore[arg-type]
def test_baichuan_embeddings_multi_documents() -> None:
"""Test Baichuan Text Embedding for documents with multi texts."""
document = "午餐吃了螺蛳粉"
doc_amount = 35
embeddings = BaichuanTextEmbeddings() # type: ignore[call-arg]
output = embeddings.embed_documents([document] * doc_amount)
assert len(output) == doc_amount # type: ignore[arg-type]
assert len(output[0]) == 1024 # type: ignore[index]

View File

@@ -0,0 +1,18 @@
from typing import cast
from langchain_core.pydantic_v1 import SecretStr
from langchain_community.embeddings import BaichuanTextEmbeddings
def test_sparkllm_initialization_by_alias() -> None:
# Effective initialization
embeddings = BaichuanTextEmbeddings( # type: ignore[call-arg]
model="embedding_model", # type: ignore[arg-type]
api_key="your-api-key", # type: ignore[arg-type]
)
assert embeddings.model_name == "embedding_model"
assert (
cast(SecretStr, embeddings.baichuan_api_key).get_secret_value()
== "your-api-key"
)