import argparse
from pprint import pprint

import torch

from zoedepth.utils.easydict import EasyDict as edict
import torch.optim as optim
from tqdm import tqdm
import torch.cuda.amp as amp
import torch.nn as nn
from zoedepth.utils.misc import count_parameters, parallelize
from zoedepth.trainers.builder import get_trainer
import torch.multiprocessing as mp
from zoedepth.data.data_mono import DepthDataLoader
from zoedepth.models.builder import build_model
from zoedepth.data.data_mono import MixedNYUKITTI
from zoedepth.utils.arg_utils import parse_unknown
from zoedepth.utils.config import change_dataset, get_config, ALL_EVAL_DATASETS, ALL_INDOOR, ALL_OUTDOOR
from zoedepth.trainers.loss import GradL1Loss, SILogLoss, FocalLossV1
import zoedepth.utils.logging as logging
from zoedepth.utils.misc import (RunningAverageDict, colors, compute_metrics,compute_every_err,
                        count_parameters)
from zoedepth.models.transformer_decoder.router import *
#"-p", "local::/home/yss/桌面/code-master/Focus/weights/ZoeDepthNKgeneric_15-May_20-43-38d049ca35b3_best.pt"



def slice_image(image, slice_size):
    bs, c, height, width = image.shape
    num_rows = height // slice_size[0]
    num_cols = width // slice_size[1]
    slices = []
    
    for row in range(num_rows):
        for col in range(num_cols):
            start_row = row * slice_size[0]
            end_row = (row + 1) * slice_size[0]
            start_col = col * slice_size[1]
            end_col = (col + 1) * slice_size[1]
            
            image_slice = image[:, :, start_row:end_row, start_col:end_col]
            slices.append(image_slice)
    
    return torch.cat(slices, dim=1)


def reconstruct_slices(slices, num_rows, num_cols):
    #height = num_rows * slice_size[0]
    #width = num_cols * slice_size[1]
    bs, _, slice_h, slice_w = slices.shape
    image = torch.zeros(1, bs, slice_h*num_rows, slice_w*num_rows).to(slices.device)
    
    for i, slice in enumerate(torch.transpose(slices, 1, 0)):
        row = i // num_cols
        col = i % num_cols
        start_row = row * slice_h
        end_row = (row + 1) * slice_h
        start_col = col * slice_w
        end_col = (col + 1) * slice_w
        
        image[:, :, start_row:end_row, start_col:end_col] = slice
    
    return image.transpose(1, 0)


class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, in_channels),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return (x * y).mean(dim=1).unsqueeze(1)

