mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
improve openai embeddings (#351)
add more formal support for explicitly specifying each model, but in a backwards compatible way
This commit is contained in:
parent
428508bd75
commit
ed143b598f
@ -22,9 +22,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = "babbage"
|
||||
"""Model name to use."""
|
||||
|
||||
document_model_name: str = "text-embedding-ada-002"
|
||||
query_model_name: str = "text-embedding-ada-002"
|
||||
openai_api_key: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
@ -32,6 +31,26 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
# TODO: deprecate this
|
||||
@root_validator(pre=True)
|
||||
def get_model_names(cls, values: Dict) -> Dict:
|
||||
"""Get model names from just old model name."""
|
||||
if "model_name" in values:
|
||||
if "document_model_name" in values:
|
||||
raise ValueError(
|
||||
"Both `model_name` and `document_model_name` were provided, "
|
||||
"but only one should be."
|
||||
)
|
||||
if "query_model_name" in values:
|
||||
raise ValueError(
|
||||
"Both `model_name` and `query_model_name` were provided, "
|
||||
"but only one should be."
|
||||
)
|
||||
model_name = values.pop("model_name")
|
||||
values["document_model_name"] = f"text-search-{model_name}-doc-001"
|
||||
values["query_model_name"] = f"text-search-{model_name}-query-001"
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
@ -66,7 +85,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
responses = [
|
||||
self._embedding_func(text, engine=f"text-search-{self.model_name}-doc-001")
|
||||
self._embedding_func(text, engine=self.document_model_name)
|
||||
for text in texts
|
||||
]
|
||||
return responses
|
||||
@ -80,7 +99,5 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
embedding = self._embedding_func(
|
||||
text, engine=f"text-search-{self.model_name}-query-001"
|
||||
)
|
||||
embedding = self._embedding_func(text, engine=self.query_model_name)
|
||||
return embedding
|
||||
|
Loading…
Reference in New Issue
Block a user