mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-11-04 02:03:32 +00:00 
			
		
		
		
	Addded missed docstrings. Fixed inconsistency in docstrings. **Note** CC @efriis There were PR errors on `langchain_experimental/prompt_injection_identifier/hugging_face_identifier.py` But, I didn't touch this file in this PR! Can it be some cache problems? I fixed this error.
		
			
				
	
	
		
			362 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			362 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from typing import (
 | 
						|
    Any,
 | 
						|
    AsyncIterator,
 | 
						|
    Callable,
 | 
						|
    Dict,
 | 
						|
    Iterator,
 | 
						|
    List,
 | 
						|
    Optional,
 | 
						|
    Type,
 | 
						|
    Union,
 | 
						|
)
 | 
						|
 | 
						|
from langchain_core.callbacks import (
 | 
						|
    AsyncCallbackManagerForLLMRun,
 | 
						|
    CallbackManagerForLLMRun,
 | 
						|
)
 | 
						|
from langchain_core.language_models.chat_models import BaseChatModel
 | 
						|
from langchain_core.language_models.llms import create_base_retry_decorator
 | 
						|
from langchain_core.messages import (
 | 
						|
    AIMessage,
 | 
						|
    AIMessageChunk,
 | 
						|
    BaseMessage,
 | 
						|
    BaseMessageChunk,
 | 
						|
    ChatMessage,
 | 
						|
    ChatMessageChunk,
 | 
						|
    FunctionMessage,
 | 
						|
    FunctionMessageChunk,
 | 
						|
    HumanMessage,
 | 
						|
    HumanMessageChunk,
 | 
						|
    SystemMessage,
 | 
						|
    SystemMessageChunk,
 | 
						|
)
 | 
						|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
 | 
						|
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
 | 
						|
from langchain_core.utils import convert_to_secret_str
 | 
						|
from langchain_core.utils.env import get_from_dict_or_env
 | 
						|
 | 
						|
from langchain_community.adapters.openai import convert_message_to_dict
 | 
						|
 | 
						|
 | 
						|
def _convert_delta_to_message_chunk(
 | 
						|
    _dict: Any, default_class: Type[BaseMessageChunk]
 | 
						|
) -> BaseMessageChunk:
 | 
						|
    """Convert a delta response to a message chunk."""
 | 
						|
    role = _dict.role
 | 
						|
    content = _dict.content or ""
 | 
						|
    additional_kwargs: Dict = {}
 | 
						|
 | 
						|
    if role == "user" or default_class == HumanMessageChunk:
 | 
						|
        return HumanMessageChunk(content=content)
 | 
						|
    elif role == "assistant" or default_class == AIMessageChunk:
 | 
						|
        return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
 | 
						|
    elif role == "system" or default_class == SystemMessageChunk:
 | 
						|
        return SystemMessageChunk(content=content)
 | 
						|
    elif role == "function" or default_class == FunctionMessageChunk:
 | 
						|
        return FunctionMessageChunk(content=content, name=_dict.name)
 | 
						|
    elif role or default_class == ChatMessageChunk:
 | 
						|
        return ChatMessageChunk(content=content, role=role)
 | 
						|
    else:
 | 
						|
        return default_class(content=content)
 | 
						|
 | 
						|
 | 
						|
def convert_dict_to_message(_dict: Any) -> BaseMessage:
 | 
						|
    """Convert a dict response to a message."""
 | 
						|
    role = _dict.role
 | 
						|
    content = _dict.content or ""
 | 
						|
    if role == "user":
 | 
						|
        return HumanMessage(content=content)
 | 
						|
    elif role == "assistant":
 | 
						|
        content = _dict.content
 | 
						|
        additional_kwargs: Dict = {}
 | 
						|
        return AIMessage(content=content, additional_kwargs=additional_kwargs)
 | 
						|
    elif role == "system":
 | 
						|
        return SystemMessage(content=content)
 | 
						|
    elif role == "function":
 | 
						|
        return FunctionMessage(content=content, name=_dict.name)
 | 
						|
    else:
 | 
						|
        return ChatMessage(content=content, role=role)
 | 
						|
 | 
						|
 | 
						|
