# -*- coding: utf-8 -*-
"""
RPC client shim for the NVDA Eloquence driver.

- Exposes exactly what the provided driver (eloquence.py) expects:
  constants, langs, params/vparams, eciPath, lastindex,
  synth_queue (list of (callable, args_tuple)),
  and process()/initialize()/speak()/index()/synth()/stop()/pause()/terminate(),
  set_voice/getVParam/setVParam/setVariant/cmdProsody.

- Forwards text verbatim (including back-quote tags) to the 32-bit server.
  The server converts UTF-8 → ANSI and calls ECI correctly, so tags are not spoken.
"""

from __future__ import annotations
import os, threading, queue, time, logging, sys
sys.path.append(os.path.join(os.path.dirname(__file__), "eloquence"))
from typing import Optional, Any
import subprocess
import signal
import nvwave, config
import grpc
from eci_pb2 import (
    InitRequest, TtsRequest, AddText, InsertIndex, Synthesize,
    SetParamRequest, SetVParamRequest, GetVParamRequest, CopyVoiceRequest, Empty
)
from eci_pb2_grpc import ECIStub

# ---- Public constants expected by the driver ----
hsz=1; pitch=2; fluctuation=3; rgh=4; bth=5; rate=6; vlm=7

langs = {
    'esm': (131073, 'Latin American Spanish'),
    'esp': (131072, 'Castilian Spanish'),
    'ptb': (458752, 'Brazilian Portuguese'),
    'frc': (196609, 'French Canadian'),
    'fra': (196608, 'French'),
    'fin': (589824, 'Finnish'),
    'deu': (262144, 'German'),
    'ita': (327680, 'Italian'),
    'enu': (65536,  'American English'),
    'eng': (65537,  'British English'),
}

class _SafeDict(dict):
    def __getitem__(self, k): return dict.get(self, k, 0)

vparams=_SafeDict()
params=_SafeDict({9:65536})
eciPath="eloquence\\eci.dll"
lastindex=0

# Public queue the driver fills with [(callable, args_tuple), ...]
synth_queue: "queue.Queue[list[tuple]]" = queue.Queue()

# ---- Local audio playback ----
SAMPLE_RATE=11025; CHANNELS=1; BITS=16
_player: nvwave.WavePlayer | None = None

# ---- RPC state ----
_stub: ECIStub | None = None
_server_proc: subprocess.Popen | None = None
_stream_q: "queue.Queue[TtsRequest]" = queue.Queue()
_stream_thr: threading.Thread | None = None
_connected=False
_onIndexReached=None

def _ensure_player():
    global _player
    if _player: return
    try:
        device = config.conf["audio"]["outputDevice"]
        ducking = True if config.conf["audio"].get("audioDuckingMode") else False
        _player = nvwave.WavePlayer(CHANNELS, SAMPLE_RATE, BITS, outputDevice=device, wantDucking=ducking)
    except Exception:
        device = config.conf.get("speech", {}).get("outputDevice")
        nvwave.WavePlayer.MIN_BUFFER_MS = 1500
        _player = nvwave.WavePlayer(CHANNELS, SAMPLE_RATE, BITS, outputDevice=device, buffered=True)

def _stream_iter():
    while True:
        req = _stream_q.get()
        if req is None: break
        yield req

def _stream_worker(stub: ECIStub):
    global lastindex
    _ensure_player()
    try:
        for ev in stub.Tts(_stream_iter()):
            which = ev.WhichOneof("event")
            if which=="audio":
                try: _player.feed(bytes(ev.audio.pcm))
                except Exception:
                    _player.idle(); time.sleep(0.02)
            elif which=="index":
                lastindex = int(ev.index.value) if ev.index.value>=0 else None
                cb = _onIndexReached
                if cb:
                    try: cb(lastindex)
                    except Exception: pass
            elif which=="stopped":
                try: _player.stop()
                except Exception: pass
            elif which=="vparams":
                for k,v in ev.vparams.data.items():
                    try: vparams[int(k)] = int(v)
                    except Exception: pass
    except Exception:
        # Keep quiet if the stream ends; NVDA may reconnect.
        pass

