import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), "eloquence"))

# -*- coding: utf-8 -*-
"""
_gRPC-enabled_ replacement for the classic _eloquence.py shim.

This module preserves the public surface that the provided NVDA Eloquence
driver expects, but routes all ECI calls to a 32‑bit gRPC server that loads
ECI.DLL. Audio is streamed back as 16‑bit mono 11025 Hz PCM and played via
nvwave.WavePlayer on the client (64‑bit) side.

Expected by the driver (see eloquence.py):
- constants: hsz, pitch, fluctuation, rgh, bth, rate, vlm
- dicts/vars: langs, params, vparams, eciPath, lastindex
- functions: initialize, eciCheck, speak, index, synth, stop, pause, terminate,
             set_voice, getVParam, setVParam, setVariant, cmdProsody, process
- queue: synth_queue (driver enqueues lists of (callable, args_tuple))

Notes:
- Requires generated Python stubs from the same eci.proto: eci_pb2.py, eci_pb2_grpc.py
- Server default address: 127.0.0.1:18951

You can override the DLL path with environment variable ECI_DLL_PATH.
"""

from __future__ import annotations
import os
import sys
import threading
import queue
import time
import logging
from typing import Optional, Tuple, Any

# NVDA modules
import nvwave
import config

# --- gRPC stubs ---
try:
    import grpc
    from eci_pb2 import (
        InitRequest, TtsRequest, AddText, InsertIndex, Synthesize,
        SetParamRequest, SetVParamRequest, GetVParamRequest, CopyVoiceRequest, Empty
    )
    from eci_pb2_grpc import ECIStub
except Exception as e:
    # Defer import error until initialize so NVDA can still import the synth driver list
    grpc = None
    ECIStub = None

# ---- Public constants the driver imports ----
hsz = 1
pitch = 2
fluctuation = 3
rgh = 4
bth = 5
rate = 6
vlm = 7

# ---- Tolerant language map the driver consults when enumerating voices ----
class _Langs(dict):
    _base = {
        'esm': (131073, 'Latin American Spanish'),
        'esp': (131072, 'Castilian Spanish'),
        'ptb': (458752, 'Brazilian Portuguese'),
        'frc': (196609, 'French Canadian'),
        'fra': (196608, 'French'),
        'fin': (589824, 'Finnish'),
        'deu': (262144, 'German'),
        'ita': (327680, 'Italian'),
        'enu': (65536,  'American English'),
        'eng': (65537,  'British English'),
    }
    _by_prefix = [
        ('en',  (65536, 'American English')),
        ('fr',  (196608, 'French')),
        ('de',  (262144, 'German')),
        ('it',  (327680, 'Italian')),
        ('es',  (131072, 'Castilian Spanish')),
        ('pt',  (458752, 'Brazilian Portuguese')),
        ('fi',  (589824, 'Finnish')),
        # Fallbacks for unknown/non‑Eloquence codes
        ('zh',  (65536, 'American English')),
        ('chs', (65536, 'American English')),
        ('cht', (65536, 'American English')),
        ('ja',  (65536, 'American English')),
        ('ko',  (65536, 'American English')),
    ]
    def __getitem__(self, key):
        k = (key or '').strip().lower()
        if k in self._base:
            return self._base[k]
        for pref, val in self._by_prefix:
            if k.startswith(pref):
                return val
        return (65536, 'American English')
    def get(self, key, default=None):
        try: return self[key]
        except Exception: return default

langs = _Langs()

# ---- Public module variables expected by the driver ----
class _SafeDict(dict):
    def __getitem__(self, k):
        return dict.get(self, k, 0)

vparams: dict[int, int] = _SafeDict()
params: dict[int, int] = _SafeDict({9: 65536})  # Default ENU
eciPath: str = "eloquence\\eci.dll"             # Will be replaced with absolute path at initialize()
lastindex: Optional[int] = 0

# Public synth queue the driver pushes into: a list of (callable, args_tuple)
synth_queue: "queue.Queue[list[Tuple[Any, Tuple[Any, ...]]]]" = queue.Queue()

# ---- Audio / RPC state ----
SAMPLE_RATE = 11025
CHANNELS = 1
BITS = 16

