Fake_Detection / app.py
ElBeh's picture
Update app.py
92f6098 verified
Raw
History Blame Contribute Delete
2.97 kB
import keras
import numpy as np
import streamlit as st
import random
import os
from PIL import Image, ImageOps
from io import BytesIO
from huggingface_hub import snapshot_download
def random_crop(img, min_size=160, max_size=2048, ratio=5/8):
width, height = img.size
crop_width = random.randint(min_size, min(max_size, width))
crop_height = int(crop_width * ratio)
if crop_height > height:
crop_height = height
crop_width = int(crop_height / ratio)
left = random.randint(0, width - crop_width)
top = random.randint(0, height - crop_height)
right = left + crop_width
bottom = top + crop_height
return img.crop((left, top, right, bottom))
def jpg_compression(img):
quality = random.randint(65, 100)
jpeg_image = BytesIO()
img.convert("RGB").save(jpeg_image, 'JPEG', quality=quality)
jpeg_image.seek(0)
compressed_img = Image.open(jpeg_image)
return compressed_img
def get_prediction(img):
x = np.array(img)
x = np.expand_dims(x, axis=0)
predictions = model.predict(x)
return predictions[0,:]
models = ['DDPM', 'Glide', 'Latent Diffusion', 'Palette', 'Stable Diffusion', 'VQ Diffusion', 'real', 'unseen_fake']
st.title("Fake Detection")
st.divider()
st.subheader("Modelvariant")
variant = st.selectbox(
"Choose the model",
("ResNet50v2-Basemodel", "ResNet50v2-Finetuned"),
index=None,
placeholder="Choose a model",
label_visibility="hidden"
)
st.write("You selected model: ", variant)
binary = st.toggle("Activate binary classification")
st.divider()
if variant == "ResNet50v2-Basemodel":
local_model_path = snapshot_download(repo_id="ElBeh/ma_basemodel")
else:
local_model_path = snapshot_download(repo_id="ElBeh/ma_finetuned_model")
model = keras.models.load_model(local_model_path)
st.subheader("Image Preprocessing")
crop = st.toggle("random crop")
compress = st.toggle("jpeg compression")
st.divider()
file_name = st.file_uploader("Choose an image...")
#st.button("execute classification", type="primary")
if file_name is not None:
col1, col2 = st.columns(2)
image = Image.open(file_name)
image = ImageOps.exif_transpose(image)
if image.size != (200, 200) or image.mode != 'RGB':
if crop:
image = random_crop(image)
image = image.resize((200, 200), Image.LANCZOS)
#if image.format != "JPEG" and compress :
if compress :
image = jpg_compression(image)
col1.image(image, use_column_width=True)
predictions = get_prediction(image)
if binary:
col2.header("Prediction")
if predictions[6] > 0.5:
col2.markdown(":green[real image!]")
else:
col2.markdown(":red[fake image!]")
else:
col2.header("Probabilities")
if crop:
st.button("re-crop")
for idx,p in enumerate(predictions):
col2.text(f"{ models[idx] }: { round(p * 100, 2)}%")