diff --git a/src/main.py b/src/main.py index addf962..658e035 100644 --- a/src/main.py +++ b/src/main.py @@ -3,13 +3,13 @@ import asyncio import requests from pathlib import Path -from fastapi import FastAPI, Form, HTTPException, Query, Request -from fastapi.responses import HTMLResponse, JSONResponse, FileResponse -from fastapi.templating import Jinja2Templates -from fastapi.staticfiles import StaticFiles -import logging -from dotenv import load_dotenv - +from fastapi import FastAPI, Form, HTTPException, Query, Request +from fastapi.responses import HTMLResponse, JSONResponse, FileResponse +from fastapi.templating import Jinja2Templates +from fastapi.staticfiles import StaticFiles +import logging +from dotenv import load_dotenv + from src.auto_annotate_images import auto_annotate_images from src.download_images import download_images from src.search_images import search_images @@ -20,72 +20,76 @@ from src.create_data_yaml import create_data_yaml import shutil from src.utils.annotation_converter import convert_to_yolo_format, ensure_directory - -load_dotenv() - -app = FastAPI() - -templates = Jinja2Templates(directory=Path(__file__).parent / "templates") - -log_file_path = os.path.join(os.getcwd(), 'app_logs.txt') -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler(log_file_path), - logging.StreamHandler() - ] -) -logger = logging.getLogger(__name__) - -images_path = "dataset/train/images" -labels_path = "dataset/train/labels" -download_path = "dataset/train/images" - -os.makedirs(download_path, exist_ok=True) -app.mount("/images", StaticFiles(directory=download_path), name="images") - -training_status = { - "step": 0, - "status": "Idle", - "detail": "", - "completed": False, - "success": False, - "model_path": "", - "error": "", - "query": "" -} - - -def clear_directory(path): - if os.path.exists(path): - shutil.rmtree(path) - os.makedirs(path) - - -def reset_training_status(query): - global training_status - training_status = { - "step": 0, - "status": "Starting", - "detail": "Initializing...", - "completed": False, - "success": False, - "model_path": "", - "error": "", - "query": query - } - - -@app.get("/", response_class=HTMLResponse) -async def index(request: Request): - clear_directory(images_path) - clear_directory(labels_path) - return templates.TemplateResponse("search.html", {"request": request}) - - + +load_dotenv() + +app = FastAPI() + +templates = Jinja2Templates(directory=Path(__file__).parent / "templates") + +log_file_path = os.path.join(os.getcwd(), 'app_logs.txt') +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(log_file_path), + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +images_path = "dataset/train/images" +labels_path = "dataset/train/labels" +download_path = "dataset/train/images" + +os.makedirs(download_path, exist_ok=True) +app.mount("/images", StaticFiles(directory=download_path), name="images") + +training_status = { + "step": 0, + "status": "Idle", + "detail": "", + "completed": False, + "success": False, + "model_path": "", + "error": "", + "query": "" +} + + +def clear_directory(path): + if os.path.exists(path): + shutil.rmtree(path) + os.makedirs(path) + + +def reset_training_status(query): + global training_status + training_status = { + "step": 0, + "status": "Starting", + "detail": "Initializing...", + "completed": False, + "success": False, + "model_path": "", + "error": "", + "query": query + } + + +@app.get("/", response_class=HTMLResponse) +async def index(request: Request): + clear_directory(images_path) + clear_directory(labels_path) + return templates.TemplateResponse("search.html", {"request": request}) + + @app.post("/search", response_class=HTMLResponse) -async def search(request: Request, query: str = Form(...), page: int = Form(default=0)): +async def search( + request: Request, + query: str = Form(...), + page: int = Form( + default=0)): try: api_key = os.getenv("GOOGLE_API_KEY") search_engine_id = os.getenv("SEARCH_ENGINE_ID") @@ -98,7 +102,11 @@ async def search(request: Request, query: str = Form(...), page: int = Form(defa # Fetch more images to allow pagination/different selections # Fetch 30 images so user can get different results on "Search Again" - images = search_images(query, api_key, search_engine_id, num_results=30) + images = search_images( + query, + api_key, + search_engine_id, + num_results=30) if not images: return templates.TemplateResponse("search.html", { @@ -109,37 +117,42 @@ async def search(request: Request, query: str = Form(...), page: int = Form(defa # Calculate which images to show based on page/retry count # This allows fetching different subsets on each "Search Again" start_idx = page * 9 - end_idx = start_idx + 15 # Get 15 images to select from (more than 9 needed) - + # Get 15 images to select from (more than 9 needed) + end_idx = start_idx + 15 + images_subset = images[start_idx:end_idx] - + if len(images_subset) < 9: - logger.warning(f"Not enough images for page {page}, got {len(images_subset)}") + logger.warning( + f"Not enough images for page {page}, got {len(images_subset)}") images_subset = images[start_idx:] - + if not images_subset: return templates.TemplateResponse("search.html", { "request": request, "error": "No more images available. Try a different search term." }) - # Download images temporarily to extract features for balanced selection + # Download images temporarily to extract features for balanced + # selection temp_download_path = "dataset/temp_selection" os.makedirs(temp_download_path, exist_ok=True) - + try: image_paths = download_images(images_subset, temp_download_path) - + # Select balanced images (70% relevance, 30% dissimilarity) selected_images = select_balanced_images( - images_subset, - image_paths, - num_images=min(9, len(images_subset)), + images_subset, + image_paths, + num_images=min(9, len(images_subset)), relevance_weight=0.7 ) - logger.info(f"Selected {len(selected_images)} balanced images for query: {query} (page {page})") + logger.info( + f"Selected {len(selected_images)} balanced images for query: {query} (page {page})") except Exception as e: - logger.warning(f"Balanced selection failed, falling back to first 9 images: {e}") + logger.warning( + f"Balanced selection failed, falling back to first 9 images: {e}") selected_images = images_subset[:9] finally: # Clean up temporary downloads @@ -158,92 +171,92 @@ async def search(request: Request, query: str = Form(...), page: int = Form(defa "request": request, "error": f"Search failed: {str(e)}" }) - - -@app.post("/select", response_class=HTMLResponse) -async def select( - request: Request, - selected_images: list[str] = Form(...), - original_query: str = Form(...)): - try: - if not selected_images or len(selected_images) < 3: - raise HTTPException(status_code=400, - detail="Please select at least 3 images.") - - clear_directory(images_path) - clear_directory(labels_path) - + + +@app.post("/select", response_class=HTMLResponse) +async def select( + request: Request, + selected_images: list[str] = Form(...), + original_query: str = Form(...)): + try: + if not selected_images or len(selected_images) < 3: + raise HTTPException(status_code=400, + detail="Please select at least 3 images.") + + clear_directory(images_path) + clear_directory(labels_path) + local_image_paths = download_images(selected_images, download_path) images_data = [ (f"/images/{os.path.basename(path)}", os.path.basename(path)) for path in local_image_paths ] - - return templates.TemplateResponse("annotate.html", { - "request": request, - "query": original_query, - "images": images_data - }) - except Exception as e: - logger.error(f"Error during image selection: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@app.post("/start-training", response_class=HTMLResponse) -async def start_training( - request: Request, - image_urls: list[str] = Form(...), - annotations: list[str] = Form(...), - original_query: str = Form(...)): - try: - reset_training_status(original_query) - - asyncio.create_task( - run_training( - image_urls, - annotations, - original_query)) - - return templates.TemplateResponse("training.html", { - "request": request, - "query": original_query - }) - except Exception as e: - logger.error(f"Error starting training: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -async def run_training(image_urls, annotations, original_query): - global training_status - try: - training_status["status"] = "Downloading" - training_status["detail"] = "Saving your annotated images" - training_status["step"] = 0 - - ensure_directory( - labels_dir := os.path.join( - "dataset", "train", "labels")) - - for image_url, annotation_json in zip(image_urls, annotations): - try: - image_name = os.path.basename(image_url) - image_path = os.path.join(images_path, image_name) - - with Image.open(image_path) as img: - img_width, img_height = img.size - - yolo_annotation = convert_to_yolo_format( - annotation_json, img_width, img_height) - - label_filename = os.path.splitext(image_name)[0] + ".txt" - label_path = os.path.join(labels_dir, label_filename) - - with open(label_path, 'w') as f: - f.write(yolo_annotation) - except Exception as e: - logger.error(f"Error processing annotation: {e}") - + + return templates.TemplateResponse("annotate.html", { + "request": request, + "query": original_query, + "images": images_data + }) + except Exception as e: + logger.error(f"Error during image selection: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/start-training", response_class=HTMLResponse) +async def start_training( + request: Request, + image_urls: list[str] = Form(...), + annotations: list[str] = Form(...), + original_query: str = Form(...)): + try: + reset_training_status(original_query) + + asyncio.create_task( + run_training( + image_urls, + annotations, + original_query)) + + return templates.TemplateResponse("training.html", { + "request": request, + "query": original_query + }) + except Exception as e: + logger.error(f"Error starting training: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +async def run_training(image_urls, annotations, original_query): + global training_status + try: + training_status["status"] = "Downloading" + training_status["detail"] = "Saving your annotated images" + training_status["step"] = 0 + + ensure_directory( + labels_dir := os.path.join( + "dataset", "train", "labels")) + + for image_url, annotation_json in zip(image_urls, annotations): + try: + image_name = os.path.basename(image_url) + image_path = os.path.join(images_path, image_name) + + with Image.open(image_path) as img: + img_width, img_height = img.size + + yolo_annotation = convert_to_yolo_format( + annotation_json, img_width, img_height) + + label_filename = os.path.splitext(image_name)[0] + ".txt" + label_path = os.path.join(labels_dir, label_filename) + + with open(label_path, 'w') as f: + f.write(yolo_annotation) + except Exception as e: + logger.error(f"Error processing annotation: {e}") + training_status["step"] = 1 training_status["status"] = "Scraping" training_status["detail"] = "Finding similar images..." @@ -268,7 +281,7 @@ async def run_training(image_urls, annotations, original_query): training_status["step"] = 2 training_status["status"] = "Downloading" - + if similar_images: training_status["detail"] = f"Downloading {len(similar_images)} similar images" try: @@ -278,44 +291,44 @@ async def run_training(image_urls, annotations, original_query): logger.info("Continuing training with only annotated images") else: training_status["detail"] = "No additional images to download, using annotated images only" - - training_status["step"] = 3 - training_status["status"] = "Annotating" - training_status["detail"] = "Auto-annotating scraped images" - - auto_annotate_images(images_path, labels_dir) - - training_status["step"] = 4 - training_status["status"] = "Training" - training_status["detail"] = "Training YOLOv8 model (this may take a few minutes)" - - data_yaml_path = create_data_yaml(labels_dir, original_query) - - model_path = train_model(data_yaml_path, 'yolov8') - - if model_path and os.path.exists(model_path): - training_status["completed"] = True - training_status["success"] = True - training_status["model_path"] = model_path - training_status["status"] = "Complete" - training_status["detail"] = "Model trained successfully!" - else: - raise Exception("Model training failed - no output model found") - - except Exception as e: - logger.error(f"Training error: {e}") - training_status["completed"] = True - training_status["success"] = False - training_status["error"] = str(e) - training_status["status"] = "Error" - training_status["detail"] = str(e) - - -@app.get("/training-status") -async def get_training_status(): - return JSONResponse(training_status) - - + + training_status["step"] = 3 + training_status["status"] = "Annotating" + training_status["detail"] = "Auto-annotating scraped images" + + auto_annotate_images(images_path, labels_dir) + + training_status["step"] = 4 + training_status["status"] = "Training" + training_status["detail"] = "Training YOLOv8 model (this may take a few minutes)" + + data_yaml_path = create_data_yaml(labels_dir, original_query) + + model_path = train_model(data_yaml_path, 'yolov8') + + if model_path and os.path.exists(model_path): + training_status["completed"] = True + training_status["success"] = True + training_status["model_path"] = model_path + training_status["status"] = "Complete" + training_status["detail"] = "Model trained successfully!" + else: + raise Exception("Model training failed - no output model found") + + except Exception as e: + logger.error(f"Training error: {e}") + training_status["completed"] = True + training_status["success"] = False + training_status["error"] = str(e) + training_status["status"] = "Error" + training_status["detail"] = str(e) + + +@app.get("/training-status") +async def get_training_status(): + return JSONResponse(training_status) + + @app.get("/results", response_class=HTMLResponse) async def results(request: Request, model: str = Query(...)): query = training_status.get("query", "Unknown") @@ -345,24 +358,25 @@ async def download_model(model: str = Query(...)): try: # Normalize the path and validate it exists model_path = os.path.normpath(model) - + # Security check: ensure path is within project directory project_root = os.path.abspath(".") abs_model_path = os.path.abspath(model_path) - + if not abs_model_path.startswith(project_root): - logger.error(f"Attempted to download file outside project directory: {abs_model_path}") + logger.error( + f"Attempted to download file outside project directory: {abs_model_path}") raise HTTPException(status_code=403, detail="Access denied") - + if not os.path.exists(abs_model_path): logger.error(f"Model file not found: {abs_model_path}") raise HTTPException(status_code=404, detail="Model file not found") - + # Get the filename for the download filename = os.path.basename(abs_model_path) - + logger.info(f"Downloading model: {filename}") - + return FileResponse( path=abs_model_path, media_type="application/octet-stream", @@ -373,11 +387,11 @@ async def download_model(model: str = Query(...)): except Exception as e: logger.error(f"Error downloading model: {e}") raise HTTPException(status_code=500, detail="Failed to download model") - - -@app.get("/error", response_class=HTMLResponse) -async def error_page(request: Request, message: str = Query(...)): - return templates.TemplateResponse("error.html", { + + +@app.get("/error", response_class=HTMLResponse) +async def error_page(request: Request, message: str = Query(...)): + return templates.TemplateResponse("error.html", { "request": request, "error": message }) @@ -388,22 +402,21 @@ async def debug_api(request: Request): """Debug endpoint to test API configuration""" api_key = os.getenv("GOOGLE_API_KEY") search_engine_id = os.getenv("SEARCH_ENGINE_ID") - + debug_info = { "api_key_set": bool(api_key), "api_key_preview": f"{api_key[:10]}..." if api_key else "NOT SET", "search_engine_id_set": bool(search_engine_id), "search_engine_id": search_engine_id, } - + # Try to make a simple API call test_result = None if api_key and search_engine_id: try: test_url = ( f"https://www.googleapis.com/customsearch/v1?" - f"q=test&searchType=image&key={api_key}&cx={search_engine_id}&num=1" - ) + f"q=test&searchType=image&key={api_key}&cx={search_engine_id}&num=1") response = requests.get(test_url, timeout=5) test_result = { "status_code": response.status_code, @@ -416,7 +429,7 @@ async def debug_api(request: Request): "success": False, "error": str(e) } - + return templates.TemplateResponse("debug.html", { "request": request, "debug_info": debug_info, diff --git a/src/scrape_similar.py b/src/scrape_similar.py index 1596579..7110599 100644 --- a/src/scrape_similar.py +++ b/src/scrape_similar.py @@ -21,7 +21,8 @@ def scrape_similar_images( # Query variations to try - gradually simpler if advanced searches fail query_variations = [ f"{original_query} filetype:jpg OR filetype:png", # Specific file types - f"{original_query} clear photo", # Descriptive quality + f"{original_query} clear photo", + # Descriptive quality f"{original_query} high resolution", f"{original_query} isolated", f"{original_query} product photo", @@ -39,20 +40,20 @@ def scrape_similar_images( try: logger.debug(f"Attempting search with query: {query}") - + images = search_images( query, api_key, search_engine_id, num_results=num_results_per_image ) - + if images: logger.info(f"Got {len(images)} images from query: {query}") similar_images.extend(images) else: logger.debug(f"No images from query: {query}") - + except Exception as e: logger.warning(f"Failed to search for '{query}': {str(e)[:100]}") # Continue to next query variation @@ -60,8 +61,9 @@ def scrape_similar_images( # Remove duplicates while preserving order similar_images = list(dict.fromkeys(similar_images)) - + final_count = min(len(similar_images), total_images_to_download) - logger.info(f"Scrape similar images: collected {final_count}/{total_images_to_download} images after removing duplicates") + logger.info( + f"Scrape similar images: collected {final_count}/{total_images_to_download} images after removing duplicates") return similar_images[:total_images_to_download] diff --git a/src/search_images.py b/src/search_images.py index 6a9a19e..f992140 100644 --- a/src/search_images.py +++ b/src/search_images.py @@ -14,39 +14,45 @@ def search_images(query, api_key, search_engine_id, num_results=10): """ images = [] google_error = None - + # Try Google Custom Search first try: - images = _search_google_custom_search(query, api_key, search_engine_id, num_results) + images = _search_google_custom_search( + query, api_key, search_engine_id, num_results) if images: - logger.info(f"Successfully retrieved {len(images)} images from Google Custom Search") + logger.info( + f"Successfully retrieved {len(images)} images from Google Custom Search") return images except Exception as e: google_error = str(e) logger.warning(f"Google Custom Search failed: {google_error}") - + # Fallback to Bing Images (free, no API key needed) try: logger.info("Falling back to Bing Images for search") images = _search_bing_images(query, num_results) if images: - logger.info(f"Successfully retrieved {len(images)} images from Bing Images") + logger.info( + f"Successfully retrieved {len(images)} images from Bing Images") return images except Exception as e: logger.error(f"Bing Images fallback also failed: {str(e)}") - + # If both fail, raise an error with helpful message if google_error: raise Exception( f"Unable to search for images. Google API error: {google_error}\n\n" f"The app attempted to use a fallback image source (Bing Images) but it also failed. " - f"Please check your internet connection and try again." - ) + f"Please check your internet connection and try again.") else: raise Exception("No image search service is available") -def _search_google_custom_search(query, api_key, search_engine_id, num_results=10): +def _search_google_custom_search( + query, + api_key, + search_engine_id, + num_results=10): """Search using Google Custom Search API""" images = [] results_per_page = 10 @@ -60,9 +66,9 @@ def _search_google_custom_search(query, api_key, search_engine_id, num_results=1 try: response = requests.get(search_url, timeout=10) - + logger.debug(f"Google API Response Status: {response.status_code}") - + if response.status_code != 200: error_message = _parse_google_api_error(response) raise Exception(error_message) @@ -93,49 +99,59 @@ def _search_bing_images(query, num_results=10): images = [] max_retries = 3 retry_count = 0 - + # Clean up query: remove Google-style filters that Bing doesn't understand clean_query = query - clean_query = clean_query.replace(" filetype:jpg", "").replace(" filetype:png", "") + clean_query = clean_query.replace( + " filetype:jpg", "").replace( + " filetype:png", "") clean_query = clean_query.replace(" OR ", " ") # Replace OR with space clean_query = clean_query.strip() - - logger.debug(f"Bing Images search query cleaned: '{query}' -> '{clean_query}'") - + + logger.debug( + f"Bing Images search query cleaned: '{query}' -> '{clean_query}'") + while retry_count < max_retries: try: # Bing Images search URL search_url = "https://www.bing.com/images/search" - + params = { "q": clean_query, "count": min(num_results, 35), } - + headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" } - - response = requests.get(search_url, params=params, headers=headers, timeout=10) - logger.debug(f"Bing Images Response Status: {response.status_code}") - + + response = requests.get( + search_url, + params=params, + headers=headers, + timeout=10) + logger.debug( + f"Bing Images Response Status: {response.status_code}") + if response.status_code != 200: - raise Exception(f"Bing Images returned status {response.status_code}") - + raise Exception( + f"Bing Images returned status {response.status_code}") + # Extract image URLs from the HTML response using regex # Bing stores lazy-loaded images in data-src attributes # These are Bing image proxy URLs (tse1.mm.bing.net, etc.) image_pattern = r']+data-src="([^"]+)"' matches = re.findall(image_pattern, response.text) - + if not matches: - logger.debug(f"No images found on Bing Images (attempt {retry_count + 1}/{max_retries})") + logger.debug( + f"No images found on Bing Images (attempt {retry_count + 1}/{max_retries})") retry_count += 1 if retry_count < max_retries: time.sleep(1) # Wait before retrying continue raise Exception("No images found on Bing Images after retries") - + # Process URLs and decode HTML entities for url in matches: if url.startswith('http') and len(images) < num_results: @@ -143,25 +159,29 @@ def _search_bing_images(query, num_results=10): url = url.replace('&', '&') url = url.replace('\\/', '/') images.append(url) - + if not images: - logger.debug(f"No valid image URLs found (attempt {retry_count + 1}/{max_retries})") + logger.debug( + f"No valid image URLs found (attempt {retry_count + 1}/{max_retries})") retry_count += 1 if retry_count < max_retries: time.sleep(1) continue raise Exception("No valid image URLs found after retries") - - logger.info(f"Bing Images search returned {len(images)} images for query: {clean_query}") + + logger.info( + f"Bing Images search returned {len(images)} images for query: {clean_query}") return images[:num_results] - + except Exception as e: if retry_count < max_retries - 1: - logger.debug(f"Bing Images error (attempt {retry_count + 1}/{max_retries}): {str(e)}") + logger.debug( + f"Bing Images error (attempt {retry_count + 1}/{max_retries}): {str(e)}") retry_count += 1 time.sleep(1) else: - logger.error(f"Bing Images error after {max_retries} attempts: {str(e)}") + logger.error( + f"Bing Images error after {max_retries} attempts: {str(e)}") raise @@ -171,12 +191,12 @@ def _parse_google_api_error(response): data = response.json() if 'error' in data: error_obj = data['error'] - + if isinstance(error_obj, dict): message = error_obj.get('message', 'Unknown error') code = error_obj.get('code', response.status_code) status = error_obj.get('status', 'UNKNOWN') - + if status == 'PERMISSION_DENIED' or code == 403: return ( f"Google Custom Search API Access Denied (403): {message}\n\n" @@ -184,8 +204,7 @@ def _parse_google_api_error(response): f"• The Custom Search JSON API is not enabled in your Google Cloud project\n" f"• Your API key doesn't have the right permissions\n" f"• The search engine ID (CX) is incorrect or disabled\n\n" - f"The app will use Bing Images as a fallback image source." - ) + f"The app will use Bing Images as a fallback image source.") elif status == 'INVALID_ARGUMENT' or code == 400: return f"Invalid Request: {message}" elif status == 'UNAUTHENTICATED' or code == 401: @@ -196,7 +215,7 @@ def _parse_google_api_error(response): return f"API Error ({code}): {message}" else: return f"API Error: {str(error_obj)}" - except: + except BaseException: pass - + return f"Google API failed with status {response.status_code}. Using Bing Images as fallback." diff --git a/src/search_most_dissimilar_images.py b/src/search_most_dissimilar_images.py index fa1e4d3..5473e53 100644 --- a/src/search_most_dissimilar_images.py +++ b/src/search_most_dissimilar_images.py @@ -1,93 +1,94 @@ -import numpy as np -from sklearn.metrics.pairwise import cosine_distances -from PIL import Image -from torchvision import models, transforms -import torch -from src.download_images import download_images - - -# Load a pre-trained ResNet50 model for feature extraction -model = models.resnet50(weights='IMAGENET1K_V1') -model = model.eval() # Set the model to evaluation mode - -# Remove the final classification layer to extract 2048-dim features from avgpool -feature_extractor = torch.nn.Sequential(*list(model.children())[:-1]) -feature_extractor = feature_extractor.eval() - -# Transformation for input images (resize, normalize, etc.) -transform = transforms.Compose([ - transforms.Resize((224, 224)), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -]) - - -def extract_features(image_path): - """ - Extracts features from an image using a pre-trained model. - - :param image_path: Path to the image. - :return: Feature vector of the image (2048-dim from avgpool layer). - """ - try: - image = Image.open(image_path).convert( - 'RGB') # Ensure the image is in RGB format - image_tensor = transform(image).unsqueeze( - 0) # Transform and add batch dimension - with torch.no_grad(): - # Use the feature extractor to get 2048-dim features from avgpool - features = feature_extractor(image_tensor) - features = features.flatten().numpy() - return features - except Exception as e: - print(f"Error extracting features from {image_path}: {e}") - return None - - -def select_most_dissimilar_images(image_urls, num_images): - """ - Selects the most dissimilar images from a list. - - :param image_urls: List of image URLs. - :param num_images: Number of most dissimilar images to select. - :return: List of the most dissimilar image URLs. - """ - # Download images to a local directory - # Temporary directory to save downloaded images - download_path = "dataset/train/images" - image_paths = download_images(image_urls, download_path=download_path) - - # Validate that images were downloaded - if not image_paths: - print("No images were downloaded.") - return [] - - # Extract features from each downloaded image - features = [] - valid_image_paths = [] - for path in image_paths: - feature = extract_features(path) - if feature is not None: - features.append(feature) - valid_image_paths.append(path) - - # Ensure there are enough features to proceed - if len(features) < num_images: - print("Not enough images to select the most dissimilar ones.") - return image_urls[:len(features)] # Return as many images as available - - # Convert features list to a numpy array - features = np.array(features) - - # Compute the cosine distance matrix between image features - distance_matrix = cosine_distances(features) - - # Sum the distances for each image and sort them in descending order - dissimilarity_scores = np.sum(distance_matrix, axis=1) - most_dissimilar_indices = np.argsort(dissimilarity_scores)[-num_images:] - - # Select the most dissimilar image URLs - most_dissimilar_images = [image_urls[idx] - for idx in most_dissimilar_indices] - - return most_dissimilar_images +import numpy as np +from sklearn.metrics.pairwise import cosine_distances +from PIL import Image +from torchvision import models, transforms +import torch +from src.download_images import download_images + + +# Load a pre-trained ResNet50 model for feature extraction +model = models.resnet50(weights='IMAGENET1K_V1') +model = model.eval() # Set the model to evaluation mode + +# Remove the final classification layer to extract 2048-dim features from +# avgpool +feature_extractor = torch.nn.Sequential(*list(model.children())[:-1]) +feature_extractor = feature_extractor.eval() + +# Transformation for input images (resize, normalize, etc.) +transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +]) + + +def extract_features(image_path): + """ + Extracts features from an image using a pre-trained model. + + :param image_path: Path to the image. + :return: Feature vector of the image (2048-dim from avgpool layer). + """ + try: + image = Image.open(image_path).convert( + 'RGB') # Ensure the image is in RGB format + image_tensor = transform(image).unsqueeze( + 0) # Transform and add batch dimension + with torch.no_grad(): + # Use the feature extractor to get 2048-dim features from avgpool + features = feature_extractor(image_tensor) + features = features.flatten().numpy() + return features + except Exception as e: + print(f"Error extracting features from {image_path}: {e}") + return None + + +def select_most_dissimilar_images(image_urls, num_images): + """ + Selects the most dissimilar images from a list. + + :param image_urls: List of image URLs. + :param num_images: Number of most dissimilar images to select. + :return: List of the most dissimilar image URLs. + """ + # Download images to a local directory + # Temporary directory to save downloaded images + download_path = "dataset/train/images" + image_paths = download_images(image_urls, download_path=download_path) + + # Validate that images were downloaded + if not image_paths: + print("No images were downloaded.") + return [] + + # Extract features from each downloaded image + features = [] + valid_image_paths = [] + for path in image_paths: + feature = extract_features(path) + if feature is not None: + features.append(feature) + valid_image_paths.append(path) + + # Ensure there are enough features to proceed + if len(features) < num_images: + print("Not enough images to select the most dissimilar ones.") + return image_urls[:len(features)] # Return as many images as available + + # Convert features list to a numpy array + features = np.array(features) + + # Compute the cosine distance matrix between image features + distance_matrix = cosine_distances(features) + + # Sum the distances for each image and sort them in descending order + dissimilarity_scores = np.sum(distance_matrix, axis=1) + most_dissimilar_indices = np.argsort(dissimilarity_scores)[-num_images:] + + # Select the most dissimilar image URLs + most_dissimilar_images = [image_urls[idx] + for idx in most_dissimilar_indices] + + return most_dissimilar_images diff --git a/src/select_balanced_images.py b/src/select_balanced_images.py index 44c02bf..738e370 100644 --- a/src/select_balanced_images.py +++ b/src/select_balanced_images.py @@ -24,7 +24,8 @@ model = models.resnet50(weights='IMAGENET1K_V1') model = model.eval() # Set the model to evaluation mode -# Remove the final classification layer to extract 2048-dim features from avgpool +# Remove the final classification layer to extract 2048-dim features from +# avgpool feature_extractor = torch.nn.Sequential(*list(model.children())[:-1]) feature_extractor = feature_extractor.eval() @@ -42,7 +43,7 @@ def extract_features(image_path): Args: image_path: Path to the image. - + Returns: Feature vector of the image (2048-dim from avgpool layer), or None if failed. """ @@ -58,7 +59,11 @@ def extract_features(image_path): return None -def select_balanced_images(image_urls, image_paths, num_images=9, relevance_weight=0.7): +def select_balanced_images( + image_urls, + image_paths, + num_images=9, + relevance_weight=0.7): """ Selects images that balance search relevance with visual dissimilarity. @@ -68,19 +73,20 @@ def select_balanced_images(image_urls, image_paths, num_images=9, relevance_weig num_images: Number of images to select (default 9) relevance_weight: Weight for relevance score (0-1). Dissimilarity weight = 1 - relevance_weight Default 0.7 means 70% relevance, 30% dissimilarity - + Returns: List of selected image URLs, balanced between relevance and dissimilarity """ - + if len(image_urls) < num_images: - logger.warning(f"Requested {num_images} images but only {len(image_urls)} available") + logger.warning( + f"Requested {num_images} images but only {len(image_urls)} available") return image_urls - + # Extract features from all images features_list = [] valid_indices = [] - + for idx, path in enumerate(image_paths): feature = extract_features(path) if feature is not None: @@ -88,47 +94,53 @@ def select_balanced_images(image_urls, image_paths, num_images=9, relevance_weig valid_indices.append(idx) else: logger.debug(f"Skipping image {idx} - could not extract features") - + if len(features_list) < num_images: - logger.warning(f"Only {len(features_list)} images have valid features, returning top {min(len(image_urls), num_images)}") + logger.warning( + f"Only {len(features_list)} images have valid features, returning top {min(len(image_urls), num_images)}") return image_urls[:min(len(image_urls), num_images)] - + features = np.array(features_list) - + # Calculate dissimilarity scores based on visual features # Compute cosine distance matrix between image features distance_matrix = cosine_distances(features) - - # Calculate dissimilarity score for each image (sum of distances to all others) + + # Calculate dissimilarity score for each image (sum of distances to all + # others) dissimilarity_scores = np.sum(distance_matrix, axis=1) - + # Normalize both scores to 0-1 range dissimilarity_weight = 1 - relevance_weight - + # Relevance score: images earlier in search results have higher relevance # Map position (0 to len-1) to relevance (1.0 to 0.0) - relevance_scores = 1.0 - np.arange(len(features_list)) / max(1, len(features_list) - 1) - + relevance_scores = 1.0 - \ + np.arange(len(features_list)) / max(1, len(features_list) - 1) + # Normalize dissimilarity scores to 0-1 range if dissimilarity_scores.max() > dissimilarity_scores.min(): - dissimilarity_scores_norm = (dissimilarity_scores - dissimilarity_scores.min()) / (dissimilarity_scores.max() - dissimilarity_scores.min()) + dissimilarity_scores_norm = ( + dissimilarity_scores - dissimilarity_scores.min()) / ( + dissimilarity_scores.max() - dissimilarity_scores.min()) else: dissimilarity_scores_norm = dissimilarity_scores - + # Combined score: weighted combination of relevance and dissimilarity - combined_scores = (relevance_weight * relevance_scores + - dissimilarity_weight * dissimilarity_scores_norm) - + combined_scores = (relevance_weight * relevance_scores + + dissimilarity_weight * dissimilarity_scores_norm) + # Select top num_images indices by combined score selected_feature_indices = np.argsort(combined_scores)[-num_images:][::-1] - + # Map back to original image indices selected_indices = [valid_indices[idx] for idx in selected_feature_indices] - + # Return selected image URLs selected_images = [image_urls[idx] for idx in selected_indices] - - logger.info(f"Selected {len(selected_images)} images using balanced strategy " - f"(relevance_weight={relevance_weight}, dissimilarity_weight={dissimilarity_weight})") - + + logger.info( + f"Selected {len(selected_images)} images using balanced strategy " + f"(relevance_weight={relevance_weight}, dissimilarity_weight={dissimilarity_weight})") + return selected_images diff --git a/src/train_model.py b/src/train_model.py index d6f556b..85ceec2 100644 --- a/src/train_model.py +++ b/src/train_model.py @@ -9,7 +9,7 @@ def get_optimal_batch_size(): """ Determines optimal batch size based on available VRAM. - + Returns: int: Optimal batch size (16, 8, or 4) """ @@ -57,34 +57,41 @@ def train_model(data_yaml_path, model_type='yolov8'): device=device # Use GPU if available, else CPU ) - # Get the best model path - model_dir = "runs/detect" - if os.path.exists(model_dir): - train_dirs = [os.path.join(model_dir, d) for d in os.listdir(model_dir) - if os.path.isdir(os.path.join(model_dir, d)) and d.startswith('train')] - - if not train_dirs: - logger.error("No training directories found in runs/detect") - return None - - latest_train_dir = max(train_dirs, key=os.path.getmtime) - model_path = os.path.join(latest_train_dir, "weights", "best.pt") - - if os.path.exists(model_path): - logger.info(f"Model trained and saved at {model_path}") - return model_path - else: - logger.error(f"Model file not found at {model_path}") - # Try to find if last.pt exists as fallback - last_path = os.path.join(latest_train_dir, "weights", "last.pt") - if os.path.exists(last_path): - logger.info(f"best.pt not found, using last.pt at {last_path}") - return last_path - return None - else: - logger.error( - "Training directory not found at runs/detect. Training may have failed.") - return None + # Get the best model path + model_dir = "runs/detect" + if os.path.exists(model_dir): + train_dirs = [ + os.path.join( + model_dir, + d) for d in os.listdir(model_dir) if os.path.isdir( + os.path.join( + model_dir, + d)) and d.startswith('train')] + + if not train_dirs: + logger.error("No training directories found in runs/detect") + return None + + latest_train_dir = max(train_dirs, key=os.path.getmtime) + model_path = os.path.join(latest_train_dir, "weights", "best.pt") + + if os.path.exists(model_path): + logger.info(f"Model trained and saved at {model_path}") + return model_path + else: + logger.error(f"Model file not found at {model_path}") + # Try to find if last.pt exists as fallback + last_path = os.path.join( + latest_train_dir, "weights", "last.pt") + if os.path.exists(last_path): + logger.info( + f"best.pt not found, using last.pt at {last_path}") + return last_path + return None + else: + logger.error( + "Training directory not found at runs/detect. Training may have failed.") + return None except Exception as e: logger.error(f"Error during model training: {e}") diff --git a/src/utils/annotation_converter.py b/src/utils/annotation_converter.py index 27f2819..bda777a 100644 --- a/src/utils/annotation_converter.py +++ b/src/utils/annotation_converter.py @@ -1,8 +1,8 @@ -import os -import json -from PIL import Image - - +import os +import json +from PIL import Image + + def convert_to_yolo_format(json_annotation, img_width, img_height): """ Convert JSON bbox annotation to YOLO format (normalized coordinates). @@ -19,19 +19,22 @@ def convert_to_yolo_format(json_annotation, img_width, img_height): # Handle empty or None annotations if not json_annotation or json_annotation == '{}' or json_annotation == '[]': return "" - + try: # Parse JSON string to dict or list data = json.loads(json_annotation) except (json.JSONDecodeError, TypeError): return "" - + # Handle case where data is a list (backward compatibility) if isinstance(data, list): if len(data) == 0: return "" # Convert list format to dict format with rects - data = {"rects": data, "canvasWidth": img_width, "canvasHeight": img_height} + data = { + "rects": data, + "canvasWidth": img_width, + "canvasHeight": img_height} # Check if it's the new format with rects array if isinstance(data, dict) and 'rects' in data and isinstance( @@ -87,7 +90,8 @@ def convert_to_yolo_format(json_annotation, img_width, img_height): return "\n".join(yolo_lines) - # Legacy format: simple x, y, width, height (only if it's a dict with those keys) + # Legacy format: simple x, y, width, height (only if it's a dict with + # those keys) if isinstance(data, dict) and 'x' in data and 'y' in data: bbox = data @@ -105,12 +109,12 @@ def convert_to_yolo_format(json_annotation, img_width, img_height): # Return YOLO format string (class 0) return f"0 {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}" - + # If no valid format found, return empty string return "" - - -def ensure_directory(directory): - """Create directory if it doesn't exist.""" - if not os.path.exists(directory): - os.makedirs(directory) + + +def ensure_directory(directory): + """Create directory if it doesn't exist.""" + if not os.path.exists(directory): + os.makedirs(directory)