# -*- coding: utf-8 -*-
##------------------------------------------
## Motion Supervised co-part Segmentation
## Demo Ver 0.01
##
## 2024.07.13 Masahiro Izutsu
##------------------------------------------
## motion_seg.py
# Examples with 10-segments model
# 10seg: source_image = './sample/image/16.png'
# target_video = './sample/videos/04.mp4'
# index = [2] (唇)
#
# 10seg: source_image = './sample/image/25.png'
# target_video = './sample/videos/11.mp4'
# index = [7,9] (目)
#
# 5seg: source_image = './sample/image/27.png'
# target_video = './sample/videos/02.mp4'
# index = [3,4,5] (髪)
#
# 5seg: source_image = './sample/image/27.png'
# target_video = './sample/videos/04.mp4'
# use_source_segmentation = True
# index = [3,4,5] (顔)
#
# 5seg: source_image = './sample/image/23.png'
# target_video = './sample/videos/07.mp4'
# index = [1] (髭)
#
# super: source_image = './sample/image/16.png'
# target_video = './sample/videos/04.mp4'
# index = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] (髪以外)
# Color Escape Code
GREEN = '\033[1;32m'
RED = '\033[1;31m'
NOCOLOR = '\033[0m'
YELLOW = '\033[1;33m'
# 定数定義
DEF_CONFIG_10SEG = './config/vox-256-sem-10segments.yaml'
DEF_CHECKPOINT_10SEG = './sample/vox-10segments.pth.tar'
DEF_CONFIG_5SEG = './config/vox-256-sem-5segments.yaml'
DEF_CHECKPOINT_5SEG = './sample/vox-5segments.pth.tar'
DEF_CONFIG_SUPER = './config/vox-256-sem-10segments.yaml'
DEF_CHECKPOINT_SUPER = './sample/vox-first-order.pth.tar'
DEF_RESULT_10SEG_VIDEO = './results/result_10seg.mp4'
DEF_RESULT_5SEG_VIDEO = './results/result_5seg.mp4'
DEF_RESULT_SUPER_VIDEO = './results/result_super.mp4'
DEF_RESULT_10SEG_IMAGE = './results/result_10seg.png'
DEF_RESULT_5SEG_IMAGE = './results/result_5seg.png'
DEF_RESULT_SUPER_IMAGE = './results/result_super.png'
DEF_RESULT_10SEG_INDEX_LIP = [2]
# import
import os
import argparse
import imageio.v2 as imageio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.patches as mpatches
import torch
import torch.nn.functional as F
from skimage.transform import resize
from skimage import img_as_ubyte
from part_swap2 import load_checkpoints
from part_swap2 import load_face_parser
from part_swap2 import make_video
import my_dialog
import my_imagetool
import my_videotool
import my_movieplay
import warnings
warnings.simplefilter('ignore', UserWarning) # warning error 対応
# タイトル
title = 'Motion Supervised co-part Segmentation Ver. 0.01'
sub_title = '10-segments model', '5-segments model', 'supervised part-swaps'
# Parses arguments for the application
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--category', default = '0', type = str, help = 'Category (0:10-segment, 1:5-segment, 2:supervised) Default value is 0')
parser.add_argument("--config", default='', help="path to config")
parser.add_argument("--checkpoint", default='', help="path to checkpoint to restore")
parser.add_argument("--source_image", default='', help="path to source image")
parser.add_argument("--target_video", default='', help="path to target video")
parser.add_argument("--result_video", default='', help="path to output")
parser.add_argument("--swap_index", default="-1", type=lambda x: list(map(int, x.split(','))), help='index of swaped parts')
parser.add_argument("--hard", action="store_true", help="use hard segmentation labels for blending")
parser.add_argument("--use_source_segmentation", action="store_true", help="use source segmentation for swaping")
parser.add_argument("--first_order_motion_model", action="store_true", help="use first order model for alignment")
parser.add_argument("--supervised", action="store_true", help="use supervised segmentation labels for blending. Only for faces.")
parser.add_argument("--cpu", action="store_true", help="cpu mode")
return parser
# 基本情報の表示
def display_info(opt, title):
if opt.category[0] == '0':
cat = f'{opt.category}: ** {sub_title[0]} **'
elif opt.category[0] == '1':
cat = f'{opt.category}: ** {sub_title[1]} **'
elif opt.category[0] == '2':
cat = f'{opt.category}: ** {sub_title[2]} **'
else:
cat = f'{opt.category}: ** setup **'
print('\n' + GREEN + title + ': Starting application...' + NOCOLOR)
print('\n - ' + YELLOW + 'category : ' + NOCOLOR, cat)
print(' - ' + YELLOW + 'config : ' + NOCOLOR, opt.config)
print(' - ' + YELLOW + 'checkpoint : ' + NOCOLOR, opt.checkpoint)
print(' - ' + YELLOW + 'source_image : ' + NOCOLOR, opt.source_image)
print(' - ' + YELLOW + 'target_video : ' + NOCOLOR, opt.target_video)
print(' - ' + YELLOW + 'result_video : ' + NOCOLOR, opt.result_video)
print(' - ' + YELLOW + 'swap_index : ' + NOCOLOR, opt.swap_index)
print(' - ' + YELLOW + 'hard : ' + NOCOLOR, opt.hard)
print(' - ' + YELLOW + 'use_source_segmentation : ' + NOCOLOR, opt.use_source_segmentation)
print(' - ' + YELLOW + 'first_order_motion_model: ' + NOCOLOR, opt.first_order_motion_model)
print(' - ' + YELLOW + 'supervised : ' + NOCOLOR, opt.supervised)
print(' - ' + YELLOW + 'cpu : ' + NOCOLOR, opt.cpu)
print(' ')
# セグメンテーションの視覚化
def visualize_segmentation(image, network, supervised=False, hard=True, colormap='gist_rainbow'):
with torch.no_grad():
inp = torch.tensor(image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).cuda()
if supervised:
inp = F.interpolate(inp, size=(512, 512))
inp = (inp - network.mean) / network.std
mask = torch.softmax(network(inp)[0], dim=1)
mask = F.interpolate(mask, size=image.shape[:2])
else:
mask = network(inp)['segmentation']
mask = F.interpolate(mask, size=image.shape[:2], mode='bilinear')
if hard:
mask = (torch.max(mask, dim=1, keepdim=True)[0] == mask).float()
colormap = plt.get_cmap(colormap)
num_segments = mask.shape[1]
mask = mask.squeeze(0).permute(1, 2, 0).cpu().numpy()
color_mask = 0
patches = []
for i in range(num_segments):
if i != 0:
color = np.array(colormap((i - 1) / (num_segments - 1)))[:3]
else:
color = np.array((0, 0, 0))
patches.append(mpatches.Patch(color=color, label=str(i)))
color_mask += mask[..., i:(i+1)] * color.reshape(1, 1, 3)
fig, ax = plt.subplots(1, 2, figsize=(12,6), dpi=64)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1) # 2024.07.13
ax[0].imshow(color_mask)
ax[1].imshow(0.3 * image + 0.7 * color_mask)
ax[1].legend(handles=patches)
ax[0].axis('off')
ax[1].axis('off')
# 終了処理(結果の表示と保存)
def end_Process(save_path, dispf = False):
if len(save_path) > 0:
plt.savefig(save_path)
if dispf:
plt.show()
plt.close()
return
# セグメント表示
def segment_disp(opt, msg = 'Visualize segmentation', maxsize = 0, loop_f = True):
source_image = imageio.imread(opt.source_image)
source_image = resize(source_image, (256, 256))[..., :3]
if opt.supervised:
face_parser = load_face_parser(opt.cpu)
visualize_segmentation(source_image, face_parser, supervised = True, hard = opt.hard, colormap = 'tab20')
else :
reconstruction_module, segmentation_module = load_checkpoints(opt.config, checkpoint = opt.checkpoint, blend_scale = 1)
visualize_segmentation(source_image, segmentation_module, hard = opt.hard)
end_Process(opt.result_video)
my_imagetool.image2disp(opt.result_video, winname = msg, maxsize = my_imagetool.WINDOW_WIDTH, loop_f = loop_f)
# ビデオ生成
def part_swap(opt, maxsize = 0, loop_f = True):
# ファイルの存在確認
if not os.path.isfile(opt.source_image):
print(RED + f"File not found !! '{opt.source_image}' " + NOCOLOR)
return
if not os.path.isfile(opt.target_video):
print(RED + f"File not found !! '{opt.target_video}' " + NOCOLOR)
return
# 静止画/動画 読み出し
source_image = imageio.imread(opt.source_image)
target_video, fps = my_videotool.read_video(opt.target_video)
# 256x256 にリサイズ6
source_image = resize(source_image, (256, 256))[..., :3]
target_video = [resize(frame, (256, 256))[..., :3] for frame in target_video]
# 静止画/動画から処理
blend_scale = (256 / 4) / 512 if opt.supervised else 1
reconstruction_module, segmentation_module = load_checkpoints(opt.config, opt.checkpoint, blend_scale=blend_scale,
first_order_motion_model=opt.first_order_motion_model, cpu=opt.cpu)
if opt.supervised:
face_parser = load_face_parser(opt.cpu)
else:
face_parser = None
predictions = make_video(opt.swap_index, source_image, target_video, reconstruction_module, segmentation_module,
face_parser, hard=opt.hard, use_source_segmentation=opt.use_source_segmentation, cpu=opt.cpu)
# 出力ファイル名
out_path1 = '' # 処理結果画像
out_path2 = '' # 静止画/元動画/処理結果画像
ext = ''
if len(opt.result_video) > 0:
base_dir_pair = os.path.split(opt.source_image)
s_name, ext = os.path.splitext(base_dir_pair[1])
base_dir_pair = os.path.split(opt.target_video)
d_name, ext = os.path.splitext(base_dir_pair[1])
name, ext = os.path.splitext(opt.result_video)
out_path1 = name + '_' + s_name + '_' + d_name + ext
out_path2 = name + '_' + s_name + '_' + d_name + '_a' + ext
# 処理結果の保存 1
if out_path1 != '':
if ext == '.gif':
imageio.mimsave(out_path1, [img_as_ubyte(frame) for frame in predictions], fps = fps, loop = 0)
else:
imageio.mimsave(out_path1, [img_as_ubyte(frame) for frame in predictions], fps = fps)
print(f" Saving... → '{out_path1}'")
# 音声トラックの付加
my_videotool.add_audio(opt.target_video, out_path1, log_f = False)
# 生成動画の表示 1
my_movieplay.movie_play(out_path1, title = 'Processed result image 1')
# 静止画/元動画/処理結果画像の作成
ani = my_videotool.img_movie3x1(source_image, target_video, predictions, interval = fps)
# 処理結果の保存 2
if out_path2 != '':
my_videotool.save_video(ani, out_path2)
print(f" Saving... → '{out_path2}'")
# 音声トラックの付加
my_videotool.add_audio(opt.target_video, out_path2, log_f = False)
# 入力画像/元動画/生成動画の表示 2
my_movieplay.movie_play(out_path2, title = 'Processed result image 2')
print('\n Finished.')
return out_path1, out_path2
# main関数エントリーポイント(実行開始)
if __name__ == "__main__":
parser = parse_args()
opt = parser.parse_args()
segment_f = opt.swap_index == [-1] # visualizing the segmentation flag
if len(opt.source_image) == 0:
opt.source_image = my_dialog.select_image_file('静止画像 ', './sample/images')
if len(opt.source_image) == 0:
exit(0)
if len(opt.target_video) == 0 and not segment_f:
opt.target_video = my_dialog.select_movie_file('参照動画 ', './sample/videos')
if len(opt.target_video) == 0:
exit(0)
# カテゴリー別の前処理
if opt.category[0] == '0': # 0: 10-segment
opt.supervised = False
opt.config = DEF_CONFIG_10SEG if len(opt.config) == 0 else opt.config
opt.checkpoint = DEF_CHECKPOINT_10SEG if len(opt.checkpoint) == 0 else opt.checkpoint
if segment_f:
opt.result_video = DEF_RESULT_10SEG_IMAGE
opt.hard = True
else:
opt.result_video = DEF_RESULT_10SEG_VIDEO if len(opt.result_video) == 0 else opt.result_video
if opt.category[0] == '1': # 1: 5-segment
opt.supervised = False
opt.config = DEF_CONFIG_5SEG if len(opt.config) == 0 else opt.config
opt.checkpoint = DEF_CHECKPOINT_5SEG if len(opt.checkpoint) == 0 else opt.checkpoint
if segment_f:
opt.result_video = DEF_RESULT_5SEG_IMAGE
opt.hard = True
else:
opt.result_video = DEF_RESULT_5SEG_VIDEO if len(opt.result_video) == 0 else opt.result_video
elif opt.category[0] == '2': # 2: supervised
opt.supervised = True
opt.config = DEF_CONFIG_SUPER if len(opt.config) == 0 else opt.config
opt.checkpoint = DEF_CHECKPOINT_SUPER if len(opt.checkpoint) == 0 else opt.checkpoint
opt.first_order_motion_model = True
if segment_f:
opt.result_video = DEF_RESULT_SUPER_IMAGE
opt.hard = True
else:
opt.result_video = DEF_RESULT_SUPER_VIDEO if len(opt.result_video) == 0 else opt.result_video
display_info(opt, title)
if segment_f:
segment_disp(opt, msg = sub_title[int(opt.category[0])], maxsize = my_imagetool.WINDOW_WIDTH, loop_f = False)
else:
part_swap(opt, maxsize = my_imagetool.WINDOW_WIDTH, loop_f = False)