import os
import uuid
import numpy as np
import multiprocessing
from functools import lru_cache
import time

# Load environment variables from .env file
from dotenv import load_dotenv
load_dotenv()

# Remove hardcoded credentials
# os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/home/karan/Downloads/gen-lang-client-0906226163-6940370bcba5.json"

from fastapi import FastAPI, UploadFile, Form, HTTPException, File, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from typing import Optional, List
import pytesseract
from pdf2image import convert_from_bytes
from fuzzywuzzy import fuzz
import io
from PIL import Image, ImageDraw
import logging
import urllib.parse
import base64
import traceback
from pydantic import BaseModel

# Lazy import EasyOCR to improve startup time
easyocr = None

# Constants
DPI = int(os.getenv("PDF_DPI", "140"))  # Changed from 300 to match react-pdf-highlighter-extended scaling
PADDING = int(os.getenv("PADDING", "15"))
FUZZY_THRESHOLD = int(os.getenv("FUZZY_THRESHOLD", "75"))  # Lowered threshold for better matching
DEBUG_DIR = os.getenv("DEBUG_DIR", "debug_images")  # Add directory for debug images
PDF_WIDTH = float(os.getenv("PDF_WIDTH", "991.6666666666666"))  # Update PDF width to match actual dimensions
PDF_HEIGHT = float(os.getenv("PDF_HEIGHT", "1403.3333333333333"))  # Update PDF height to match actual dimensions
SAVE_DEBUG_IMAGES = os.getenv("SAVE_DEBUG_IMAGES", "false").lower() == "true"
MAX_WORKERS = int(os.getenv("MAX_WORKERS", str(min(multiprocessing.cpu_count(), 2))))  # Limit CPU usage
USE_EASYOCR = os.getenv("USE_EASYOCR", "true").lower() == "true"  # Option to disable EasyOCR
EASYOCR_TIMEOUT = int(os.getenv("EASYOCR_TIMEOUT", "60"))  # Timeout for EasyOCR processing in seconds
TESSERACT_CONFIG = os.getenv("TESSERACT_CONFIG", "--psm 6 --oem 3 -c preserve_interword_spaces=1")
CACHE_SIZE = int(os.getenv("CACHE_SIZE", "10"))  # Number of items to cache

# Configure logging
if os.getenv("ENV", "development") == "production":
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
else:
    # Set to DEBUG level for more detailed logs during development
    logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')

# Create logger
logger = logging.getLogger(__name__)

# Add a file handler to save logs to a file
file_handler = logging.FileHandler('app.log')
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)

def normalize_text(text):
    # Normalize text by removing extra spaces, punctuation, and converting to lowercase
    import re
    text = re.sub(r'[^\w\s]', '', text)  # Remove punctuation
    return ' '.join(text.lower().split())

def find_paragraph_coordinates(ocr_data, paragraph):
    logger.info(f"Searching for paragraph: {paragraph[:50]}...")
    paragraph_lower = normalize_text(paragraph)

    # Improve OCR data processing
    blocks = []
    current_block = []
    last_top = -1

    # Debug OCR output
    logger.info("OCR Text blocks:")
    for i, text in enumerate(ocr_data['text']):
        if text.strip():
            logger.debug(f"Block {i}: {text}")

    # Group text into logical blocks based on positioning with looser tolerance
    for i, (text, left, top) in enumerate(zip(ocr_data['text'], ocr_data['left'], ocr_data['top'])):
        if not text.strip():
            continue

        if last_top >= 0 and abs(top - last_top) > ocr_data['height'][i] * 2.0:  # Increased tolerance
            if current_block:
                block_text = ' '.join(item[1] for item in current_block)
                logger.debug(f"Found block: {block_text}")
                blocks.append(current_block)
                current_block = []

        current_block.append((i, text, left, top))
        last_top = top

    if current_block:
        blocks.append(current_block)

    # Find best matching block using improved scoring
    best_score = 0
    best_block = None

    for block in blocks:
        block_text = normalize_text(' '.join(item[1] for item in block))
        # Use multiple matching techniques with adjusted weights
        ratio = fuzz.ratio(paragraph_lower, block_text)
        partial = fuzz.partial_ratio(paragraph_lower, block_text)
        token_sort = fuzz.token_sort_ratio(paragraph_lower, block_text)

        # Combined score with adjusted weightage
        score = (ratio * 0.2 + partial * 0.6 + token_sort * 0.2)  # Increased partial ratio weight

        logger.debug(f"Block: {block_text[:50]}... Score: {score}")

        if score > best_score and score > FUZZY_THRESHOLD:
            best_score = score
            best_block = block
            logger.info(f"Found better match with score {score}: {block_text[:100]}...")

    if not best_block:
        logger.warning(f"No match found above threshold {FUZZY_THRESHOLD}")
        return None

    # Calculate precise coordinates
    indices = [item[0] for item in best_block]
    left = min(ocr_data['left'][i] for i in indices)
    top = min(ocr_data['top'][i] for i in indices)
    right = max(ocr_data['left'][i] + ocr_data['width'][i] for i in indices)
    bottom = max(ocr_data['top'][i] + ocr_data['height'][i] for i in indices)

    return (left, top, right, bottom)

