"""
title: DigDash Chatbot Function
author: DigDash
version: 0.3.0
required_open_webui_version: 0.6.5
requirements: fastmcp, aiohttp
"""

import aiohttp
import ast
import asyncio
import json
from fastapi import Request
from fastmcp import Client
from fastmcp.client.transports import StreamableHttpTransport
from open_webui.models.users import Users
from open_webui.utils.chat import generate_chat_completion
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Callable
from urllib.parse import urlparse


async def call_data_viz_generator_tool_async(
        api_key: str, url: str, user_message: str
) -> dict:
    """
    Calls the DataVizGeneratorJson tool over Streamable HTTP transport.
    - api_key: your key for auth (sent via headers)
    - url: the MCP endpoint URL, e.g. https://yourserver.com/mcp
    - user_message: the argument passed to the tool
    Returns parsed JSON or error info.
    """
    transport = StreamableHttpTransport(url=url, headers={"X-API-KEY": api_key})
    async with Client(transport) as client:
        try:
            tool_result_chunks = await client.call_tool(
                "DataVizGeneratorJson", {"arg0": user_message}
            )
        except Exception as e:
            return {"error": "Tool call failed", "details": str(e)}

        # Collect streamed chunks
        text_parts = []
        for chunk in tool_result_chunks:
            # A chunk may have .text or .json or other content types
            if hasattr(chunk, "text") and chunk.text is not None:
                text_parts.append(chunk.text)
            elif hasattr(chunk, "json") and chunk.json is not None:
                # If the chunk gives JSON objects already
                text_parts.append(json.dumps(chunk.json))
            else:
                # Fallback: string conversion
                text_parts.append(str(chunk))

        text_result = "".join(text_parts)

        # Try to parse JSON
        try:
            json_result = json.loads(text_result)
            return json_result
        except json.JSONDecodeError as e:
            return {
                "error": "Unable to parse result as JSON",
                "partial_text": text_result,
                "details": str(e),
            }


def clean_tool_output(text: str) -> str:
    return ast.literal_eval(text)


def get_user_intent_prompt():
    return f"""
**Task:** Extract the final intent from a list of USER messages, keeping only those related to **data visualization**.

---

### ❌ Not Related to Dataviz

Ignore factual or general knowledge questions (e.g., weather, translations, definitions, simple math).

**Examples:**
- "Pourquoi le ciel est-il bleu ?"
- "Quelle est la capitale de l’Italie ?"
- "Combien font 2 + 2 ?"

### ✅ Related to Dataviz

Keep messages about trends, comparisons, distributions, or metrics that benefit from being visualized.

**Examples:**
- "Quelle est la répartition des âges de nos utilisateurs ?"
- "Compare les performances des campagnes marketing."
- "Quel est le taux de conversion par canal ?"

### **Instructions:**

1. Analyze the last USER message.
2. If **not suitable for dataviz**, respond **exactly** with:  
   `Not related to dataviz`
3. If **suitable**, extract the **final dataviz intent** in the user's language.

#### Final Intent Rules:
- Use only the last complete dataviz-related query and its refinements.
- Refinements include region, chart type, filters, etc.
- Ignore unrelated or off-topic messages entirely.
- A new full dataviz question **overrides** all prior ones.

### **Example:**

USER: Quelles sont les cout des magasins au canada?  
USER: en france  
USER: et les type de clients  
USER: en carte  
→ **Final Intent:** `Les coûts des magasins et types de clients en France, en carte`

USER: Quelles sont les 5 produits les plus rentables?  
USER: en barre  
→ **Final Intent:** `Les 5 produits les plus rentables en barre`

USER: Pourquoi les chats chassent des souris ?  
→ **Output:** `Not related to dataviz`

### Input:
"""


async def extract_user_intent(
        __user__: dict, __request__: Request, messages, modelId: str
) -> str:
    user_messages = [msg["content"] for msg in messages if msg.get("role") == "user"]
    if not user_messages:
        raise ValueError("No user messages found in the conversation history.")
    formatted_user_messages = "\n".join(f"USER: {msg}" for msg in user_messages)

    prompt = get_user_intent_prompt() + formatted_user_messages

    print(f"extract user intent prompt: {prompt}")

    return await call_llm(
        __user__,
        __request__,
        [{"role": "user", "content": prompt}],
        modelId,
    )


def prepare_messages(
        messages: List[Dict[str, str]], prompt: Optional[str] = None
) -> List[Dict[str, str]]:
    """Return a copy of messages with optional system prompt prepended."""
    full_messages = messages.copy()
    if prompt:
        full_messages.insert(0, {"role": "system", "content": prompt})
    return full_messages


def get_user(__user__: dict):
    """Retrieve user object from ID."""
    return Users.get_user_by_id(__user__["id"])


