Source code for sketchkit.renderer3d.renderer3d

from PIL import Image

import numpy as np
from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
from copy import deepcopy


from sketchkit.core.sketch3d import Sketch3D
from sketchkit.core.camera import Camera
from sketchkit.core.sketch import Curve, Path, Point, Sketch, Vertex


[docs] @dataclass class Render3DOptions: canvas_size: tuple[int, int] = (512, 512) background_color: tuple[int, int, int] = (255, 255, 255) stroke_color: tuple[int, int, int] | None = None stroke_width: float | None = None
# back_end: str = "cairo" # "cairo" or "pydiffvg" # device: str = "cuda" # "cpu" or "cuda"
[docs] class Renderer3D(ABC): render_options = Render3DOptions() # exposed as public method
[docs] def project_curves_ndc( self, curves: torch.Tensor, cameras: list[Camera] | Camera | None = None, ) -> torch.Tensor: if cameras is not None: if isinstance(cameras, Camera): cameras = [cameras] """Project Bézier control points to normalized device coordinates using camera matrices.""" device = curves.device V = len(cameras) N, P = curves.shape[0], curves.shape[1] verts = curves.detach().cpu().numpy().reshape(-1, 3) # Homogeneous coordinates verts_homo = np.concatenate([verts, np.ones((verts.shape[0], 1))], axis=1) out = np.zeros((V, N, P, 2), dtype=np.float32) for vi in range(V): if cameras is not None: ext = cameras[vi].extrinsics K = cameras[vi].intrinsics # Camera coordinates: Xc = E @ Xw Xc_homo = (ext @ verts_homo.T).T z = np.clip(Xc_homo[:, 2:3], 1e-6, None) x_norm = Xc_homo[:, 0:1] / z y_norm = Xc_homo[:, 1:2] / z u = K[0, 0] * (x_norm) + K[0, 2] v = K[1, 1] * y_norm + K[1, 2] out[vi] = np.concatenate([u, v], axis=1).reshape(N, P, 2) return torch.from_numpy(out).to(device) else: raise ValueError("Must provide 'cameras' for projection.")
[docs] @staticmethod def _normalize_colors( colors: np.ndarray | torch.Tensor | list | None, ) -> np.ndarray | None: if colors is None: return None arr = np.asarray(colors, dtype=np.float32) if arr.ndim != 2 or arr.shape[1] not in (3, 4): raise ValueError("Colors must be shaped [N,3] or [N,4].") return arr
[docs] def _get_sketch_colors(self, sketch: Sketch3D) -> np.ndarray | None: colors = [] has_color = False for path in sketch.paths: for curve in path.curves: c = curve.p_start.color if c is not None: has_color = True # Ensure we have a consistent shape, e.g. (3,) or (4,) # If some are None and some are set, we need a default for None. # We'll handle normalization later, just collect raw values here. colors.append(c) else: colors.append(None) if not has_color: return None # Fill Nones with black or some default if we have mixed content # Assuming if has_color is True, we want to use per-curve colors. final_colors = [] for c in colors: if c is None: final_colors.append([0.0, 0.0, 0.0]) else: final_colors.append(c) return np.array(final_colors, dtype=np.float32)
[docs] def _resolve_stroke_colors( self, sketch: Sketch3D, options: Render3DOptions, num_curves: int ) -> np.ndarray: """Resolve stroke colors from options or sketch, returning normalized [0,1] float array.""" if options.stroke_color is not None: # options.stroke_color is 0-255 int/float tuple stroke_color = np.array( [c / 255.0 for c in options.stroke_color], dtype=np.float32 ) colors_arr = np.tile(stroke_color, (num_curves, 1)) else: colors_arr = self._get_sketch_colors(sketch) if colors_arr is None: colors_arr = np.zeros((num_curves, 3), dtype=np.float32) return colors_arr
[docs] def _get_sketch_widths(self, sketch: Sketch3D) -> np.ndarray | None: widths = [] has_width = False for path in sketch.paths: for curve in path.curves: # Check if curve has 4-point width array if curve.widths is not None: has_width = True widths.append(curve.widths) # shape (4,) # Fallback to vertex thickness elif curve.p_start.thickness is not None: has_width = True # Use same width for all 4 control points w = curve.p_start.thickness widths.append(np.array([w, w, w, w], dtype=np.float32)) else: widths.append(None) if not has_width: return None final_widths = [] for w in widths: if w is None: final_widths.append(np.array([2.0, 2.0, 2.0, 2.0], dtype=np.float32)) else: final_widths.append(w) return np.array(final_widths, dtype=np.float32) # shape (N, 4)
[docs] def _resolve_stroke_widths( self, sketch: Sketch3D, options: Render3DOptions, num_curves: int ) -> np.ndarray | float: """Resolve stroke widths from options or sketch.""" sw = None if options is not None: # Check if options is a dict or object if isinstance(options, dict): sw = options.get("stroke_width") else: sw = getattr(options, "stroke_width", None) if sw is not None: # Global stroke width: return scalar or expand to (N, 4) return float(sw) widths_arr = self._get_sketch_widths(sketch) if widths_arr is None: return 2.0 # widths_arr is now shape (N, 4) with per-control-point widths return widths_arr
# Project control points into a 2D Sketch representation for Cairo
[docs] def _sketch_from_projected( self, projected: np.ndarray, colors: np.ndarray | None, stroke_width: float | np.ndarray, ) -> Sketch: paths: list[Path] = [] for curve_idx, ctrl in enumerate(projected): px = ctrl[:, 0] * float(self.render_options.canvas_size[0]) py = ctrl[:, 1] * float( self.render_options.canvas_size[1] ) # no need to perform flip #(1.0 - ctrl[:, 1]) * float(self.height) rgb = None if colors is not None: rgb = tuple(float(c) for c in colors[curve_idx][:3]) if stroke_width is None: sw = 1.0 elif isinstance(stroke_width, (float, int)): sw = float(stroke_width) elif isinstance(stroke_width, np.ndarray): if stroke_width.ndim == 1: # Per-curve scalar: shape (N,) sw = float(stroke_width[curve_idx]) elif stroke_width.ndim == 2: # Per-curve 4-point widths: shape (N, 4) # For Cairo, use average or starting width sw = float(stroke_width[curve_idx, 0]) else: sw = 1.0 else: sw = 1.0 v_start = Vertex(px[0], py[0], thickness=sw, color=rgb, opacity=1.0) v_end = Vertex(px[3], py[3], thickness=sw, color=rgb, opacity=1.0) p_ctrl1 = Point(px[1], py[1]) p_ctrl2 = Point(px[2], py[2]) paths.append(Path([Curve(v_start, v_end, p_ctrl1, p_ctrl2)])) return Sketch(paths)
[docs] def render( self, sketch3d: Sketch3D, cameras: list[Camera], render_options: Render3DOptions | dict | None = None, ) -> list[Image.Image]: """ Overridden render method to support both legacy batch rendering and new Renderer3D interface. """ if render_options is None: opts = self.render_options elif isinstance(render_options, dict): opts = deepcopy(self.render_options) for k, v in render_options.items(): if hasattr(opts, k): setattr(opts, k, v) else: opts = render_options return self._render_batch(sketch3d, cameras, opts)
[docs] @abstractmethod def _render_batch(self, *args, **kwargs) -> list[Image.Image]: pass