mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-04 02:33:05 +00:00
core[minor], langchain[patch], experimental[patch]: Added missing py.typed
to langchain_core
(#14143)
See PR title. From what I can see, `poetry` will auto-include this. Please let me know if I am missing something here. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
f7c257553d
commit
24385a00de
0
libs/core/langchain_core/py.typed
Normal file
0
libs/core/langchain_core/py.typed
Normal file
@ -1,5 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Any, Callable, List
|
from typing import Any, Callable, List, cast
|
||||||
|
|
||||||
from langchain.prompts.chat import (
|
from langchain.prompts.chat import (
|
||||||
BaseChatPromptTemplate,
|
BaseChatPromptTemplate,
|
||||||
@ -68,9 +68,9 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
|
|||||||
time_prompt = SystemMessage(
|
time_prompt = SystemMessage(
|
||||||
content=f"The current time and date is {time.strftime('%c')}"
|
content=f"The current time and date is {time.strftime('%c')}"
|
||||||
)
|
)
|
||||||
used_tokens = self.token_counter(base_prompt.content) + self.token_counter(
|
used_tokens = self.token_counter(
|
||||||
time_prompt.content
|
cast(str, base_prompt.content)
|
||||||
)
|
) + self.token_counter(cast(str, time_prompt.content))
|
||||||
memory: VectorStoreRetriever = kwargs["memory"]
|
memory: VectorStoreRetriever = kwargs["memory"]
|
||||||
previous_messages = kwargs["messages"]
|
previous_messages = kwargs["messages"]
|
||||||
relevant_docs = memory.get_relevant_documents(str(previous_messages[-10:]))
|
relevant_docs = memory.get_relevant_documents(str(previous_messages[-10:]))
|
||||||
@ -88,7 +88,7 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
|
|||||||
f"from your past:\n{relevant_memory}\n\n"
|
f"from your past:\n{relevant_memory}\n\n"
|
||||||
)
|
)
|
||||||
memory_message = SystemMessage(content=content_format)
|
memory_message = SystemMessage(content=content_format)
|
||||||
used_tokens += self.token_counter(memory_message.content)
|
used_tokens += self.token_counter(cast(str, memory_message.content))
|
||||||
historical_messages: List[BaseMessage] = []
|
historical_messages: List[BaseMessage] = []
|
||||||
for message in previous_messages[-10:][::-1]:
|
for message in previous_messages[-10:][::-1]:
|
||||||
message_tokens = self.token_counter(message.content)
|
message_tokens = self.token_counter(message.content)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Generic Wrapper for chat LLMs, with sample implementations
|
"""Generic Wrapper for chat LLMs, with sample implementations
|
||||||
for Llama-2-chat, Llama-2-instruct and Vicuna models.
|
for Llama-2-chat, Llama-2-instruct and Vicuna models.
|
||||||
"""
|
"""
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional, cast
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
@ -90,8 +90,12 @@ class ChatWrapper(BaseChatModel):
|
|||||||
if self.usr_0_end is None:
|
if self.usr_0_end is None:
|
||||||
self.usr_0_end = self.usr_n_end
|
self.usr_0_end = self.usr_n_end
|
||||||
|
|
||||||
prompt_parts.append(self.sys_beg + messages[0].content + self.sys_end)
|
prompt_parts.append(
|
||||||
prompt_parts.append(self.usr_0_beg + messages[1].content + self.usr_0_end)
|
self.sys_beg + cast(str, messages[0].content) + self.sys_end
|
||||||
|
)
|
||||||
|
prompt_parts.append(
|
||||||
|
self.usr_0_beg + cast(str, messages[1].content) + self.usr_0_end
|
||||||
|
)
|
||||||
|
|
||||||
for ai_message, human_message in zip(messages[2::2], messages[3::2]):
|
for ai_message, human_message in zip(messages[2::2], messages[3::2]):
|
||||||
if not isinstance(ai_message, AIMessage) or not isinstance(
|
if not isinstance(ai_message, AIMessage) or not isinstance(
|
||||||
@ -102,8 +106,12 @@ class ChatWrapper(BaseChatModel):
|
|||||||
"optionally prepended by a system message"
|
"optionally prepended by a system message"
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_parts.append(self.ai_n_beg + ai_message.content + self.ai_n_end)
|
prompt_parts.append(
|
||||||
prompt_parts.append(self.usr_n_beg + human_message.content + self.usr_n_end)
|
self.ai_n_beg + cast(str, ai_message.content) + self.ai_n_end
|
||||||
|
)
|
||||||
|
prompt_parts.append(
|
||||||
|
self.usr_n_beg + cast(str, human_message.content) + self.usr_n_end
|
||||||
|
)
|
||||||
|
|
||||||
return "".join(prompt_parts)
|
return "".join(prompt_parts)
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional, cast
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.schema import AIMessage, HumanMessage
|
from langchain.schema import AIMessage, HumanMessage
|
||||||
@ -54,10 +54,10 @@ class BaseModeration:
|
|||||||
message = prompt.messages[-1]
|
message = prompt.messages[-1]
|
||||||
self.chat_message_index = len(prompt.messages) - 1
|
self.chat_message_index = len(prompt.messages) - 1
|
||||||
if isinstance(message, HumanMessage):
|
if isinstance(message, HumanMessage):
|
||||||
input_text = message.content
|
input_text = cast(str, message.content)
|
||||||
|
|
||||||
if isinstance(message, AIMessage):
|
if isinstance(message, AIMessage):
|
||||||
input_text = message.content
|
input_text = cast(str, message.content)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid input type {type(input_text)}. "
|
f"Invalid input type {type(input_text)}. "
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from html.parser import HTMLParser
|
from html.parser import HTMLParser
|
||||||
from typing import Any, DefaultDict, Dict, List, Optional
|
from typing import Any, DefaultDict, Dict, List, Optional, cast
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
@ -176,7 +176,7 @@ class AnthropicFunctions(BaseChatModel):
|
|||||||
response = self.model.predict_messages(
|
response = self.model.predict_messages(
|
||||||
messages, stop=stop, callbacks=run_manager, **kwargs
|
messages, stop=stop, callbacks=run_manager, **kwargs
|
||||||
)
|
)
|
||||||
completion = response.content
|
completion = cast(str, response.content)
|
||||||
if forced:
|
if forced:
|
||||||
tag_parser = TagParser()
|
tag_parser = TagParser()
|
||||||
|
|
||||||
@ -210,7 +210,7 @@ class AnthropicFunctions(BaseChatModel):
|
|||||||
message = AIMessage(content=msg, additional_kwargs=kwargs)
|
message = AIMessage(content=msg, additional_kwargs=kwargs)
|
||||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
else:
|
else:
|
||||||
response.content = response.content.strip()
|
response.content = cast(str, response.content).strip()
|
||||||
return ChatResult(generations=[ChatGeneration(message=response)])
|
return ChatResult(generations=[ChatGeneration(message=response)])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -239,7 +239,7 @@ def __getattr__(name: str) -> Any:
|
|||||||
|
|
||||||
return FewShotPromptTemplate
|
return FewShotPromptTemplate
|
||||||
elif name == "Prompt":
|
elif name == "Prompt":
|
||||||
from langchain_core.prompts import Prompt
|
from langchain.prompts import Prompt
|
||||||
|
|
||||||
_warn_on_import(name, replacement="langchain.prompts.Prompt")
|
_warn_on_import(name, replacement="langchain.prompts.Prompt")
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
from typing import Any, Dict, Iterator, List, Mapping, Optional, cast
|
||||||
|
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
@ -33,9 +33,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
|
|
||||||
def convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
|
def convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
|
||||||
content = _dict.get("choice", {}).get("message", {}).get("content", "")
|
content = _dict.get("choice", {}).get("message", {}).get("content", "")
|
||||||
return AIMessage(
|
return AIMessage(content=content)
|
||||||
content=content,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
|
class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
|
||||||
@ -118,7 +116,7 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
|
|||||||
msg = convert_dict_to_message(res)
|
msg = convert_dict_to_message(res)
|
||||||
yield ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
|
yield ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(msg.content)
|
run_manager.on_llm_new_token(cast(str, msg.content))
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
@ -135,7 +133,7 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
|
|||||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||||
res = self.client.chat(params)
|
res = self.client.chat(params)
|
||||||
msg = convert_dict_to_message(res)
|
msg = convert_dict_to_message(res)
|
||||||
completion = msg.content
|
completion = cast(str, msg.content)
|
||||||
|
|
||||||
message = AIMessage(content=completion)
|
message = AIMessage(content=completion)
|
||||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
@ -41,7 +41,7 @@ class ForefrontAI(LLM):
|
|||||||
repetition_penalty: int = 1
|
repetition_penalty: int = 1
|
||||||
"""Penalizes repeated tokens according to frequency."""
|
"""Penalizes repeated tokens according to frequency."""
|
||||||
|
|
||||||
forefrontai_api_key: SecretStr = None
|
forefrontai_api_key: SecretStr
|
||||||
|
|
||||||
base_url: Optional[str] = None
|
base_url: Optional[str] = None
|
||||||
"""Base url to use, if None decides based on model name."""
|
"""Base url to use, if None decides based on model name."""
|
||||||
@ -51,7 +51,7 @@ class ForefrontAI(LLM):
|
|||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=True)
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key exists in environment."""
|
"""Validate that api key exists in environment."""
|
||||||
values["forefrontai_api_key"] = convert_to_secret_str(
|
values["forefrontai_api_key"] = convert_to_secret_str(
|
||||||
|
@ -159,6 +159,7 @@ class _VertexAIBase(BaseModel):
|
|||||||
|
|
||||||
class _VertexAICommon(_VertexAIBase):
|
class _VertexAICommon(_VertexAIBase):
|
||||||
client: "_LanguageModel" = None #: :meta private:
|
client: "_LanguageModel" = None #: :meta private:
|
||||||
|
client_preview: "_LanguageModel" = None #: :meta private:
|
||||||
model_name: str
|
model_name: str
|
||||||
"Underlying model name."
|
"Underlying model name."
|
||||||
temperature: float = 0.0
|
temperature: float = 0.0
|
||||||
@ -406,13 +407,16 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
|
|||||||
values["async_client"] = PredictionServiceAsyncClient(
|
values["async_client"] = PredictionServiceAsyncClient(
|
||||||
client_options=client_options, client_info=client_info
|
client_options=client_options, client_info=client_info
|
||||||
)
|
)
|
||||||
values["endpoint_path"] = values["client"].endpoint_path(
|
|
||||||
project=values["project"],
|
|
||||||
location=values["location"],
|
|
||||||
endpoint=values["endpoint_id"],
|
|
||||||
)
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def endpoint_path(self) -> str:
|
||||||
|
return self.client.endpoint_path(
|
||||||
|
project=self.project,
|
||||||
|
location=self.location,
|
||||||
|
endpoint=self.endpoint_id,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "vertexai_model_garden"
|
return "vertexai_model_garden"
|
||||||
|
Loading…
Reference in New Issue
Block a user