class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.bot_conv =  nn.Sequential(
            nn.Conv2d(in_channels=int(in_channels[0]),
                      out_channels=in_channels[0]*2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channels[0]*2),
            nn.ReLU(),
            nn.Conv2d(in_channels=int(in_channels[0]*2),
                      out_channels=out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
            )
        
        self.skip_conv1 = nn.Sequential(
            nn.Conv2d(in_channels=int(in_channels[1]),
                      out_channels=in_channels[1]*2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channels[1]*2),
            nn.ReLU(),
            nn.Conv2d(in_channels=int(in_channels[1]*2),
                      out_channels=out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
            )
        
        self.skip_conv2 = nn.Sequential(
            nn.Conv2d(in_channels=int(in_channels[2]),
                      out_channels=in_channels[2]*2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channels[2]*2),
            nn.ReLU(),
            nn.Conv2d(in_channels=int(in_channels[2]*2),
                      out_channels=out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
            )

        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        
        self.fusion1 = SelectiveFeatureFusion(out_channels)
        self.fusion2 = SelectiveFeatureFusion(out_channels)
        self.fusion3 = SelectiveFeatureFusion(out_channels)
        self.seg = nn.Sequential(
            nn.Conv2d(in_channels=out_channels,
                      out_channels=4, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.Conv2d(in_channels=4,
                      out_channels=4, kernel_size=3, stride=1, padding=1),
            nn.Softmax(dim=1)
            )
        
    def forward(self, x_blocks): # x_blocks, outconv_activation
        x_1, x_2, x_3,x_4 = x_blocks #bs 256 48 64   bs 256 96 128   bs 256 192 256  bs 256 24 32 
        x_1_ = self.bot_conv(x_1)
        out = self.up(x_1_)

        x_2_ = self.skip_conv1(x_2)
        out = self.fusion1(x_2_, out)
        out = self.up(out)

        x_3_ = self.skip_conv2(x_3)
        out = self.fusion2(x_3_, out)
        out = self.up(out)

        out = self.fusion3(x_4, out)
        out = self.seg(out)

        return out


class SelectiveFeatureFusion(nn.Module):
    def __init__(self, in_channel=64):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=int(in_channel*2),
                      out_channels=in_channel*2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channel*2),
            nn.ReLU(),
            nn.Conv2d(in_channels=int(in_channel*2),
                      out_channels=in_channel, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channel),
            nn.ReLU()
            )

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=in_channel, 
                      out_channels=int(in_channel), kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(int(in_channel)),
            nn.ReLU(),
            nn.Conv2d(in_channels=in_channel,
                      out_channels=int(in_channel / 2), kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(int(in_channel / 2)),
            nn.ReLU())

        self.conv3 = nn.Conv2d(in_channels=int(in_channel / 2), 
                               out_channels=2, kernel_size=3, stride=1, padding=1)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x_local, x_global):
        x = torch.cat((x_local, x_global), dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        attn = self.sigmoid(x)

        out = x_local * attn[:, 0, :, :].unsqueeze(1) + \
              x_global * attn[:, 1, :, :].unsqueeze(1)

        return out
    

def set_max_to_one(seg_map, depth_maps):
    max_values, _ = torch.max(seg_map[:, :4, :, :], dim=1, keepdim=True)
    Seg_map = torch.where(seg_map[:, :4, :, :] == max_values, torch.tensor(1).to(max_values.device), torch.tensor(0).to(max_values.device))

    return torch.sum(Seg_map * depth_maps, dim=1).unsqueeze(1)


class CrossAttention(nn.Module):
    def __init__(self, in_dim_q, in_dim_k, out_dim):
        super(CrossAttention, self).__init__()
        
        # Query projection
        self.query_conv = nn.Conv2d(in_dim_q, out_dim, kernel_size=1)
        
        # Key projection
        self.key_conv = nn.Conv2d(in_dim_k, out_dim, kernel_size=1)
        
        # Value projection
        self.value_conv = nn.Conv2d(in_dim_k, out_dim, kernel_size=1)
        
        # Softmax function
        self.softmax = nn.Softmax(dim=-1)
        self.relu = nn.ReLU()
        self.softplus = nn.Softplus()
    def forward(self, x_q, x_k):
        batch_size, _, height_q, width_q = x_q.size() # bs x 1 x 480 x 640
        _, _, height_k, width_k = x_k.size()
        
        # Query tensor
        query = self.query_conv(x_q).view(batch_size, -1, height_q * width_q).permute(0, 2, 1)               # bs x 3072 x 64 channel:bs x 64 x 3072
        
        # Key tensor
        key = self.key_conv(x_k).view(batch_size, -1, height_k * width_k)  # bs x 64 x 3072 channel:bs x 3072 x 64
        
        # Attention scores
        energy = torch.bmm(query, key)
        attention = self.softmax(energy) # bs x 3072 x 3072 # bs x 64 x 64
        
        # Value tensor
        value = self.value_conv(x_k) # bs x 4 x 
        #value = x_k.view(batch_size, -1, height_k * width_k) # bs x 4*64 x 3072 

        # Output tensor
        out = torch.bmm(self.softplus(value).view(batch_size, -1, height_k * width_k), attention.permute(0, 2, 1)).permute(0, 2, 1).view(batch_size, -1, height_q, width_q)
        
        return out#.mean(dim=1)


class Router(nn.Module):
    def __init__(self, in_dim_q=4, in_dim_k=4, out_dim=4):
        super().__init__()
        # 4focal as query
        w, h = 48, 64
        self.query_conv = nn.Sequential(
            nn.Conv2d(in_channels=in_dim_q, 
                      out_channels=out_dim * 2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_dim * 2),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_dim * 2, 
                      out_channels=out_dim, kernel_size=3, stride=1, padding=1),
            nn.AdaptiveAvgPool2d([w, h]),
            )# nn.AdaptiveAvgPool2d()
        # generic as key
        self.key_conv = nn.Sequential(
            nn.Conv2d(in_channels=in_dim_k, 
                      out_channels=out_dim * 2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_dim * 2),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_dim * 2, 
                      out_channels=out_dim, kernel_size=3, stride=1, padding=1),
            nn.AdaptiveAvgPool2d([w, h]),
            )
        self.wh = [w, h]
        self.softmax = nn.Softmax(dim=1)
        self.scale = 3072
    def forward(self, focal, generic):
        #generic = preds[]
        #focal = preds['generic']
        batch_size, c, height , width = generic.shape
        k = self.key_conv(generic.repeat(1,4,1,1)).view(batch_size, -1, self.wh[0] * self.wh[1])
        q = self.query_conv(focal).view(batch_size, -1, self.wh[0] * self.wh[1]).permute(0, 2, 1)
        v = focal
        energy = torch.bmm(q, k)
        attention = self.softmax(energy) # bs x 3072 x 3072 # bs x 64 x 64

        return attention