class ChatFireworks(BaseChatModel):
 | 
						|
    """Fireworks Chat models."""
 | 
						|
 | 
						|
    model: str = "accounts/fireworks/models/llama-v2-7b-chat"
 | 
						|
    model_kwargs: dict = Field(
 | 
						|
        default_factory=lambda: {
 | 
						|
            "temperature": 0.7,
 | 
						|
            "max_tokens": 512,
 | 
						|
            "top_p": 1,
 | 
						|
        }.copy()
 | 
						|
    )
 | 
						|
    fireworks_api_key: Optional[SecretStr] = None
 | 
						|
    max_retries: int = 20
 | 
						|
    use_retry: bool = True
 | 
						|
 | 
						|
    @property
 | 
						|
    def lc_secrets(self) -> Dict[str, str]:
 | 
						|
        return {"fireworks_api_key": "FIREWORKS_API_KEY"}
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def is_lc_serializable(cls) -> bool:
 | 
						|
        return True
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def get_lc_namespace(cls) -> List[str]:
 | 
						|
        """Get the namespace of the langchain object."""
 | 
						|
        return ["langchain", "chat_models", "fireworks"]
 | 
						|
 | 
						|
    @root_validator()
 | 
						|
    def validate_environment(cls, values: Dict) -> Dict:
 | 
						|
        """Validate that api key in environment."""
 | 
						|
        try:
 | 
						|
            import fireworks.client
 | 
						|
        except ImportError as e:
 | 
						|
            raise ImportError(
 | 
						|
                "Could not import fireworks-ai python package. "
 | 
						|
                "Please install it with `pip install fireworks-ai`."
 | 
						|
            ) from e
 | 
						|
        fireworks_api_key = convert_to_secret_str(
 | 
						|
            get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
 | 
						|
        )
 | 
						|
        fireworks.client.api_key = fireworks_api_key.get_secret_value()
 | 
						|
        return values
 | 
						|
 | 
						|
    @property
 | 
						|
    def _llm_type(self) -> str:
 | 
						|
        """Return type of llm."""
 | 
						|
        return "fireworks-chat"
 | 
						|
 | 
						|
    def _generate(
 | 
						|
        self,
 | 
						|
        messages: List[BaseMessage],
 | 
						|
        stop: Optional[List[str]] = None,
 | 
						|
        run_manager: Optional[CallbackManagerForLLMRun] = None,
 | 
						|
        **kwargs: Any,
 | 
						|
    ) -> ChatResult:
 | 
						|
        message_dicts = self._create_message_dicts(messages)
 | 
						|
 | 
						|
        params = {
 | 
						|
            "model": self.model,
 | 
						|
            "messages": message_dicts,
 | 
						|
            **self.model_kwargs,
 | 
						|
            **kwargs,
 | 
						|
        }
 | 
						|
        response = completion_with_retry(
 | 
						|
            self,
 | 
						|
            self.use_retry,
 | 
						|
            run_manager=run_manager,
 | 
						|
            stop=stop,
 | 
						|
            **params,
 | 
						|
        )
 | 
						|
        return self._create_chat_result(response)
 | 
						|
 | 
						|
    async def _agenerate(
 | 
						|
        self,
 | 
						|
        messages: List[BaseMessage],
 | 
						|
        stop: Optional[List[str]] = None,
 | 
						|
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
 | 
						|
        **kwargs: Any,
 | 
						|
    ) -> ChatResult:
 | 
						|
        message_dicts = self._create_message_dicts(messages)
 | 
						|
        params = {
 | 
						|
            "model": self.model,
 | 
						|
            "messages": message_dicts,
 | 
						|
            **self.model_kwargs,
 | 
						|
            **kwargs,
 | 
						|
        }
 | 
						|
        response = await acompletion_with_retry(
 | 
						|
            self, self.use_retry, run_manager=run_manager, stop=stop, **params
 | 
						|
        )
 | 
						|
        return self._create_chat_result(response)
 | 
						|
 | 
						|
    def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
 | 
						|
        if llm_outputs[0] is None:
 | 
						|
            return {}
 | 
						|
        return llm_outputs[0]
 | 
						|
 | 
						|
    def _create_chat_result(self, response: Any) -> ChatResult:
 | 
						|
        generations = []
 | 
						|
        for res in response.choices:
 | 
						|
            message = convert_dict_to_message(res.message)
 | 
						|
            gen = ChatGeneration(
 | 
						|
                message=message,
 | 
						|
                generation_info=dict(finish_reason=res.finish_reason),
 | 
						|
            )
 | 
						|
            generations.append(gen)
 | 
						|
        llm_output = {"model": self.model}
 | 
						|
        return ChatResult(generations=generations, llm_output=llm_output)
 | 
						|
 | 
						|
    def _create_message_dicts(
 | 
						|
        self, messages: List[BaseMessage]
 | 
						|
    ) -> List[Dict[str, Any]]:
 | 
						|
        message_dicts = [convert_message_to_dict(m) for m in messages]
 | 
						|
        return message_dicts
 | 
						|
 | 
						|
    def _stream(
 | 
						|
        self,
 | 
						|
        messages: List[BaseMessage],
 | 
						|
        stop: Optional[List[str]] = None,
 | 
						|
        run_manager: Optional[CallbackManagerForLLMRun] = None,
 | 
						|
        **kwargs: Any,
 | 
						|
    ) -> Iterator[ChatGenerationChunk]:
 | 
						|
        message_dicts = self._create_message_dicts(messages)
 | 
						|
        default_chunk_class = AIMessageChunk
 | 
						|
        params = {
 | 
						|
            "model": self.model,
 | 
						|
            "messages": message_dicts,
 | 
						|
            "stream": True,
 | 
						|
            **self.model_kwargs,
 | 
						|
            **kwargs,
 | 
						|
        }
 | 
						|
        for chunk in completion_with_retry(
 | 
						|
            self, self.use_retry, run_manager=run_manager, stop=stop, **params
 | 
						|
        ):
 | 
						|
            choice = chunk.choices[0]
 | 
						|
            chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
 | 
						|
            finish_reason = choice.finish_reason
 | 
						|
            generation_info = (
 | 
						|
                dict(finish_reason=finish_reason) if finish_reason is not None else None
 | 
						|
            )
 | 
						|
            default_chunk_class = chunk.__class__
 | 
						|
            chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
 | 
						|
            yield chunk
 | 
						|
            if run_manager:
 | 
						|
                run_manager.on_llm_new_token(chunk.text, chunk=chunk)
 | 
						|
 | 
						|
    async def _astream(
 | 
						|
        self,
 | 
						|
        messages: List[BaseMessage],
 | 
						|
        stop: Optional[List[str]] = None,
 | 
						|
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
 | 
						|
        **kwargs: Any,
 | 
						|
    ) -> AsyncIterator[ChatGenerationChunk]:
 | 
						|
        message_dicts = self._create_message_dicts(messages)
 | 
						|
        default_chunk_class = AIMessageChunk
 | 
						|
        params = {
 | 
						|
            "model": self.model,
 | 
						|
            "messages": message_dicts,
 | 
						|
            "stream": True,
 | 
						|
            **self.model_kwargs,
 | 
						|
            **kwargs,
 | 
						|
        }
 | 
						|
        async for chunk in await acompletion_with_retry_streaming(
 | 
						|
            self, self.use_retry, run_manager=run_manager, stop=stop, **params
 | 
						|
        ):
 | 
						|
            choice = chunk.choices[0]
 | 
						|
            chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
 | 
						|
            finish_reason = choice.finish_reason
 | 
						|
            generation_info = (
 | 
						|
                dict(finish_reason=finish_reason) if finish_reason is not None else None
 | 
						|
            )
 | 
						|
            default_chunk_class = chunk.__class__
 | 
						|
            chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
 | 
						|
            yield chunk
 | 
						|
            if run_manager:
 | 
						|
                await run_manager.on_llm_new_token(token=chunk.text, chunk=chunk)
 | 
						|
 | 
						|
 | 
						|
