mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 10:23:30 +00:00
together: fix chat model and embedding classes (#21353)
This commit is contained in:
parent
d6ef5fe86a
commit
bb81ae5c8c
@ -59,7 +59,7 @@ class ChatTogether(BaseChatOpenAI):
|
||||
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env are `TOGETHER_API_KEY` if not provided."""
|
||||
together_api_base: Optional[str] = Field(
|
||||
default="https://api.together.ai/v1/chat/completions", alias="base_url"
|
||||
default="https://api.together.ai/v1/", alias="base_url"
|
||||
)
|
||||
|
||||
@root_validator()
|
||||
|
@ -51,7 +51,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
model: str = "togethercomputer/m2-bert-80M-8k-retrieval"
|
||||
"""Embeddings model name to use. Do not add suffixes like `-query` and `-passage`.
|
||||
"""Embeddings model name to use.
|
||||
Instead, use 'togethercomputer/m2-bert-80M-8k-retrieval' for example.
|
||||
"""
|
||||
dimensions: Optional[int] = None
|
||||
@ -62,7 +62,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
||||
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
"""API Key for Solar API."""
|
||||
together_api_base: str = Field(
|
||||
default="https://api.together.ai/v1/embeddings", alias="base_url"
|
||||
default="https://api.together.ai/v1/", alias="base_url"
|
||||
)
|
||||
"""Endpoint URL to use."""
|
||||
embedding_ctx_length: int = 4096
|
||||
@ -166,12 +166,18 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
||||
"default_query": values["default_query"],
|
||||
}
|
||||
if not values.get("client"):
|
||||
sync_specific = {"http_client": values["http_client"]}
|
||||
sync_specific = (
|
||||
{"http_client": values["http_client"]} if values["http_client"] else {}
|
||||
)
|
||||
values["client"] = openai.OpenAI(
|
||||
**client_params, **sync_specific
|
||||
).embeddings
|
||||
if not values.get("async_client"):
|
||||
async_specific = {"http_client": values["http_async_client"]}
|
||||
async_specific = (
|
||||
{"http_client": values["http_async_client"]}
|
||||
if values["http_async_client"]
|
||||
else {}
|
||||
)
|
||||
values["async_client"] = openai.AsyncOpenAI(
|
||||
**client_params, **async_specific
|
||||
).embeddings
|
||||
@ -179,8 +185,6 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
self.model = self.model.replace("-query", "").replace("-passage", "")
|
||||
|
||||
params: Dict = {"model": self.model, **self.model_kwargs}
|
||||
if self.dimensions is not None:
|
||||
params["dimensions"] = self.dimensions
|
||||
@ -197,7 +201,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
embeddings = []
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-passage"
|
||||
params["model"] = params["model"]
|
||||
|
||||
for text in texts:
|
||||
response = self.client.create(input=text, **params)
|
||||
@ -217,7 +221,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
||||
Embedding for the text.
|
||||
"""
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-query"
|
||||
params["model"] = params["model"]
|
||||
|
||||
response = self.client.create(input=text, **params)
|
||||
|
||||
@ -236,7 +240,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
embeddings = []
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-passage"
|
||||
params["model"] = params["model"]
|
||||
|
||||
for text in texts:
|
||||
response = await self.async_client.create(input=text, **params)
|
||||
@ -256,7 +260,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
||||
Embedding for the text.
|
||||
"""
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-query"
|
||||
params["model"] = params["model"]
|
||||
|
||||
response = await self.async_client.create(input=text, **params)
|
||||
|
||||
|
@ -17,5 +17,5 @@ class TestTogethertandard(ChatModelIntegrationTests):
|
||||
@pytest.fixture
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "meta-llama/Llama-3-8b-chat-hf",
|
||||
"model": "mistralai/Mistral-7B-Instruct-v0.1",
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user