def main(config):
    # load zoedepth_nk_generic
    depth_model = build_model(config)
    
    # load data
    train_loader = DepthDataLoader(config, 'train').data
    #train_loader = MixedNYUKITTI(config, "train").data
    test_loader = DepthDataLoader(config, 'online_eval').data

    # train router
    depth_model = depth_model.eval().cuda()
    for param in depth_model.parameters():
        param.requires_grad = False
    #in_channels, out_channels = [256,256,256,256],32
    #in_dim_q, in_dim_k, out_dim = 4,4,4
    #router = Decoder(in_channels, out_channels).cuda()
    in_channels, num_classes = 256, 2
    router = MultiScaleMaskedTransformerDecoder(in_channels, num_classes).cuda()
    # Training settings
    device = torch.device('cuda')
    criterion_d = SILogLoss()
    focal_length = {
            'near':[0.001,  3.0, 0.25],
            'middle':[3.0,  10.0,   0.25],
            'wide':[10.0,   25.0,   0.3],
            'ultra':[25.0,   80.0,    0.4]
        }
    optimizer = optim.Adam(router.parameters(), 1e-4)

    global global_step
    global_step = 0

    # training
    for epoch in range(1, config.epochs + 1):
        print('\nEpoch: %03d - %03d' % (epoch, config.epochs))
        loss_train = train_router(depth_model, router, train_loader, test_loader, optimizer, criterion_d, epoch, device, focal_length, config)
        
        """
        writer.add_scalar('Training loss', loss_train, epoch)

        if epoch % args.val_freq == 0:
            results_dict, loss_val = validate(val_loader, model, criterion_d, 
                                              device=device, epoch=epoch, args=args,
                                              log_dir=log_dir)
            writer.add_scalar('Val loss', loss_val, epoch)

            result_lines = logging.display_result(results_dict)
            if args.kitti_crop:
                print("\nCrop Method: ", args.kitti_crop)
            print(result_lines)

            with open(log_txt, 'a') as txtfile:
                txtfile.write('\nEpoch: %03d - %03d' % (epoch, args.epochs))
                txtfile.write(result_lines)                

            for each_metric, each_results in results_dict.items():
                writer.add_scalar(each_metric, each_results, epoch)
        """