async def call_llm(
        __user__: dict,
        __request__: Request,
        messages: List[Dict[str, str]],
        modelId: str,
        prompt: Optional[str] = None,
) -> Optional[str]:
    if not modelId:
        raise ValueError("Model not specified in the body.")

    full_messages = prepare_messages(messages, prompt)
    user = get_user(__user__)

    body = {"model": modelId, "messages": full_messages}
    response = await generate_chat_completion(__request__, body, user)

    content = (
        response.get("choices", [{}])[0].get("message", {}).get("content", "").strip()
    )
    print(f"LLM response: {content}")
    return content


async def call_llm_streaming(
        __user__: dict,
        __request__: Request,
        messages: List[Dict[str, str]],
        modelId: str,
        __event_emitter__,
        prompt: Optional[str] = None,
) -> None:
    if not modelId:
        raise ValueError("Model not specified in the body.")

    full_messages = prepare_messages(messages, prompt)

    api_key = __request__.headers.get("authorization")
    if not api_key:
        raise ValueError("Missing Authorization header")

    url = str(__request__.url)
    headers = {"Authorization": api_key, "Content-Type": "application/json"}
    body = {"model": modelId, "messages": full_messages, "stream": True}

    async with aiohttp.ClientSession() as session:
        async with session.post(url, headers=headers, json=body) as resp:
            async for line in resp.content:
                line = line.decode("utf-8").strip()
                if not line or not line.startswith("data: "):
                    continue
                data = line[len("data: "):]
                if data == "[DONE]":
                    break
                try:
                    event_json = json.loads(data)
                    delta = event_json["choices"][0]["delta"]
                    if "content" in delta:
                        await append_message_in_chat(
                            __event_emitter__, delta["content"]
                        )
                except Exception as e:
                    print(f"\n[Error parsing stream chunk: {e}]")


async def append_message_in_chat(__event_emitter__, message):
    await __event_emitter__(
        {
            "type": "chat:message:delta",
            "data": {"content": message},
        }
    )


async def emit_status_to_user(__event_emitter__, message):
    await __event_emitter__(
        {
            "type": "chat:message",
            "data": {"content": message},
        }
    )


class Pipe:
    class Valves(BaseModel):
        DIGDASH_MCP_TOOL_URL: str = Field(
            default="http://dev01-dev.lan.digdash.com:8086/sse",
            description="DigDash MCP URL for accessing Digdash API endpoints.",
        )
        MODEL_ID: str = Field(
            default="Meta-Llama-3_3-70B-Instruct",
            description="URL of this server.",
        )

    class UserValves(BaseModel):
        DIGDASH_API_KEY: str = Field(
            default="test",
            description="API key for authenticating requests to the Digdash API.",
        )

    def __init__(self):
        self.valves = self.Valves()

    async def pipe(
            self,
            body: dict,
            __user__: dict,
            __request__: Request,
            __event_emitter__,
            __event_call__,
            __metadata__,
    ):
        DIGDASH_MCP_TOOL_URL = self.valves.DIGDASH_MCP_TOOL_URL
        digdashApiKey = __user__["valves"].DIGDASH_API_KEY

        if not DIGDASH_MCP_TOOL_URL:
            return "Error: The DigDash MCP Tool URL was not provided in the Valve"
        if not digdashApiKey:
            return "Error: The DigDash API key was not provided in the User Valve"

        modelId = self.valves.MODEL_ID
        if not modelId:
            return "Error: The DigDash function ModelId was not provided in the Valve"

        messages = body.get("messages", [])

        # Detect title, tags and follow-up generation as tasks.
        if __metadata__.get("task"):
            return await call_llm(__user__, __request__, messages, modelId)

        await emit_status_to_user(
            __event_emitter__, "Analyse de la requête en cours..."
        )

        user_intent = await extract_user_intent(
            __user__,
            __request__,
            messages,
            modelId,
        )

        if user_intent == "Not related to dataviz":
            await emit_status_to_user(__event_emitter__, "")
            return await call_llm_streaming(
                __user__,
                __request__,
                messages,
                modelId,
                __event_emitter__,
                "Respond to the user's latest message",
            )

        await emit_status_to_user(
            __event_emitter__,
            f"Traitement de la requête '{user_intent}' par Digdash...",
        )

        result = await call_data_viz_generator_tool_async(
            digdashApiKey, DIGDASH_MCP_TOOL_URL, user_intent
        )

        print(f"result: {result}")

        response = result.get("summary", "Error: can not extract error explanation")

        if not result.get("isError"):
            response += f"\n{result.get('html', '')}"

        if "alternatives" in result:
            response += f"\n{result.get('alternatives', '')}"

        return response