_onIndexReached = None
_player: Optional[nvwave.WavePlayer] = None
_stub: Optional[ECIStub] = None
_stream_thr: Optional[threading.Thread] = None
_connected = False

# Internal stream for direct TTS commands (not used by the driver directly)
_stream_q: "queue.Queue[TtsRequest]" = queue.Queue()

def _ensure_player():
    global _player
    if _player:
        return
    try:
        device = config.conf["audio"]["outputDevice"]
        ducking = True if config.conf["audio"].get("audioDuckingMode") else False
        _player = nvwave.WavePlayer(CHANNELS, SAMPLE_RATE, BITS, outputDevice=device, wantDucking=ducking)
    except Exception:
        device = config.conf.get("speech", {}).get("outputDevice")
        nvwave.WavePlayer.MIN_BUFFER_MS = 1500
        _player = nvwave.WavePlayer(CHANNELS, SAMPLE_RATE, BITS, outputDevice=device, buffered=True)

def _stream_iter():
    while True:
        req = _stream_q.get()
        if req is None:
            break
        yield req

def _stream_worker(stub: ECIStub):
    global lastindex
    _ensure_player()
    try:
        for ev in stub.Tts(_stream_iter()):
            which = ev.WhichOneof("event")
            if which == "audio":
                try:
                    _player.feed(bytes(ev.audio.pcm))
                except Exception:
                    _player.idle()
                    time.sleep(0.02)
            elif which == "index":
                lastindex = int(ev.index.value) if ev.index.value >= 0 else None
                cb = _onIndexReached
                if cb:
                    try: cb(lastindex)
                    except Exception: pass
            elif which == "stopped":
                try: _player.stop()
                except Exception: pass
            elif which == "vparams":
                for k, v in ev.vparams.data.items():
                    try: vparams[int(k)] = int(v)
                    except Exception: pass
    except Exception:
        # If stream dies, leave playback state quiet
        pass

def _find_eci_dll(user_hint: Optional[str]) -> str:
    # 1) explicit hint
    if user_hint and os.path.isabs(user_hint) and os.path.exists(user_hint):
        return user_hint
    # 2) environment override
    env = os.environ.get("ECI_DLL_PATH")
    if env and os.path.exists(env):
        return env
    # 3) typical NVDA addon layouts relative to this file
    here = os.path.dirname(__file__)
    candidates = [
        os.path.join(here, "eloquence", "eci.dll"),
        os.path.join(here, "eci.dll"),
        # common installs
        r"C:\Program Files (x86)\Eloquence\eloquence\eci.dll",
        r"C:\Eloquence\eloquence\eci.dll",
    ]
    for p in candidates:
        if os.path.exists(p):
            return p
    # fall back to whatever was passed; server will report a clear error
    return user_hint or "eloquence\\eci.dll"

def _coerce_voice_id(vl) -> Optional[int]:
    if vl is None:
        return None
    if isinstance(vl, (tuple, list)) and vl:
        try: return int(vl[0])
        except Exception: pass
    try:
        return int(vl)
    except Exception:
        pass
    if isinstance(vl, str):
        key = vl.strip().lower()
        if key in langs._base:
            return int(langs._base[key][0])
    return None

# ---- Public API ----

def eciCheck() -> bool:
    """Probe for a reasonable Eloquence environment so NVDA can enable the driver."""
    dll = _find_eci_dll(None)
    # Don't open sockets here; NVDA calls this a lot. Existence check is enough.
    return os.path.exists(dll)

