import itertools
import json
from collections.abc import Iterable
from pathlib import Path
from typing import Any, TextIO

import gradio as gr  # type: ignore
from fastapi import FastAPI
from gradio.themes.utils.colors import slate  # type: ignore
from llama_index.llms import ChatMessage, ChatResponse, MessageRole

from private_gpt.di import root_injector
from private_gpt.server.chat.chat_service import ChatService
from private_gpt.server.chunks.chunks_service import ChunksService
from private_gpt.server.ingest.ingest_service import IngestService
from private_gpt.settings.settings import settings
from private_gpt.ui.images import logo_svg

ingest_service = root_injector.get(IngestService)
chat_service = root_injector.get(ChatService)
chunks_service = root_injector.get(ChunksService)


def _chat(message: str, history: list[list[str]], mode: str, *_: Any) -> Any:
    def yield_deltas(stream: Iterable[ChatResponse | str]) -> Iterable[str]:
        full_response: str = ""
        for delta in stream:
            if isinstance(delta, str):
                full_response += str(delta)
            elif isinstance(delta, ChatResponse):
                full_response += delta.delta or ""
            yield full_response

    def build_history() -> list[ChatMessage]:
        history_messages: list[ChatMessage] = list(
            itertools.chain(
                *[
                    [
                        ChatMessage(content=interaction[0], role=MessageRole.USER),
                        ChatMessage(content=interaction[1], role=MessageRole.ASSISTANT),
                    ]
                    for interaction in history
                ]
            )
        )

        # max 20 messages to try to avoid context overflow
        return history_messages[:20]

    new_message = ChatMessage(content=message, role=MessageRole.USER)
    all_messages = [*build_history(), new_message]
    match mode:
        case "Query Documents":
            query_stream = chat_service.stream_chat(
                messages=all_messages,
                use_context=True,
            )
            yield from yield_deltas(query_stream)

        case "LLM Chat":
            llm_stream = chat_service.stream_chat(
                messages=all_messages,
                use_context=False,
            )
            yield from yield_deltas(llm_stream)

        case "Context Chunks":
            response = chunks_service.retrieve_relevant(
                text=message,
                limit=2,
                prev_next_chunks=1,
            ).__iter__()
            yield "```" + json.dumps(
                [node.__dict__ for node in response],
                default=lambda o: o.__dict__,
                indent=2,
            )


def _list_ingested_files() -> list[str]:
    files = set()
    for ingested_document in ingest_service.list_ingested():
        if ingested_document.doc_metadata is not None:
            files.add(
                ingested_document.doc_metadata.get("file_name") or "[FILE NAME MISSING]"
            )
    return list(files)


# Global state
_uploaded_file_list = [[row] for row in _list_ingested_files()]


def _upload_file(file: TextIO) -> list[list[str]]:
    path = Path(file.name)
    ingest_service.ingest(file_name=path.name, file_data=path)
    _uploaded_file_list.append([path.name])
    return _uploaded_file_list


with gr.Blocks(
    theme=gr.themes.Soft(primary_hue=slate),
    css=".logo { "
    "display:flex;"
    "background-color: #C7BAFF;"
    "height: 80px;"
    "border-radius: 8px;"
    "align-content: center;"
    "justify-content: center;"
    "align-items: center;"
    "}"
    ".logo img { height: 25% }",
) as blocks:
    with gr.Blocks(), gr.Row():
        gr.HTML(f"<div class='logo'/><img src={logo_svg} alt=PrivateGPT></div")

    with gr.Row():
        with gr.Column(scale=3, variant="compact"):
            mode = gr.Radio(
                ["Query Documents", "LLM Chat", "Context Chunks"],
                label="Mode",
                value="Query Documents",
            )
            upload_button = gr.components.UploadButton(
                "Upload a File",
                type="file",
                file_count="single",
                size="sm",
            )
            ingested_dataset = gr.List(
                _uploaded_file_list,
                headers=["File name"],
                label="Ingested Files",
                interactive=False,
                render=False,  # Rendered under the button
            )
            upload_button.upload(
                _upload_file, inputs=upload_button, outputs=ingested_dataset
            )
            ingested_dataset.render()
        with gr.Column(scale=7):
            chatbot = gr.ChatInterface(
                _chat,
                chatbot=gr.Chatbot(
                    label="Chat",
                    show_copy_button=True,
                    render=False,
                    avatar_images=(
                        None,
                        "https://lh3.googleusercontent.com/drive-viewer/AK7aPa"
                        "AicXck0k68nsscyfKrb18o9ak3BSaWM_Qzm338cKoQlw72Bp0UKN84"
                        "IFZjXjZApY01mtnUXDeL4qzwhkALoe_53AhwCg=s2560",
                    ),
                ),
                additional_inputs=[mode, upload_button],
            )


def mount_in_app(app: FastAPI) -> None:
    blocks.queue()
    gr.mount_gradio_app(app, blocks, path=settings.ui.path)


if __name__ == "__main__":
    blocks.queue()
    blocks.launch(debug=False, show_api=False)