#!/usr/bin/env python3
"""
Hyprland monitor manager — WYSIWYG curses TUI.

Keys (normal mode):
  Tab / Shift+Tab   cycle selected monitor
  h j k l           move monitor (50 px)
  H J K L           move monitor (10 px fine)
  u / i             rotate CCW / CW
  t / g             scale up / scale down (valid Hyprland steps)
  m                 toggle mirror (pick target) / un-mirror
  n / N             cycle display mode forward / backward
  s                 save to hypr/usr/monitors.lua
  Enter             save & quit
  q / Esc           quit (prompts if unsaved changes)

Mirror-pick mode:
  Tab / Shift+Tab   cycle target
  Enter             confirm
  Esc               cancel
"""

import curses
import json
import math
import os
import re
import subprocess
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

MONITORS_LUA = Path.home() / "Dotfiles/desktopenvs/hyprlua/hypr/usr/monitors.lua"
MOVE_STEP       = 50
MOVE_STEP_FINE  = 10
MIN_SCALE       = 0.25
MAX_SCALE       = 4.0
_SCALE_MAX_DENOM = 6   # max denominator when enumerating valid Hyprland scales
MIN_BOX_W       = 14
MIN_BOX_H       = 4
INFO_W          = 32
STATUS_ROWS     = 2   # status + help rows at the bottom

TRANSFORM_LABEL = {
    0: "↕ 0°",
    1: "↻ 90°",
    2: "↕ 180°",
    3: "↺ 90°",
    4: "⇔ 0°",
    5: "⇔↻ 90°",
    6: "⇔ 180°",
    7: "⇔↺ 90°",
}

_MODE_RE = re.compile(r"(\d+)x(\d+)@([\d.]+)Hz")

# ---------------------------------------------------------------------------
# Rotation helpers
# ---------------------------------------------------------------------------

def rotate_cw(t: int) -> int:
    return (t & 4) | ((t + 1) & 3)

def rotate_ccw(t: int) -> int:
    return (t & 4) | ((t - 1) & 3)

# ---------------------------------------------------------------------------
# Scale helpers
# ---------------------------------------------------------------------------

def valid_scales(width: int, height: int) -> List[float]:
    """Return sorted list of scales valid for (width, height).

    A scale s = p/q (in lowest terms) is valid iff both width/s and height/s
    are integers, i.e. p divides gcd(width, height).  We limit q ≤
    _SCALE_MAX_DENOM to keep the step count practical (~20 steps per monitor).
    """
    g = math.gcd(width, height)
    divisors = [k for k in range(1, g + 1) if g % k == 0]
    result: set = set()
    for p in divisors:
        for q in range(1, _SCALE_MAX_DENOM + 1):
            if math.gcd(p, q) != 1:
                continue
            s = p / q
            if MIN_SCALE <= s <= MAX_SCALE:
                result.add(round(s, 10))
    return sorted(result)

# ---------------------------------------------------------------------------
# Data model
# ---------------------------------------------------------------------------

@dataclass
class MonitorState:
    name: str
    x: int
    y: int
    width: int
    height: int
    refresh_rate: float
    transform: int
    scale: float
    mirror_of: str
    available_modes: List[str]
    mode_index: int
    dirty: bool = False

    @property
    def logical_width(self) -> int:
        if (self.transform & 3) in (1, 3):
            return self.height
        return self.width

    @property
    def logical_height(self) -> int:
        if (self.transform & 3) in (1, 3):
            return self.width
        return self.height

    @property
    def mode_str(self) -> str:
        return f"{self.width}x{self.height}@{int(round(self.refresh_rate))}"

    @classmethod
    def from_json(cls, d: dict) -> "MonitorState":
        modes = d.get("availableModes", [])
        w = d.get("width", 1920)
        h = d.get("height", 1080)
        rr = d.get("refreshRate", 60.0)
        # Find current mode index
        mode_index = 0
        for i, m in enumerate(modes):
            mo = _MODE_RE.match(m)
            if mo and int(mo.group(1)) == w and int(mo.group(2)) == h:
                if abs(float(mo.group(3)) - rr) < 1.0:
                    mode_index = i
                    break
        return cls(
            name=d.get("name", ""),
            x=d.get("x", 0),
            y=d.get("y", 0),
            width=w,
            height=h,
            refresh_rate=rr,
            transform=d.get("transform", 0),
            scale=d.get("scale", 1.0),
            mirror_of="" if d.get("mirrorOf", "none") in ("none", "") else d["mirrorOf"],
            available_modes=modes,
            mode_index=mode_index,
        )

