import subprocess
import numpy as np
import cv2
from ultralytics import YOLO
from screeninfo import get_monitors
from pathlib import Path
import pygame
import math

# --- Global Variables ---
# These are used to hold the state of the application.
model = None
camera = None
screen = None
screen_W, screen_H = 0, 0
classes_to_detect = [] # Holds the final list of class indices to detect

def init_model(model_name):
    """Initializes the YOLO model from a file."""
    global model
    script_dir = Path(__file__).parent
    model_path = script_dir / model_name
    print(f"Loading model from {model_path}...")
    model = YOLO(model_path)

def init_fullscreen_display():
    """Initializes a fullscreen Pygame window."""
    global screen, screen_W, screen_H
    pygame.init()
    
    # Load and set the window icon
    icon_path = Path(__file__).parent / "icon.png"
    if icon_path.exists():
        icon = pygame.image.load(icon_path)
        pygame.display.set_icon(icon)

    # Get primary monitor resolution
    try:
        monitor = get_monitors()[0]
        screen_W, screen_H = monitor.width, monitor.height
    except IndexError:
        print("Warning: Could not detect monitor resolution. Falling back to 1920x1080.")
        screen_W, screen_H = 1920, 1080

    screen = pygame.display.set_mode((screen_W, screen_H), pygame.FULLSCREEN)
    pygame.display.set_caption("WebCam Detection")

def release_fullscreen_display():
    """Quits Pygame."""
    pygame.quit()

def get_user_class_selection(model_names):
    """Displays a terminal menu for the user to select which classes to detect."""
    all_classes = sorted(list(model_names.values()))
    print("--- Please select classes to detect ---")
    
    # --- Multi-column printing logic ---
    num_columns = 4
    try:
        max_len = max(len(f"{i}: {name}") for i, name in enumerate(all_classes))
    except ValueError: # Handles empty class list
        max_len = 10
    col_width = max_len + 4  # Add padding

    num_rows = math.ceil(len(all_classes) / num_columns)

    for i in range(num_rows):
        line_items = [f"{idx}: {all_classes[idx]}".ljust(col_width) 
                      for j in range(num_columns) 
                      if (idx := i + j * num_rows) < len(all_classes)]
        print("".join(line_items))

    print("---------------------------------------")
    
    default_classes = ["Person", "Glasses", "Tie", "Monitor/TV"]

    while True:
        try:
            prompt = (
                f"Enter numbers (e.g., 0,15,22), "
                f"'d' for default ({', '.join(default_classes)}), "
                f"or press Enter for all: "
            )
            raw_input = input(prompt)

            if raw_input.strip().lower() == 'd':
                print(f"Default selection chosen: {', '.join(default_classes)}")
                return [name for name in default_classes if name in all_classes]

            if not raw_input.strip():
                print("No selection made. Detecting all classes.")
                return []

            selected_indices = [int(i.strip()) for i in raw_input.split(',')]
            
            selected_names = []
            valid_indices = True
            for i in selected_indices:
                if 0 <= i < len(all_classes):
                    selected_names.append(all_classes[i])
                else:
                    print(f"Error: Index {i} is out of range (0-{len(all_classes)-1}).")
                    valid_indices = False
            
            if not valid_indices: continue
            if not selected_names: 
                print("No valid classes selected. Please try again.")
                continue

            print(f"\nYou have selected: {', '.join(selected_names)}\n")
            return selected_names
        except ValueError:
            print("Invalid input. Please enter numbers separated by commas.")
        except Exception as e:
            print(f"An unexpected error occurred: {e}")
            return []

def set_selected_class_indices():
    """Gets user selection and converts class names to the indices the model uses."""
    global classes_to_detect
    class_names = get_user_class_selection(model.names)
    name_to_index = {v: k for k, v in model.names.items()}
    classes_to_detect = [name_to_index[name] for name in class_names if name in name_to_index]

def gstreamer_pipeline(sensor_id=0, capture_width=1920, capture_height=1080, display_width=960, display_height=540, framerate=30, flip_method=0):
    """Constructs a GStreamer pipeline for capturing from a Jetson CSI camera."""
    return (
        f"nvarguscamerasrc sensor-id={sensor_id} ! "
        f"video/x-raw(memory:NVMM), width={capture_width}, height={capture_height}, framerate={framerate}/1 ! "
        f"nvvidconv flip-method={flip_method} ! "
        f"video/x-raw, width={display_width}, height={display_height}, format=BGRx ! "
        "videoconvert ! "
        "video/x-raw, format=BGR ! appsink"
    )

