sketchkit.utils package¶
Submodules¶
sketchkit.utils.dataset module¶
- class sketchkit.utils.dataset.SketchDataset(root: str | Path | None = None, load_all: bool = False, cislab_source: bool = False, skip_integrity_check: bool = False)[source]¶
Bases:
ABC,SequenceAbstract base class for sketch datasets.
This class provides a common interface for all sketch datasets in SketchKit. It handles dataset initialization, caching, downloading, and provides standard methods for accessing sketch data.
- root¶
Root directory where the dataset is stored.
- Type:
str
- items_metadata¶
DataFrame containing metadata for all dataset items.
- Type:
pandas.DataFrame
- load_all¶
Whether to load all data into memory at initialization.
- Type:
bool
- cislab_source¶
Whether to use CISLAB as the data source.
- Type:
bool
- raw_data¶
Container for all loaded data when load_all is True.
- items_metadata¶
DataFrame containing metadata for all dataset items.
- Type:
pandas.DataFrame
- raw_data¶
List containing memory-loaded data. If the item is not loaded, raw_data[id] is None. __getitem__ return item from this list.
- Type:
list
Note
Subclasses must implement all abstract methods to provide dataset-specific functionality for integrity checking, downloading, metadata loading, and data access.
- _abc_impl = <_abc._abc_data object>¶
- abstractmethod _check_integrity() bool[source]¶
Check if the dataset files are complete and valid.
This method should verify that all necessary dataset files exist and are not corrupted. It’s called during initialization to determine if the dataset needs to be downloaded.
- Returns:
True if the dataset is complete and valid, False otherwise.
- Return type:
bool
- abstractmethod _download()[source]¶
Download the dataset files to the root directory.
This method should implement the logic to download all necessary dataset files from their source location to the local root directory. It’s called when the integrity check fails during initialization.
- Raises:
RuntimeError – If download fails or encounters an error.
- abstractmethod _get_single(idx: int) Sketch[source]¶
Retrieve a sketch item by index.
- Parameters:
idx (int) – Index of the item to retrieve. Must be in range [0, len(dataset)).
- Returns:
The sketch object at the specified index.
- Return type:
- Raises:
IndexError – If idx is out of bounds.
- abstractmethod _load_all()[source]¶
Load all dataset items into memory.
This method should load all dataset items into memory for faster access. It’s called when load_all is True during initialization or when explicitly requested.
Note
The loaded data should be stored in self.raw_data for later access by __getitem__.
- abstractmethod _load_items_metadata()[source]¶
Load metadata for all items in the dataset into a pandas DataFrame.
This method should populate the items_metadata DataFrame with information about all available items in the dataset. The metadata is used to provide quick access to item information without loading the actual data.
Note
This method is called after successful integrity check and should populate self.items_metadata as a pandas DataFrame with appropriate columns for dataset-specific metadata.
- extra_repr() str[source]¶
Return extra information for the string representation.
This method can be overridden by subclasses to provide additional dataset-specific information in the string representation.
- Returns:
Additional information to include in __repr__. Empty by default.
- Return type:
str
- unload()[source]¶
Unloads the raw data from memory by deleting the reference and setting it to None. This method is used to free up memory by removing the raw_data attribute from the object. After calling this method, the raw_data will no longer be accessible and memory previously occupied by it can be garbage collected.
sketchkit.utils.file module¶
- sketchkit.utils.file.dir_md5(path, ignore_hidden=True)[source]¶
Calculates the MD5 hash of all files in a directory.
- Parameters:
path – Path to the directory.
ignore_hidden – Whether to ignore hidden files (starting with ‘.’). Defaults to True.
- Returns:
The MD5 hash as a hexadecimal string.
- Return type:
str
- sketchkit.utils.file.download_with_gdown(output_folder, gdrive_id: str, filename: str)[source]¶
Download dataset from google drive. :param output_folder: The folder to download the dataset. :param gdrive_id: The Google Drive ID of the dataset. :param filename: The name of the downloaded file.
- Returns:
None
- sketchkit.utils.file.download_with_wget(url: str, file_path: str, overwrite=True, pg_bar=True, desc=None)[source]¶
Downloads a file from a URL with a progress bar using tqdm.
- Parameters:
url – The URL to download the file from.
file_path – The local path where the file will be saved.
overwrite – Whether to overwrite the file if it already exists. Defaults to True.
desc – Description text to show in the progress bar. Defaults to None.
- sketchkit.utils.file.extract_files(file_path, output_dir, remove_sourcefile=True)[source]¶
Extracts files from an archive to a specified folder.
- Parameters:
file_path – Path to the archive file to extract.
output_dir – Path to the folder where files will be extracted.
remove_sourcefile – Whether to remove the source archive file after extraction. Defaults to True.
sketchkit.utils.geometry module¶
- sketchkit.utils.geometry.bezier_lengths(x: ndarray, steps: int = 64) ndarray[source]¶
Calculate the length of Bézier curves.
- Parameters:
x – [N, 4, 2] Control points (P0, P1, P2, P3) as numpy array
steps – Number of sampling steps along the curve
- Returns:
[N, 1] Lengths of the curves as numpy array
- sketchkit.utils.geometry.gauss_legendre_nodes_weights(m: int, device: device, dtype: dtype = torch.float32)[source]¶
返回 [0,1] 区间上的 m 阶 Gauss-Legendre 节点与权重
- sketchkit.utils.geometry.is_line(x: ndarray, epsilon: float = 1e-06)[source]¶
Check if a Curve is a line.
- Parameters:
x – [N, 4, 2] Control points (P0, P1, P2, P3) as numpy array
epsilon (float, optional) – Threshold for floating point comparison. Defaults to 1e-6.
- Returns:
True if the curve is a line, False otherwise.
- Return type:
[N] bool
- sketchkit.utils.geometry.points_collinear(a, b, c, epsilon=1e-06)[source]¶
Check if three points are collinear.
- Parameters:
a (np.ndarray) – First point (2,).
b (np.ndarray) – Second point (2,).
c (np.ndarray) – Third point (2,).
epsilon (float, optional) – Threshold for floating point comparison. Defaults to 1e-6.
- Returns:
True if points are collinear, False otherwise.
- Return type:
bool