"""
mini_sprite_vector_demo.py – embed ➜ store ➜ search using local BLIP embeddings

This demo mirrors *embeddings_demo.py* but persists the vectors in MongoDB Atlas
and performs similarity queries with the Atlas Vector Search operator.  Each
vector is the projected multimodal **CLS token** extracted by BLIP via LAVIS.
The accompanying text input is the image filename (minus extension).
"""

import glob, os, datetime as dt
from pathlib import Path
from typing import List

from dotenv import load_dotenv
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
import pymongo
from pymongo.collection import Collection
from bson.binary import Binary, BinaryVectorDtype  # PyMongo ≥ 4.11

load_dotenv()

MONGODB_URI = os.getenv("MONGODB_URI")  # mongodb+srv://…
MONGODB_DB = os.getenv("MONGODB_DB", "detection_demo")

COLL_NAME = os.getenv("REFERENCE_COLLECTION", "image_detection")
INDEX_NAME = os.getenv("VECTOR_INDEX_NAME", "image_detection_index")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------------------- 1. Transformers setup ---------------------------
# Use a widely-available base checkpoint; override via env if desired
CKPT = os.getenv("CLIP_CHECKPOINT", "openai/clip-vit-base-patch32")

processor = CLIPProcessor.from_pretrained(CKPT, use_fast=False)
# Vision-only model → returns CLS embedding in pooler_output
vision = CLIPModel.from_pretrained(CKPT).to(device).eval()

DIMENSIONS = vision.config.vision_config.hidden_size  # 512 for CLIP base

# Helper: convert numpy array → BSON binary
as_bson = lambda v: Binary.from_vector(v, BinaryVectorDtype.FLOAT32)

mdb = pymongo.MongoClient(MONGODB_URI)[MONGODB_DB]

if COLL_NAME not in mdb.list_collection_names():
    mdb.command("create", COLL_NAME)

coll: Collection = mdb[COLL_NAME]

# ─────────────────────────────── 2. Ensure vector index ──────────────────────────
if not any(ix["name"] == INDEX_NAME for ix in coll.list_search_indexes()):
    coll.create_search_index(
        {
            "name": INDEX_NAME,
            "type": "vectorSearch",
            "definition": {
                "fields": [
                    {
                        "path": "embedding",
                        "type": "vector",
                        "numDimensions": DIMENSIONS,
                        "similarity": "cosine",
                    }
                ]
            },
        }
    )
    print(f"✅ Created Atlas Vector Search index '{INDEX_NAME}'")


def ingest_dir(folder: str, label: str = "positive") -> None:
    paths = sorted(glob.glob(f"{folder}/*.[pj][np]g"))

    # Skip files already in the collection
    seen = {
        d["filename"]
        for d in coll.find(
            {"filename": {"$in": [Path(p).name for p in paths]}}, {"filename": 1}
        )
    }
    todo = [p for p in paths if Path(p).name not in seen]
    if not todo:
        return

    print(f"Embedding {len(todo)} new images from '{folder}' …")

    docs = []
    for path in todo:
        img = Image.open(path).convert("RGB")

        # Processor handles resize / normalise
        inputs = processor(images=img, return_tensors="pt").to(device)
        with torch.no_grad():
            vec = vision.get_image_features(**inputs)
            vec = torch.nn.functional.normalize(vec, dim=-1)[0]

        fname = Path(path).name
        docs.append(
            {
                "filename": fname,
                "filepath": str(Path(path).resolve()),
                "embedding": as_bson(vec),
                "created_at": dt.datetime.utcnow(),
                "label": label,
            }
        )

    coll.insert_many(docs)
    print(f"✅ Ingested {len(docs)} images.")


def find_similar(img_path: str, k: int = 3, comment: str = ""):
    img = Image.open(img_path).convert("RGB")
    inputs = processor(images=img, return_tensors="pt").to(device)
    with torch.no_grad():
        vec = vision.get_image_features(**inputs)
        vec = torch.nn.functional.normalize(vec, dim=-1)[0]

    hits = coll.aggregate(
        [
            {
                "$vectorSearch": {
                    "index": INDEX_NAME,
                    "path": "embedding",
                    "queryVector": as_bson(vec),
                    "numCandidates": 5,
                    "limit": k,
                }
            },
            {
                "$project": {
                    "_id": 0,
                    "filename": 1,
                    "label": 1,
                    "score": {"$meta": "vectorSearchScore"},
                }
            },
        ]
    )

    print(f"\n🔍  Top-{k} for {Path(img_path).name} - {comment}")
    for h in hits:
        if h["score"] > 0.82:
            print(f"{h['score']:.3f}  {h['label']:<8}  {h['filename']}")


if __name__ == "__main__":
    ingest_dir("../golden_set", label="marijuana")
    ingest_dir("../negative_set", label="plant/flower")

    find_similar(
        "../img/mimosa.png", comment="This is a strain of weed with purple tones"
    )
    find_similar("../img/marigold.png", comment="Marigolds are yellow flowers")

    find_similar("../img/dog1.png", comment="This is a dog")
