import os
import cv2
import numpy as np
import torch
from PIL import Image
import gdown
from torch.nn import functional as F
import warnings
warnings.filterwarnings("ignore")
from sketchkit.animation.RIFE.model.RIFE_HDv3 import Model
from sketchkit.utils.file import download_with_wget, CISLAB_CDN, extract_files
[docs]
class RIFE:
"""
A class for sketch animation using the method of paper "Real-Time Intermediate Flow Estimation for Video Frame Interpolation"
in ECCV 2022.
Attributes:
device (str): The device that is used.
"""
def __init__(self, device: str="cuda"):
"""
Initializes the RIFE animator.
Args:
device (str, optional): The device that is used.
"""
self.method = "RIFE"
self.device = device
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
self._download_models()
self.model = Model(self.model_dir, device)
[docs]
def _download_models(self, cislab_source=False):
"""Download the pretrained models"""
root = os.path.join(
os.environ.get("HOME", "/"),
".sketchkit/weights/animation/",
self.method,
)
os.makedirs(root, exist_ok=True)
self.model_dir = os.path.join(root, "train_log")
self.model_path = os.path.join(self.model_dir, "flownet.pkl")
filename = "RIFEv4.25_0919.zip"
archive_path = os.path.join(root, filename)
if not os.path.exists(self.model_path):
if cislab_source:
cislab_url = f"{CISLAB_CDN}/datasets/Sketchy/{filename}"
print(f"Attempting download from CISLAB: {cislab_url}")
download_with_wget(
cislab_url, archive_path, desc="Downloading from CISLAB CDN"
)
print("CISLAB download completed successfully!")
else:
gdrive_id = "1ZKjcbmt1hypiFprJPIKW0Tt0lr_2i7bg"
print("Using gdown for Google Drive download...")
gdown.download(id=gdrive_id, output=archive_path, quiet=False)
print("Google Drive download completed successfully!")
extract_files(
file_path=archive_path, output_dir=root, remove_sourcefile=True
)
os.remove(os.path.join(self.model_dir, "IFNet_HDv3.py"))
os.remove(os.path.join(self.model_dir, "refine.py"))
os.remove(os.path.join(self.model_dir, "RIFE_HDv3.py"))
[docs]
def run(
self,
input_image_list: list[np.ndarray],
inner_frames: int
) -> list[np.ndarray]:
"""
Generate sketch animation or inbetweening for a list of keyframes.
Args:
input_image_list (list[numpy.ndarray]): a list of keyframes in shape (H, W, 3), with values in [0, 255].
inner_frames (int): number of intermediate frames.
Returns:
frame_list (list[numpy.ndarray]): a list of image frames in size H*W*3 in [0, 255].
"""
frame_list = []
for img_i in range(len(input_image_list) - 1):
img0 = input_image_list[img_i]
img1 = input_image_list[img_i + 1]
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(self.device) / 255.0).unsqueeze(0)
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(self.device) / 255.0).unsqueeze(0)
n, c, h, w = img0.shape
ph = ((h - 1) // 64 + 1) * 64
pw = ((w - 1) // 64 + 1) * 64
padding = (0, pw - w, 0, ph - h)
img0 = F.pad(img0, padding)
img1 = F.pad(img1, padding)
if img_i == 0:
frame_list.append(img0)
n = inner_frames + 1
for i in range(n - 1):
frame_list.append(self.model.inference(img0, img1, (i + 1) * 1.0 / n))
frame_list.append(img1)
frame_list = [(item[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]
for item in frame_list]
return frame_list