mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
improve llamacpp embeddings (#12972)
- **Description:** Improve llamacpp embedding class by adding the `device` parameter so it can be passed to the model and used with `gpu`, `cpu` or Apple metal (`mps`). Improve performance by making use of the bulk client api to compute embeddings in batches. - **Dependencies:** none - **Tag maintainer:** @hwchase17 --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
f882824eac
commit
654da27255
@ -57,6 +57,9 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
|||||||
verbose: bool = Field(True, alias="verbose")
|
verbose: bool = Field(True, alias="verbose")
|
||||||
"""Print verbose output to stderr."""
|
"""Print verbose output to stderr."""
|
||||||
|
|
||||||
|
device: Optional[str] = Field(None, alias="device")
|
||||||
|
"""Device type to use and pass to the model"""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "forbid"
|
extra = "forbid"
|
||||||
|
|
||||||
@ -75,6 +78,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
|||||||
"n_threads",
|
"n_threads",
|
||||||
"n_batch",
|
"n_batch",
|
||||||
"verbose",
|
"verbose",
|
||||||
|
"device",
|
||||||
]
|
]
|
||||||
model_params = {k: values[k] for k in model_param_names}
|
model_params = {k: values[k] for k in model_param_names}
|
||||||
# For backwards compatibility, only include if non-null.
|
# For backwards compatibility, only include if non-null.
|
||||||
@ -108,8 +112,8 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
embeddings = [self.client.embed(text) for text in texts]
|
embeddings = self.client.create_embedding(texts)
|
||||||
return [list(map(float, e)) for e in embeddings]
|
return [list(map(float, e["embedding"])) for e in embeddings["data"]]
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Embed a query using the Llama model.
|
"""Embed a query using the Llama model.
|
||||||
|
Loading…
Reference in New Issue
Block a user