def initialize(indexCallback=None, host: str = "127.0.0.1:18951", eci_path: Optional[str] = None):
    """Connect to the server, initialize ECI, and start the streaming thread."""
    global _stub, _stream_thr, _onIndexReached, _connected, eciPath

    if _connected:
        _onIndexReached = indexCallback
        return

    if grpc is None or ECIStub is None:
        raise RuntimeError("gRPC stubs not available. Ensure grpcio/grpcio-tools are installed and eci_pb2*.py are present.")

    _onIndexReached = indexCallback
    ch = grpc.insecure_channel(host, options=[("grpc.max_receive_message_length", 20 * 1024 * 1024)])
    _stub = ECIStub(ch)

    # Resolve DLL path & remember it for the driver (used to enumerate *.syn alongside the DLL)
    dll_path = _find_eci_dll(eci_path)
    eciPath = dll_path

    # Pick a default voice/lang if configured; otherwise ENU
    desired_lang = None
    try:
        eci_cfg = config.conf.get('speech', {}).get('eci', {})
        v = eci_cfg.get('voice', '')
        if v and v in langs._base:
            desired_lang = langs._base[v][0]
    except Exception:
        pass
    if desired_lang is None:
        desired_lang = 65536  # ENU

    rep = _stub.Init(InitRequest(eci_path=dll_path, voice_lang_id=int(desired_lang)))
    if not rep.ok:
        raise RuntimeError(f"ECI init failed: {rep.message}")

    # Cache vparams that the driver reads during init
    for p in (1, 2, 3, 4, 5, 6, 7, 9):
        try:
            v = _stub.GetVParam(GetVParamRequest(param=p)).value
            vparams[p] = int(v)
        except Exception:
            if p == 9:
                vparams[9] = int(desired_lang)
                params[9] = int(desired_lang)

    # Start the bidi stream thread
    _stream_thr = threading.Thread(target=_stream_worker, args=(_stub,), daemon=True)
    _stream_thr.start()
    _connected = True

def speak(text: str):
    if not isinstance(text, str):
        text = str(text)
    _stream_q.put(TtsRequest(add_text=AddText(text=text)))

def index(x: int):
    _stream_q.put(TtsRequest(insert_index=InsertIndex(index=int(x))))

def synth():
    _stream_q.put(TtsRequest(synth=Synthesize()))

def stop():
    if _stub:
        try: _stub.Stop(Empty())
        except Exception: pass

def pause(switch: bool):
    # Playback is client-side
    try:
        if _player: _player.pause(bool(switch))
    except Exception:
        pass

def terminate():
    global _connected
    try:
        _stream_q.put(None)  # end the generator
    except Exception:
        pass
    try:
        if _player: _player.close()
    except Exception:
        pass
    _connected = False

def set_voice(vl):
    vid = _coerce_voice_id(vl)
    if vid is None:
        return
    if _stub:
        try: _stub.SetParam(SetParamRequest(param=9, value=vid))
        except Exception: pass
    params[9] = vid
    vparams[9] = vid

def getVParam(pr) -> int:
    p = int(pr)
    if _stub:
        try:
            val = _stub.GetVParam(GetVParamRequest(param=p)).value
            vparams[p] = int(val)
        except Exception:
            pass
    return int(vparams.get(p, 0))

def setVParam(pr, vl, temporary: bool = False):
    p = int(pr)
    try:
        v = int(vl)
    except Exception:
        return
    if _stub:
        try: _stub.SetVParam(SetVParamRequest(param=p, value=v, temporary=bool(temporary)))
        except Exception: pass
    if not temporary:
        vparams[p] = v

def setVariant(v):
    try:
        vv = int(v)
    except Exception:
        return
    if _stub:
        try: _stub.CopyVoice(CopyVoiceRequest(variant=vv))
        except Exception:
            pass

def cmdProsody(pr, multiplier):
    """Client-side helper used by the driver to scale params relative to current value."""
    val = getVParam(pr)
    if multiplier:
        try:
            val = int(val * multiplier)
        except Exception:
            pass
    setVParam(pr, val, temporary=True)

# ---- Queue bridging for the NVDA driver ----

def process():
    """
    Drain the public synth_queue (a list of (callable, args_tuple)) that the
    driver enqueues, and invoke those callables in order.
    """
    drained = False
    while True:
        try:
            ops = synth_queue.get_nowait()
            drained = True
        except queue.Empty:
            break

        if isinstance(ops, (list, tuple)):
            for item in ops:
                try:
                    fn, args = item
                    if callable(fn):
                        if isinstance(args, tuple):
                            fn(*args)
                        elif args is None:
                            fn()
                        else:
                            # Some drivers might pass a single arg directly
                            fn(args)
                except Exception:
                    # Keep streaming even if one op misbehaves
                    logging.exception("Error executing synth op: %r", item)
        else:
            # Unknown shape; ignore gracefully
            pass
    return drained
