From 4df206981fcef15744508a0c8653047ef1f8f1fe Mon Sep 17 00:00:00 2001 From: Prunebutt Date: Sun, 19 Oct 2025 23:33:47 +0200 Subject: [PATCH] make cli a mvp --- src/protestswap/cli.py | 42 ++++++++++++++++--- src/protestswap/protestswap.py | 74 ++++++++++++++++++++-------------- 2 files changed, 81 insertions(+), 35 deletions(-) diff --git a/src/protestswap/cli.py b/src/protestswap/cli.py index 67c4726..16d0d8f 100644 --- a/src/protestswap/cli.py +++ b/src/protestswap/cli.py @@ -1,6 +1,10 @@ import os import logging +import argparse + +import cv2 + from protestswap.protestswap import ProtestFaceSwapper logger = logging.getLogger("protestswap") @@ -10,19 +14,47 @@ INSIGHTFACE_VAR = "INSIGHTFACE_ROOT_DIR" def main(): + parser = argparse.ArgumentParser( + prog="protestswap", description="Swaps faces of people on protests" + ) + parser.add_argument("target", help="The image to swap the faces on") + parser.add_argument( + "-t", "--template", action="append", help="The template faces to paste on" + ) + + args = parser.parse_args() + + print(args.target) + + if not args.template: + print("You need at least one template (with '-t')!") + exit(1) + insightface_dir = None - if os.path.exists(DEFAULT_INSIGHTFACE_DIR): - insightface_dir = DEFAULT_INSIGHTFACE_DIR - elif os.environ.get(INSIGHTFACE_VAR): + if os.environ.get(INSIGHTFACE_VAR): insightface_dir = os.environ.get(INSIGHTFACE_VAR) + elif os.path.exists(DEFAULT_INSIGHTFACE_DIR): + insightface_dir = DEFAULT_INSIGHTFACE_DIR else: logger.warning( f"No directory at '{DEFAULT_INSIGHTFACE_DIR}' and '{INSIGHTFACE_VAR}' not set yet." ) insightface_dir = f"{os.environ.get('HOME')}/.cache/insightface" - # swapper = ProtestFaceSwapper(insightface_dir) - # swapper.prepare_model() + swapper = ProtestFaceSwapper(insightface_dir) + swapper.prepare_model() + + swapper.set_target_path(args.target) + + for template in args.template: + print(f"Adding template {template}...") + swapper.add_template_path(template) + + print("Detecting faces") + swapper.detect_target_faces() + + swapped = swapper.swap() + cv2.imwrite("result.png", swapped) if __name__ == "__main__": diff --git a/src/protestswap/protestswap.py b/src/protestswap/protestswap.py index 1b7f584..8e6a981 100644 --- a/src/protestswap/protestswap.py +++ b/src/protestswap/protestswap.py @@ -1,10 +1,10 @@ #!/usr/bin/env python3 -import logging import os +import logging import cv2 -import random # TODO: seed! +import random # TODO: seed! import insightface from insightface.app import FaceAnalysis @@ -14,51 +14,65 @@ from insightface.app.common import Face logger = logging.getLogger("protestswap") -SWAPPING_MODEL = 'inswapper_128.onnx' +SWAPPING_MODEL = "inswapper_128.onnx" class ProtestFaceSwapper: def __init__(self, root_dir): - self._app : FaceAnalysis = FaceAnalysis(name='buffalo_l', root=insightface_dir) - self._model : INSwapper - self._source : cv2.typing.MatLike|None = None - self._source_faces : list[Face] = [] - self._template_faces : list[Face] = [] + self._root_dir = root_dir + self._app: FaceAnalysis = FaceAnalysis(name="buffalo_l", root=root_dir) + self._model: INSwapper + self._target: cv2.typing.MatLike | None = None + self._target_faces: list[Face] = [] + self._template_faces: list[Face] = [] self._app.prepare(ctx_id=0, det_size=(640, 640)) def prepare_model(self) -> None: - self._model = insightface.model_zoo.get_model(SWAPPING_MODEL, download=True, download_zip=True) # Revealed type: "INSwapper" + self._model = insightface.model_zoo.get_model( + os.path.join(self._root_dir, "models", SWAPPING_MODEL), + root=self._root_dir, + download=False, + download_zip=False, + ) # Revealed type: "INSwapper" + print(f"Model: {self._model}") - def set_source(self, path: str) -> None: - self._source = cv2.imread(path) + def set_target_path(self, path: str) -> None: + self.set_target(cv2.imread(path)) - def detect_source_faces(self): - self._source_faces = self._app.get(self._source) + def set_target(self, image: cv2.UMat) -> None: + self._target = image - def add_template_faces(self, template_path: str): - self._template_faces += [self._app.get(cv2.imread(template_path))] + def detect_target_faces(self): + self._target_faces = self._app.get(self._target) + print(f"Found {len(self._target_faces)} faces!") + + def add_template_path(self, template_path: str): + self.add_template(cv2.imread(template_path)) + + def add_template(self, template: cv2.UMat): + self._template_faces += self._app.get(template) def swap(self) -> cv2.typing.MatLike: - work_copy = self._source_faces.copy() - for i, face in enumerate(self._source_faces): - print(f"Replacing face {i+1}/{len(self._source_faces)}...", end="", flush=True); + work_copy = self._target.copy() + for i, face in enumerate(self._target_faces): + print( + f"Replacing face {i + 1}/{len(self._target_faces)}...", + # end="", + # flush=True, + ) imposed_face = random.choice(self._template_faces) - work_copy = self._model.get(work_copy, face, imposed_face, paste_back=True) + print( + f"work_copy: {type(work_copy)}, face: {type(face)}, imposed: { + type(imposed_face) + }" + ) + work_copy = self._model.get( + work_copy, face, imposed_face, paste_back=True) print(" done!") return work_copy @staticmethod - def readFile(path: str) -> cv2.typing.MatLike|None: + def readFile(path: str) -> cv2.typing.MatLike | None: return cv2.imread(path) - - -def main(): - - swapper = ProtestFaceSwapper(insightface_dir) - swapper.prepare_model() - - -if __name__ == '__main__': - main()