Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """This file contains useful layout utilities for images. They are: | |
| - add_border: Add a border to an image. | |
| - cat/hcat/vcat: Join images by arranging them in a line. If the images have different | |
| sizes, they are aligned as specified (start, end, center). Allows you to specify a gap | |
| between images. | |
| Images are assumed to be float32 tensors with shape (channel, height, width). | |
| """ | |
| from typing import Any, Generator, Iterable, Literal, Union | |
| import torch | |
| from torch import Tensor | |
| Alignment = Literal["start", "center", "end"] | |
| Axis = Literal["horizontal", "vertical"] | |
| Color = Union[ | |
| int, | |
| float, | |
| Iterable[int], | |
| Iterable[float], | |
| Tensor, | |
| Tensor, | |
| ] | |
| def _sanitize_color(color: Color) -> Tensor: # "#channel" | |
| # Convert tensor to list (or individual item). | |
| if isinstance(color, torch.Tensor): | |
| color = color.tolist() | |
| # Turn iterators and individual items into lists. | |
| if isinstance(color, Iterable): | |
| color = list(color) | |
| else: | |
| color = [color] | |
| return torch.tensor(color, dtype=torch.float32) | |
| def _intersperse(iterable: Iterable, delimiter: Any) -> Generator[Any, None, None]: | |
| it = iter(iterable) | |
| yield next(it) | |
| for item in it: | |
| yield delimiter | |
| yield item | |
| def _get_main_dim(main_axis: Axis) -> int: | |
| return { | |
| "horizontal": 2, | |
| "vertical": 1, | |
| }[main_axis] | |
| def _get_cross_dim(main_axis: Axis) -> int: | |
| return { | |
| "horizontal": 1, | |
| "vertical": 2, | |
| }[main_axis] | |
| def _compute_offset(base: int, overlay: int, align: Alignment) -> slice: | |
| assert base >= overlay | |
| offset = { | |
| "start": 0, | |
| "center": (base - overlay) // 2, | |
| "end": base - overlay, | |
| }[align] | |
| return slice(offset, offset + overlay) | |
| def overlay( | |
| base: Tensor, # "channel base_height base_width" | |
| overlay: Tensor, # "channel overlay_height overlay_width" | |
| main_axis: Axis, | |
| main_axis_alignment: Alignment, | |
| cross_axis_alignment: Alignment, | |
| ) -> Tensor: # "channel base_height base_width" | |
| # The overlay must be smaller than the base. | |
| _, base_height, base_width = base.shape | |
| _, overlay_height, overlay_width = overlay.shape | |
| assert base_height >= overlay_height and base_width >= overlay_width | |
| # Compute spacing on the main dimension. | |
| main_dim = _get_main_dim(main_axis) | |
| main_slice = _compute_offset( | |
| base.shape[main_dim], overlay.shape[main_dim], main_axis_alignment | |
| ) | |
| # Compute spacing on the cross dimension. | |
| cross_dim = _get_cross_dim(main_axis) | |
| cross_slice = _compute_offset( | |
| base.shape[cross_dim], overlay.shape[cross_dim], cross_axis_alignment | |
| ) | |
| # Combine the slices and paste the overlay onto the base accordingly. | |
| selector = [..., None, None] | |
| selector[main_dim] = main_slice | |
| selector[cross_dim] = cross_slice | |
| result = base.clone() | |
| result[selector] = overlay | |
| return result | |
| def cat( | |
| main_axis: Axis, | |
| *images: Iterable[Tensor], # "channel _ _" | |
| align: Alignment = "center", | |
| gap: int = 8, | |
| gap_color: Color = 1, | |
| ) -> Tensor: # "channel height width" | |
| """Arrange images in a line. The interface resembles a CSS div with flexbox.""" | |
| device = images[0].device | |
| gap_color = _sanitize_color(gap_color).to(device) | |
| # Find the maximum image side length in the cross axis dimension. | |
| cross_dim = _get_cross_dim(main_axis) | |
| cross_axis_length = max(image.shape[cross_dim] for image in images) | |
| # Pad the images. | |
| padded_images = [] | |
| for image in images: | |
| # Create an empty image with the correct size. | |
| padded_shape = list(image.shape) | |
| padded_shape[cross_dim] = cross_axis_length | |
| base = torch.ones(padded_shape, dtype=torch.float32, device=device) | |
| base = base * gap_color[:, None, None] | |
| padded_images.append(overlay(base, image, main_axis, "start", align)) | |
| # Intersperse separators if necessary. | |
| if gap > 0: | |
| # Generate a separator. | |
| c, _, _ = images[0].shape | |
| separator_size = [gap, gap] | |
| separator_size[cross_dim - 1] = cross_axis_length | |
| separator = torch.ones((c, *separator_size), dtype=torch.float32, device=device) | |
| separator = separator * gap_color[:, None, None] | |
| # Intersperse the separator between the images. | |
| padded_images = list(_intersperse(padded_images, separator)) | |
| return torch.cat(padded_images, dim=_get_main_dim(main_axis)) | |
| def hcat( | |
| *images: Iterable[Tensor], # "channel _ _" | |
| align: Literal["start", "center", "end", "top", "bottom"] = "start", | |
| gap: int = 8, | |
| gap_color: Color = 1, | |
| ): | |
| """Shorthand for a horizontal linear concatenation.""" | |
| return cat( | |
| "horizontal", | |
| *images, | |
| align={ | |
| "start": "start", | |
| "center": "center", | |
| "end": "end", | |
| "top": "start", | |
| "bottom": "end", | |
| }[align], | |
| gap=gap, | |
| gap_color=gap_color, | |
| ) | |
| def vcat( | |
| *images: Iterable[Tensor], # "channel _ _" | |
| align: Literal["start", "center", "end", "left", "right"] = "start", | |
| gap: int = 8, | |
| gap_color: Color = 1, | |
| ): | |
| """Shorthand for a horizontal linear concatenation.""" | |
| return cat( | |
| "vertical", | |
| *images, | |
| align={ | |
| "start": "start", | |
| "center": "center", | |
| "end": "end", | |
| "left": "start", | |
| "right": "end", | |
| }[align], | |
| gap=gap, | |
| gap_color=gap_color, | |
| ) | |
| def add_border( | |
| image: Tensor, # "channel height width" | |
| border: int = 8, | |
| color: Color = 1, | |
| ) -> Tensor: # "channel new_height new_width" | |
| color = _sanitize_color(color).to(image) | |
| c, h, w = image.shape | |
| result = torch.empty( | |
| (c, h + 2 * border, w + 2 * border), dtype=torch.float32, device=image.device | |
| ) | |
| result[:] = color[:, None, None] | |
| result[:, border : h + border, border : w + border] = image | |
| return result | |