linhaotong
update
b9f87ab
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Model loading and state dict conversion utilities.
"""
from typing import Dict, Tuple
import torch
from depth_anything_3.utils.logger import logger
def convert_general_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert general model state dict to match current model architecture.
Args:
state_dict: Original state dictionary
Returns:
Converted state dictionary
"""
# Replace module prefixes
state_dict = {k.replace("module.", "model."): v for k, v in state_dict.items()}
state_dict = {k.replace(".net.", ".backbone."): v for k, v in state_dict.items()}
# Remove camera token if present
if "model.backbone.pretrained.camera_token" in state_dict:
del state_dict["model.backbone.pretrained.camera_token"]
# Replace camera token naming
state_dict = {
k.replace(".camera_token_extra", ".camera_token"): v for k, v in state_dict.items()
}
# Replace head naming
state_dict = {
k.replace("model.all_heads.camera_cond_head", "model.cam_enc"): v
for k, v in state_dict.items()
}
state_dict = {
k.replace("model.all_heads.camera_head", "model.cam_dec"): v for k, v in state_dict.items()
}
state_dict = {k.replace(".more_mlps.", ".backbone."): v for k, v in state_dict.items()}
state_dict = {k.replace(".fc_rot.", ".fc_qvec."): v for k, v in state_dict.items()}
state_dict = {
k.replace("model.all_heads.head", "model.head"): v for k, v in state_dict.items()
}
# Replace output naming
state_dict = {
k.replace("output_conv2_additional.sky_mask", "sky_output_conv2"): v
for k, v in state_dict.items()
}
state_dict = {k.replace("_ray.", "_aux."): v for k, v in state_dict.items()}
# Update GS-DPT head naming and value
state_dict = {k.replace("gaussian_param_head.", "gs_head."): v for k, v in state_dict.items()}
return state_dict
def convert_metric_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert metric model state dict to match current model architecture.
Args:
state_dict: Original metric state dictionary
Returns:
Converted state dictionary
"""
# Add module prefix for metric models
state_dict = {"module." + k: v for k, v in state_dict.items()}
return convert_general_state_dict(state_dict)
def load_pretrained_weights(model, model_path: str, is_metric: bool = False) -> Tuple[list, list]:
"""
Load pretrained weights for a single model.
Args:
model: Model instance to load weights into
model_path: Path to the pretrained weights
is_metric: Whether this is a metric model
Returns:
Tuple of (missed_keys, unexpected_keys)
"""
state_dict = torch.load(model_path, map_location="cpu")
if is_metric:
state_dict = convert_metric_state_dict(state_dict)
else:
state_dict = convert_general_state_dict(state_dict)
missed, unexpected = model.load_state_dict(state_dict, strict=False)
logger.info("Missed keys:", missed)
logger.info("Unexpected keys:", unexpected)
return missed, unexpected
def load_pretrained_nested_weights(
model, main_model_path: str, metric_model_path: str
) -> Tuple[list, list]:
"""
Load pretrained weights for a nested model with both main and metric branches.
Args:
model: Nested model instance
main_model_path: Path to main model weights
metric_model_path: Path to metric model weights
Returns:
Tuple of (missed_keys, unexpected_keys)
"""
# Load main model weights
state_dict0 = torch.load(main_model_path, map_location="cpu")
state_dict0 = convert_general_state_dict(state_dict0)
state_dict0 = {k.replace("model.", "model.da3."): v for k, v in state_dict0.items()}
# Load metric model weights
state_dict1 = torch.load(metric_model_path, map_location="cpu")
state_dict1 = convert_metric_state_dict(state_dict1)
state_dict1 = {k.replace("model.", "model.da3_metric."): v for k, v in state_dict1.items()}
# Combine state dictionaries
combined_state_dict = state_dict0.copy()
combined_state_dict.update(state_dict1)
missed, unexpected = model.load_state_dict(combined_state_dict, strict=False)
print("Missed keys:", missed)
print("Unexpected keys:", unexpected)
return missed, unexpected