#!/usr/bin/env python3
"""
dns-watch: Live DNS dashboard for Technitium DNS Server
https://github.com/runnyroosts/dns-watch
"""

import subprocess
import queue
import threading
import time
import os
import json
import urllib.request
import urllib.parse
from datetime import datetime
import tkinter as tk
from tkinter import ttk

CONFIG_DIR  = os.path.expanduser("~/.config/dns-watch")
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
MAX_ROWS    = 500
REFRESH_MS  = 500
POLL_SEC    = 0.5

DB_PRESETS = [
    ("Raspberry Pi / Debian default",  "/etc/dns/apps/Query Logs (Sqlite)/querylogs.db"),
    ("Ubuntu / snap install",          "/snap/technitium-dns/current/etc/dns/apps/Query Logs (Sqlite)/querylogs.db"),
    ("Docker volume default",          "/opt/technitium/dns/apps/Query Logs (Sqlite)/querylogs.db"),
    ("Custom path",                    ""),
]

RESPONSE_LABEL = {2: 'UPSTREAM', 3: 'CACHED', 4: 'BLOCKED'}
RCODE_LABEL    = {0: 'NoError', 1: 'FormErr', 2: 'ServFail',
                  3: 'NxDomain', 4: 'NotImp', 5: 'Refused'}
QTYPE_LABEL    = {1: 'A', 2: 'NS', 5: 'CNAME', 6: 'SOA', 15: 'MX',
                  16: 'TXT', 28: 'AAAA', 33: 'SRV', 65: 'HTTPS', 255: 'ANY'}
PROTO_LABEL    = {0: 'UDP', 1: 'UDP', 2: 'TCP', 3: 'TLS', 4: 'HTTPS', 5: 'QUIC'}

C_BG      = "#1e1e2e"
C_SURFACE = "#181825"
C_HEADER  = "#313244"
C_FG      = "#cdd6f4"
C_DIM     = "#6c7086"
C_GREEN   = "#a6e3a1"
C_BLUE    = "#89b4fa"
C_RED     = "#f38ba8"
C_YELLOW  = "#f9e2af"
C_ORANGE  = "#fab387"


# ── Config ────────────────────────────────────────────────────────────────────

def load_config():
    if os.path.exists(CONFIG_FILE):
        try:
            with open(CONFIG_FILE) as f:
                return json.load(f)
        except Exception:
            pass
    return None