# Global EasyOCR reader instance
_easyocr_reader = None

def get_easyocr_reader():
    """
    Get or initialize the EasyOCR reader (singleton pattern)
    """
    global _easyocr_reader, easyocr

    if _easyocr_reader is None:
        # Lazy import EasyOCR only when needed
        if easyocr is None:
            logger.info("Importing EasyOCR module")
            import easyocr as easyocr_module
            easyocr = easyocr_module

        logger.info("Initializing EasyOCR reader (singleton)")
        _easyocr_reader = easyocr.Reader(['en'], recog_network='latin_g2', gpu=False)
        logger.info("EasyOCR reader initialized successfully")

    return _easyocr_reader

@lru_cache(maxsize=CACHE_SIZE)
def process_image_with_easyocr(image_hash):
    """
    Process an image with EasyOCR and cache the results

    Args:
        image_hash: A hash of the image data to use as cache key

    Returns:
        The EasyOCR results or None if processing failed
    """
    # This is just a placeholder - the actual implementation will use the image from a global cache
    # We use this approach because lru_cache requires hashable arguments
    global _current_image_for_easyocr

    if not _current_image_for_easyocr:
        return None

    try:
        reader = get_easyocr_reader()
        return reader.readtext(_current_image_for_easyocr)
    except Exception as e:
        logger.error(f"Error in EasyOCR processing: {str(e)}")
        return None

# Global variable to hold the current image being processed
_current_image_for_easyocr = None

def find_paragraph_coordinates_easyocr(image, paragraph):
    """
    Find paragraph coordinates using EasyOCR with optimized processing
    """
    global _current_image_for_easyocr

    if not USE_EASYOCR:
        logger.info("EasyOCR is disabled by configuration")
        return None

    logger.info("Starting EasyOCR processing")
    try:
        # Convert image to numpy array
        logger.info("Converting image to numpy array")
        image_np = np.array(image)  # Convert PIL image to numpy array
        logger.info(f"Image converted to numpy array with shape: {image_np.shape}")

        # Create a hash of the image data for caching
        image_hash = hash(image_np.tobytes())

        # Store the current image in the global variable for the cached function to use
        _current_image_for_easyocr = image_np

        # Process with timeout to prevent hanging
        start_time = time.time()
        logger.info("Starting EasyOCR text detection and recognition with timeout")

        # Use the cached function
        result = process_image_with_easyocr(image_hash)

        # Clear the global variable
        _current_image_for_easyocr = None

        processing_time = time.time() - start_time
        logger.info(f"EasyOCR processing completed in {processing_time:.2f} seconds")

        if result is None or len(result) == 0:
            logger.warning("EasyOCR did not detect any text in the image")
            return None

        # Log some sample detections
        for i, detection in enumerate(result[:3]):
            if i < 3:  # Log only first 3 detections to avoid excessive logging
                logger.info(f"Sample detection {i+1}: {detection[1][:50]}...")

        paragraph_lower = normalize_text(paragraph)
        best_score = 0
        best_block = None

        logger.info("Starting text matching with fuzzy search")
        for detection in result:
            block_text = normalize_text(detection[1])
            ratio = fuzz.ratio(paragraph_lower, block_text)
            partial = fuzz.partial_ratio(paragraph_lower, block_text)
            token_sort = fuzz.token_sort_ratio(paragraph_lower, block_text)

            score = (ratio * 0.2 + partial * 0.6 + token_sort * 0.2)  # Adjusted weights

            # Log detailed matching scores for debugging
            logger.debug(f"Text: '{block_text[:30]}...' - Ratio: {ratio}, Partial: {partial}, Token Sort: {token_sort}, Combined: {score}")

            if score > best_score and score > FUZZY_THRESHOLD:
                best_score = score
                best_block = detection[0]
                logger.info(f"Found better match with score {score}: {block_text[:100]}...")

        if not best_block:
            logger.warning(f"No match found above threshold {FUZZY_THRESHOLD} using EasyOCR")
            return None

        logger.info("Match found, calculating coordinates")
        left = min(point[0] for point in best_block)
        top = min(point[1] for point in best_block)
        right = max(point[0] for point in best_block)
        bottom = max(point[1] for point in best_block)

        logger.info(f"Calculated coordinates: left={left}, top={top}, right={right}, bottom={bottom}")
        return (left, top, right, bottom)
    except Exception as e:
        logger.error(f"Error in EasyOCR processing: {str(e)}\n{traceback.format_exc()}")
        return None

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow all origins
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"]
)