# ---------------------------------------------------------------------------
# hyprctl helpers
# ---------------------------------------------------------------------------

def fetch_monitors() -> List[MonitorState]:
    r = subprocess.run(
        ["hyprctl", "monitors", "-j"],
        capture_output=True, text=True, check=True,
    )
    return [MonitorState.from_json(d) for d in json.loads(r.stdout)]


def apply_monitor(m: MonitorState) -> Optional[str]:
    if m.mirror_of:
        lua = f"hl.monitor({{output='{m.name}', mirror='{m.mirror_of}'}})"
    else:
        lua = (
            f"hl.monitor({{"
            f"output='{m.name}', "
            f"mode='{m.mode_str}', "
            f"position='{m.x}x{m.y}', "
            f"scale={m.scale}, "
            f"transform={m.transform}"
            f"}})"
        )
    r = subprocess.run(["hyprctl", "eval", lua], capture_output=True, text=True, check=False)
    if r.returncode != 0:
        return (r.stderr or r.stdout).strip()
    return None

# ---------------------------------------------------------------------------
# Save
# ---------------------------------------------------------------------------

def save_monitors_lua(monitors: List[MonitorState], path: Path) -> None:
    lines = ["-- generated by monitor-manager -- do not edit by hand\n"]
    for m in monitors:
        if m.mirror_of:
            lines.append(
                f'hl.monitor({{\n'
                f'    output = "{m.name}",\n'
                f'    mirror = "{m.mirror_of}",\n'
                f'}})\n\n'
            )
        else:
            lines.append(
                f'hl.monitor({{\n'
                f'    output    = "{m.name}",\n'
                f'    mode      = "{m.mode_str}",\n'
                f'    position  = "{m.x}x{m.y}",\n'
                f'    scale     = {m.scale},\n'
                f'    transform = {m.transform},\n'
                f'}})\n\n'
            )
    lines.append(
        'hl.config({\n'
        '    xwayland = {\n'
        '        force_zero_scaling = true,\n'
        '    },\n'
        '})\n'
    )
    tmp = path.with_suffix(".lua.tmp")
    tmp.write_text("".join(lines))
    os.replace(tmp, path)

# ---------------------------------------------------------------------------
# Canvas math
# ---------------------------------------------------------------------------

def compute_scale(monitors: List[MonitorState], pane_cols: int, pane_rows: int) -> float:
    if not monitors:
        return 1.0
    max_x = max(m.x + m.logical_width  for m in monitors)
    max_y = max(m.y + m.logical_height for m in monitors)
    if max_x <= 0 or max_y <= 0:
        return 1.0
    sx = pane_cols / (max_x * 1.15)
    sy = pane_rows / (max_y * 1.15)
    return min(sx, sy)


def to_screen(cx: int, cy: int, scale: float, margin_col: int = 1, margin_row: int = 1):
    col = margin_col + int(cx * scale)
    row = margin_row + int(cy * scale * 0.5)
    return row, col

# ---------------------------------------------------------------------------
# Safe addstr / addch wrappers
# ---------------------------------------------------------------------------

def safe_addstr(win, row, col, text, attr=0):
    try:
        max_row, max_col = win.getmaxyx()
        if row < 0 or row >= max_row or col < 0 or col >= max_col:
            return
        avail = max_col - col - 1
        if avail <= 0:
            return
        win.addstr(row, col, text[:avail], attr)
    except curses.error:
        pass


def safe_addch(win, row, col, ch, attr=0):
    try:
        max_row, max_col = win.getmaxyx()
        if row < 0 or row >= max_row or col < 0 or col >= max_col:
            return
        win.addch(row, col, ch, attr)
    except curses.error:
        pass

# ---------------------------------------------------------------------------
# Box drawing
# ---------------------------------------------------------------------------

