Mentions légales du service

Skip to content
Snippets Groups Projects
Unverified Commit 94390b64 authored by Max Zhaoshuo Li 李赵硕's avatar Max Zhaoshuo Li 李赵硕 Committed by GitHub
Browse files

Merge pull request #133 from NVlabs/fix-sch

fix sch
parents e954bfbf 8e0ecf5b
No related branches found
No related tags found
No related merge requests found
...@@ -51,10 +51,14 @@ class Trainer(BaseTrainer): ...@@ -51,10 +51,14 @@ class Trainer(BaseTrainer):
self.losses["render"] = self.criteria["render"](data["rgb_map"], data["image"]) self.losses["render"] = self.criteria["render"](data["rgb_map"], data["image"])
self.metrics["psnr"] = -10 * torch_F.mse_loss(data["rgb_map"], data["image"]).log10() self.metrics["psnr"] = -10 * torch_F.mse_loss(data["rgb_map"], data["image"]).log10()
def get_curvature_weight(self, current_iteration, init_weight, decay_factor): def get_curvature_weight(self, current_iteration, init_weight):
if "curvature" in self.weights: if "curvature" in self.weights:
weight = (min(current_iteration / self.warm_up_end, 1.) if self.warm_up_end > 0 else 1.) * init_weight if current_iteration <= self.warm_up_end:
self.weights["curvature"] = weight / decay_factor self.weights["curvature"] = current_iteration / self.warm_up_end * init_weight
else:
model = self.model_module
decay_factor = model.neural_sdf.growth_rate ** (model.neural_sdf.anneal_levels - 1)
self.weights["curvature"] = init_weight / decay_factor
def _start_of_iteration(self, data, current_iteration): def _start_of_iteration(self, data, current_iteration):
model = self.model_module model = self.model_module
...@@ -63,8 +67,7 @@ class Trainer(BaseTrainer): ...@@ -63,8 +67,7 @@ class Trainer(BaseTrainer):
model.neural_sdf.set_active_levels(current_iteration) model.neural_sdf.set_active_levels(current_iteration)
if self.cfg_gradient.mode == "numerical": if self.cfg_gradient.mode == "numerical":
model.neural_sdf.set_normal_epsilon() model.neural_sdf.set_normal_epsilon()
decay_factor = model.neural_sdf.growth_rate ** model.neural_sdf.add_levels # TODO: verify? self.get_curvature_weight(current_iteration, self.cfg.trainer.loss_weight.curvature)
self.get_curvature_weight(current_iteration, self.cfg.trainer.loss_weight.curvature, decay_factor)
elif self.cfg_gradient.mode == "numerical": elif self.cfg_gradient.mode == "numerical":
model.neural_sdf.set_normal_epsilon() model.neural_sdf.set_normal_epsilon()
......
...@@ -95,18 +95,16 @@ class NeuralSDF(torch.nn.Module): ...@@ -95,18 +95,16 @@ class NeuralSDF(torch.nn.Module):
return points_enc return points_enc
def set_active_levels(self, current_iter=None): def set_active_levels(self, current_iter=None):
add_levels = (current_iter - self.warm_up_end) // self.cfg_sdf.encoding.coarse2fine.step anneal_levels = max((current_iter - self.warm_up_end) // self.cfg_sdf.encoding.coarse2fine.step, 1)
self.add_levels = min(self.cfg_sdf.encoding.levels - self.cfg_sdf.encoding.coarse2fine.init_active_level, self.anneal_levels = min(self.cfg_sdf.encoding.levels, anneal_levels)
add_levels) self.active_levels = max(self.cfg_sdf.encoding.coarse2fine.init_active_level, self.anneal_levels)
self.active_levels = self.cfg_sdf.encoding.coarse2fine.init_active_level + self.add_levels
assert self.active_levels <= self.cfg_sdf.encoding.levels
def set_normal_epsilon(self): def set_normal_epsilon(self):
if self.cfg_sdf.encoding.coarse2fine.enabled: if self.cfg_sdf.encoding.coarse2fine.enabled:
epsilon_res = self.resolutions[self.active_levels - 1] epsilon_res = self.resolutions[self.anneal_levels - 1]
self.normal_eps = 1. / epsilon_res
else: else:
self.normal_eps = 1. / self.resolutions[-1] epsilon_res = self.resolutions[-1]
self.normal_eps = 1. / epsilon_res
@torch.no_grad() @torch.no_grad()
def _get_coarse2fine_mask(self, points_enc, feat_dim): def _get_coarse2fine_mask(self, points_enc, feat_dim):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment