ADD: Filter Duplicate images out
This commit is contained in:
parent
cc301ce395
commit
8edc250058
@ -1,13 +1,15 @@
|
|||||||
# quick_labeler.py
|
|
||||||
import random
|
import random
|
||||||
import shutil, os
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
import open_clip
|
||||||
|
|
||||||
SOURCE = Path("alle_meine_fotos/")
|
SOURCE = Path("alle_meine_fotos/")
|
||||||
images = list(SOURCE.glob("**/*.jpg")) + list(SOURCE.glob("**/*.png"))
|
THRESHOLD = 0.95
|
||||||
|
|
||||||
|
|
||||||
DIRS = [
|
DIRS = [
|
||||||
"dataset/train/wallpaper", "dataset/train/no_wallpaper",
|
"dataset/train/wallpaper", "dataset/train/no_wallpaper",
|
||||||
@ -16,7 +18,50 @@ DIRS = [
|
|||||||
for d in DIRS:
|
for d in DIRS:
|
||||||
Path(d).mkdir(parents=True, exist_ok=True)
|
Path(d).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
for img_path in images:
|
# --- Embeddings berechnen ---
|
||||||
|
print("Lade CLIP Modell...")
|
||||||
|
clip_model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
|
||||||
|
clip_model.eval()
|
||||||
|
|
||||||
|
def get_embedding(img_path):
|
||||||
|
try:
|
||||||
|
img = preprocess(Image.open(img_path).convert("RGB")).unsqueeze(0)
|
||||||
|
with torch.no_grad():
|
||||||
|
return clip_model.encode_image(img).numpy().flatten()
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Fehler bei {img_path.name}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
all_images = list(SOURCE.glob("**/*.jpg")) + list(SOURCE.glob("**/*.png"))
|
||||||
|
|
||||||
|
print(f"{len(all_images)} Bilder gefunden, berechne Embeddings...")
|
||||||
|
embeddings, valid_images = [], []
|
||||||
|
for i, p in enumerate(all_images):
|
||||||
|
print(f" {i+1}/{len(all_images)}: {p.name}", end="\r")
|
||||||
|
emb = get_embedding(p)
|
||||||
|
if emb is not None:
|
||||||
|
embeddings.append(emb)
|
||||||
|
valid_images.append(p)
|
||||||
|
|
||||||
|
# --- Duplikate rausfiltern ---
|
||||||
|
print(f"\nFiltere Duplikate (Threshold={THRESHOLD})...")
|
||||||
|
sim_matrix = cosine_similarity(embeddings)
|
||||||
|
to_skip = set()
|
||||||
|
for i in range(len(valid_images)):
|
||||||
|
if i in to_skip:
|
||||||
|
continue
|
||||||
|
for j in range(i + 1, len(valid_images)):
|
||||||
|
if sim_matrix[i][j] > THRESHOLD:
|
||||||
|
to_skip.add(j)
|
||||||
|
|
||||||
|
unique_images = [p for i, p in enumerate(valid_images) if i not in to_skip]
|
||||||
|
print(f"{len(valid_images) - len(unique_images)} Duplikate entfernt, {len(unique_images)} verbleiben.\n")
|
||||||
|
|
||||||
|
# Modell kann jetzt freigegeben werden
|
||||||
|
del clip_model, embeddings, sim_matrix
|
||||||
|
|
||||||
|
# --- Labelschleife ---
|
||||||
|
for img_path in unique_images:
|
||||||
img = Image.open(img_path)
|
img = Image.open(img_path)
|
||||||
plt.imshow(img)
|
plt.imshow(img)
|
||||||
plt.title(img_path.name)
|
plt.title(img_path.name)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user