def draw_box(win, row, col, h, w, attr=0):
    if h < 2 or w < 2:
        return
    safe_addch(win, row,       col,       "┌", attr)
    safe_addch(win, row,       col + w-1, "┐", attr)
    safe_addch(win, row + h-1, col,       "└", attr)
    safe_addch(win, row + h-1, col + w-1, "┘", attr)
    for c in range(col + 1, col + w - 1):
        safe_addch(win, row,       c, "─", attr)
        safe_addch(win, row + h-1, c, "─", attr)
    for r in range(row + 1, row + h - 1):
        safe_addch(win, r, col,       "│", attr)
        safe_addch(win, r, col + w-1, "│", attr)

# ---------------------------------------------------------------------------
# App
# ---------------------------------------------------------------------------

class App:
    def __init__(self, stdscr):
        self.stdscr = stdscr
        self.monitors: List[MonitorState] = []
        self.selected_idx: int = 0
        self.mode: str = "normal"         # "normal" | "mirror_pick"
        self.mirror_source_idx: int = 0
        self.mirror_target_idx: int = 1
        self.dirty: bool = False
        self.status_msg: str = ""
        self._scale: float = 0.0          # cached canvas scale
        self._scale_pane: tuple = (0, 0)  # pane size used for cached scale
        self._load_monitors()
        self._init_colors()

    def _load_monitors(self):
        self.monitors = fetch_monitors()
        if self.selected_idx >= len(self.monitors):
            self.selected_idx = 0

    def _init_colors(self):
        curses.start_color()
        curses.use_default_colors()
        # 1 = selected (cyan bold)
        curses.init_pair(1, curses.COLOR_CYAN,   -1)
        # 2 = normal (white)
        curses.init_pair(2, curses.COLOR_WHITE,  -1)
        # 3 = mirror target (yellow)
        curses.init_pair(3, curses.COLOR_YELLOW, -1)
        # 4 = mirrored / dim
        curses.init_pair(4, curses.COLOR_BLACK + 8 if curses.COLORS >= 16 else curses.COLOR_WHITE, -1)
        # 5 = status bar (reversed)
        curses.init_pair(5, -1, -1)
        # 6 = help (green)
        curses.init_pair(6, curses.COLOR_GREEN,  -1)

    def _get_scale(self, pane_cols: int, pane_rows: int) -> float:
        """Return cached scale; recompute only on resize or when a monitor escapes the viewport."""
        pane = (pane_cols, pane_rows)
        if self._scale > 0 and self._scale_pane == pane:
            # Check every monitor still fits inside the current viewport
            inner_cols = pane_cols - 2
            inner_rows = pane_rows - 2
            all_fit = True
            for m in self.monitors:
                brow, bcol = to_screen(m.x + m.logical_width, m.y + m.logical_height, self._scale)
                if bcol > inner_cols or brow > inner_rows:
                    all_fit = False
                    break
            if all_fit:
                return self._scale
        self._scale = compute_scale(self.monitors, pane_cols - 2, pane_rows - 2)
        self._scale_pane = pane
        return self._scale

    # -----------------------------------------------------------------------
    # Event loop
    # -----------------------------------------------------------------------

    def run(self):
        curses.curs_set(0)
        self.stdscr.timeout(100)
        self.draw()
        while True:
            ch = self.stdscr.getch()
            if ch == curses.KEY_RESIZE:
                self._scale_pane = (0, 0)  # force scale recompute on resize
                self.draw()
                continue
            if ch == -1:
                continue
            if self.mode == "normal":
                result = self.handle_key_normal(ch)
                if result == "quit":
                    break
            elif self.mode == "mirror_pick":
                self.handle_key_mirror(ch)
            self.draw()

    # -----------------------------------------------------------------------
    # Key handlers
    # -----------------------------------------------------------------------

    def handle_key_normal(self, ch) -> Optional[str]:
        mon = self.monitors[self.selected_idx] if self.monitors else None

        # Tab / Shift+Tab — cycle monitor
        if ch == ord("\t"):
            self.selected_idx = (self.selected_idx + 1) % max(1, len(self.monitors))
            return
        if ch == curses.KEY_BTAB:
            self.selected_idx = (self.selected_idx - 1) % max(1, len(self.monitors))
            return

        if mon is None:
            return

        # Movement — coarse
        if ch == ord("h"):
            self.move_monitor(-MOVE_STEP, 0)
        elif ch == ord("l"):
            self.move_monitor(MOVE_STEP, 0)
        elif ch == ord("k"):
            self.move_monitor(0, -MOVE_STEP)
        elif ch == ord("j"):
            self.move_monitor(0, MOVE_STEP)
        # Movement — fine
        elif ch == ord("H"):
            self.move_monitor(-MOVE_STEP_FINE, 0)
        elif ch == ord("L"):
            self.move_monitor(MOVE_STEP_FINE, 0)
        elif ch == ord("K"):
            self.move_monitor(0, -MOVE_STEP_FINE)
        elif ch == ord("J"):
            self.move_monitor(0, MOVE_STEP_FINE)
        # Rotation
        elif ch == ord("u"):
            self.rotate_monitor(-1)
        elif ch == ord("i"):
            self.rotate_monitor(+1)
        # Scale
        elif ch == ord("t"):
            self.scale_monitor(+1)
        elif ch == ord("g"):
            self.scale_monitor(-1)
        # Mirror
        elif ch == ord("m"):
            if mon.mirror_of:
                mon.mirror_of = ""
                err = apply_monitor(mon)
                mon.dirty = True
                self.dirty = True
                self.status_msg = err or f"Un-mirrored {mon.name}"
            elif len(self.monitors) < 2:
                self.status_msg = "Need 2+ monitors to mirror"
            else:
                self.mirror_source_idx = self.selected_idx
                self.mirror_target_idx = (self.selected_idx + 1) % len(self.monitors)
                self.mode = "mirror_pick"
        # Mode cycling
        elif ch == ord("n"):
            self.cycle_mode(+1)
        elif ch == ord("N"):
            self.cycle_mode(-1)
        # Save
        elif ch == ord("s"):
            self._save()
        # Save & quit
        elif ch in (curses.KEY_ENTER, ord("\n"), ord("\r")):
            self._save()
            return "quit"
        # Quit
        elif ch in (ord("q"), 27):  # q or Esc
            if self.dirty:
                action = self.prompt_save_quit()
                if action == "cancel":
                    return None
                if action == "save":
                    self._save()
            return "quit"

    def handle_key_mirror(self, ch):
        n = len(self.monitors)

        def next_target(delta: int):
            t = (self.mirror_target_idx + delta) % n
            # skip source
            if t == self.mirror_source_idx:
                t = (t + delta) % n
            self.mirror_target_idx = t

        if ch == ord("\t"):
            next_target(+1)
        elif ch == curses.KEY_BTAB:
            next_target(-1)
        elif ch in (curses.KEY_ENTER, ord("\n"), ord("\r")):
            self.set_mirror(self.mirror_source_idx, self.mirror_target_idx)
            self.mode = "normal"
        elif ch == 27:  # Esc
            self.mode = "normal"
            self.status_msg = "Mirror cancelled"

    # -----------------------------------------------------------------------
    # Actions
    # -----------------------------------------------------------------------

    def move_monitor(self, dx: int, dy: int):
        mon = self.monitors[self.selected_idx]
        mon.x = max(0, mon.x + dx)
        mon.y = max(0, mon.y + dy)
        err = apply_monitor(mon)
        mon.dirty = True
        self.dirty = True
        self.status_msg = err or f"Moved {mon.name} to {mon.x},{mon.y}"

    def rotate_monitor(self, direction: int):
        mon = self.monitors[self.selected_idx]
        if direction > 0:
            mon.transform = rotate_cw(mon.transform)
        else:
            mon.transform = rotate_ccw(mon.transform)
        err = apply_monitor(mon)
        mon.dirty = True
        self.dirty = True
        self.status_msg = err or f"Rotated {mon.name} → {TRANSFORM_LABEL[mon.transform]}"

    def scale_monitor(self, direction: int):
        mon = self.monitors[self.selected_idx]
        scales = valid_scales(mon.width, mon.height)
        cur = round(mon.scale, 10)
        if direction > 0:
            candidates = [s for s in scales if s > cur + 1e-9]
            if not candidates:
                self.status_msg = f"Scale already at max ({mon.scale}x)"
                return
            new_scale = candidates[0]
        else:
            candidates = [s for s in scales if s < cur - 1e-9]
            if not candidates:
                self.status_msg = f"Scale already at min ({mon.scale}x)"
                return
            new_scale = candidates[-1]
        mon.scale = new_scale
        err = apply_monitor(mon)
        mon.dirty = True
        self.dirty = True
        self.status_msg = err or f"{mon.name} scale → {new_scale}x"

    def cycle_mode(self, delta: int):
        mon = self.monitors[self.selected_idx]
        if not mon.available_modes:
            self.status_msg = "No mode list available"
            return
        mon.mode_index = (mon.mode_index + delta) % len(mon.available_modes)
        mo = _MODE_RE.match(mon.available_modes[mon.mode_index])
        if mo:
            mon.width  = int(mo.group(1))
            mon.height = int(mo.group(2))
            mon.refresh_rate = float(mo.group(3))
        err = apply_monitor(mon)
        mon.dirty = True
        self.dirty = True
        self.status_msg = err or f"{mon.name} → {mon.mode_str}"

    def set_mirror(self, src_idx: int, tgt_idx: int):
        src = self.monitors[src_idx]
        tgt = self.monitors[tgt_idx]
        src.mirror_of = tgt.name
        err = apply_monitor(src)
        src.dirty = True
        self.dirty = True
        self.status_msg = err or f"Mirroring {src.name} → {tgt.name}"

    def _save(self):
        try:
            save_monitors_lua(self.monitors, MONITORS_LUA)
            for m in self.monitors:
                m.dirty = False
            self.dirty = False
            self.status_msg = f"Saved to {MONITORS_LUA.name}"
            subprocess.run(["hyprctl", "reload"], capture_output=True, check=False)
        except Exception as e:
            self.status_msg = f"Save failed: {e}"

    # -----------------------------------------------------------------------
    # Prompt
    # -----------------------------------------------------------------------

    def prompt_save_quit(self) -> str:
        rows, cols = self.stdscr.getmaxyx()
        prompt = "Unsaved changes.  [s]ave & quit   [n]o save   [c]ancel"
        safe_addstr(self.stdscr, rows - 1, 0, " " * (cols - 1), curses.A_REVERSE)
        safe_addstr(self.stdscr, rows - 1, 0, prompt[:cols - 1], curses.A_REVERSE)
        self.stdscr.timeout(-1)
        while True:
            ch = self.stdscr.getch()
            if ch in (ord("s"), ord("S")):
                self.stdscr.timeout(100)
                return "save"
            if ch in (ord("n"), ord("N")):
                self.stdscr.timeout(100)
                return "nosave"
            if ch in (ord("c"), ord("C"), 27):
                self.stdscr.timeout(100)
                return "cancel"

    # -----------------------------------------------------------------------
    # Drawing
    # -----------------------------------------------------------------------

    def draw(self):
        self.stdscr.erase()
        rows, cols = self.stdscr.getmaxyx()

        use_info = cols >= 80
        canvas_w = (cols - INFO_W - 1) if use_info else cols
        canvas_h = max(4, rows - STATUS_ROWS)

        self.draw_canvas(canvas_w, canvas_h)
        if use_info:
            self.draw_info(canvas_w, canvas_h)
            # vertical separator
            for r in range(canvas_h):
                safe_addch(self.stdscr, r, canvas_w, "│", curses.color_pair(2))

        self.draw_status(rows - 2, cols)
        self.draw_help(rows - 1, cols)
        self.stdscr.noutrefresh()
        curses.doupdate()

    def draw_canvas(self, pane_cols: int, pane_rows: int):
        scale = self._get_scale(pane_cols, pane_rows)

        for idx, mon in enumerate(self.monitors):
            brow, bcol = to_screen(mon.x, mon.y, scale)
            bw = max(MIN_BOX_W, int(mon.logical_width  * scale))
            bh = max(MIN_BOX_H, int(mon.logical_height * scale * 0.5))

            # Clamp to pane
            if brow >= pane_rows or bcol >= pane_cols:
                continue

            # Choose color
            if self.mode == "mirror_pick" and idx == self.mirror_target_idx:
                attr = curses.color_pair(3) | curses.A_BOLD
            elif idx == self.selected_idx:
                attr = curses.color_pair(1) | curses.A_BOLD
            elif mon.mirror_of:
                attr = curses.color_pair(4)
            else:
                attr = curses.color_pair(2)

            draw_box(self.stdscr, brow, bcol, bh, bw, attr)

            # Labels inside box
            inner_w = bw - 2
            if inner_w < 1:
                continue

            def label_line(row_offset: int, text: str):
                r = brow + row_offset
                c = bcol + 1
                if r >= pane_rows or r <= brow or r >= brow + bh - 1:
                    return
                safe_addstr(self.stdscr, r, c, text[:inner_w], attr)

            # Line 1: name (+ dirty marker)
            name_label = mon.name + (" *" if mon.dirty else "")
            if mon.mirror_of:
                name_label = f"{mon.name} →{mon.mirror_of}"
            label_line(1, name_label)

            # Line 2: mode
            if bh > 3:
                label_line(2, mon.mode_str)

            # Line 3: rotation
            if bh > 4:
                label_line(3, TRANSFORM_LABEL.get(mon.transform, ""))

    def draw_info(self, start_col: int, pane_rows: int):
        col = start_col + 1
        row = 0
        safe_addstr(self.stdscr, row, col, "Monitors:", curses.color_pair(6) | curses.A_BOLD)
        row += 1
        for idx, mon in enumerate(self.monitors):
            if row >= pane_rows:
                break
            prefix = "> " if idx == self.selected_idx else "  "
            attr = curses.color_pair(1) | curses.A_BOLD if idx == self.selected_idx else curses.color_pair(2)
            safe_addstr(self.stdscr, row, col, f"{prefix}{mon.name}", attr)
            row += 1
            if row < pane_rows:
                safe_addstr(self.stdscr, row, col + 2, mon.mode_str, curses.color_pair(2))
                row += 1
            if row < pane_rows:
                safe_addstr(self.stdscr, row, col + 2, f"pos {mon.x},{mon.y}", curses.color_pair(2))
                row += 1
            if row < pane_rows:
                safe_addstr(self.stdscr, row, col + 2, f"scale {mon.scale}", curses.color_pair(2))
                row += 1
            if row < pane_rows:
                rot_str = TRANSFORM_LABEL.get(mon.transform, "")
                mirror_str = f"  →{mon.mirror_of}" if mon.mirror_of else ""
                safe_addstr(self.stdscr, row, col + 2, rot_str + mirror_str, curses.color_pair(2))
                row += 1
            row += 1  # blank between monitors

    def draw_status(self, row: int, cols: int):
        if self.mode == "mirror_pick":
            src = self.monitors[self.mirror_source_idx].name if self.monitors else "?"
            tgt = self.monitors[self.mirror_target_idx].name if self.monitors else "?"
            mode_tag = f"[MIRROR] {src} → {tgt}"
        else:
            mode_tag = "[NORMAL]"
            if self.dirty:
                mode_tag += " *"

        msg = f"{mode_tag}  {self.status_msg}"
        safe_addstr(self.stdscr, row, 0, " " * (cols - 1), curses.A_REVERSE)
        safe_addstr(self.stdscr, row, 0, msg[:cols - 1], curses.A_REVERSE)

    def draw_help(self, row: int, cols: int):
        if self.mode == "mirror_pick":
            text = "Tab:cycle-target  Enter:confirm  Esc:cancel"
        else:
            text = "Tab:next  hjkl:move  HJKL:fine  u/i:rot  t/g:scale  m:mirror  n/N:mode  s:save  Enter:save+quit  q:quit"
        safe_addstr(self.stdscr, row, 0, text[:cols - 1], curses.color_pair(6))

# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

def main():
    try:
        curses.wrapper(lambda stdscr: App(stdscr).run())
    except subprocess.CalledProcessError as e:
        print(f"Error: hyprctl failed — is Hyprland running?\n{e}", file=sys.stderr)
        sys.exit(1)
    except KeyboardInterrupt:
        pass


if __name__ == "__main__":
    main()
