Skip to content

Commit a7cc8ef

Browse files
authored
Update main.py
1 parent 95891ac commit a7cc8ef

File tree

1 file changed

+73
-131
lines changed

1 file changed

+73
-131
lines changed

main.py

Lines changed: 73 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,41 @@
11
import os
2-
import sys
32
import zipfile
4-
import pickle
5-
6-
from flask import Flask, request, render_template, redirect, url_for, send_from_directory, flash
7-
from werkzeug.utils import secure_filename
8-
93
import numpy as np
104
import tensorflow as tf
115
from tensorflow.keras.models import load_model
6+
from fastapi import FastAPI, UploadFile, File, Request
7+
from fastapi.responses import HTMLResponse
8+
from fastapi.templating import Jinja2Templates
9+
from fastapi.staticfiles import StaticFiles
10+
from io import BytesIO
1211
from PIL import Image, UnidentifiedImageError
1312

14-
# ---------------------------
15-
# Settings and global constants
16-
# ---------------------------
17-
UPLOAD_FOLDER = 'uploads'
13+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
14+
STATIC_DIR = os.path.join(BASE_DIR, "static")
15+
TEMPLATES_DIR = os.path.join(BASE_DIR, "templates")
16+
UPLOAD_FOLDER = os.path.join(BASE_DIR, "uploads")
17+
ZIP_PATH = os.path.join(BASE_DIR, "photos.zip")
1818
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
1919

20-
ZIP_PATH = 'photos.zip'
21-
MODEL_PATH = 'resnet50_local.h5'
22-
FEATURES_PATH = 'image_features.pkl'
23-
TOP_K = 5
24-
SIMILARITY_THRESHOLD = 0.45 # For display purposes, consider as "60%"
20+
app = FastAPI()
21+
model = tf.keras.applications.MobileNetV2(weights='imagenet', include_top=False, pooling='avg')
22+
templates = Jinja2Templates(directory=TEMPLATES_DIR)
2523

26-
app = Flask(__name__)
27-
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
28-
app.secret_key = 'your_secret_key_here' # Used for flash messages
29-
30-
# ---------------------------
31-
# Function for correctly locating resources
32-
# (useful when packaging with PyInstaller, among others)
33-
# ---------------------------
34-
def resource_path(relative_path):
35-
try:
36-
base_path = sys._MEIPASS
37-
except Exception:
38-
base_path = os.path.abspath(".")
39-
return os.path.join(base_path, relative_path)
24+
@app.get("/gallery")
25+
async def get_gallery():
26+
images = get_images_from_zip()
27+
with zipfile.ZipFile(ZIP_PATH, "r") as archive:
28+
saved_images = []
29+
for image_key in images[:20]: # Show first 20 images
30+
try:
31+
saved_name = extract_and_save_image(archive, image_key)
32+
saved_images.append(saved_name)
33+
except Exception as e:
34+
print(f"Error processing image {image_key}: {e}")
35+
return {"photos": saved_images}
4036

41-
# ---------------------------
42-
# Functions for image processing and feature extraction
43-
# ---------------------------
44-
model = load_model(MODEL_PATH)
37+
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
38+
app.mount("/uploads", StaticFiles(directory=UPLOAD_FOLDER), name="uploads")
4539

4640
def preprocess_image(image: Image.Image):
4741
image = image.resize((224, 224))
@@ -53,116 +47,64 @@ def preprocess_image(image: Image.Image):
5347
return image_array
5448

5549
def get_image_features(image: Image.Image):
56-
preprocessed_image = preprocess_image(image)
57-
features = model.predict(preprocessed_image)
58-
return features.flatten()
50+
return model.predict(preprocess_image(image))
5951

60-
def cosine_similarity(features1, features2):
61-
num = np.dot(features1, features2)
62-
den = np.linalg.norm(features1) * np.linalg.norm(features2)
63-
return num / (den + 1e-9)
52+
def compare_images(image1_features, image2_features):
53+
return np.linalg.norm(image1_features - image2_features)
6454

65-
def extract_features_from_zip():
66-
image_features = {}
67-
with zipfile.ZipFile(ZIP_PATH, 'r') as archive:
68-
image_keys = [name for name in archive.namelist() if name.lower().endswith(('.jpg', '.jpeg', '.png'))]
69-
for image_key in image_keys:
70-
try:
71-
with archive.open(image_key) as image_file:
72-
image = Image.open(image_file).convert('RGB')
73-
features = get_image_features(image)
74-
image_features[image_key] = features
75-
except UnidentifiedImageError:
76-
print(f"Cannot open image: {image_key}")
77-
except Exception as e:
78-
print(f"Error processing {image_key}: {e}")
79-
with open(FEATURES_PATH, 'wb') as f:
80-
pickle.dump(image_features, f)
81-
print(f"Features extracted and saved to {FEATURES_PATH}")
82-
83-
def load_precomputed_features():
84-
if not os.path.exists(FEATURES_PATH):
85-
extract_features_from_zip()
86-
with open(FEATURES_PATH, 'rb') as f:
87-
image_features = pickle.load(f)
88-
return image_features
89-
90-
image_features = load_precomputed_features()
55+
def get_images_from_zip():
56+
with zipfile.ZipFile(ZIP_PATH, "r") as archive:
57+
return [name for name in archive.namelist() if name.lower().endswith((".jpg", ".jpeg", ".png"))]
9158