def conditional_decorator(
 | 
						|
    condition: bool, decorator: Callable[[Any], Any]
 | 
						|
) -> Callable[[Any], Any]:
 | 
						|
    """Define conditional decorator.
 | 
						|
 | 
						|
    Args:
 | 
						|
        condition: The condition.
 | 
						|
        decorator: The decorator.
 | 
						|
 | 
						|
    Returns:
 | 
						|
        The decorated function.
 | 
						|
    """
 | 
						|
 | 
						|
    def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
 | 
						|
        if condition:
 | 
						|
            return decorator(func)
 | 
						|
        return func
 | 
						|
 | 
						|
    return actual_decorator
 | 
						|
 | 
						|
 | 
						|
def completion_with_retry(
 | 
						|
    llm: ChatFireworks,
 | 
						|
    use_retry: bool,
 | 
						|
    *,
 | 
						|
    run_manager: Optional[CallbackManagerForLLMRun] = None,
 | 
						|
    **kwargs: Any,
 | 
						|
) -> Any:
 | 
						|
    """Use tenacity to retry the completion call."""
 | 
						|
    import fireworks.client
 | 
						|
 | 
						|
    retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
 | 
						|
 | 
						|
    @conditional_decorator(use_retry, retry_decorator)
 | 
						|
    def _completion_with_retry(**kwargs: Any) -> Any:
 | 
						|
        """Use tenacity to retry the completion call."""
 | 
						|
        return fireworks.client.ChatCompletion.create(
 | 
						|
            **kwargs,
 | 
						|
        )
 | 
						|
 | 
						|
    return _completion_with_retry(**kwargs)
 | 
						|
 | 
						|
 | 
						|
