mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 11:08:55 +00:00
Add client_name="langchain" to Cohere usage (#11328)
Hey, we're looking to invest more in adding cohere integrations to langchain so would love to get more of an idea for how it's used. Hopefully this pr is acceptable. This week I'm also going to be looking into adding our new [retrieval augmented generation product](https://txt.cohere.com/chat-with-rag/) to langchain. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
37aec1e050
commit
b1e3843931
@ -37,6 +37,8 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
"""Maximum number of retries to make when generating."""
|
||||
request_timeout: Optional[float] = None
|
||||
"""Timeout in seconds for the Cohere API request."""
|
||||
user_agent: str = "langchain"
|
||||
"""Identifier for the application making the request."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -55,11 +57,18 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
try:
|
||||
import cohere
|
||||
|
||||
client_name = values["user_agent"]
|
||||
values["client"] = cohere.Client(
|
||||
cohere_api_key, max_retries=max_retries, timeout=request_timeout
|
||||
cohere_api_key,
|
||||
max_retries=max_retries,
|
||||
timeout=request_timeout,
|
||||
client_name=client_name,
|
||||
)
|
||||
values["async_client"] = cohere.AsyncClient(
|
||||
cohere_api_key, max_retries=max_retries, timeout=request_timeout
|
||||
cohere_api_key,
|
||||
max_retries=max_retries,
|
||||
timeout=request_timeout,
|
||||
client_name=client_name,
|
||||
)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
|
@ -80,6 +80,9 @@ class BaseCohere(Serializable):
|
||||
streaming: bool = Field(default=False)
|
||||
"""Whether to stream the results."""
|
||||
|
||||
user_agent: str = "langchain"
|
||||
"""Identifier for the application making the request."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
@ -94,8 +97,11 @@ class BaseCohere(Serializable):
|
||||
cohere_api_key = get_from_dict_or_env(
|
||||
values, "cohere_api_key", "COHERE_API_KEY"
|
||||
)
|
||||
values["client"] = cohere.Client(cohere_api_key)
|
||||
values["async_client"] = cohere.AsyncClient(cohere_api_key)
|
||||
client_name = values["user_agent"]
|
||||
values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
|
||||
values["async_client"] = cohere.AsyncClient(
|
||||
cohere_api_key, client_name=client_name
|
||||
)
|
||||
return values
|
||||
|
||||
|
||||
|
@ -30,6 +30,8 @@ class CohereRerank(BaseDocumentCompressor):
|
||||
"""Model to use for reranking."""
|
||||
|
||||
cohere_api_key: Optional[str] = None
|
||||
user_agent: str = "langchain"
|
||||
"""Identifier for the application making the request."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -40,18 +42,18 @@ class CohereRerank(BaseDocumentCompressor):
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
cohere_api_key = get_from_dict_or_env(
|
||||
values, "cohere_api_key", "COHERE_API_KEY"
|
||||
)
|
||||
try:
|
||||
import cohere
|
||||
|
||||
values["client"] = cohere.Client(cohere_api_key)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import cohere python package. "
|
||||
"Please install it with `pip install cohere`."
|
||||
)
|
||||
cohere_api_key = get_from_dict_or_env(
|
||||
values, "cohere_api_key", "COHERE_API_KEY"
|
||||
)
|
||||
client_name = values["user_agent"]
|
||||
values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
|
||||
return values
|
||||
|
||||
def compress_documents(
|
||||
|
Loading…
Reference in New Issue
Block a user