def validate_router(depth_model, router, test_loader, config):
    depth_model.eval()
    router.eval()
    metrics_list = [RunningAverageDict() for _ in range(10)]
    for indx, sample in tqdm(enumerate(test_loader), total=len(test_loader)):
        if 'has_valid_depth' in sample:
            if not sample['has_valid_depth']:
                continue
        image, depth = sample['image'], sample['depth']
        image, depth = image.cuda(), depth.cuda()
        o_img = image
        import pdb
        #pdb.set_trace()
        depth = depth.squeeze().unsqueeze(0).unsqueeze(0)
        focal = sample.get('focal', torch.Tensor(
            [715.0873]).cuda())  # This magic number (focal) is only used for evaluating BTS model
        #"""
        if config.dataset == 'kitti':# resize
            bs, _, h, w = image.shape
            assert w > h and bs == 1
            interval_all = w - 480 # 1216-352 = 864
            shift_size = 3
            interval = interval_all // (shift_size-1) # shift_size = 16  864//15=57
            sliding_images = []
            sliding_masks = torch.zeros((bs, 1, h, w), device=image.device) # 352x352
            for i in range(shift_size):
                sliding_images.append(image[..., :, i*interval:i*interval+480])
                sliding_masks[..., :, i*interval:i*interval+480] += 1
            image = torch.cat(sliding_images, dim=0)# 3x3x352x480          
        #"""    
        with torch.no_grad():
            preds = depth_model(image)
            depth_maps = torch.cat((preds['near'], preds['middle'], preds['wide'], preds['ultra'], preds['generic']), dim=1) # bs 4 384 512
            x_blocks, mask_features = preds['x_blocks'][:3], preds['x_blocks'][-1]
            #depth_maps = slice_image(depth_maps,[48, 64])
            #depth_query = slice_image(preds['generic'],[48, 64])
            seg_map = router(x_blocks, mask_features)
            focal_depth = torch.sum(seg_map * depth_maps[:,:3,:,:], dim=1).unsqueeze(1)
            #focal_depth = set_max_to_one(seg_map, depth_maps)
        seg_test = False
        Ms = compute_every_err(depth, focal_depth, seg_test=seg_test, config=config) 
        #Ms = compute_every_err(depth, pred['ultra'], config=config)
        #"""
        # Save image, depth, pred for visualization

        if seg_test:
            for i in range(len(Ms)):
                if Ms[i] != 0:
                    metrics_list[i].update(Ms[i])
        else:
            metrics_list[0].update(Ms)
    round_vals = True
    if round_vals:
        def r(m): return round(m, 3)
    else:
        def r(m): return m
    if seg_test:
        for i in range(len(Ms)):    
            metrics_list[i] = {k: r(v) for k, v in metrics_list[i].get_value().items()}
    else:
        metrics_list[0] = {k: r(v) for k, v in metrics_list[0].get_value().items()}

    print(metrics_list)

    return 