class UploadError(BaseModel):
    detail: str

class PDFData(BaseModel):
    content: str

class StandardResponse(BaseModel):
    status: int
    data: Optional[dict] = None
    message: Optional[str] = None

class PDFResponse(BaseModel):
    content: str

@app.post("/upload-pdf")
async def upload_pdf(
    request: Request,
    file: Optional[UploadFile] = File(default=None),
    paragraphs: Optional[str] = Form(default=None)
):
    try:
        content_type = request.headers.get('content-type', '')
        logger.info(f"Request content type: {content_type}")

        # Handle different content types
        if 'multipart/form-data' in content_type:
            # Form data is already parsed by FastAPI
            if not file or not paragraphs:
                form = await request.form()
                file = file or form.get('file')
                paragraphs = paragraphs or form.get('paragraphs')
        elif 'application/x-www-form-urlencoded' in content_type:
            # Get form data that FastAPI already parsed
            form = await request.form()
            file = file or form.get('file')
            paragraphs = paragraphs or form.get('paragraphs')

            # Handle file data if it's provided as a string
            if isinstance(file, str):
                try:
                    file_data = file
                    file = UploadFile(
                        filename="uploaded.pdf",
                        file=io.BytesIO(file_data.encode()),
                        content_type="application/pdf"
                    )
                except Exception as e:
                    logger.error(f"Error converting file data: {str(e)}")
                    raise HTTPException(status_code=422, detail="Invalid file data")
        else:
            raise HTTPException(status_code=422,
                              detail="Unsupported content type. Use multipart/form-data")

        # Validate inputs
        if not file:
            raise HTTPException(status_code=422, detail="No file provided")
        if not paragraphs:
            raise HTTPException(status_code=422, detail="No search text provided")

        # Rest of the validation and processing
        cleaned_paragraph = paragraphs.strip()
        if not cleaned_paragraph:
            raise HTTPException(status_code=422, detail="Search text cannot be empty")

        logger.info(f"Processing file: {file.filename}")
        logger.info(f"Search text: {cleaned_paragraph[:100]}...")

        pdf_bytes = await file.read()
        images = convert_from_bytes(pdf_bytes, dpi=DPI)

        # First, try to find the text using Tesseract OCR
        for page_number, image in enumerate(images):
            ocr_data = pytesseract.image_to_data(
                image=image,
                output_type=pytesseract.Output.DICT,
                config='--psm 6 --oem 3 -c preserve_interword_spaces=1'
            )

            coords = find_paragraph_coordinates(ocr_data, cleaned_paragraph)
            if coords:
                logger.info(f"Text found on page {page_number + 1} using Tesseract")
                break
        else:
            # If text is not found using Tesseract, try EasyOCR
            logger.info("Text not found using Tesseract, trying EasyOCR as fallback")
            for page_number, image in enumerate(images):
                coords = find_paragraph_coordinates_easyocr(image, cleaned_paragraph)
                if coords:
                    logger.info(f"Text found on page {page_number + 1} using EasyOCR")
                    break
            else:
                logger.warning("Paragraph not found in any page using both Tesseract and EasyOCR")
                return {
                    "status": 500,
                    "message": "Paragraph not found",
                    "data": {}
                }

        # Add padding to the coordinates
        left = max(coords[0] - PADDING, 0)
        top = max(coords[1] - PADDING, 0)
        right = min(coords[2] + PADDING, image.width)
        bottom = min(coords[3] + PADDING, image.height)

        # Scale coordinates to match exact PDF dimensions
        scale_x = PDF_WIDTH / image.width
        scale_y = PDF_HEIGHT / image.height

        # Apply scaling precisely
        x1 = round(left * scale_x, 2)
        y1 = round(top * scale_y, 2)
        x2 = round(right * scale_x, 2)
        y2 = round(bottom * scale_y, 2)

        if SAVE_DEBUG_IMAGES:
            # Create debug directory if not exists
            os.makedirs(DEBUG_DIR, exist_ok=True)

            # Create a copy of the image for highlighting
            marked_image = image.copy()
            draw = ImageDraw.Draw(marked_image)
            draw.rectangle([left, top, right, bottom], outline="red", width=2)

            # Generate file paths
            debug_image_path = os.path.join(DEBUG_DIR, f"page_{page_number + 1}.png")
            content_path = os.path.join(DEBUG_DIR, f"content_{page_number + 1}.png")

            # Save debug images
            marked_image.save(debug_image_path)
            content_image = image.crop((left, top, right, bottom))
            content_path = os.path.join(DEBUG_DIR, f"content_{page_number + 1}.png")

            # Compress image before saving
            max_size = (800, 800)  # Maximum dimensions
            content_image.thumbnail(max_size, Image.LANCZOS)

            # Save with optimized compression
            content_image.save(content_path,
                             "PNG",
                             optimize=True,
                             quality=85)  # Reduce quality for smaller size

            # Convert compressed image to base64
            buffered = io.BytesIO()
            content_image.save(buffered,
                             format="PNG",
                             optimize=True,
                             quality=85)
            img_base64 = base64.b64encode(buffered.getvalue()).decode()

            logger.info(f"Debug images saved: {debug_image_path}, {content_path}")
        else:
            img_base64 = None

        # Generate a unique 15-character ID
        unique_id = str(uuid.uuid4()).replace('-', '')[:15]

        # Now return the response with base64 image included
        return {
            "status": 200,
            "message": "PDF file processed successfully",
            "data": {
                "comments": [{
                    "content": {
                        "image": content_path if SAVE_DEBUG_IMAGES else None,
                        "base64": f"data:image/png;base64,{img_base64}" if img_base64 else None,
                        "text": cleaned_paragraph
                    },
                    "type": "area",
                    "position": {
                        "boundingRect": {
                            "x1": x1,
                            "y1": y1,
                            "x2": x2,
                            "y2": y2,
                            "width": PDF_WIDTH,  # Use full PDF width
                            "height": PDF_HEIGHT,  # Use full PDF height
                            "pageNumber": page_number + 1
                        },
                        "rects": [],
                    },
                    "comment": "AI Title",
                    "_id": unique_id,
                    "debug_image": debug_image_path if SAVE_DEBUG_IMAGES else None
                }]
            }
        }
    except Exception as e:
        logger.error(f"Error processing request: {str(e)}\n{traceback.format_exc()}")
        return {
            "status": 500,
            "message": f"Error processing request: {str(e)}",
            "data": {}
        }

