# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict as TyDict from typing import List, Sequence import torch import torch.nn as nn from depth_anything_3.model.dpt import DPT from depth_anything_3.model.utils.head_utils import activate_head_gs, custom_interpolate class GSDPT(DPT): def __init__( self, dim_in: int, patch_size: int = 14, output_dim: int = 4, activation: str = "linear", conf_activation: str = "sigmoid", features: int = 256, out_channels: Sequence[int] = (256, 512, 1024, 1024), pos_embed: bool = True, feature_only: bool = False, down_ratio: int = 1, conf_dim: int = 1, norm_type: str = "idt", # use to match legacy GS-DPT head, "idt" / "layer" fusion_block_inplace: bool = False, ) -> None: super().__init__( dim_in=dim_in, patch_size=patch_size, output_dim=output_dim, activation=activation, conf_activation=conf_activation, features=features, out_channels=out_channels, pos_embed=pos_embed, down_ratio=down_ratio, head_name="raw_gs", use_sky_head=False, norm_type=norm_type, fusion_block_inplace=fusion_block_inplace, ) self.conf_dim = conf_dim if conf_dim and conf_dim > 1: assert ( conf_activation == "linear" ), "use linear prediction when using view-dependent opacity" merger_out_dim = features if feature_only else features // 2 self.images_merger = nn.Sequential( nn.Conv2d(3, merger_out_dim // 4, 3, 1, 1), # fewer channels first nn.GELU(), nn.Conv2d(merger_out_dim // 4, merger_out_dim // 2, 3, 1, 1), nn.GELU(), nn.Conv2d(merger_out_dim // 2, merger_out_dim, 3, 1, 1), nn.GELU(), ) # ------------------------------------------------------------------------- # Internal forward (single chunk) # ------------------------------------------------------------------------- def _forward_impl( self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int, images: torch.Tensor, ) -> TyDict[str, torch.Tensor]: B, _, C = feats[0].shape ph, pw = H // self.patch_size, W // self.patch_size resized_feats = [] for stage_idx, take_idx in enumerate(self.intermediate_layer_idx): x = feats[take_idx][:, patch_start_idx:] # [B*S, N_patch, C] x = self.norm(x) x = x.permute(0, 2, 1).reshape(B, C, ph, pw) # [B*S, C, ph, pw] x = self.projects[stage_idx](x) if self.pos_embed: x = self._add_pos_embed(x, W, H) x = self.resize_layers[stage_idx](x) # Align scale resized_feats.append(x) # 2) Fusion pyramid (main branch only) fused = self._fuse(resized_feats) fused = self.scratch.output_conv1(fused) # 3) Upsample to target resolution, optionally add position encoding again h_out = int(ph * self.patch_size / self.down_ratio) w_out = int(pw * self.patch_size / self.down_ratio) fused = custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True) # inject the image information here fused = fused + self.images_merger(images) if self.pos_embed: fused = self._add_pos_embed(fused, W, H) # 4) Shared neck1 # feat = self.scratch.output_conv1(fused) feat = fused # 5) Main head: logits -> activate_head or single channel activation main_logits = self.scratch.output_conv2(feat) outs: TyDict[str, torch.Tensor] = {} if self.has_conf: pred, conf = activate_head_gs( main_logits, activation=self.activation, conf_activation=self.conf_activation, conf_dim=self.conf_dim, ) outs[self.head_main] = pred.squeeze(1) outs[f"{self.head_main}_conf"] = conf.squeeze(1) else: outs[self.head_main] = self._apply_activation_single(main_logits).squeeze(1) return outs