def train_router(depth_model, router, train_loader, test_loader, optimizer, criterion_d, epoch, device, focal_length, config):
    global global_step
    depth_model.eval()
    router.train()
    depth_loss = logging.AverageMeter()
    half_epoch = config.epochs // 2
    scaler = amp.GradScaler(enabled=config.use_amp)
    
    
    nyu_drop_ratio = {
            'many':(0.998,[0.001,4]),   #84353885
            'medium':(0.995,[4,6]),     #158136407
            'few':(0.0,[6,10])         #172015536
        }
    for batch_idx, batch in enumerate(train_loader):      
        global_step += 1
        steps = 4847
        for param_group in optimizer.param_groups:
            if global_step < steps * half_epoch:
                current_lr = (2e-4 - 2e-5) * (global_step / (steps * half_epoch)) ** 0.9 + 2e-5
            else:
                current_lr = 2e-4 - ((global_step - steps * half_epoch) / (steps * half_epoch)) * (2e-4 - 1e-7)
            param_group['lr'] = current_lr
        #print(current_lr)
        input_RGB = batch['image'].to(device)
        depth_gt = batch['depth'].to(device)
        mask = batch["mask"].to(device).to(torch.bool)
        """
        drop_mask = torch.rand(3,480,640) #if dataset=='nyu' else torch.rand(3,352,1216)
        for index, (_, value) in enumerate(nyu_drop_ratio.items()):
            drop_mask[index,:,:] = drop_mask[index,:,:] > value[0]
        drop_mask = drop_mask.to(depth_gt.device)
        for i in range(depth_gt.shape[0]):
            depth = depth_gt[i]
            retain_mask = torch.zeros([480,640]).to(depth_gt.device)
            for index, (_, value) in enumerate(nyu_drop_ratio.items()):
                temp_mask = torch.logical_and(drop_mask[index,:,:], torch.logical_and(depth > value[1][0], depth <= value[1][1]))
                retain_mask = torch.logical_or(retain_mask, temp_mask)
            mask[i,:,:,:] = torch.logical_and(retain_mask, mask[i,:,:,:])
        """
        # focal length mask
        Focal_mask = {k:None for k in focal_length}
        for focal in Focal_mask:
            Focal_mask[focal] = torch.logical_and(mask, torch.logical_and(depth_gt > focal_length[focal][0], depth_gt <= focal_length[focal][1]))
        #Focal_mask['valid'] = torch.logical_or(Focal_mask['near'] + Focal_mask['middle'], Focal_mask['wide'] + Focal_mask['ultra'])
        #seg_gt = torch.cat((Focal_mask['near'], Focal_mask['middle'], Focal_mask['wide'], Focal_mask['ultra']), dim=1).float() # bs 4 384 512
        with torch.no_grad():
            preds = depth_model(input_RGB)
            depth_maps = torch.cat((preds['near'], preds['middle'], preds['wide'], preds['ultra'], preds['generic']), dim=1) # bs 4 384 512
            x_blocks, mask_features = preds['x_blocks'][:3], preds['x_blocks'][-1]# bs 256 48 64      bs 256 96 128       bs 256 192 256      bs 32 384 512
            #depth_maps = slice_image(depth_maps,[48, 64])
            #depth_query = slice_image(preds['generic'],[48, 64])
        seg_map = router(x_blocks, mask_features)
        focal_depth = torch.sum(seg_map * depth_maps[:,:3,:,:], dim=1).unsqueeze(1)
        #focal_depth = reconstruct_slices(focal_depth, num_rows=8, num_cols=8)
        
        """
        if seg_gt.shape[-2:] != seg_map.shape[-2:]:
            seg_map = nn.functional.interpolate(
                seg_map, seg_gt.shape[-2:], mode='bilinear', align_corners=True)
        seg_map = seg_map.permute(0, 2, 3, 1).contiguous().view(-1, 4)
        seg_gt = seg_gt.permute(0, 2, 3, 1).contiguous().view(-1, 4)
        valid_mask = Focal_mask['valid'].repeat(1, 4, 1, 1).permute(0, 2, 3, 1).contiguous().view(-1, 4)
        """
        loss_d, _ = criterion_d(focal_depth, depth_gt, mask=mask, interpolate=True, return_interpolated=True)
        depth_loss.update(loss_d.item(), input_RGB.size(0))
        #scaler.scale(loss_d).backward()
        loss_d.backward()
        logging.progress_bar(batch_idx, len(train_loader), config.epochs, epoch,
                            ('Depth Loss: %.4f (%.4f)' %
                            (depth_loss.val, depth_loss.avg)))
        optimizer.step()
        #scaler.step(optimizer)
        
        optimizer.zero_grad()
        if global_step % 400==0:
            validate_router(depth_model, router, test_loader, config)


    return loss_d

def infer(model, rgb):
    pass


if __name__ == '__main__':
    mp.set_start_method('forkserver')
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", type=str,required=True, 
                        default='zoedepth_nk_generic',help="Name of the model to evaluate")
    parser.add_argument("-p", "--pretrained_resource", type=str,
                        required=True, default='local::/home/yss/桌面/code-master/Focus/weights/ZoeDepthNKgeneric_15-May_20-43-38d049ca35b3_latest.pt', help="Pretrained resource to use for fetching weights. If not set, default resource from model config is used,  Refer models.model_io.load_state_from_resource for more details.")
    parser.add_argument("-d", "--dataset", type=str, required=False,
                        default='nyu', help="Dataset to evaluate on")
    parser.add_argument("-v","--version_name", type=str, required=False,
                        default='generic', help="version_name")
    parser.add_argument("--batch_size", type=int, required=False,
                        default='3', help="version_name")
    

    args, unknown_args = parser.parse_known_args()
    overwrite = {"pretrained_resource": args.pretrained_resource}
    config = get_config(args.model, "train", args.dataset, **overwrite)
    # add new parameters
    config.batch_size = 5
    config.distributed = False
    config.epochs = 2
    #config.router_model = 'zoedepth'
    #config.router_pretrained_resource = 'local::/home/yss/桌面/code-master/Long-tail/weights/ZoeDepthv1_09-May_13-44-e2e00279f3b7_best.pt'
    

    main(config)
    print('done')