core[minor], langchain[patch], experimental[patch]: Added missing py.typed to langchain_core (#14143)

See PR title.

From what I can see, `poetry` will auto-include this. Please let me know
if I am missing something here.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
James Braza 2023-12-01 22:15:23 -05:00 committed by GitHub
parent f7c257553d
commit 24385a00de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 40 additions and 30 deletions

View File

View File

@ -1,5 +1,5 @@
import time import time
from typing import Any, Callable, List from typing import Any, Callable, List, cast
from langchain.prompts.chat import ( from langchain.prompts.chat import (
BaseChatPromptTemplate, BaseChatPromptTemplate,
@ -68,9 +68,9 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
time_prompt = SystemMessage( time_prompt = SystemMessage(
content=f"The current time and date is {time.strftime('%c')}" content=f"The current time and date is {time.strftime('%c')}"
) )
used_tokens = self.token_counter(base_prompt.content) + self.token_counter( used_tokens = self.token_counter(
time_prompt.content cast(str, base_prompt.content)
) ) + self.token_counter(cast(str, time_prompt.content))
memory: VectorStoreRetriever = kwargs["memory"] memory: VectorStoreRetriever = kwargs["memory"]
previous_messages = kwargs["messages"] previous_messages = kwargs["messages"]
relevant_docs = memory.get_relevant_documents(str(previous_messages[-10:])) relevant_docs = memory.get_relevant_documents(str(previous_messages[-10:]))
@ -88,7 +88,7 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
f"from your past:\n{relevant_memory}\n\n" f"from your past:\n{relevant_memory}\n\n"
) )
memory_message = SystemMessage(content=content_format) memory_message = SystemMessage(content=content_format)
used_tokens += self.token_counter(memory_message.content) used_tokens += self.token_counter(cast(str, memory_message.content))
historical_messages: List[BaseMessage] = [] historical_messages: List[BaseMessage] = []
for message in previous_messages[-10:][::-1]: for message in previous_messages[-10:][::-1]:
message_tokens = self.token_counter(message.content) message_tokens = self.token_counter(message.content)

View File

@ -1,7 +1,7 @@
"""Generic Wrapper for chat LLMs, with sample implementations """Generic Wrapper for chat LLMs, with sample implementations
for Llama-2-chat, Llama-2-instruct and Vicuna models. for Llama-2-chat, Llama-2-instruct and Vicuna models.
""" """
from typing import Any, List, Optional from typing import Any, List, Optional, cast
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -90,8 +90,12 @@ class ChatWrapper(BaseChatModel):
if self.usr_0_end is None: if self.usr_0_end is None:
self.usr_0_end = self.usr_n_end self.usr_0_end = self.usr_n_end
prompt_parts.append(self.sys_beg + messages[0].content + self.sys_end) prompt_parts.append(
prompt_parts.append(self.usr_0_beg + messages[1].content + self.usr_0_end) self.sys_beg + cast(str, messages[0].content) + self.sys_end
)
prompt_parts.append(
self.usr_0_beg + cast(str, messages[1].content) + self.usr_0_end
)
for ai_message, human_message in zip(messages[2::2], messages[3::2]): for ai_message, human_message in zip(messages[2::2], messages[3::2]):
if not isinstance(ai_message, AIMessage) or not isinstance( if not isinstance(ai_message, AIMessage) or not isinstance(
@ -102,8 +106,12 @@ class ChatWrapper(BaseChatModel):
"optionally prepended by a system message" "optionally prepended by a system message"
) )
prompt_parts.append(self.ai_n_beg + ai_message.content + self.ai_n_end) prompt_parts.append(
prompt_parts.append(self.usr_n_beg + human_message.content + self.usr_n_end) self.ai_n_beg + cast(str, ai_message.content) + self.ai_n_end
)
prompt_parts.append(
self.usr_n_beg + cast(str, human_message.content) + self.usr_n_end
)
return "".join(prompt_parts) return "".join(prompt_parts)

View File

@ -1,5 +1,5 @@
import uuid import uuid
from typing import Any, Callable, Optional from typing import Any, Callable, Optional, cast
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import AIMessage, HumanMessage from langchain.schema import AIMessage, HumanMessage
@ -54,10 +54,10 @@ class BaseModeration:
message = prompt.messages[-1] message = prompt.messages[-1]
self.chat_message_index = len(prompt.messages) - 1 self.chat_message_index = len(prompt.messages) - 1
if isinstance(message, HumanMessage): if isinstance(message, HumanMessage):
input_text = message.content input_text = cast(str, message.content)
if isinstance(message, AIMessage): if isinstance(message, AIMessage):
input_text = message.content input_text = cast(str, message.content)
else: else:
raise ValueError( raise ValueError(
f"Invalid input type {type(input_text)}. " f"Invalid input type {type(input_text)}. "

View File

@ -1,7 +1,7 @@
import json import json
from collections import defaultdict from collections import defaultdict
from html.parser import HTMLParser from html.parser import HTMLParser
from typing import Any, DefaultDict, Dict, List, Optional from typing import Any, DefaultDict, Dict, List, Optional, cast
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
@ -176,7 +176,7 @@ class AnthropicFunctions(BaseChatModel):
response = self.model.predict_messages( response = self.model.predict_messages(
messages, stop=stop, callbacks=run_manager, **kwargs messages, stop=stop, callbacks=run_manager, **kwargs
) )
completion = response.content completion = cast(str, response.content)
if forced: if forced:
tag_parser = TagParser() tag_parser = TagParser()
@ -210,7 +210,7 @@ class AnthropicFunctions(BaseChatModel):
message = AIMessage(content=msg, additional_kwargs=kwargs) message = AIMessage(content=msg, additional_kwargs=kwargs)
return ChatResult(generations=[ChatGeneration(message=message)]) return ChatResult(generations=[ChatGeneration(message=message)])
else: else:
response.content = response.content.strip() response.content = cast(str, response.content).strip()
return ChatResult(generations=[ChatGeneration(message=response)]) return ChatResult(generations=[ChatGeneration(message=response)])
@property @property

View File

@ -239,7 +239,7 @@ def __getattr__(name: str) -> Any:
return FewShotPromptTemplate return FewShotPromptTemplate
elif name == "Prompt": elif name == "Prompt":
from langchain_core.prompts import Prompt from langchain.prompts import Prompt
_warn_on_import(name, replacement="langchain.prompts.Prompt") _warn_on_import(name, replacement="langchain.prompts.Prompt")

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, Iterator, List, Mapping, Optional from typing import Any, Dict, Iterator, List, Mapping, Optional, cast
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
@ -33,9 +33,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
def convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage: def convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
content = _dict.get("choice", {}).get("message", {}).get("content", "") content = _dict.get("choice", {}).get("message", {}).get("content", "")
return AIMessage( return AIMessage(content=content)
content=content,
)
class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase): class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
@ -118,7 +116,7 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
msg = convert_dict_to_message(res) msg = convert_dict_to_message(res)
yield ChatGenerationChunk(message=AIMessageChunk(content=msg.content)) yield ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
if run_manager: if run_manager:
run_manager.on_llm_new_token(msg.content) run_manager.on_llm_new_token(cast(str, msg.content))
def _generate( def _generate(
self, self,
@ -135,7 +133,7 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
params = self._convert_prompt_msg_params(messages, **kwargs) params = self._convert_prompt_msg_params(messages, **kwargs)
res = self.client.chat(params) res = self.client.chat(params)
msg = convert_dict_to_message(res) msg = convert_dict_to_message(res)
completion = msg.content completion = cast(str, msg.content)
message = AIMessage(content=completion) message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)]) return ChatResult(generations=[ChatGeneration(message=message)])

View File

@ -41,7 +41,7 @@ class ForefrontAI(LLM):
repetition_penalty: int = 1 repetition_penalty: int = 1
"""Penalizes repeated tokens according to frequency.""" """Penalizes repeated tokens according to frequency."""
forefrontai_api_key: SecretStr = None forefrontai_api_key: SecretStr
base_url: Optional[str] = None base_url: Optional[str] = None
"""Base url to use, if None decides based on model name.""" """Base url to use, if None decides based on model name."""
@ -51,7 +51,7 @@ class ForefrontAI(LLM):
extra = Extra.forbid extra = Extra.forbid
@root_validator() @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment.""" """Validate that api key exists in environment."""
values["forefrontai_api_key"] = convert_to_secret_str( values["forefrontai_api_key"] = convert_to_secret_str(

View File

@ -159,6 +159,7 @@ class _VertexAIBase(BaseModel):
class _VertexAICommon(_VertexAIBase): class _VertexAICommon(_VertexAIBase):
client: "_LanguageModel" = None #: :meta private: client: "_LanguageModel" = None #: :meta private:
client_preview: "_LanguageModel" = None #: :meta private:
model_name: str model_name: str
"Underlying model name." "Underlying model name."
temperature: float = 0.0 temperature: float = 0.0
@ -406,13 +407,16 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
values["async_client"] = PredictionServiceAsyncClient( values["async_client"] = PredictionServiceAsyncClient(
client_options=client_options, client_info=client_info client_options=client_options, client_info=client_info
) )
values["endpoint_path"] = values["client"].endpoint_path(
project=values["project"],
location=values["location"],
endpoint=values["endpoint_id"],
)
return values return values
@property
def endpoint_path(self) -> str:
return self.client.endpoint_path(
project=self.project,
location=self.location,
endpoint=self.endpoint_id,
)
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
return "vertexai_model_garden" return "vertexai_model_garden"