PTI + DragGAN

I came across a tool called DragGAN this weekend. Although GANs are somewhat outdated, the fun example videos triggered me to play with the technique for a bit. Running the provided demos is very easy in Google Colab. The only hiccup I experienced was that I had to manually upload the StyleGAN-Human model to Colab to add it to the GUI list. It is not included in the original download script.

The DragGAN tutorial suggests using the PTI technique to use it own custom images. There are however no detailed instructions on how to combine the two techniques and pass the correct information between them. This notebook shows how it can be done. It can run in Google Colab on a T4 GPU.

Note that the basemodel we use here is stylegan2_ada_ffhq which has been trained on Flickr Faces HD (FFHD). As such, it will only work on pictures of faces.

In [1]:
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

class Downloader(object):
    def __init__(self, use_pydrive):
        self.use_pydrive = use_pydrive

        if self.use_pydrive:
            self.authenticate()

    def authenticate(self):
        auth.authenticate_user()
        gauth = GoogleAuth()
        gauth.credentials = GoogleCredentials.get_application_default()
        self.drive = GoogleDrive(gauth)

    def download_file(self, file_id, file_dst):
        if self.use_pydrive:
            downloaded = self.drive.CreateFile({'id':file_id})
            downloaded.FetchMetadata(fetch_all=True)
            downloaded.GetContentFile(file_dst)
        else:
            !gdown --id $file_id -O $file_dst

downloader = Downloader(True)

Step 1 - Install Packages required by PTI

In [ ]:
!pip install lpips wandb

# used for faster inference of StyleGAN by enabling C++ code compilation
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force

Step 2 - Download Pretrained models

In [ ]:
!git clone https://github.com/XingangPan/DragGAN.git
In [ ]:
!git clone https://github.com/danielroich/PTI.git
%cd /content/PTI/
!git checkout da94d59d15d94822e95840ab5a0aa9ba1a19c851
In [10]:
import os
image_dir_name = 'image'
os.makedirs(f'./{image_dir_name}_original', exist_ok=True)
os.makedirs(f'./{image_dir_name}_processed', exist_ok=True)
save_path = "pretrained_models"
os.makedirs(save_path, exist_ok=True)
In [11]:
downloader.download_file("125OG7SMkXI-Kf2aqiwLLHyCvSW-gZk3M", os.path.join(save_path, 'ffhq.pkl'))
In [12]:
downloader.download_file("1xPmn19T6Bdd-_RfCVlgNBbfYoh1muYxR", os.path.join(save_path, 'align.dat'))

Step 3 - Configuration Setup

In [19]:
import sys
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import display

from configs import paths_config, hyperparameters, global_config
from utils.align_data import pre_process_images
from scripts.run_pti import run_PTI

image_name = 'personal_image'
global_config.device = 'cuda'
paths_config.e4e = '/content/PTI/pretrained_models/e4e_ffhq_encode.pt'
paths_config.input_data_id = image_dir_name
paths_config.input_data_path = f'/content/PTI/{image_dir_name}_processed'
paths_config.stylegan2_ada_ffhq = '/content/PTI/pretrained_models/ffhq.pkl'
paths_config.checkpoints_dir = '/content/PTI/'
paths_config.style_clip_pretrained_mappers = '/content/PTI/pretrained_models'
hyperparameters.use_locality_regularization = False

Step 4 - Preproccess Data

TODO: upload a picture to /content/PTI/image_original/personal_image.jpg

In [ ]:
original_image = Image.open(f'./{image_dir_name}_original/{image_name}.jpg')
In [ ]:
pre_process_images(f'/content/PTI/{image_dir_name}_original')

Step 5 - Invert images using PTI

In order to run PTI and use StyleGAN2-ada, the cwd should the parent of 'torch_utils' and 'dnnlib'

In [ ]:
model_id = run_PTI(use_wandb=False, use_multi_id_training=False)

Visualize results

In [26]:
def load_generators(model_id, image_name):
  with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
    d = pickle.load(f)
    old_G = d['G_ema'].cuda()
    old_D = d['D'].cuda()

  with open(f'{paths_config.checkpoints_dir}/model_{model_id}_{image_name}.pt', 'rb') as f_new:
    new_G = torch.load(f_new).cuda()

  return old_G, old_D, new_G
In [27]:
old_G, old_D, new_G = load_generators(model_id, image_name)
In [28]:
# def plot_syn_images(syn_images):
#   for img in syn_images:
#       img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()[0]
#       plt.axis('off')
#       resized_image = Image.fromarray(img,mode='RGB').resize((256,256))
#       display(resized_image)
#       del img
#       del resized_image
#       torch.cuda.empty_cache()
In [29]:
w_pivot_path = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{paths_config.pti_results_keyword}/{image_name}/0.pt'
# w_pivot = torch.load(w_pivot_path)

# old_image = old_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True)
# new_image = new_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True)

# print('Upper image is the inversion before Pivotal Tuning and the lower image is the product of pivotal tuning')
# plot_syn_images([old_image, new_image])

Export

In [31]:
def export_updated_pickle(old_G, old_D, new_G, output_path):
  tmp = {}
  tmp['G'] = old_G.eval().requires_grad_(False).cpu()
  tmp['G_ema'] = new_G.eval().requires_grad_(False).cpu()
  tmp['D'] = old_D.eval().requires_grad_(False).cpu()
  tmp['training_set_kwargs'] = None
  tmp['augment_pipe'] = None

  with open(output_path, 'wb') as f:
      pickle.dump(tmp, f)

output_path = f'{paths_config.checkpoints_dir}/stylegan2_{image_name}.pkl'
export_updated_pickle(old_G, old_D, new_G, output_path)
In [32]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

!mkdir -p /content/DragGAN/checkpoints
!cp $output_path /content/DragGAN/checkpoints
!cp $w_pivot_path /content/DragGAN/checkpoints

DragGAN

In [ ]:
%cd /content/DragGAN
!git checkout c5e88b3eaf64c33a9e82782d75b4329d16711c3a
In [ ]:
!pip install -r requirements.txt
In [35]:
# !python scripts/download_model.py

Fix some errors in python scripts:

  • use our custom w_pivot from PTI
  • set the default model in the GUI to our own
  • bypass the watermark due to a font issue
In [36]:
!sed -i 's#None.*w_load#torch.load("/content/DragGAN/checkpoints/0.pt"),#' /content/DragGAN/visualizer_drag_gradio.py
!sed -i 's/stylegan2_lions_512_pytorch/stylegan2_personal_image/' /content/DragGAN/visualizer_drag_gradio.py
!sed -i 's/d = ImageDraw/return input_image_array  # d = ImageDraw/' /content/DragGAN/viz/renderer.py
In [ ]:
!python /content/DragGAN/visualizer_drag_gradio.py