make cli a mvp
This commit is contained in:
parent
11012b7db5
commit
4df206981f
2 changed files with 81 additions and 35 deletions
|
|
@ -1,6 +1,10 @@
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
|
||||||
from protestswap.protestswap import ProtestFaceSwapper
|
from protestswap.protestswap import ProtestFaceSwapper
|
||||||
|
|
||||||
logger = logging.getLogger("protestswap")
|
logger = logging.getLogger("protestswap")
|
||||||
|
|
@ -10,19 +14,47 @@ INSIGHTFACE_VAR = "INSIGHTFACE_ROOT_DIR"
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="protestswap", description="Swaps faces of people on protests"
|
||||||
|
)
|
||||||
|
parser.add_argument("target", help="The image to swap the faces on")
|
||||||
|
parser.add_argument(
|
||||||
|
"-t", "--template", action="append", help="The template faces to paste on"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(args.target)
|
||||||
|
|
||||||
|
if not args.template:
|
||||||
|
print("You need at least one template (with '-t')!")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
insightface_dir = None
|
insightface_dir = None
|
||||||
if os.path.exists(DEFAULT_INSIGHTFACE_DIR):
|
if os.environ.get(INSIGHTFACE_VAR):
|
||||||
insightface_dir = DEFAULT_INSIGHTFACE_DIR
|
|
||||||
elif os.environ.get(INSIGHTFACE_VAR):
|
|
||||||
insightface_dir = os.environ.get(INSIGHTFACE_VAR)
|
insightface_dir = os.environ.get(INSIGHTFACE_VAR)
|
||||||
|
elif os.path.exists(DEFAULT_INSIGHTFACE_DIR):
|
||||||
|
insightface_dir = DEFAULT_INSIGHTFACE_DIR
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"No directory at '{DEFAULT_INSIGHTFACE_DIR}' and '{INSIGHTFACE_VAR}' not set yet."
|
f"No directory at '{DEFAULT_INSIGHTFACE_DIR}' and '{INSIGHTFACE_VAR}' not set yet."
|
||||||
)
|
)
|
||||||
insightface_dir = f"{os.environ.get('HOME')}/.cache/insightface"
|
insightface_dir = f"{os.environ.get('HOME')}/.cache/insightface"
|
||||||
|
|
||||||
# swapper = ProtestFaceSwapper(insightface_dir)
|
swapper = ProtestFaceSwapper(insightface_dir)
|
||||||
# swapper.prepare_model()
|
swapper.prepare_model()
|
||||||
|
|
||||||
|
swapper.set_target_path(args.target)
|
||||||
|
|
||||||
|
for template in args.template:
|
||||||
|
print(f"Adding template {template}...")
|
||||||
|
swapper.add_template_path(template)
|
||||||
|
|
||||||
|
print("Detecting faces")
|
||||||
|
swapper.detect_target_faces()
|
||||||
|
|
||||||
|
swapped = swapper.swap()
|
||||||
|
cv2.imwrite("result.png", swapped)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import random # TODO: seed!
|
import random # TODO: seed!
|
||||||
|
|
@ -14,37 +14,61 @@ from insightface.app.common import Face
|
||||||
|
|
||||||
logger = logging.getLogger("protestswap")
|
logger = logging.getLogger("protestswap")
|
||||||
|
|
||||||
SWAPPING_MODEL = 'inswapper_128.onnx'
|
SWAPPING_MODEL = "inswapper_128.onnx"
|
||||||
|
|
||||||
|
|
||||||
class ProtestFaceSwapper:
|
class ProtestFaceSwapper:
|
||||||
def __init__(self, root_dir):
|
def __init__(self, root_dir):
|
||||||
self._app : FaceAnalysis = FaceAnalysis(name='buffalo_l', root=insightface_dir)
|
self._root_dir = root_dir
|
||||||
|
self._app: FaceAnalysis = FaceAnalysis(name="buffalo_l", root=root_dir)
|
||||||
self._model: INSwapper
|
self._model: INSwapper
|
||||||
self._source : cv2.typing.MatLike|None = None
|
self._target: cv2.typing.MatLike | None = None
|
||||||
self._source_faces : list[Face] = []
|
self._target_faces: list[Face] = []
|
||||||
self._template_faces: list[Face] = []
|
self._template_faces: list[Face] = []
|
||||||
|
|
||||||
self._app.prepare(ctx_id=0, det_size=(640, 640))
|
self._app.prepare(ctx_id=0, det_size=(640, 640))
|
||||||
|
|
||||||
def prepare_model(self) -> None:
|
def prepare_model(self) -> None:
|
||||||
self._model = insightface.model_zoo.get_model(SWAPPING_MODEL, download=True, download_zip=True) # Revealed type: "INSwapper"
|
self._model = insightface.model_zoo.get_model(
|
||||||
|
os.path.join(self._root_dir, "models", SWAPPING_MODEL),
|
||||||
|
root=self._root_dir,
|
||||||
|
download=False,
|
||||||
|
download_zip=False,
|
||||||
|
) # Revealed type: "INSwapper"
|
||||||
|
print(f"Model: {self._model}")
|
||||||
|
|
||||||
def set_source(self, path: str) -> None:
|
def set_target_path(self, path: str) -> None:
|
||||||
self._source = cv2.imread(path)
|
self.set_target(cv2.imread(path))
|
||||||
|
|
||||||
def detect_source_faces(self):
|
def set_target(self, image: cv2.UMat) -> None:
|
||||||
self._source_faces = self._app.get(self._source)
|
self._target = image
|
||||||
|
|
||||||
def add_template_faces(self, template_path: str):
|
def detect_target_faces(self):
|
||||||
self._template_faces += [self._app.get(cv2.imread(template_path))]
|
self._target_faces = self._app.get(self._target)
|
||||||
|
print(f"Found {len(self._target_faces)} faces!")
|
||||||
|
|
||||||
|
def add_template_path(self, template_path: str):
|
||||||
|
self.add_template(cv2.imread(template_path))
|
||||||
|
|
||||||
|
def add_template(self, template: cv2.UMat):
|
||||||
|
self._template_faces += self._app.get(template)
|
||||||
|
|
||||||
def swap(self) -> cv2.typing.MatLike:
|
def swap(self) -> cv2.typing.MatLike:
|
||||||
work_copy = self._source_faces.copy()
|
work_copy = self._target.copy()
|
||||||
for i, face in enumerate(self._source_faces):
|
for i, face in enumerate(self._target_faces):
|
||||||
print(f"Replacing face {i+1}/{len(self._source_faces)}...", end="", flush=True);
|
print(
|
||||||
|
f"Replacing face {i + 1}/{len(self._target_faces)}...",
|
||||||
|
# end="",
|
||||||
|
# flush=True,
|
||||||
|
)
|
||||||
imposed_face = random.choice(self._template_faces)
|
imposed_face = random.choice(self._template_faces)
|
||||||
work_copy = self._model.get(work_copy, face, imposed_face, paste_back=True)
|
print(
|
||||||
|
f"work_copy: {type(work_copy)}, face: {type(face)}, imposed: {
|
||||||
|
type(imposed_face)
|
||||||
|
}"
|
||||||
|
)
|
||||||
|
work_copy = self._model.get(
|
||||||
|
work_copy, face, imposed_face, paste_back=True)
|
||||||
print(" done!")
|
print(" done!")
|
||||||
|
|
||||||
return work_copy
|
return work_copy
|
||||||
|
|
@ -52,13 +76,3 @@ class ProtestFaceSwapper:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def readFile(path: str) -> cv2.typing.MatLike | None:
|
def readFile(path: str) -> cv2.typing.MatLike | None:
|
||||||
return cv2.imread(path)
|
return cv2.imread(path)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
|
|
||||||
swapper = ProtestFaceSwapper(insightface_dir)
|
|
||||||
swapper.prepare_model()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue