import argparse
from pprint import pprint

import torch

from zoedepth.utils.easydict import EasyDict as edict
import torch.optim as optim
import torch.cuda.amp as amp
import torch.nn as nn
from zoedepth.utils.misc import count_parameters, parallelize
from zoedepth.trainers.builder import get_trainer
import torch.multiprocessing as mp
from zoedepth.data.data_mono import DepthDataLoader
from zoedepth.models.builder import build_model
from zoedepth.data.data_mono import MixedNYUKITTI
from zoedepth.utils.arg_utils import parse_unknown
from zoedepth.utils.config import change_dataset, get_config, ALL_EVAL_DATASETS, ALL_INDOOR, ALL_OUTDOOR
from zoedepth.trainers.loss import GradL1Loss, SILogLoss, FocalLossV1
import zoedepth.utils.logging as logging

class DepthRouter(nn.Module):
    def __init__(self):
        super(DepthRouter, self).__init__()
        self.conv1 = nn.Conv2d(4, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax()

    def forward(self, depth_maps):
        x = self.relu(self.conv1(depth_maps))
        x = self.softmax(self.conv2(x))
        routed_depth_map = torch.sum(x * depth_maps, dim=1) / torch.sum(x, dim=1)

        return routed_depth_map.unsqueeze(1)

def main(config):
    # load zoedepth_nk_generic
    depth_model = build_model(config)
    
    # load data
    train_loader = DepthDataLoader(config, 'train').data
    test_loader = DepthDataLoader(config, 'online_eval').data

    # train router
    depth_model = depth_model.cuda()
    in_channels = [1,1,1]
    out_cnanels = 1
    router = DepthRouter().cuda()

    # Training settings
    device = torch.device('cuda')
    criterion_d = SILogLoss()
    optimizer = optim.Adam(router.parameters(), 1e-4)

    global global_step
    global_step = 0

    # training
    for epoch in range(1, config.epochs + 1):
        print('\nEpoch: %03d - %03d' % (epoch, config.epochs))
        loss_train = train_router(depth_model, router, train_loader, optimizer, criterion_d, epoch, device, config)
        
        """
        writer.add_scalar('Training loss', loss_train, epoch)

        if epoch % args.val_freq == 0:
            results_dict, loss_val = validate(val_loader, model, criterion_d, 
                                              device=device, epoch=epoch, args=args,
                                              log_dir=log_dir)
            writer.add_scalar('Val loss', loss_val, epoch)

            result_lines = logging.display_result(results_dict)
            if args.kitti_crop:
                print("\nCrop Method: ", args.kitti_crop)
            print(result_lines)

            with open(log_txt, 'a') as txtfile:
                txtfile.write('\nEpoch: %03d - %03d' % (epoch, args.epochs))
                txtfile.write(result_lines)                

            for each_metric, each_results in results_dict.items():
                writer.add_scalar(each_metric, each_results, epoch)
        """
def validate_router(focal_model, router_model):

    return 

def train_router(depth_model, router, train_loader, optimizer, criterion_d, epoch, device, config):
    global global_step
    depth_model.eval()
    router.train()
    depth_loss = logging.AverageMeter()
    half_epoch = config.epochs // 2
    scaler = amp.GradScaler(enabled=config.use_amp)

    for batch_idx, batch in enumerate(train_loader):      
        global_step += 1

        for param_group in optimizer.param_groups:
            if global_step < 2019 * half_epoch:
                current_lr = (3e-3 - 1e-4) * (global_step /
                                              2019/half_epoch) ** 0.9 + 3e-5
            else:
                current_lr = (1e-4 - 3e-3) * (global_step /
                                              2019/half_epoch - 1) ** 0.9 + 1e-4
            param_group['lr'] = current_lr

        input_RGB = batch['image'].to(device)
        depth_gt = batch['depth'].to(device)
        mask = batch["mask"].to(device).to(torch.bool)
        with torch.no_grad():
            preds = depth_model(input_RGB)
            depth_maps = torch.cat((preds['near'], preds['middle'], preds['wide'], preds['ultra']), dim=1)
        focal_depth = router(depth_maps)
        
        loss_d, _ = criterion_d(focal_depth, depth_gt, mask=mask, interpolate=True, return_interpolated=True)
        depth_loss.update(loss_d.item(), input_RGB.size(0))
        scaler.scale(loss_d).backward()

        logging.progress_bar(batch_idx, len(train_loader), config.epochs, epoch,
                            ('Depth Loss: %.4f (%.4f)' %
                            (depth_loss.val, depth_loss.avg)))
        #optimizer.step()
        scaler.step(optimizer)
        optimizer.zero_grad()

    return loss_d

def infer(model, rgb):
    pass


if __name__ == '__main__':
    mp.set_start_method('forkserver')
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", type=str,required=True, 
                        default='zoedepth_nk_generic',help="Name of the model to evaluate")
    parser.add_argument("-p", "--pretrained_resource", type=str,
                        required=True, default='local::/home/yss/桌面/code-master/Long-tail/weights/ZoeDepthNKgeneric_10-May_21-34-678642d67723_latest.pt', help="Pretrained resource to use for fetching weights. If not set, default resource from model config is used,  Refer models.model_io.load_state_from_resource for more details.")
    parser.add_argument("-d", "--dataset", type=str, required=False,
                        default='nyu', help="Dataset to evaluate on")
    parser.add_argument("-v","--version_name", type=str, required=False,
                        default='generic', help="version_name")
    parser.add_argument("--batch_size", type=int, required=False,
                        default='3', help="version_name")
    

    args, unknown_args = parser.parse_known_args()
    overwrite = {"pretrained_resource": args.pretrained_resource}
    config = get_config(args.model, "eval", args.dataset, **overwrite)
    # add new parameters
    config.batch_size = 12
    config.distributed = False
    config.epochs = 5
    config.router_model = 'zoedepth'
    config.router_pretrained_resource = 'local::/home/yss/桌面/code-master/Long-tail/weights/ZoeDepthv1_09-May_13-44-e2e00279f3b7_best.pt'
    

    main(config)
    print('done')