def prepare_camera(flip_method=0):
    """Initializes the camera, trying the virtual device first, then falling back to CSI."""
    global camera
    if camera and camera.isOpened():
        return

    # Priority 1: Try the v4l2loopback device (used with run_webcam.sh)
    print("Attempting to open virtual camera /dev/video10...")
    camera = cv2.VideoCapture("/dev/video10", cv2.CAP_V4L2)
    if camera.isOpened():
        print("✅ Successfully opened virtual camera /dev/video10.")
        return

    # Priority 2: Fallback to direct CSI camera access (for direct python script execution)
    print("⚠️ Could not open /dev/video10. Falling back to direct CSI camera access.")
    pipeline = gstreamer_pipeline(flip_method=flip_method)
    print(f"Using GStreamer Pipeline: {pipeline}")
    camera = cv2.VideoCapture(pipeline, cv2.CAP_GSTREAMER)
    
    if not camera.isOpened():
        print("❌ Error: Unable to open camera with any method.")
        camera = None
    else:
        print("✅ CSI Camera setup completed via direct pipeline.")

def release_camera():
    """Releases the camera resource."""
    global camera
    if camera:
        print("Releasing camera resources...")
        camera.release()
        camera = None

def get_frame():
    """Reads a single frame from the camera."""
    if not (camera and camera.isOpened()):
        return None
    ret, frame = camera.read()
    if not ret:
        print("❌ Error reading frame from camera.")
        return None
    return frame

def get_class_color(class_name):
    """Generates a consistent, unique color for each class name."""
    hash_val = hash(class_name)%100000
    np.random.seed(hash_val)
    return tuple(np.random.randint(0, 255, 3).tolist())

def predict_and_detect(img, conf=0.5, rectangle_thickness=2, text_thickness=1):
    """Runs YOLO prediction and draws bounding boxes on the image."""
    # The model uses the global `classes_to_detect` list set by the user.
    global classes_to_detect
    if classes_to_detect:
        results = model.predict(img, classes=classes_to_detect, conf=conf, verbose=False)
    else:
        results = model.predict(img, conf=conf, verbose=False)

    for result in results:
        for box in result.boxes:
            class_name = result.names[int(box.cls[0])]
            color = get_class_color(class_name)

            cv2.rectangle(img, (int(box.xyxy[0][0]), int(box.xyxy[0][1])),
                          (int(box.xyxy[0][2]), int(box.xyxy[0][3])), color, rectangle_thickness)

            # Prepare text and background for the label
            text = f"{class_name}"
            font = cv2.FONT_HERSHEY_PLAIN
            font_scale = 1
            (text_width, text_height), baseline = cv2.getTextSize(text, font, font_scale, text_thickness)
            
            # Draw a filled rectangle behind the text for better readability
            text_origin = (int(box.xyxy[0][0]), int(box.xyxy[0][1]) - 10)
            cv2.rectangle(img, (text_origin[0], text_origin[1] - text_height - baseline),
                          (text_origin[0] + text_width, text_origin[1] + baseline), color, -1)
            # Draw the text label
            cv2.putText(img, text, (text_origin[0], text_origin[1]), font, font_scale, (255, 255, 255), text_thickness, cv2.LINE_AA)
    return img, results

def main():
    """Main application entry point."""
    # 1. Initialize the model first to get class names
    init_model("yolo11n_object365.pt")

    # 2. Get user's class selection via the terminal
    set_selected_class_indices()

    # 3. Initialize Pygame display and camera only after selection is done
    init_fullscreen_display()
    prepare_camera(flip_method=0)

    if not camera or not camera.isOpened():
        print("Critical Error: No camera available. Exiting.")
        release_fullscreen_display()
        return

    running = True
    while running:
        try:
            # Handle events (quit on 'q' key or window close)
            for event in pygame.event.get():
                if event.type == pygame.QUIT or (event.type == pygame.KEYDOWN and event.key == pygame.K_q):
                    running = False

            frame = get_frame()
            if frame is None:
                print("Could not retrieve frame, ending loop.")
                break

            processed_image, _ = predict_and_detect(frame, conf=0.5)
            
            # Convert from OpenCV BGR to Pygame RGB
            rgb_image = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
            # Flip horizontally for a mirror-like view
            rgb_image = cv2.flip(rgb_image, 1)
            # Rotate and convert to a Pygame surface for display
            surface = pygame.surfarray.make_surface(np.rot90(rgb_image))

            screen.blit(pygame.transform.scale(surface, (screen_W, screen_H)), (0, 0))
            pygame.display.update()

        except Exception as ex:
            print(f"An unexpected error occurred in the main loop: {ex}")
            break

    # Cleanup resources
    release_camera()
    release_fullscreen_display()

if __name__ == "__main__":
    main()
