#!/usr/bin/env python3
# /var/www/html/facial_api/build_encodings.py
"""
Pipeline de encodings faciais (Baralho do Crime -> MySQL face_encodings)
- Lê registros de baralho_crime com foto_path válido
- Gera encodings (128-d) com face_recognition (HOG por padrão; opcional CNN)
- Persistência idempotente: UNIQUE(pessoa_id, source_image_path, face_index)
- Compacta encodings (npy -> zlib) para ganho de espaço

Execução:
  python3 build_encodings.py              # HOG
  python3 build_encodings.py --cnn        # CNN (requer dlib com CUDA ou bem otimizado)
  python3 build_encodings.py --limit 50   # processa N pendentes
  python3 build_encodings.py --rebuild    # refaz encodings de todos
"""

import os, sys, zlib, io, argparse, logging
import numpy as np
from PIL import Image
import face_recognition as fr
import mysql.connector
from mysql.connector import errorcode

# ---------- CONFIG ----------
DB_CONFIG = {
    "user": "root",
    "password": "152535ff",
    "host": "127.0.0.1",
    "database": "facial",
    "autocommit": True,
}
IMG_BASE = "/var/www/html/facial"  # base padrão das fotos
LOG_PATH = "/var/log/face_encoding.log"

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[logging.FileHandler(LOG_PATH), logging.StreamHandler()]
)

# ---------- DB Helpers ----------
def db():
    return mysql.connector.connect(**DB_CONFIG)

def fetch_targets(conn, limit=None, rebuild=False):
    """
    Se rebuild=False: pega apenas quem não tem encoding ainda (join anti).
    Se rebuild=True : pega todos com foto_path válido.
    """
    cur = conn.cursor(dictionary=True)
    if rebuild:
        sql = """
          SELECT b.id AS pessoa_id, b.foto_path
          FROM baralho_crime b
          WHERE b.foto_path IS NOT NULL AND b.foto_path <> ''
        """
        if limit:
            sql += " LIMIT %s"
            cur.execute(sql, (limit,))
        else:
            cur.execute(sql)
    else:
        sql = """
          SELECT b.id AS pessoa_id, b.foto_path
          FROM baralho_crime b
          LEFT JOIN face_encodings e
            ON e.pessoa_id = b.id AND e.source_image_path = b.foto_path
          WHERE b.foto_path IS NOT NULL AND b.foto_path <> ''
            AND e.id IS NULL
        """
        if limit:
            sql += " LIMIT %s"
            cur.execute(sql, (limit,))
        else:
            cur.execute(sql)
    rows = cur.fetchall()
    cur.close()
    return rows

def upsert_encoding(conn, pessoa_id, source_image_path, face_index, enc_vec, method, version="face_recognition_1.3"):
    """
    Salva encoding np.float64 (128) como .npy comprimido (zlib) em MEDIUMBLOB.
    """
    # Serializa .npy em memória
    bio = io.BytesIO()
    np.save(bio, enc_vec.astype(np.float32))  # float32 reduz 50% de espaço; suficiente
    raw = bio.getvalue()
    comp = zlib.compress(raw, level=9)

    sql = """
    INSERT INTO face_encodings
      (pessoa_id, source_image_path, face_index, encoding_dim, encoding, method, model_version, updated_at)
    VALUES (%s, %s, %s, %s, %s, %s, %s, NOW())
    ON DUPLICATE KEY UPDATE
      encoding_dim=VALUES(encoding_dim),
      encoding=VALUES(encoding),
      method=VALUES(method),
      model_version=VALUES(model_version),
      updated_at=NOW()
    """
    cur = conn.cursor()
    cur.execute(sql, (pessoa_id, source_image_path, face_index, len(enc_vec), comp, method, version))
    cur.close()

# ---------- Image / Face ----------
def load_image_any(path):
    # Robustez: abre via PIL, garante RGB
    with Image.open(path) as im:
        im = im.convert("RGB")
        return np.array(im)

def select_best_face(face_locations):
    """
    Escolhe a face com maior área (heurística simples).
    face_locations: [(top, right, bottom, left), ...]
    """
    if not face_locations:
        return None, None
    areas = []
    for i, (t, r, b, l) in enumerate(face_locations):
        areas.append(((b - t) * (r - l), i))
    areas.sort(reverse=True)
    _, idx = areas[0]
    return face_locations[idx], idx

# ---------- Main ----------
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cnn", action="store_true", help="usar modelo CNN (mais pesado)")
    parser.add_argument("--limit", type=int, default=None, help="limitar N entradas")
    parser.add_argument("--rebuild", action="store_true", help="recalcular encodings de todos")
    args = parser.parse_args()

    model = "cnn" if args.cnn else "hog"
    logging.info("Modelo de detecção: %s", model)

    conn = db()
    rows = fetch_targets(conn, limit=args.limit, rebuild=args.rebuild)
    logging.info("Alvos para processamento: %d", len(rows))

    processed, skipped, failed = 0, 0, 0

    for r in rows:
        pessoa_id = int(r["pessoa_id"])
        src = r["foto_path"].strip()

        # Normaliza caminho: se vier sem base, prefixa IMG_BASE
        if not os.path.isabs(src):
            src = os.path.join(IMG_BASE, src)

        if not os.path.exists(src):
            logging.warning("Imagem não encontrada: %s (pessoa_id=%s)", src, pessoa_id)
            skipped += 1
            continue

        try:
            img = load_image_any(src)
            locs = fr.face_locations(img, model=model)
            if not locs:
                logging.warning("Nenhuma face detectada em: %s", src)
                skipped += 1
                continue

            # Seleciona melhor face e extrai encoding
            best_loc, best_idx = select_best_face(locs)
            encs = fr.face_encodings(img, known_face_locations=[best_loc])
            if not encs:
                logging.warning("Falhou em gerar encoding: %s", src)
                skipped += 1
                continue

            enc_vec = encs[0]
            upsert_encoding(conn, pessoa_id, src, best_idx, enc_vec, method=model)
            processed += 1
            logging.info("OK pessoa_id=%s face_index=%s src=%s", pessoa_id, best_idx, src)

        except Exception as e:
            logging.error("Erro pessoa_id=%s src=%s -> %s", pessoa_id, src, e)
            failed += 1

    logging.info("Resumo -> processados:%d | pulados:%d | falhas:%d", processed, skipped, failed)
    conn.close()

if __name__ == "__main__":
    main()
