Add client_name="langchain" to Cohere usage (#11328)

Hey, we're looking to invest more in adding cohere integrations to
langchain so would love to get more of an idea for how it's used.
Hopefully this pr is acceptable. This week I'm also going to be looking
into adding our new [retrieval augmented generation
product](https://txt.cohere.com/chat-with-rag/) to langchain.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
billytrend-cohere 2023-10-30 18:20:55 +00:00 committed by GitHub
parent 37aec1e050
commit b1e3843931
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 9 deletions

View File

@ -37,6 +37,8 @@ class CohereEmbeddings(BaseModel, Embeddings):
"""Maximum number of retries to make when generating."""
request_timeout: Optional[float] = None
"""Timeout in seconds for the Cohere API request."""
user_agent: str = "langchain"
"""Identifier for the application making the request."""
class Config:
"""Configuration for this pydantic object."""
@ -55,11 +57,18 @@ class CohereEmbeddings(BaseModel, Embeddings):
try:
import cohere
client_name = values["user_agent"]
values["client"] = cohere.Client(
cohere_api_key, max_retries=max_retries, timeout=request_timeout
cohere_api_key,
max_retries=max_retries,
timeout=request_timeout,
client_name=client_name,
)
values["async_client"] = cohere.AsyncClient(
cohere_api_key, max_retries=max_retries, timeout=request_timeout
cohere_api_key,
max_retries=max_retries,
timeout=request_timeout,
client_name=client_name,
)
except ImportError:
raise ValueError(

View File

@ -80,6 +80,9 @@ class BaseCohere(Serializable):
streaming: bool = Field(default=False)
"""Whether to stream the results."""
user_agent: str = "langchain"
"""Identifier for the application making the request."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
@ -94,8 +97,11 @@ class BaseCohere(Serializable):
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
values["client"] = cohere.Client(cohere_api_key)
values["async_client"] = cohere.AsyncClient(cohere_api_key)
client_name = values["user_agent"]
values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
values["async_client"] = cohere.AsyncClient(
cohere_api_key, client_name=client_name
)
return values

View File

@ -30,6 +30,8 @@ class CohereRerank(BaseDocumentCompressor):
"""Model to use for reranking."""
cohere_api_key: Optional[str] = None
user_agent: str = "langchain"
"""Identifier for the application making the request."""
class Config:
"""Configuration for this pydantic object."""
@ -40,18 +42,18 @@ class CohereRerank(BaseDocumentCompressor):
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
try:
import cohere
values["client"] = cohere.Client(cohere_api_key)
except ImportError:
raise ImportError(
"Could not import cohere python package. "
"Please install it with `pip install cohere`."
)
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
client_name = values["user_agent"]
values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
return values
def compress_documents(