@lru_cache(maxsize=CACHE_SIZE)
def extract_text_from_pdf(pdf_bytes_hash, pdf_bytes):
    """
    Extract text from a PDF file using OCR with optimized processing.

    Args:
        pdf_bytes_hash: Hash of the PDF bytes for caching
        pdf_bytes: The PDF file as bytes

    Returns:
        A string containing the extracted text from all pages
    """
    try:
        logger.info("Starting PDF text extraction process")
        logger.info(f"PDF size: {len(pdf_bytes)} bytes")

        # Convert PDF to images with optimized settings
        logger.info("Converting PDF to images using pdf2image")
        start_time = time.time()
        images = convert_from_bytes(pdf_bytes, dpi=DPI, thread_count=MAX_WORKERS)
        conversion_time = time.time() - start_time
        logger.info(f"PDF conversion complete. Found {len(images)} pages. Took {conversion_time:.2f} seconds")

        # Extract text from each page using Tesseract OCR
        all_text = []
        for page_number, image in enumerate(images):
            logger.info(f"Processing page {page_number + 1} of {len(images)}")

            # Resize large images to reduce memory usage and processing time
            if image.width > 2000 or image.height > 2000:
                scale_factor = min(2000 / image.width, 2000 / image.height)
                new_width = int(image.width * scale_factor)
                new_height = int(image.height * scale_factor)
                logger.info(f"Resizing large image from {image.width}x{image.height} to {new_width}x{new_height}")
                image = image.resize((new_width, new_height), Image.LANCZOS)

            logger.info(f"Image dimensions: {image.width}x{image.height}")

            # Use Tesseract OCR to extract text
            logger.info("Starting Tesseract OCR processing")
            tesseract_start = time.time()
            page_text = pytesseract.image_to_string(
                image=image,
                config=TESSERACT_CONFIG
            )
            tesseract_time = time.time() - tesseract_start
            logger.info(f"Tesseract OCR completed in {tesseract_time:.2f} seconds")

            if page_text.strip():
                logger.info(f"Tesseract found text on page {page_number + 1}. Text length: {len(page_text)} characters")
                # Log a sample of the text for debugging
                text_sample = page_text[:100] + "..." if len(page_text) > 100 else page_text
                logger.info(f"Text sample: {text_sample}")
                all_text.append(f"Page {page_number + 1}:\n{page_text}")
            elif USE_EASYOCR:
                # If Tesseract fails and EasyOCR is enabled, try EasyOCR as fallback
                logger.info(f"Tesseract returned empty text for page {page_number + 1}, trying EasyOCR as fallback")

                # Use our optimized EasyOCR implementation
                global _current_image_for_easyocr

                # Convert image to numpy array
                image_np = np.array(image)

                # Create a hash of the image data for caching
                image_hash = hash(image_np.tobytes())

                # Store the current image in the global variable for the cached function to use
                _current_image_for_easyocr = image_np

                # Process with timeout
                easyocr_start = time.time()
                result = process_image_with_easyocr(image_hash)

                # Clear the global variable
                _current_image_for_easyocr = None

                easyocr_time = time.time() - easyocr_start

                if result:
                    logger.info(f"EasyOCR processing completed in {easyocr_time:.2f} seconds")
                    logger.info(f"EasyOCR found {len(result)} text regions")

                    # Log some sample detections
                    for i, detection in enumerate(result[:3]):
                        if i < 3:  # Log only first 3 detections
                            logger.info(f"Sample detection {i+1}: {detection[1][:50]}...")

                    page_text = "\n".join([text[1] for text in result])
                    logger.info(f"Combined EasyOCR text length: {len(page_text)} characters")
                    all_text.append(f"Page {page_number + 1}:\n{page_text}")
                else:
                    logger.warning(f"No text detected on page {page_number + 1} using either Tesseract or EasyOCR")
                    all_text.append(f"Page {page_number + 1}:\n[No text detected]")
            else:
                logger.warning(f"No text detected on page {page_number + 1} using Tesseract and EasyOCR is disabled")
                all_text.append(f"Page {page_number + 1}:\n[No text detected]")

        total_text = "\n\n".join(all_text)
        logger.info(f"Text extraction complete. Total text length: {len(total_text)} characters")
        return total_text
    except Exception as e:
        logger.error(f"Error extracting text from PDF: {str(e)}\n{traceback.format_exc()}")
        raise ValueError(f"Failed to extract text from PDF: {str(e)}")

