import torch
from scene_loaders.ibr_scene import Scene
import random
import os
from tqdm import tqdm
from os import makedirs
from renderer.render_point_cloud import render_scene, render_scene_diffuse_only
import torchvision
from arguments.parse_args import get_args
from utils.camera_utils import get_test_cameras

args = get_args()

with torch.no_grad():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
    torch.manual_seed(0)
    random.seed(0)

    args.global_downscale = 1000

    scene = Scene(args, resolution_scales=[1.0], load_input_data=False)
    cameras_path = get_test_cameras(args)
    debug_path = os.path.join(args.scene_representation_folder, "test_path_renders", "iter_{}".format(scene.load_iter), "debug")

    render_path = os.path.join(args.scene_representation_folder, "test_path_renders", "iter_{}".format(scene.load_iter), "renders")
    gts_path = os.path.join(args.scene_representation_folder, "test_path_renders", "iter_{}".format(scene.load_iter), "gt")
    polytope_mask_path = os.path.join(args.scene_representation_folder, "test_path_renders", "iter_{}".format(scene.load_iter), "polytope_masks")

    # makedirs(debug_path, exist_ok=True)
    makedirs(render_path, exist_ok=True)
    makedirs(gts_path, exist_ok=True)
    makedirs(polytope_mask_path, exist_ok=True)

    staticCam = cameras_path[10]

    imgCount = len(cameras_path)

    for idx, cam in tqdm(enumerate(cameras_path)):
        torch.cuda.empty_cache()
        print("\rExporting images %i/%i" % (idx+1, imgCount), end="")
        gt = torch.cat((cam.original_image[:, 0:3, :, :].cuda(),
                        cam.gt_alpha_mask.cuda()), dim=1)
        polytope_mask = cam.polytope_mask.cuda()

        if not args.diffuse_only:
            view, diffuse_render, specular_render, mask, specular_environment_mask, _, _, _ = render_scene(cam, scene, cata_camera=None)
        else:
            view, diffuse_render, _ = render_scene_diffuse_only(cam, scene)
        if False:
            viewStatic, _, _, _, _, _ = render_scene(staticCam, scene, cata_camera=cam)
        
            collage = torch.cat((gt, view, viewStatic,
                                 diffuse_render[:,:3,...],
                                 specular_render[:,:3,...],
                                 mask.repeat(1,3,1,1)), dim=0)
            torchvision.utils.save_image(torchvision.utils.make_grid(collage, nrow=3), os.path.join(debug_path, "path_" + str(idx).zfill(4) + ".png"))
        torchvision.utils.save_image(view, os.path.join(render_path, "path_" + str(idx).zfill(4) + ".png"))
        torchvision.utils.save_image(gt, os.path.join(gts_path, "path_" + str(idx).zfill(4) + ".png"))
        torchvision.utils.save_image(polytope_mask, os.path.join(polytope_mask_path, "path_" + str(idx).zfill(4) + ".png"))

        #torchvision.utils.save_image(view, os.path.join(outpath, "path_{}.png".format(idx)))
        # torchvision.utils.save_image(diffuse_render[:,:3,...], os.path.join(outpath, "diffuse_{}.png".format(idx)))
        #torchvision.utils.save_image(specular_render[:,:3,...], os.path.join(outpath, "specular_{}.png".format(idx)))
        # torchvision.utils.save_image(mask, os.path.join(outpath, "mask_{}.png".format(idx)))

        # torchvision.utils.save_image(images[idx], os.path.join(outpath, "gt_{}.png".format(idx)))