79 lines
2.4 KiB
Python
79 lines
2.4 KiB
Python
import random
|
|
import shutil
|
|
from pathlib import Path
|
|
from PIL import Image
|
|
import matplotlib.pyplot as plt
|
|
import torch
|
|
import numpy as np
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
import open_clip
|
|
|
|
SOURCE = Path("alle_meine_fotos/")
|
|
THRESHOLD = 0.95
|
|
|
|
DIRS = [
|
|
"dataset/train/wallpaper", "dataset/train/no_wallpaper",
|
|
"dataset/val/wallpaper", "dataset/val/no_wallpaper",
|
|
]
|
|
for d in DIRS:
|
|
Path(d).mkdir(parents=True, exist_ok=True)
|
|
|
|
# --- Embeddings berechnen ---
|
|
print("Lade CLIP Modell...")
|
|
clip_model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
|
|
clip_model.eval()
|
|
|
|
def get_embedding(img_path):
|
|
try:
|
|
img = preprocess(Image.open(img_path).convert("RGB")).unsqueeze(0)
|
|
with torch.no_grad():
|
|
return clip_model.encode_image(img).numpy().flatten()
|
|
except Exception as e:
|
|
print(f" Fehler bei {img_path.name}: {e}")
|
|
return None
|
|
|
|
all_images = list(SOURCE.glob("**/*.jpg")) + list(SOURCE.glob("**/*.png"))
|
|
|
|
print(f"{len(all_images)} Bilder gefunden, berechne Embeddings...")
|
|
embeddings, valid_images = [], []
|
|
for i, p in enumerate(all_images):
|
|
print(f" {i+1}/{len(all_images)}: {p.name}", end="\r")
|
|
emb = get_embedding(p)
|
|
if emb is not None:
|
|
embeddings.append(emb)
|
|
valid_images.append(p)
|
|
|
|
# --- Duplikate rausfiltern ---
|
|
print(f"\nFiltere Duplikate (Threshold={THRESHOLD})...")
|
|
sim_matrix = cosine_similarity(embeddings)
|
|
to_skip = set()
|
|
for i in range(len(valid_images)):
|
|
if i in to_skip:
|
|
continue
|
|
for j in range(i + 1, len(valid_images)):
|
|
if sim_matrix[i][j] > THRESHOLD:
|
|
to_skip.add(j)
|
|
|
|
unique_images = [p for i, p in enumerate(valid_images) if i not in to_skip]
|
|
print(f"{len(valid_images) - len(unique_images)} Duplikate entfernt, {len(unique_images)} verbleiben.\n")
|
|
|
|
# Modell kann jetzt freigegeben werden
|
|
del clip_model, embeddings, sim_matrix
|
|
|
|
# --- Labelschleife ---
|
|
for img_path in unique_images:
|
|
img = Image.open(img_path)
|
|
plt.imshow(img)
|
|
plt.title(img_path.name)
|
|
plt.axis("off")
|
|
plt.show(block=False)
|
|
|
|
label = input("Wallpaper? (y/n/q): ").strip().lower()
|
|
plt.close()
|
|
|
|
if label == "q":
|
|
break
|
|
elif label in ("y", "n"):
|
|
folder = "wallpaper" if label == "y" else "no_wallpaper"
|
|
split = "train" if random.random() < 0.8 else "val"
|
|
shutil.copy(img_path, f"dataset/{split}/{folder}/{img_path.name}") |