@app.post("/readpdf", response_model=StandardResponse, status_code=200)
async def read_pdf(file: UploadFile = File(...)):
    """
    Endpoint to upload a PDF file and extract its text content.

    Args:
        file: The PDF file to upload

    Returns:
        StandardResponse with status code and PDF content in data field

    Raises:
        HTTPException: If the PDF cannot be read or processed
    """
    try:
        # Validate file type
        if not file.filename.lower().endswith('.pdf'):
            return StandardResponse(
                status=400,
                message="File must be a PDF"
            )

        logger.info(f"Processing PDF file: {file.filename}")

        # Read the file
        pdf_bytes = await file.read()

        # Create a hash of the PDF bytes for caching
        pdf_bytes_hash = hash(pdf_bytes)

        # Extract text from the PDF using the cached function
        content = extract_text_from_pdf(pdf_bytes_hash, pdf_bytes)

        return StandardResponse(
            status=200,
            data={"content": content}
        )
    except ValueError as e:
        logger.error(f"Error processing PDF: {str(e)}")
        return StandardResponse(
            status=400,
            message=str(e)
        )
    except Exception as e:
        logger.error(f"Unexpected error: {str(e)}\n{traceback.format_exc()}")
        return StandardResponse(
            status=500,
            message=f"An unexpected error occurred: {str(e)}"
        )

if __name__ == "__main__":
    import uvicorn

    # Get port from environment variable or use default
    port = int(os.getenv("PORT", "8000"))

    # Configure Uvicorn with optimized settings for EC2 medium instance
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=port,
        workers=MAX_WORKERS,  # Limit to available CPU cores
        log_level=os.getenv("LOG_LEVEL", "warning"),
        limit_concurrency=int(os.getenv("LIMIT_CONCURRENCY", "20")),  # Limit concurrent requests
        timeout_keep_alive=int(os.getenv("TIMEOUT_KEEP_ALIVE", "5")),  # Reduce idle connections
    )