9259
def extract_and_save_image(archive, image_key):
9360
with archive.open(image_key) as image_file:
94-
image = Image.open(image_file).convert('RGB')
61+
image = Image.open(image_file)
9562
safe_image_name = os.path.basename(image_key)
9663
image_path = os.path.join(UPLOAD_FOLDER, safe_image_name)
9764
image.save(image_path)
9865
return safe_image_name
9966

100-
def find_similar_images(uploaded_image: Image.Image):
101-
uploaded_image_features = get_image_features(uploaded_image)
102-
similarities = []
103-
for image_key, features in image_features.items():
104-
sim_val = cosine_similarity(uploaded_image_features, features)
105-
if sim_val >= SIMILARITY_THRESHOLD:
106-
similarities.append((image_key, sim_val))
107-
similarities.sort(key=lambda x: x[1], reverse=True)
108-
similar_images = []
109-
if similarities:
110-
with zipfile.ZipFile(ZIP_PATH, 'r') as archive:
111-
for image_key, sim_val in similarities[:TOP_K]:
112-
saved_image_name = extract_and_save_image(archive, image_key)
113-
similar_images.append((saved_image_name, sim_val))
114-
return similar_images
67+
@app.get("/get_all_photos")
68+
async def get_all_photos():
69+
images = get_images_from_zip()
70+
with zipfile.ZipFile(ZIP_PATH, "r") as archive:
71+
saved_images = []
72+
for image_key in images:
73+
try:
74+
saved_name = extract_and_save_image(archive, image_key)
75+
saved_images.append(saved_name)
76+
except Exception as e:
77+
print(f"Error processing image {image_key}: {e}")
78+
return {"photos": saved_images}
11579

116-
def allowed_file(filename):
117-
ALLOWED_EXTENSIONS = {'jpg', 'jpeg', 'png', 'bmp', 'gif'}
118-
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
80+
@app.get("/", response_class=HTMLResponse)
81+
async def read_root(request: Request):
82+
return templates.TemplateResponse("upload_form.html", {"request": request})
11983

120-
# ---------------------------
121-
# Flask routes
122-
# ---------------------------
123-
@app.route('/', methods=['GET', 'POST'])
124-
def index():
125-
result_text = None
126-
similar_images = []
127-
if request.method == 'POST':
128-
if 'image' not in request.files:
129-
flash('No file was uploaded.')
130-
return redirect(request.url)
131-
file = request.files['image']
132-
if file.filename == '':
133-
flash('No file selected.')
134-
return redirect(request.url)
135-
if file and allowed_file(file.filename):
84+
@app.post("/find_similar/")
85+
async def find_similar_images(file: UploadFile = File(...)):
86+
uploaded_image = Image.open(BytesIO(await file.read()))
87+
uploaded_image_features = get_image_features(uploaded_image)
88+
images = get_images_from_zip()
89+
similarities = []
90+
with zipfile.ZipFile(ZIP_PATH, "r") as archive:
91+
for image_key in images:
13692
try:
137-
uploaded_image = Image.open(file.stream).convert('RGB')
93+
with archive.open(image_key) as image_file:
94+
image = Image.open(image_file).convert("RGB")
95+
similarity = compare_images(uploaded_image_features, get_image_features(image))
96+
similarities.append((image_key, similarity))
13897
except UnidentifiedImageError:
139-
flash('The uploaded file is not a valid image.')
140-
return redirect(request.url)
98+
print(f"Cannot identify image file: {image_key}")
14199
except Exception as e:
142-
flash(str(e))
143-
return redirect(request.url)
144-
similar_images = find_similar_images(uploaded_image)
145-
if similar_images:
146-
result_text = "Similar images found (similarity ≥ 60%):"
147-
else:
148-
result_text = "No matching images found."
149-
return render_template('index.html', result_text=result_text, similar_images=similar_images)
150-
151-
@app.route('/uploads/<filename>')
152-
def uploaded_file(filename):
153-
return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
154-
155-
@app.route('/internet-search')
156-
def internet_search():
157-
return "Internet search functionality is not implemented yet."
158-
159-
@app.route('/folder-search')
160-
def folder_search():
161-
return "Folder search functionality is not implemented yet."
162-
163-
@app.route('/about')
164-
def about():
165-
return render_template('about.html')
166-
167-
if __name__ == '__main__':
168-
app.run(debug=True)
100+
print(f"Error processing image {image_key}: {e}")
101+
similarities.sort(key=lambda x: x[1])
102+
similar_images = []
103+
with zipfile.ZipFile(ZIP_PATH, "r") as archive:
104+
for image_key, _ in similarities[:5]:
105+
similar_images.append(extract_and_save_image(archive, image_key))
106+
return {"filename": file.filename, "similar_images": similar_images}
107+
108+
if __name__ == "__main__":
109+
import uvicorn
110+
uvicorn.run(app, host="0.0.0.0", port=3000)

0 commit comments

Comments
 (0)