GanAttack/GanInverter/editing/ganspace.py

27 lines
1.2 KiB
Python

from .base_editing import BaseEditing
import torch
class GANSpace(BaseEditing):
def __init__(self, opts):
super(GANSpace, self).__init__()
ganspace_pca = torch.load(opts.edit_path, map_location='cpu')
self.pca_idx, self.start, self.end, self.strength = opts.ganspace_directions
self.code_mean = ganspace_pca['mean'].cuda()
self.code_comp = ganspace_pca['comp'].cuda()[self.pca_idx]
self.code_std = ganspace_pca['std'].cuda()[self.pca_idx]
if opts.edit_save_path == '':
self.save_folder = f'ganspace_{self.pca_idx}_{self.start}_{self.end}_{self.strength}'
else:
self.save_folder = opts.edit_save_path
def edit_code(self, code):
edit_codes = []
for c in code:
w_centered = c - self.code_mean
w_coord = torch.sum(w_centered[0].reshape(-1) * self.code_comp.reshape(-1)) / self.code_std
delta = (self.strength - w_coord) * self.code_comp * self.code_std
delta_padded = torch.zeros(c.shape).to('cuda')
delta_padded[self.start:self.end] += delta.repeat(self.end - self.start, 1)
edit_codes.append(c + delta_padded)
return torch.stack(edit_codes)