from fastapi import FastAPI, UploadFile, Form, HTTPException, File, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from typing import Optional
import pytesseract
from pdf2image import convert_from_bytes
from fuzzywuzzy import fuzz
import io
from PIL import Image, ImageDraw
import logging
import urllib.parse
import base64
import traceback
from pydantic import BaseModel

# Constants
DPI = 140  # Changed from 300 to match react-pdf-highlighter-extended scaling
PADDING = 15
FUZZY_THRESHOLD = 75  # Lowered threshold for better matching
DEBUG_DIR = "debug_images"  # Add directory for debug images
PDF_WIDTH = 991.6666666666666  # Update PDF width to match actual dimensions
PDF_HEIGHT = 1403.3333333333333  # Update PDF height to match actual dimensions

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def normalize_text(text):
    return ' '.join(text.lower().split())

def find_paragraph_coordinates(ocr_data, paragraph):
    logger.info(f"Searching for paragraph: {paragraph[:50]}...")
    paragraph_lower = normalize_text(paragraph)
    
    # Improve OCR data processing
    blocks = []
    current_block = []
    last_top = -1
    
    # Debug OCR output
    logger.info("OCR Text blocks:")
    for i, text in enumerate(ocr_data['text']):
        if text.strip():
            logger.debug(f"Block {i}: {text}")
    
    # Group text into logical blocks based on positioning with looser tolerance
    for i, (text, left, top) in enumerate(zip(ocr_data['text'], ocr_data['left'], ocr_data['top'])):
        if not text.strip():
            continue
            
        if last_top >= 0 and abs(top - last_top) > ocr_data['height'][i] * 2.0:  # Increased tolerance
            if current_block:
                block_text = ' '.join(item[1] for item in current_block)
                logger.debug(f"Found block: {block_text}")
                blocks.append(current_block)
                current_block = []
        
        current_block.append((i, text, left, top))
        last_top = top
    
    if current_block:
        blocks.append(current_block)

    # Find best matching block using improved scoring
    best_score = 0
    best_block = None
    
    for block in blocks:
        block_text = normalize_text(' '.join(item[1] for item in block))
        # Use multiple matching techniques with adjusted weights
        ratio = fuzz.ratio(paragraph_lower, block_text)
        partial = fuzz.partial_ratio(paragraph_lower, block_text)
        token_sort = fuzz.token_sort_ratio(paragraph_lower, block_text)
        
        # Combined score with adjusted weightage
        score = (ratio * 0.3 + partial * 0.5 + token_sort * 0.2)  # Increased partial ratio weight
        
        logger.debug(f"Block: {block_text[:50]}... Score: {score}")
        
        if score > best_score and score > FUZZY_THRESHOLD:
            best_score = score
            best_block = block
            logger.info(f"Found better match with score {score}: {block_text[:100]}...")

    if not best_block:
        logger.warning(f"No match found above threshold {FUZZY_THRESHOLD}")
        return None

    # Calculate precise coordinates
    indices = [item[0] for item in best_block]
    left = min(ocr_data['left'][i] for i in indices)
    top = min(ocr_data['top'][i] for i in indices)
    right = max(ocr_data['left'][i] + ocr_data['width'][i] for i in indices)
    bottom = max(ocr_data['top'][i] + ocr_data['height'][i] for i in indices)

    return (left, top, right, bottom)

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow all origins
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"]
)

class UploadError(BaseModel):
    detail: str

