import cv2
import numpy as np
import easyocr
import torch
import time
from PIL import Image
import io

# Global reader instance for Singleton/Warmup pattern
_reader = None

def get_ocr_reader():
    """Get or initialize the EasyOCR reader (Singleton)."""
    global _reader
    if _reader is None:
        start_time = time.time()
        print("Warmup: Loading EasyOCR reader into RAM...")
        # Check if CUDA (GPU) is available
        use_gpu = torch.cuda.is_available()
        print(f"CUDA Available: {use_gpu}")
        # Initialize reader for Vietnamese and English
        _reader = easyocr.Reader(['vi', 'en'], gpu=use_gpu)
        print(f"Warmup complete. Loaded in {time.time() - start_time:.2f} seconds.")
    return _reader

def preprocess_image(image_bytes: bytes):
    """
    Read image bytes, resize if too large, and attempt to align/warp the document.
    Returns the processed OpenCV image and a boolean indicating if it was warped.
    """
    # Load image from bytes
    nparr = np.frombuffer(image_bytes, np.uint8)
    img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
    if img is None:
        raise ValueError("Cannot decode image. Invalid image format.")
        
    h, w = img.shape[:2]
    print(f"Original Image Size: {w}x{h}")
    
    # 1. Resize if image is too large (Speed Optimization)
    max_width = 1200
    if w > max_width:
        scale = max_width / w
        new_w = max_width
        new_h = int(h * scale)
        img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
        h, w = new_h, new_w
        print(f"Resized Image to: {w}x{h}")
        
    # Keep a copy of the resized original image in case warping fails
    original_resized = img.copy()
    
    # 2. Document Alignment (Warp Perspective) - Fallback if fails
    try:
        # Convert to gray, blur, and find edges
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        blurred = cv2.GaussianBlur(gray, (5, 5), 0)
        edged = cv2.Canny(blurred, 50, 150)
        
        # Find contours
        contours, _ = cv2.findContours(edged.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours = sorted(contours, key=cv2.contourArea, reverse=True)[:5]
        
        card_contour = None
        for c in contours:
            peri = cv2.arcLength(c, True)
            approx = cv2.approxPolyDP(c, 0.02 * peri, True)
            
            # CCCD is rectangular (4 points)
            if len(approx) == 4:
                # Check if it has a reasonable size (at least 15% of the image area)
                area = cv2.contourArea(c)
                img_area = w * h
                if area > 0.15 * img_area:
                    card_contour = approx
                    break
                    
        if card_contour is not None:
            print("Detected card contours! Performing Warp Perspective...")
            # Reshape contour points
            pts = card_contour.reshape(4, 2)
            
            # Sort points: top-left, top-right, bottom-right, bottom-left
            rect = np.zeros((4, 2), dtype="float32")
            s = pts.sum(axis=1)
            rect[0] = pts[np.argmin(s)]
            rect[2] = pts[np.argmax(s)]
            diff = np.diff(pts, axis=1)
            rect[1] = pts[np.argmin(diff)]
            rect[3] = pts[np.argmax(diff)]
            
            (tl, tr, br, bl) = rect
            
            # Compute width and height of new image
            width_a = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
            width_b = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
            max_width = max(int(width_a), int(width_b))
            
            height_a = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
            height_b = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
            max_height = max(int(height_a), int(height_b))
            
            # Standard aspect ratio of CCCD is 85.60mm x 53.98mm (~1.586)
            # We map to the calculated size
            dst = np.array([
                [0, 0],
                [max_width - 1, 0],
                [max_width - 1, max_height - 1],
                [0, max_height - 1]
            ], dtype="float32")
            
            # Apply warp perspective
            M = cv2.getPerspectiveTransform(rect, dst)
            warped = cv2.warpPerspective(img, M, (max_width, max_height))
            
            # Ensure it is landscape format
            if max_height > max_width:
                print("Rotating warped image to landscape format...")
                warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE)
                
            # Resize warped card to a standard size for OCR consistency
            warped_h, warped_w = warped.shape[:2]
            if warped_w > 1000:
                scale = 1000 / warped_w
                warped = cv2.resize(warped, (1000, int(warped_h * scale)), interpolation=cv2.INTER_AREA)
                
            return warped, True
    except Exception as e:
        print(f"Perspective Warp failed: {e}. Falling back to original resized image.")
        
    # Fallback to the original resized image
    return original_resized, False

def run_ocr_on_image(image_bytes: bytes):
    """
    Takes image bytes, pre-processes, runs EasyOCR, and returns results.
    """
    # 1. Preprocess & Align Image
    processed_img, is_warped = preprocess_image(image_bytes)
    
    # 2. Get the loaded EasyOCR reader (Singleton)
    reader = get_ocr_reader()
    
    # 3. Perform OCR
    # Convert OpenCV image (BGR) to RGB for EasyOCR/Pillow
    rgb_img = cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB)
    
    start_time = time.time()
    print("Running EasyOCR prediction...")
    results = reader.readtext(rgb_img)
    print(f"EasyOCR completed in {time.time() - start_time:.2f} seconds.")
    
    # Return processed image coordinates + results
    # To display the image back to frontend (e.g. for debugging or bounding boxes)
    # we convert it to JPEG bytes
    _, encoded_img = cv2.imencode('.jpg', processed_img)
    processed_img_bytes = encoded_img.tobytes()
    
    return results, processed_img_bytes, is_warped