def save_config(cfg):
    # Lock the dir, then create the file 0600 from the start so the stored
    # Technitium token is never even briefly readable by other accounts.
    os.makedirs(CONFIG_DIR, exist_ok=True)
    os.chmod(CONFIG_DIR, 0o700)
    fd = os.open(CONFIG_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
    with os.fdopen(fd, 'w') as f:
        json.dump(cfg, f, indent=2)
    os.chmod(CONFIG_FILE, 0o600)


# ── DB query ──────────────────────────────────────────────────────────────────

def run_ssh(cfg, cmd, timeout=10):
    result = subprocess.run(
        ['ssh', '-o', 'BatchMode=yes', '-o', 'StrictHostKeyChecking=accept-new',
         '-o', 'ConnectTimeout=5',
         f"{cfg['ssh_user']}@{cfg['host']}", cmd],
        capture_output=True, text=True, timeout=timeout
    )
    return result.stdout.strip()

def query_db(cfg, sql):
    db = cfg['db_path']
    if cfg.get('mode') == 'local':
        result = subprocess.run(
            ['sudo', 'sqlite3', db, sql],
            capture_output=True, text=True, timeout=5
        )
        return result.stdout.strip()
    return run_ssh(cfg, f"sudo sqlite3 '{db}' '{sql}'")

def parse_rows(output):
    rows = []
    for line in output.splitlines():
        p = line.split('|')
        if len(p) < 7:
            continue
        ts, client, proto, rtype, rcode, qname, qtype = p[:7]
        try:
            rt, rc, qt, pr = int(rtype), int(rcode), int(qtype), int(proto)
        except ValueError:
            continue
        rows.append(dict(
            timestamp=ts,
            time=ts.split('.')[0].split(' ')[-1],
            client=client,
            proto=PROTO_LABEL.get(pr, str(pr)),
            status=RESPONSE_LABEL.get(rt, f'TYPE{rt}'),
            rcode=RCODE_LABEL.get(rc, f'RC{rc}'),
            domain=qname,
            qtype=QTYPE_LABEL.get(qt, f'T{qt}'),
        ))
    return rows


# ── Shared setup form (used by both SetupWindow and SettingsDialog) ───────────

class SetupForm:
    """Mixin that builds the connection/setup form inside any tk container.

    Subclasses must provide `self.result` and the `_finish` / `_cancel`
    callbacks (one ends an event loop, the other tears down a dialog).
    """

    def _form_label(self, frame, text, row):
        tk.Label(frame, text=text, fg=C_FG, bg=C_BG,
                 font=('monospace', 10), anchor='w').grid(
            row=row, column=0, sticky='w', padx=(16, 8), pady=4)

    def _form_entry(self, frame, row, default='', show=None):
        e = tk.Entry(frame, bg=C_SURFACE, fg=C_FG, insertbackground=C_FG,
                     font=('monospace', 10), width=36,
                     relief='flat', highlightthickness=1,
                     highlightcolor=C_BLUE, highlightbackground=C_DIM,
                     show=show or '')
        e.insert(0, default)
        e.grid(row=row, column=1, padx=(0, 16), pady=4)
        return e

    def _build_form(self, container, cfg):
        tk.Label(container, text='dns-watch Setup', fg=C_BLUE, bg=C_BG,
                 font=('monospace', 14, 'bold')).pack(pady=(16, 4))
        tk.Label(container, text='Configure your Technitium DNS server connection.',
                 fg=C_DIM, bg=C_BG, font=('monospace', 9)).pack(pady=(0, 12))

        # local vs remote toggle
        mode_frame = tk.Frame(container, bg=C_BG)
        mode_frame.pack(pady=(0, 8))
        self._mode = tk.StringVar(value=cfg.get('mode', 'remote'))
        for val, label in [('remote', 'Remote (SSH)'), ('local', 'Local (same machine)')]:
            tk.Radiobutton(
                mode_frame, text=label, variable=self._mode, value=val,
                bg=C_BG, fg=C_FG, selectcolor=C_SURFACE,
                activebackground=C_BG, activeforeground=C_FG,
                font=('monospace', 10), command=self._on_mode_change,
            ).pack(side=tk.LEFT, padx=12)

        frame = tk.Frame(container, bg=C_BG)
        frame.pack(fill=tk.X)

        self._form_label(frame, 'Host IP / hostname', 0)
        self._f_host = self._form_entry(frame, 0, cfg.get('host', ''))

        self._form_label(frame, 'SSH username', 1)
        self._f_ssh_user = self._form_entry(frame, 1, cfg.get('ssh_user', ''))

        self._form_label(frame, 'SSH password', 2)
        self._f_ssh_pass = self._form_entry(frame, 2, show='*')
        tk.Label(frame, text='(used once to set up key auth, not stored)',
                 fg=C_DIM, bg=C_BG, font=('monospace', 8)).grid(
            row=3, column=1, sticky='w', padx=(0, 16))

        self._form_label(frame, 'Technitium admin user', 4)
        self._f_tn_user = self._form_entry(frame, 4, cfg.get('tn_user', 'admin'))

        self._form_label(frame, 'Technitium admin password', 5)
        self._f_tn_pass = self._form_entry(frame, 5, show='*')

        # DB path selector: start the preset on whatever matches the saved path
        self._form_label(frame, 'Database location', 6)
        saved_db = cfg.get('db_path', DB_PRESETS[0][1])
        preset_label = next((name for name, path in DB_PRESETS
                             if path and path == saved_db), DB_PRESETS[0][0])
        self._preset_var = tk.StringVar(value=preset_label)
        preset_menu = ttk.Combobox(frame, textvariable=self._preset_var,
                                   values=[p[0] for p in DB_PRESETS],
                                   state='readonly', width=35,
                                   font=('monospace', 10))
        preset_menu.grid(row=6, column=1, padx=(0, 16), pady=4)
        preset_menu.bind('<<ComboboxSelected>>', self._on_preset)

        self._form_label(frame, 'Custom db path', 7)
        self._f_db_path = self._form_entry(frame, 7, saved_db)

        self._form_label(frame, 'Max rows (scroll history)', 8)
        self._f_max_rows = self._form_entry(frame, 8, str(cfg.get('max_rows', MAX_ROWS)))

        self._ssh_fields = [self._f_host, self._f_ssh_user, self._f_ssh_pass]

        self._status_lbl = tk.Label(container, text='', fg=C_YELLOW, bg=C_BG,
                                     font=('monospace', 9), wraplength=420)
        self._status_lbl.pack(pady=(8, 4))

        btn_frame = tk.Frame(container, bg=C_BG)
        btn_frame.pack(pady=(4, 16))
        tk.Button(btn_frame, text='Connect', bg=C_BLUE, fg=C_BG,
                  font=('monospace', 10, 'bold'), relief='flat',
                  padx=20, pady=6, command=self._run_setup).pack(side=tk.LEFT, padx=8)
        tk.Button(btn_frame, text='Cancel', bg=C_SURFACE, fg=C_FG,
                  font=('monospace', 10), relief='flat',
                  padx=20, pady=6, command=self._cancel).pack(side=tk.LEFT, padx=8)

        self._on_mode_change()

    def _on_mode_change(self, *_):
        remote = self._mode.get() == 'remote'
        for w in self._ssh_fields:
            w.config(state=tk.NORMAL if remote else tk.DISABLED)

    def _on_preset(self, _event):
        for name, path in DB_PRESETS:
            if name == self._preset_var.get():
                self._f_db_path.delete(0, tk.END)
                self._f_db_path.insert(0, path)
                break

    def _set_status(self, msg, color=C_YELLOW):
        self._status_lbl.config(text=msg, fg=color)
        self._status_lbl.update()

    def _run_setup(self):
        mode     = self._mode.get()
        tn_user  = self._f_tn_user.get().strip()
        tn_pass  = self._f_tn_pass.get()
        db_path  = self._f_db_path.get().strip()
        try:
            max_rows = int(self._f_max_rows.get())
        except ValueError:
            max_rows = MAX_ROWS
        max_rows = max(1, max_rows)

        if not all([tn_user, tn_pass, db_path]):
            self._set_status('Technitium user/password and db path are required.', C_RED)
            return

        # ── Local mode ────────────────────────────────────────────────────────
        if mode == 'local':
            self._set_status('Fetching Technitium API token locally…')
            try:
                # urllib keeps the password in-process, not in any curl argv
                qs = urllib.parse.urlencode(
                    {'user': tn_user, 'pass': tn_pass, 'includeInfo': 'true'})
                with urllib.request.urlopen(
                        f'http://localhost:5380/api/user/login?{qs}', timeout=10) as resp:
                    token = json.loads(resp.read().decode()).get('token')
                if not token:
                    self._set_status('Could not get Technitium token. Check credentials.', C_RED)
                    return
            except Exception as e:
                self._set_status(f'Token fetch failed: {e}', C_RED)
                return

            self.result = dict(mode='local', tn_user=tn_user, token=token,
                               db_path=db_path, max_rows=max_rows)
            save_config(self.result)
            self._set_status('Setup complete!', C_GREEN)
            self._status_lbl.after(800, self._finish)
            return

        # ── Remote (SSH) mode ─────────────────────────────────────────────────
        host     = self._f_host.get().strip()
        ssh_user = self._f_ssh_user.get().strip()
        ssh_pass = self._f_ssh_pass.get()

        if not all([host, ssh_user, ssh_pass]):
            self._set_status('Host, SSH username and password are required for remote mode.', C_RED)
            return

        try:
            import paramiko
        except ImportError:
            self._set_status('paramiko not installed. Run: sudo apt install python3-paramiko', C_RED)
            return

        self._set_status('Generating SSH key if needed…')
        key_path = os.path.expanduser('~/.ssh/id_ed25519')
        if not os.path.exists(key_path):
            subprocess.run(['ssh-keygen', '-t', 'ed25519', '-N', '', '-f', key_path],
                           capture_output=True)
        pub_key = open(f'{key_path}.pub').read().strip()

        self._set_status(f'Connecting to {ssh_user}@{host}…')
        try:
            client = paramiko.SSHClient()
            client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            client.connect(host, username=ssh_user, password=ssh_pass, timeout=10)
        except Exception as e:
            self._set_status(f'SSH connection failed: {e}', C_RED)
            return

        try:
            self._set_status('Deploying SSH key…')
            _, stdout, _ = client.exec_command(
                f'mkdir -p ~/.ssh && chmod 700 ~/.ssh && '
                f'grep -qF "{pub_key}" ~/.ssh/authorized_keys 2>/dev/null || '
                f'echo "{pub_key}" >> ~/.ssh/authorized_keys && chmod 600 ~/.ssh/authorized_keys'
            )
            stdout.channel.recv_exit_status()

            self._set_status('Adding sqlite3 sudo permission…')
            # feed the SSH password to sudo via stdin, not the command argv
            stdin, stdout, _ = client.exec_command(
                'sudo -S bash -c '
                f'"echo \'{ssh_user} ALL=(ALL) NOPASSWD: /usr/bin/sqlite3\' '
                '> /etc/sudoers.d/dns-watch-sqlite3"'
            )
            stdin.write(ssh_pass + '\n')
            stdin.channel.shutdown_write()
            stdout.channel.recv_exit_status()

            self._set_status('Fetching Technitium API token…')
            # pass the URL (with creds) to curl via stdin config, not argv
            qs = urllib.parse.urlencode(
                {'user': tn_user, 'pass': tn_pass, 'includeInfo': 'true'})
            stdin, stdout, _ = client.exec_command('curl -s -K -')
            stdin.write(f'url = "http://localhost:5380/api/user/login?{qs}"\n')
            stdin.channel.shutdown_write()
            token = json.loads(stdout.read().decode()).get('token')
            if not token:
                self._set_status('Could not get Technitium token. Check credentials.', C_RED)
                return
        except Exception as e:
            self._set_status(f'Remote setup failed: {e}', C_RED)
            return
        finally:
            client.close()

        self.result = dict(mode='remote', host=host, ssh_user=ssh_user,
                           tn_user=tn_user, token=token, db_path=db_path,
                           max_rows=max_rows)
        save_config(self.result)
        self._set_status('Setup complete!', C_GREEN)
        self._status_lbl.after(800, self._finish)


# ── Setup window (first run, standalone Tk root) ─────────────────────────────

class SetupWindow(SetupForm, tk.Tk):
    def __init__(self, existing_config=None):
        tk.Tk.__init__(self)
        self.title('dns-watch Setup')
        self.configure(bg=C_BG)
        self.resizable(False, False)
        self.result = None
        self._build_form(self, existing_config or {})
        self.mainloop()          # block here until _finish / _cancel ends the loop

    def _cancel(self):
        self.quit()

    def _finish(self):
        self.quit()


# ── Settings dialog (from the running app, modal Toplevel) ───────────────────

class SettingsDialog(SetupForm, tk.Toplevel):
    def __init__(self, parent, existing_config=None):
        tk.Toplevel.__init__(self, parent)
        self.title('dns-watch Settings')
        self.configure(bg=C_BG)
        self.resizable(False, False)
        self.result = None
        self._build_form(self, existing_config or {})
        self.grab_set()
        self.wait_window()

    def _cancel(self):
        self.destroy()

    def _finish(self):
        self.destroy()


# ── Main dashboard ────────────────────────────────────────────────────────────

class DnsWatch(tk.Tk):
    def __init__(self, cfg):
        super().__init__()
        self._cfg = cfg
        host = cfg.get('host', 'localhost')
        self.title(f"DNS Watch: {host}")
        self.geometry('1300x720')
        self.configure(bg=C_BG)
        self.minsize(900, 500)

        self._q          = queue.Queue()
        self._rows       = []
        self._max_rows   = int(cfg.get('max_rows', MAX_ROWS))
        self._counts     = {'UPSTREAM': 0, 'CACHED': 0, 'BLOCKED': 0, 'OTHER': 0}
        self._autoscroll = tk.BooleanVar(value=True)
        self._paused     = False
        self._last_ts    = datetime.now().strftime('%Y-%m-%d 00:00:00.0')
        self._font_size  = tk.IntVar(value=int(cfg.get('font_size', 10)))

        self._build_ui()
        self._start_poll()
        self._refresh()

    def _build_ui(self):
        bar = tk.Frame(self, bg=C_HEADER, pady=6, padx=12)
        bar.pack(fill=tk.X)

        self._lbl_total    = self._stat(bar, 'Total',    '0', C_FG)
        self._lbl_upstream = self._stat(bar, 'Upstream', '0', C_GREEN)
        self._lbl_cached   = self._stat(bar, 'Cached',   '0', C_BLUE)
        self._lbl_blocked  = self._stat(bar, 'Blocked',  '0', C_RED)
        self._lbl_other    = self._stat(bar, 'Other',    '0', C_ORANGE)

        # right side: status … Auto-scroll  Font[slider]  ⏸ Pause  ⚙ Settings
        tk.Button(bar, text='⚙ Settings', bg=C_SURFACE, fg=C_DIM,
                  font=('monospace', 9), relief='flat', padx=8,
                  command=self._open_settings).pack(side=tk.RIGHT, padx=8)

        self._btn_pause = tk.Button(
            bar, text='⏸ Pause', bg=C_SURFACE, fg=C_DIM,
            font=('monospace', 9), relief='flat', padx=8,
            command=self._toggle_pause)
        self._btn_pause.pack(side=tk.RIGHT, padx=4)

        font_frame = tk.Frame(bar, bg=C_HEADER)
        font_frame.pack(side=tk.RIGHT, padx=8)
        tk.Label(font_frame, text='Font', fg=C_DIM, bg=C_HEADER,
                 font=('monospace', 8)).pack(side=tk.LEFT)
        tk.Scale(font_frame, from_=8, to=28, orient=tk.HORIZONTAL,
                 variable=self._font_size, command=self._on_font_change,
                 bg=C_HEADER, fg=C_FG, troughcolor=C_SURFACE,
                 highlightthickness=0, length=100, showvalue=True,
                 font=('monospace', 8)).pack(side=tk.LEFT)

        tk.Checkbutton(bar, text='Auto-scroll', variable=self._autoscroll,
                       bg=C_HEADER, fg=C_DIM, activebackground=C_HEADER,
                       selectcolor=C_SURFACE, font=('monospace', 9),
                       ).pack(side=tk.RIGHT, padx=8)

        self._lbl_status = tk.Label(bar, text='Connecting…', fg=C_DIM,
                                     bg=C_HEADER, font=('monospace', 9))
        self._lbl_status.pack(side=tk.RIGHT, padx=8)

        frame = tk.Frame(self, bg=C_BG)
        frame.pack(fill=tk.BOTH, expand=True, padx=8, pady=(4, 8))

        self._style = ttk.Style(self)
        self._style.theme_use('default')
        self._apply_font(self._font_size.get())
        self._style.map('Dns.Treeview', background=[('selected', C_HEADER)])

        cols = ('time', 'client', 'domain', 'type', 'proto', 'status')
        self._tree = ttk.Treeview(frame, columns=cols, show='headings',
                                   style='Dns.Treeview', selectmode='browse')
        for col, label, width, anchor in [
            ('time',   'Time',     90,  tk.CENTER),
            ('client', 'Client',  140,  tk.W),
            ('domain', 'Domain',  620,  tk.W),
            ('type',   'Type',     60,  tk.CENTER),
            ('proto',  'Proto',    70,  tk.CENTER),
            ('status', 'Status',  110,  tk.CENTER),
        ]:
            self._tree.heading(col, text=label)
            self._tree.column(col, width=width, anchor=anchor, stretch=(col == 'domain'))

        for tag, color in [('UPSTREAM', C_GREEN), ('CACHED', C_BLUE),
                            ('BLOCKED', C_RED), ('OTHER', C_ORANGE)]:
            self._tree.tag_configure(tag, foreground=color)

        vsb = ttk.Scrollbar(frame, orient=tk.VERTICAL, command=self._tree.yview)
        self._tree.configure(yscrollcommand=vsb.set)
        vsb.pack(side=tk.RIGHT, fill=tk.Y)
        self._tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
        self._tree.bind('<Button-1>', lambda _: self._autoscroll.set(False))

    def _stat(self, parent, label, value, color):
        f = tk.Frame(parent, bg=C_HEADER, padx=14)
        f.pack(side=tk.LEFT)
        tk.Label(f, text=label, fg=C_DIM, bg=C_HEADER, font=('monospace', 8)).pack()
        lbl = tk.Label(f, text=value, fg=color, bg=C_HEADER,
                       font=('monospace', 13, 'bold'))
        lbl.pack()
        return lbl

    def _apply_font(self, size):
        size = int(size)
        self._style.configure('Dns.Treeview',
            background=C_SURFACE, foreground=C_FG,
            fieldbackground=C_SURFACE, rowheight=max(16, size + 10),
            font=('monospace', size),
        )
        self._style.configure('Dns.Treeview.Heading',
            background=C_HEADER, foreground=C_FG,
            font=('monospace', size, 'bold'), relief='flat',
        )

    def _on_font_change(self, value):
        self._apply_font(value)
        self._cfg['font_size'] = int(value)
        save_config(self._cfg)

    def _toggle_pause(self):
        self._paused = not self._paused
        if self._paused:
            self._btn_pause.config(text='▶ Resume', fg=C_YELLOW)
            self._lbl_status.config(text='paused', fg=C_YELLOW)
        else:
            self._btn_pause.config(text='⏸ Pause', fg=C_DIM)
            self._lbl_status.config(text='live', fg=C_GREEN)

    def _open_settings(self):
        dlg = SettingsDialog(self, self._cfg)
        if dlg.result:
            self._cfg = dlg.result
            self._max_rows = int(self._cfg.get('max_rows', MAX_ROWS))
            self.title(f"DNS Watch: {self._cfg.get('host', 'localhost')}")

    def _start_poll(self):
        threading.Thread(target=self._poll_worker, daemon=True).start()

    def _poll_worker(self):
        self._q.put({'_status': 'live'})
        while True:
            sql = (
                "SELECT timestamp,client_ip,protocol,response_type,rcode,qname,qtype "
                f"FROM dns_logs WHERE timestamp > '{self._last_ts}' "
                "ORDER BY timestamp ASC LIMIT 500"
            )
            try:
                output = query_db(self._cfg, sql)
            except Exception:
                output = ''
            if output:
                rows = parse_rows(output)
                if rows:
                    self._last_ts = rows[-1]['timestamp']
                    for r in rows:
                        self._q.put(r)
            time.sleep(POLL_SEC)

    def _refresh(self):
        changed = False
        try:
            while True:
                item = self._q.get_nowait()
                if '_status' in item:
                    if not self._paused:
                        self._lbl_status.config(text=item['_status'], fg=C_GREEN)
                    continue
                if not self._paused:
                    self._add_row(item)
                    changed = True
        except queue.Empty:
            pass

        if changed:
            self._update_stats()
            if self._autoscroll.get():
                ch = self._tree.get_children()
                if ch:
                    self._tree.see(ch[-1])

        self.after(REFRESH_MS, self._refresh)

    def _add_row(self, r):
        tag = r['status'] if r['status'] in ('UPSTREAM', 'CACHED', 'BLOCKED') else 'OTHER'
        self._tree.insert('', tk.END, values=(
            r['time'], r['client'], r['domain'], r['qtype'], r['proto'], r['status'],
        ), tags=(tag,))
        self._counts[tag if tag in self._counts else 'OTHER'] += 1
        self._rows.append(r)
        while len(self._rows) > self._max_rows:
            self._tree.delete(self._tree.get_children()[0])
            self._rows.pop(0)

    def _update_stats(self):
        total = sum(self._counts.values())
        self._lbl_total.config(   text=str(total))
        self._lbl_upstream.config(text=str(self._counts['UPSTREAM']))
        self._lbl_cached.config(  text=str(self._counts['CACHED']))
        self._lbl_blocked.config( text=str(self._counts['BLOCKED']))
        self._lbl_other.config(   text=str(self._counts['OTHER']))


# ── Entry point ───────────────────────────────────────────────────────────────

if __name__ == '__main__':
    cfg = load_config()

    if not cfg:
        win = SetupWindow()      # runs its own mainloop until setup is done
        cfg = win.result
        win.destroy()
        if not cfg:
            raise SystemExit

    DnsWatch(cfg).mainloop()