@app.post("/upload-pdf")
async def upload_pdf(
    request: Request,
    file: Optional[UploadFile] = File(default=None),
    paragraphs: Optional[str] = Form(default=None)
):
    try:
        content_type = request.headers.get('content-type', '')
        logger.info(f"Request content type: {content_type}")
        
        # Handle different content types
        if 'multipart/form-data' in content_type:
            # Form data is already parsed by FastAPI
            if not file or not paragraphs:
                form = await request.form()
                file = file or form.get('file')
                paragraphs = paragraphs or form.get('paragraphs')
        elif 'application/x-www-form-urlencoded' in content_type:
            # Get form data that FastAPI already parsed
            form = await request.form()
            file = file or form.get('file')
            paragraphs = paragraphs or form.get('paragraphs')
            
            # Handle file data if it's provided as a string
            if isinstance(file, str):
                try:
                    file_data = file
                    file = UploadFile(
                        filename="uploaded.pdf",
                        file=io.BytesIO(file_data.encode()),
                        content_type="application/pdf"
                    )
                except Exception as e:
                    logger.error(f"Error converting file data: {str(e)}")
                    raise HTTPException(status_code=422, detail="Invalid file data")
        else:
            raise HTTPException(status_code=422, 
                              detail="Unsupported content type. Use multipart/form-data")

        # Validate inputs
        if not file:
            raise HTTPException(status_code=422, detail="No file provided")
        if not paragraphs:
            raise HTTPException(status_code=422, detail="No search text provided")

        # Rest of the validation and processing
        cleaned_paragraph = paragraphs.strip()
        if not cleaned_paragraph:
            raise HTTPException(status_code=422, detail="Search text cannot be empty")

        logger.info(f"Processing file: {file.filename}")
        logger.info(f"Search text: {cleaned_paragraph[:100]}...")

        pdf_bytes = await file.read()
        images = convert_from_bytes(pdf_bytes, dpi=DPI)

        for page_number, image in enumerate(images):
            ocr_data = pytesseract.image_to_data(
                image=image,
                output_type=pytesseract.Output.DICT,
                config='--psm 6 --oem 3 -c preserve_interword_spaces=1'
            )
            
            coords = find_paragraph_coordinates(ocr_data, cleaned_paragraph)
            if coords:
                logger.info(f"Text found on page {page_number + 1}")
                # Add padding to the coordinates
                left = max(coords[0] - PADDING, 0)
                top = max(coords[1] - PADDING, 0)
                right = min(coords[2] + PADDING, image.width)
                bottom = min(coords[3] + PADDING, image.height)

                # Scale coordinates to match exact PDF dimensions
                scale_x = PDF_WIDTH / image.width
                scale_y = PDF_HEIGHT / image.height

                # Apply scaling precisely
                x1 = round(left * scale_x, 2)
                y1 = round(top * scale_y, 2)
                x2 = round(right * scale_x, 2)
                y2 = round(bottom * scale_y, 2)

                # Create debug directory if not exists
                import os
                os.makedirs(DEBUG_DIR, exist_ok=True)

                # Create a copy of the image for highlighting
                marked_image = image.copy()
                draw = ImageDraw.Draw(marked_image)
                draw.rectangle([left, top, right, bottom], outline="red", width=2)
                
                # Generate file paths
                debug_image_path = os.path.join(DEBUG_DIR, f"page_{page_number + 1}.png")
                content_path = os.path.join(DEBUG_DIR, f"content_{page_number + 1}.png")

                # Save debug images
                marked_image.save(debug_image_path)
                content_image = image.crop((left, top, right, bottom))
                content_path = os.path.join(DEBUG_DIR, f"content_{page_number + 1}.png")
                
                # Compress image before saving
                max_size = (800, 800)  # Maximum dimensions
                content_image.thumbnail(max_size, Image.LANCZOS)
                
                # Save with optimized compression
                content_image.save(content_path, 
                                 "PNG", 
                                 optimize=True, 
                                 quality=85)  # Reduce quality for smaller size

                # Convert compressed image to base64
                buffered = io.BytesIO()
                content_image.save(buffered, 
                                 format="PNG",
                                 optimize=True,
                                 quality=85)
                img_base64 = base64.b64encode(buffered.getvalue()).decode()

                logger.info(f"Debug images saved: {debug_image_path}, {content_path}")
                
                # Now return the response with base64 image included
                return [{
                    "content": {
                        "image": content_path,
                        "base64": f"data:image/png;base64,{img_base64}"
                    },
                    "type": "area",
                    "position": {
                        "boundingRect": {
                            "x1": x1,
                            "y1": y1,
                            "x2": x2,
                            "width": PDF_WIDTH,  # Use full PDF width
                            "height": PDF_HEIGHT,  # Use full PDF height
                            "pageNumber": page_number + 1
                        },
                        "rects": [],
                    },
                    "comment": cleaned_paragraph,
                    "id": file.filename.split('.')[0],
                    "debug_image": debug_image_path
                }]

        logger.warning("Paragraph not found in any page")
        return {"error": "Paragraph not found"}
    except Exception as e:
        logger.error(f"Error processing request: {str(e)}\n{traceback.format_exc()}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)