async def acompletion_with_retry(
 | 
						|
    llm: ChatFireworks,
 | 
						|
    use_retry: bool,
 | 
						|
    *,
 | 
						|
    run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
 | 
						|
    **kwargs: Any,
 | 
						|
) -> Any:
 | 
						|
    """Use tenacity to retry the async completion call."""
 | 
						|
    import fireworks.client
 | 
						|
 | 
						|
    retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
 | 
						|
 | 
						|
    @conditional_decorator(use_retry, retry_decorator)
 | 
						|
    async def _completion_with_retry(**kwargs: Any) -> Any:
 | 
						|
        return await fireworks.client.ChatCompletion.acreate(
 | 
						|
            **kwargs,
 | 
						|
        )
 | 
						|
 | 
						|
    return await _completion_with_retry(**kwargs)
 | 
						|
 | 
						|
 | 
						|
async def acompletion_with_retry_streaming(
 | 
						|
    llm: ChatFireworks,
 | 
						|
    use_retry: bool,
 | 
						|
    *,
 | 
						|
    run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
 | 
						|
    **kwargs: Any,
 | 
						|
) -> Any:
 | 
						|
    """Use tenacity to retry the completion call for streaming."""
 | 
						|
    import fireworks.client
 | 
						|
 | 
						|
    retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
 | 
						|
 | 
						|
    @conditional_decorator(use_retry, retry_decorator)
 | 
						|
    async def _completion_with_retry(**kwargs: Any) -> Any:
 | 
						|
        return fireworks.client.ChatCompletion.acreate(
 | 
						|
            **kwargs,
 | 
						|
        )
 | 
						|
 | 
						|
    return await _completion_with_retry(**kwargs)
 | 
						|
 | 
						|
 | 
						|
def _create_retry_decorator(
 | 
						|
    llm: ChatFireworks,
 | 
						|
    run_manager: Optional[
 | 
						|
        Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
 | 
						|
    ] = None,
 | 
						|
) -> Callable[[Any], Any]:
 | 
						|
    """Define retry mechanism."""
 | 
						|
    import fireworks.client
 | 
						|
 | 
						|
    errors = [
 | 
						|
        fireworks.client.error.RateLimitError,
 | 
						|
        fireworks.client.error.InternalServerError,
 | 
						|
        fireworks.client.error.BadGatewayError,
 | 
						|
        fireworks.client.error.ServiceUnavailableError,
 | 
						|
    ]
 | 
						|
    return create_base_retry_decorator(
 | 
						|
        error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
 | 
						|
    )
 |