def _find_eci_dll(user_hint: Optional[str]) -> str:
    if user_hint and os.path.isabs(user_hint) and os.path.exists(user_hint): return user_hint
    env = os.environ.get("ECI_DLL_PATH")
    if env and os.path.exists(env): return env
    here = os.path.dirname(__file__)
    for p in (os.path.join(here,"eloquence","eci.dll"), os.path.join(here,"eci.dll")):
        if os.path.exists(p): return p
    return user_hint or "eloquence\\eci.dll"

# ---- Public API expected by the driver ----

def eciCheck()->bool:
    """Lightweight probe so the driver can advertise availability."""
    global eciPath
    eciPath = os.path.abspath(_find_eci_dll(None))
    return os.path.exists(eciPath)

def initialize(indexCallback=None, host="127.0.0.1:18951", eci_path=None):
    """Connect to the server, initialize ECI, and start the stream thread."""
    global _stub, _stream_thr, _connected, _onIndexReached, eciPath
    if _connected:
        _onIndexReached = indexCallback
        return
    _onIndexReached = indexCallback

    ch = grpc.insecure_channel(host, options=[("grpc.max_receive_message_length", 20*1024*1024)])
    _stub = ECIStub(ch)

    dll_path = os.path.abspath(_find_eci_dll(eci_path))
    eciPath = dll_path
    # Launch the 32-bit server if not already running
    global _server_proc
    exe_name = "eci_grpc_server.exe"
    exe_path = os.path.join(os.path.dirname(__file__), "eloquence", exe_name)
    if _server_proc is None or _server_proc.poll() is not None:
        try:
            # STARTUPINFO hides the console window
            si = subprocess.STARTUPINFO()
            si.dwFlags |= subprocess.STARTF_USESHOWWINDOW
            _server_proc = subprocess.Popen(
                [exe_path],
                stdout=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL,
                startupinfo=si,
                creationflags=subprocess.CREATE_NO_WINDOW
            )
            time.sleep(0.3)  # tiny delay to let it bind to port
        except Exception as e:
            raise RuntimeError(f"Failed to start ECI gRPC server: {e}")

    desired_lang = None
    try:
        eci_cfg = config.conf.get('speech', {}).get('eci', {})
        v = eci_cfg.get('voice', '')
        if v and v in langs: desired_lang = langs[v][0]
    except Exception: pass
    if desired_lang is None: desired_lang = 65536

    rep = _stub.Init(InitRequest(eci_path=dll_path, voice_lang_id=int(desired_lang)))
    if not rep.ok:
        raise RuntimeError("ECI init failed: "+rep.message)

    # Prime vparams cache (driver reads these during init)
    for p in (1,2,3,4,5,6,7,9):
        try:
            val = _stub.GetVParam(GetVParamRequest(param=p)).value
            vparams[p] = int(val)
        except Exception:
            if p==9:
                vparams[9] = int(desired_lang); params[9]=int(desired_lang)

    _stream_thr = threading.Thread(target=_stream_worker, args=(_stub,), daemon=True)
    _stream_thr.start()
    _connected=True
