PTI + DragGAN
I came across a tool called DragGAN this weekend. Although GANs are somewhat outdated, the fun example videos triggered me to play with the technique for a bit. Running the provided demos is very easy in Google Colab. The only hiccup I experienced was that I had to manually upload the StyleGAN-Human model to Colab to add it to the GUI list. It is not included in the original download script.
- DragGAN
- edit StyleGAN images by "dragging" points from 1 spot to another
- https://arxiv.org/abs/2305.10973
- https://github.com/XingangPan/DragGAN
- Pivotal Tuning Inversion (PTI)
- enables StyleGAN editing on non-GAN-generated images
- https://arxiv.org/abs/2106.05744
- https://github.com/danielroich/PTI
The DragGAN tutorial suggests using the PTI technique to use it own custom images. There are however no detailed instructions on how to combine the two techniques and pass the correct information between them. This notebook shows how it can be done. It can run in Google Colab on a T4 GPU.
Note that the basemodel we use here is stylegan2_ada_ffhq
which has been trained on Flickr Faces HD (FFHD). As such, it will only work on pictures of faces.
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
class Downloader(object):
def __init__(self, use_pydrive):
self.use_pydrive = use_pydrive
if self.use_pydrive:
self.authenticate()
def authenticate(self):
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
self.drive = GoogleDrive(gauth)
def download_file(self, file_id, file_dst):
if self.use_pydrive:
downloaded = self.drive.CreateFile({'id':file_id})
downloaded.FetchMetadata(fetch_all=True)
downloaded.GetContentFile(file_dst)
else:
!gdown --id $file_id -O $file_dst
downloader = Downloader(True)
Step 1 - Install Packages required by PTI¶
!pip install lpips wandb
# used for faster inference of StyleGAN by enabling C++ code compilation
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force
Step 2 - Download Pretrained models¶
!git clone https://github.com/XingangPan/DragGAN.git
!git clone https://github.com/danielroich/PTI.git
%cd /content/PTI/
!git checkout da94d59d15d94822e95840ab5a0aa9ba1a19c851
import os
image_dir_name = 'image'
os.makedirs(f'./{image_dir_name}_original', exist_ok=True)
os.makedirs(f'./{image_dir_name}_processed', exist_ok=True)
save_path = "pretrained_models"
os.makedirs(save_path, exist_ok=True)
downloader.download_file("125OG7SMkXI-Kf2aqiwLLHyCvSW-gZk3M", os.path.join(save_path, 'ffhq.pkl'))
downloader.download_file("1xPmn19T6Bdd-_RfCVlgNBbfYoh1muYxR", os.path.join(save_path, 'align.dat'))
Step 3 - Configuration Setup¶
import sys
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import display
from configs import paths_config, hyperparameters, global_config
from utils.align_data import pre_process_images
from scripts.run_pti import run_PTI
image_name = 'personal_image'
global_config.device = 'cuda'
paths_config.e4e = '/content/PTI/pretrained_models/e4e_ffhq_encode.pt'
paths_config.input_data_id = image_dir_name
paths_config.input_data_path = f'/content/PTI/{image_dir_name}_processed'
paths_config.stylegan2_ada_ffhq = '/content/PTI/pretrained_models/ffhq.pkl'
paths_config.checkpoints_dir = '/content/PTI/'
paths_config.style_clip_pretrained_mappers = '/content/PTI/pretrained_models'
hyperparameters.use_locality_regularization = False
Step 4 - Preproccess Data¶
TODO: upload a picture to /content/PTI/image_original/personal_image.jpg
original_image = Image.open(f'./{image_dir_name}_original/{image_name}.jpg')
pre_process_images(f'/content/PTI/{image_dir_name}_original')
Step 5 - Invert images using PTI¶
In order to run PTI and use StyleGAN2-ada, the cwd should the parent of 'torch_utils' and 'dnnlib'
model_id = run_PTI(use_wandb=False, use_multi_id_training=False)
Visualize results¶
def load_generators(model_id, image_name):
with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
d = pickle.load(f)
old_G = d['G_ema'].cuda()
old_D = d['D'].cuda()
with open(f'{paths_config.checkpoints_dir}/model_{model_id}_{image_name}.pt', 'rb') as f_new:
new_G = torch.load(f_new).cuda()
return old_G, old_D, new_G
old_G, old_D, new_G = load_generators(model_id, image_name)
# def plot_syn_images(syn_images):
# for img in syn_images:
# img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()[0]
# plt.axis('off')
# resized_image = Image.fromarray(img,mode='RGB').resize((256,256))
# display(resized_image)
# del img
# del resized_image
# torch.cuda.empty_cache()
w_pivot_path = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{paths_config.pti_results_keyword}/{image_name}/0.pt'
# w_pivot = torch.load(w_pivot_path)
# old_image = old_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True)
# new_image = new_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True)
# print('Upper image is the inversion before Pivotal Tuning and the lower image is the product of pivotal tuning')
# plot_syn_images([old_image, new_image])
Export¶
def export_updated_pickle(old_G, old_D, new_G, output_path):
tmp = {}
tmp['G'] = old_G.eval().requires_grad_(False).cpu()
tmp['G_ema'] = new_G.eval().requires_grad_(False).cpu()
tmp['D'] = old_D.eval().requires_grad_(False).cpu()
tmp['training_set_kwargs'] = None
tmp['augment_pipe'] = None
with open(output_path, 'wb') as f:
pickle.dump(tmp, f)
output_path = f'{paths_config.checkpoints_dir}/stylegan2_{image_name}.pkl'
export_updated_pickle(old_G, old_D, new_G, output_path)
import locale
locale.getpreferredencoding = lambda: "UTF-8"
!mkdir -p /content/DragGAN/checkpoints
!cp $output_path /content/DragGAN/checkpoints
!cp $w_pivot_path /content/DragGAN/checkpoints
DragGAN¶
%cd /content/DragGAN
!git checkout c5e88b3eaf64c33a9e82782d75b4329d16711c3a
!pip install -r requirements.txt
# !python scripts/download_model.py
Fix some errors in python scripts:
- use our custom w_pivot from PTI
- set the default model in the GUI to our own
- bypass the watermark due to a font issue
!sed -i 's#None.*w_load#torch.load("/content/DragGAN/checkpoints/0.pt"),#' /content/DragGAN/visualizer_drag_gradio.py
!sed -i 's/stylegan2_lions_512_pytorch/stylegan2_personal_image/' /content/DragGAN/visualizer_drag_gradio.py
!sed -i 's/d = ImageDraw/return input_image_array # d = ImageDraw/' /content/DragGAN/viz/renderer.py
!python /content/DragGAN/visualizer_drag_gradio.py