def speak(text: str):
    if not isinstance(text, str):
        text = str(text)

    # Fast path: nothing to do if there is no back-quote pitch tag
    if "`p" not in text:
        _stream_q.put(TtsRequest(add_text=AddText(text=text)))
        return

    import re
    # Match back-quote pitch followed by a single capital letter (A–Z plus common accented range)
    # Examples: `p130A, `p90É, `p70Z
    pat = re.compile(r"`p(-?\d+)\s*([A-Z\u00C0-\u017F])")

    # We’ll stream in small chunks: any leading plain text, then scoped-cap char, then resume.
    pos = 0
    baseline_pitch = None

    for m in pat.finditer(text):
        # Emit any plain text before this match verbatim (keep all tags intact)
        if m.start() > pos:
            _stream_q.put(TtsRequest(add_text=AddText(text=text[pos:m.start()])))

        # Parse desired pitch and the capital char
        try:
            bump = int(m.group(1))
        except Exception:
            bump = None
        cap_ch = m.group(2)

        # Snapshot baseline lazily
        if baseline_pitch is None:
            baseline_pitch = getVParam(pitch)

        # Apply temporary bump only for this character
        if bump is not None:
            try:
                setVParam(pitch, bump, temporary=True)
            except Exception:
                pass

        # Send just the capital character (we intentionally do NOT send the `pNN tag)
        _stream_q.put(TtsRequest(add_text=AddText(text=cap_ch)))

        # Restore pitch immediately so it doesn't leak
        if baseline_pitch is not None:
            try:
                setVParam(pitch, int(baseline_pitch), temporary=True)
            except Exception:
                pass

        # Advance past the matched sequence
        pos = m.end()

    # Emit any trailing text after the last match verbatim
    if pos < len(text):
        _stream_q.put(TtsRequest(add_text=AddText(text=text[pos:])))

def index(x:int):
    _stream_q.put(TtsRequest(insert_index=InsertIndex(index=int(x))))

def synth():
    _stream_q.put(TtsRequest(synth=Synthesize()))

def stop():
    if _stub:
        try: _stub.Stop(Empty())
        except Exception: pass

def pause(switch:bool):
    try:
        if _player: _player.pause(bool(switch))
    except Exception: pass

def terminate():
    global _server_proc
    try:
        _stream_q.put(None)
    except Exception:
        pass
    try:
        if _player:
            _player.close()
    except Exception:
        pass
    try:
        if _server_proc and _server_proc.poll() is None:
            _server_proc.send_signal(signal.SIGTERM)
            _server_proc.wait(timeout=2)
    except Exception:
        pass
    _server_proc = None

def set_voice(vl):
    try: vid = int(vl)
    except Exception:
        try: vid = int(langs[str(vl).strip().lower()][0])
        except Exception: return
    if _stub:
        try: _stub.SetParam(SetParamRequest(param=9, value=vid))
        except Exception: pass
    params[9]=vid; vparams[9]=vid

def getVParam(pr)->int:
    p = int(pr)
    if _stub:
        try:
            val = _stub.GetVParam(GetVParamRequest(param=p)).value
            vparams[p] = int(val)
        except Exception: pass
    return int(vparams.get(p,0))

def setVParam(pr, vl, temporary:bool=False):
    p = int(pr)
    try: v = int(vl)
    except Exception: return
    if _stub:
        try: _stub.SetVParam(SetVParamRequest(param=p, value=v, temporary=bool(temporary)))
        except Exception: pass
    if not temporary:
        vparams[p] = v

def setVariant(v):
    try: vv = int(v)
    except Exception: return
    if _stub:
        try: _stub.CopyVoice(CopyVoiceRequest(variant=vv))
        except Exception: pass

def cmdProsody(pr, multiplier):
    val = getVParam(pr)
    if multiplier:
        try: val = int(val * multiplier)
        except Exception: pass
    setVParam(pr, val, temporary=True)

def process():
    """
    Drain synth_queue. The driver enqueues lists of (callable, args_tuple).
    Callables are functions from this module (speak, index, synth, etc.).
    """
    drained=False
    while True:
        try:
            ops = synth_queue.get_nowait()
            drained=True
        except queue.Empty:
            break
        if isinstance(ops, (list, tuple)):
            for item in ops:
                try:
                    fn, args = item
                    if callable(fn):
                        if isinstance(args, tuple): fn(*args)
                        elif args is None: fn()
                        else: fn(args)
                except Exception:
                    logging.exception("synth op failed: %r", item)
    return drained
