Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 54de0d43 authored by PaulWawerek-L's avatar PaulWawerek-L
Browse files

Merge branch 'paul' into context_model

parents eab41efa 760865fe
No related branches found
No related tags found
1 merge request!2OSLO-IC
......@@ -90,7 +90,7 @@ class RateDistortionLoss(torch.nn.Module):
return out
def train_epoch(train_dataloader, struct_loader, model, criterion, optimizer, optimizer_aux, patch_res, n_patch_per_sample, print_freq, epoch, clip_max_norm, folder_plot_grad):
def train_epoch(train_dataloader, struct_loader, model, criterion, optimizer, optimizer_aux, patch_res, n_patch_per_sample, print_freq, epoch, clip_max_norm, folder_plot_grad, single_conv:bool=False):
model.train()
device = next(model.parameters()).device
batch_time = common_utils.AverageMeter('Batch processing time', ':6.3f')
......@@ -126,7 +126,7 @@ def train_epoch(train_dataloader, struct_loader, model, criterion, optimizer, op
dict_weight = dict()
for r in list_res:
if struct_loader.__class__.__name__ == "HealpixSdpaStructLoader":
dict_index[r], dict_weight[r], _, _, _ = struct_loader.getStruct(sampling_res=r if noPatching else r[0], num_hops=1, patch_res=None if noPatching else r[1], patch_id=patch_id)
dict_index[r], dict_weight[r], _, _, _ = struct_loader.getStruct(sampling_res=r if noPatching else r[0], num_hops=2 if single_conv else 1, patch_res=None if noPatching else r[1], patch_id=patch_id)
else:
dict_index[r], dict_weight[r], _, _ = struct_loader.getGraph(sampling_res=r if noPatching else r[0], patch_res=None if noPatching else r[1], num_hops=0, patch_id=patch_id)
......@@ -177,7 +177,7 @@ def compute_actual_bits(compressed_stream):
total_bits_per_image = torch.sum(torch.stack(list_latent_bits, dim=0), dim=0).detach().cpu().long()
return total_bits_per_image
def test_epoch(test_dataloader, struct_loader, model, criterion, patch_res=None, visFolder=None, print_freq=None, epoch=None, checkWithActualCompression=False, only_npy=False):
def test_epoch(test_dataloader, struct_loader, model, criterion, patch_res=None, visFolder=None, print_freq=None, epoch=None, checkWithActualCompression=False, only_npy=False, single_conv:bool=False):
model.eval()
device = next(model.parameters()).device
......@@ -222,9 +222,9 @@ def test_epoch(test_dataloader, struct_loader, model, criterion, patch_res=None,
for r in list_res:
if struct_loader.__class__.__name__ == "HealpixSdpaStructLoader":
if noPatching:
dict_index[r], dict_weight[r], _ = struct_loader.getStruct(sampling_res=r, num_hops=1, patch_res=None, patch_id=patch_id)
dict_index[r], dict_weight[r], _ = struct_loader.getStruct(sampling_res=r, num_hops=2 if single_conv else 1, patch_res=None, patch_id=patch_id)
else:
dict_index[r], dict_weight[r], _, _, _ = struct_loader.getStruct(sampling_res=r[0], num_hops=1, patch_res=r[1], patch_id=patch_id)
dict_index[r], dict_weight[r], _, _, _ = struct_loader.getStruct(sampling_res=r[0], num_hops=2 if single_conv else 1, patch_res=r[1], patch_id=patch_id)
else:
if noPatching:
dict_index[r], dict_weight[r] = struct_loader.getGraph(sampling_res=r, patch_res=None, num_hops=0, patch_id=patch_id)
......@@ -421,6 +421,7 @@ parser.add_argument('--gpu', '-g', action='store_true', help='enables cuda')
parser.add_argument('--gpu_id', '-gid', type=int, default=-1, help='select cuda device by index. Default is -1 (cuda).')
parser.add_argument('--conv', '-c', type=str, default='SDPAConv', help="Graph convolution method")
parser.add_argument('--skip-connection-aggregation', '-sc', type=str, default='sum', help="Mode for jumping knowledge")
parser.add_argument('--single-conv', action='store_true', help="Use only one convolution for SDPAConv")
parser.add_argument('--pool-func', '-pf', type=str, default="stride", help="Pooling function.")
parser.add_argument('--unpool-func', '-upf', type=str, default="pixel_shuffle", help="Unpooling function.")
# HealPix
......@@ -469,9 +470,9 @@ def main():
model = getattr(spherical_models, model_name)
if args.skip_connection_aggregation == "cat": # To have almost the same number of parameters as sum or max aggregation
N, M = quality_cfgs[model_name][args.quality]
net = model(2*N, M, args.conv, args.skip_connection_aggregation, args.pool_func, args.unpool_func, args.attention, args.nonlinearity, args.context_model, args.autoregressive_relu, args.autoregressive_abs, args.autoregressive_hops).to(device)
net = model(2*N, M, args.conv, args.skip_connection_aggregation, args.pool_func, args.unpool_func, args.attention, args.nonlinearity, args.single_conv, args.context_model, args.autoregressive_relu, args.autoregressive_abs, args.autoregressive_hops).to(device)
else:
net = model(*quality_cfgs[model_name][args.quality], args.conv, args.skip_connection_aggregation, args.pool_func, args.unpool_func, args.attention, args.nonlinearity, args.context_model, args.autoregressive_relu, args.autoregressive_abs, args.autoregressive_hops).to(device)
net = model(*quality_cfgs[model_name][args.quality], args.conv, args.skip_connection_aggregation, args.pool_func, args.unpool_func, args.attention, args.nonlinearity, args.single_conv, args.context_model, args.autoregressive_relu, args.autoregressive_abs, args.autoregressive_hops).to(device)
# Use list of tuples instead of dict to be able to later check the elements are unique and there is no intersection
parameters = [(n, p) for n, p in net.named_parameters() if not n.endswith(".quantiles")]
......@@ -568,7 +569,7 @@ def main():
assert test_dataset.resolution == args.healpix_res, "resolution of test dataset doesn't match with input dataset"
valTestVisFolder = os.path.join(args.out_dir, args.foldername_valtest) if args.foldername_valtest is not None else None
loss_test = test_epoch(test_dataloader, struct_loader, net, criterion, args.patch_res_valtest, valTestVisFolder, args.print_freq, checkWithActualCompression=True, only_npy=args.only_npy_valtest)
loss_test = test_epoch(test_dataloader, struct_loader, net, criterion, args.patch_res_valtest, valTestVisFolder, args.print_freq, checkWithActualCompression=True, only_npy=args.only_npy_valtest, single_conv=args.single_conv)
print("Loss test:", ', '.join([f'{k}={v:6.4f}' for k, v in loss_test.items()]), flush=True)
os.makedirs(args.out_dir, exist_ok=True)
output_fileAddr = os.path.join(args.out_dir, args.filename_test_results+'.npz')
......@@ -596,7 +597,7 @@ def main():
os.makedirs(plotGradFolder, exist_ok=True)
perform_validation = False
for epoch in range(last_epoch + 1, args.max_epochs): # epoch=0...max_epochs-1, printing and checkpoint saving in 1...max_epochs
loss_train = train_epoch(train_dataloader, struct_loader, net, criterion, optimizer, optimizer_aux, args.patch_res_train, args.n_patch_per_sample, args.print_freq, epoch, args.clip_max_norm, plotGradFolder)
loss_train = train_epoch(train_dataloader, struct_loader, net, criterion, optimizer, optimizer_aux, args.patch_res_train, args.n_patch_per_sample, args.print_freq, epoch, args.clip_max_norm, plotGradFolder, args.single_conv)
if epoch == 0:
print(f"saving config.txt file in {args.out_dir}", flush=True)
......@@ -614,7 +615,7 @@ def main():
if perform_validation:
saveVis = (((epoch+1) % args.interval_save_valtest == 0) or ((epoch+1) == args.max_epochs)) and (args.foldername_valtest is not None)
valTestVisFolder = os.path.join(args.out_dir, args.foldername_valtest) if saveVis else None
loss_validation = test_epoch(validation_dataloader, struct_loader, net, criterion, args.patch_res_valtest, valTestVisFolder, args.print_freq, epoch, only_npy=args.only_npy_valtest)
loss_validation = test_epoch(validation_dataloader, struct_loader, net, criterion, args.patch_res_valtest, valTestVisFolder, args.print_freq, epoch, only_npy=args.only_npy_valtest, single_conv=args.single_conv)
list_mean_losses_validation[epoch] = loss_validation
loss_str = f"Loss validation: Epoch [{epoch+1:04n}/{args.max_epochs:04n}]: " + ', '.join([f'{k}={v:6.4f}' for k, v in loss_validation.items()])
print(loss_str, flush=True)
......
......@@ -27,7 +27,7 @@ class SphereFactorizedPrior(SphereCompressionModel):
M (int): Number of channels in the expansion layers (last layer of the
encoder and last layer of the hyperprior decoder)
"""
def __init__(self, N, M, conv_name, skip_conn_aggr, pool_func, unpool_func, attention:bool=False, activation:str='GDN', context_model:str='', arrelu:bool=False, arabs:bool=False, arhops:int=2, **kwargs):
def __init__(self, N, M, conv_name, skip_conn_aggr, pool_func, unpool_func, attention:bool=False, activation:str='GDN', single_conv:bool=False, context_model:str='', arrelu:bool=False, arabs:bool=False, arhops:int=2, **kwargs):
super().__init__(entropy_bottleneck_channels=M, **kwargs)
self.attention = attention
self.arrelu = arrelu
......@@ -37,33 +37,33 @@ class SphereFactorizedPrior(SphereCompressionModel):
self.activation = activation
####################### g_a #######################
self.g_a = torch.nn.ModuleList()
self.g_a.append(SLB_Downsample(conv_name, 3, N, hop=2, skip_conn_aggr=skip_conn_aggr,
self.g_a.append(SLB_Downsample(conv_name, 3, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation=activation,
pool_func=pool_func, pool_size_sqrt=2))
self.g_a.append(SLB_Downsample(conv_name, N, N, hop=2, skip_conn_aggr=skip_conn_aggr,
self.g_a.append(SLB_Downsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation=activation,
pool_func=pool_func, pool_size_sqrt=2, attention_pos=a_pos))
self.g_a.append(SLB_Downsample(conv_name, N, N, hop=2, skip_conn_aggr=skip_conn_aggr,
self.g_a.append(SLB_Downsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation=activation,
pool_func=pool_func, pool_size_sqrt=2))
# For the last layer there is no GDN anymore:
self.g_a.append(SLB_Downsample(conv_name, N, M, hop=2, skip_conn_aggr=skip_conn_aggr,
self.g_a.append(SLB_Downsample(conv_name, N, M, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
pool_func=pool_func, pool_size_sqrt=2, attention_pos=a_pos))
####################### g_s #######################
self.g_s = torch.nn.ModuleList()
a_pos = (0 if attention else -1)
self.g_s.append(SLB_Upsample(conv_name, M, N, hop=2, skip_conn_aggr=skip_conn_aggr,
self.g_s.append(SLB_Upsample(conv_name, M, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation=activation, activation_args={"inverse":True},
unpool_func=unpool_func, unpool_size_sqrt=2, attention_pos=a_pos))
a_pos = (2 if attention else -1)
self.g_s.append(SLB_Upsample(conv_name, N, N, hop=2, skip_conn_aggr=skip_conn_aggr,
self.g_s.append(SLB_Upsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation=activation, activation_args={"inverse": True},
unpool_func=unpool_func, unpool_size_sqrt=2, attention_pos=a_pos))
self.g_s.append(SLB_Upsample(conv_name, N, N, hop=2, skip_conn_aggr=skip_conn_aggr,
self.g_s.append(SLB_Upsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation=activation, activation_args={"inverse": True},
unpool_func=unpool_func, unpool_size_sqrt=2))
# For the last layer there is no GDN anymore:
self.g_s.append(SLB_Upsample(conv_name, N, 3, hop=2, skip_conn_aggr=skip_conn_aggr,
self.g_s.append(SLB_Upsample(conv_name, N, 3, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
unpool_func=unpool_func, unpool_size_sqrt=2))
################## Context Model ##################
self.context_model = context_model
......
......@@ -15,21 +15,21 @@ class SphereMeanScaleHyperprior(SphereScaleHyperprior):
N (int): Number of channels
M (int): Number of channels in the expansion layers (last layer of the encoder)
"""
def __init__(self, N, M, conv_name, skip_conn_aggr, pool_func, unpool_func, attention:bool=False, activation:str='GDN', context_model:str='', arrelu:bool=False, arabs:bool=False, arhops:int=2, **kwargs):
def __init__(self, N, M, conv_name, skip_conn_aggr, pool_func, unpool_func, attention:bool=False, activation:str='GDN', single_conv:bool=False, context_model:str='', arrelu:bool=False, arabs:bool=False, arhops:int=2, **kwargs):
if arabs:
warnings.warn("arabs=True (input abs values to context model) is not recommended for SphereMeanScaleHyperprior, setting it to False")
arabs = False
super().__init__(N, M, conv_name, skip_conn_aggr, pool_func, unpool_func, attention, activation, context_model, arrelu, arabs, arhops, **kwargs)
super().__init__(N, M, conv_name, skip_conn_aggr, pool_func, unpool_func, attention, activation, single_conv, context_model, arrelu, arabs, arhops, **kwargs)
####################### h_s #######################
self.h_s = torch.nn.ModuleList()
self.h_s.append(SLB_Upsample(conv_name, N, M, hop=2, skip_conn_aggr=skip_conn_aggr,
self.h_s.append(SLB_Upsample(conv_name, N, M, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation="ReLU", activation_args={"inplace": True},
unpool_func=unpool_func, unpool_size_sqrt=2))
self.h_s.append(SLB_Upsample(conv_name, M, M*3//2, hop=2, skip_conn_aggr=skip_conn_aggr,
self.h_s.append(SLB_Upsample(conv_name, M, M*3//2, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation="ReLU", activation_args={"inplace": True},
unpool_func=unpool_func, unpool_size_sqrt=2))
# effective hop=1 => num_conv = 1, and there is no Upsampling
self.h_s.append(SLB_Upsample(conv_name, M*3//2, M*2, hop=1, skip_conn_aggr=None,
self.h_s.append(SLB_Upsample(conv_name, M*3//2, M*2, hop=1, single_conv=single_conv, skip_conn_aggr=None,
activation="ReLU", activation_args={"inplace": True}))
################## Context Model ##################
if context_model:
......
......@@ -27,7 +27,7 @@ class SphereScaleHyperprior(SphereCompressionModel):
M (int): Number of channels in the expansion layers (last layer of the
encoder and last layer of the hyperprior decoder)
"""
def __init__(self, N, M, conv_name, skip_conn_aggr, pool_func, unpool_func, attention:bool=False, activation:str='GDN', context_model:str='', arrelu:bool=False, arabs:bool=False, arhops:int=2, **kwargs):
def __init__(self, N, M, conv_name, skip_conn_aggr, pool_func, unpool_func, attention:bool=False, activation:str='GDN', single_conv:bool=False, context_model:str='', arrelu:bool=False, arabs:bool=False, arhops:int=2, **kwargs):
super().__init__(entropy_bottleneck_channels=N, **kwargs)
self.attention = attention
self.arrelu = arrelu
......@@ -37,55 +37,55 @@ class SphereScaleHyperprior(SphereCompressionModel):
self.activation = activation
####################### g_a #######################
self.g_a = torch.nn.ModuleList()
self.g_a.append(SLB_Downsample(conv_name, 3, N, hop=2, skip_conn_aggr=skip_conn_aggr,
self.g_a.append(SLB_Downsample(conv_name, 3, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation=activation,
pool_func=pool_func, pool_size_sqrt=2))
self.g_a.append(SLB_Downsample(conv_name, N, N, hop=2, skip_conn_aggr=skip_conn_aggr,
self.g_a.append(SLB_Downsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation=activation,
pool_func=pool_func, pool_size_sqrt=2, attention_pos=a_pos))
self.g_a.append(SLB_Downsample(conv_name, N, N, hop=2, skip_conn_aggr=skip_conn_aggr,
self.g_a.append(SLB_Downsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation=activation,
pool_func=pool_func, pool_size_sqrt=2))
# For the last layer there is no GDN anymore:
self.g_a.append(SLB_Downsample(conv_name, N, M, hop=2, skip_conn_aggr=skip_conn_aggr,
self.g_a.append(SLB_Downsample(conv_name, N, M, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
pool_func=pool_func, pool_size_sqrt=2, attention_pos=a_pos))
####################### g_s #######################
self.g_s = torch.nn.ModuleList()
a_pos = (0 if attention else -1)
self.g_s.append(SLB_Upsample(conv_name, M, N, hop=2, skip_conn_aggr=skip_conn_aggr,
self.g_s.append(SLB_Upsample(conv_name, M, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation=activation, activation_args={"inverse":True},
unpool_func=unpool_func, unpool_size_sqrt=2, attention_pos=a_pos))
a_pos = (2 if attention else -1)
self.g_s.append(SLB_Upsample(conv_name, N, N, hop=2, skip_conn_aggr=skip_conn_aggr,
self.g_s.append(SLB_Upsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation=activation, activation_args={"inverse": True},
unpool_func=unpool_func, unpool_size_sqrt=2, attention_pos=a_pos))
self.g_s.append(SLB_Upsample(conv_name, N, N, hop=2, skip_conn_aggr=skip_conn_aggr,
self.g_s.append(SLB_Upsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation=activation, activation_args={"inverse": True},
unpool_func=unpool_func, unpool_size_sqrt=2))
# For the last layer there is no GDN anymore:
self.g_s.append(SLB_Upsample(conv_name, N, 3, hop=2, skip_conn_aggr=skip_conn_aggr,
self.g_s.append(SLB_Upsample(conv_name, N, 3, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
unpool_func=unpool_func, unpool_size_sqrt=2))
####################### h_a #######################
self.h_a = torch.nn.ModuleList()
# effective hop=1 => num_conv = 1, and there is no downsampling
self.h_a.append(SLB_Downsample(conv_name, M, N, hop=1, skip_conn_aggr=None,
self.h_a.append(SLB_Downsample(conv_name, M, N, hop=1, single_conv=single_conv, skip_conn_aggr=None,
activation="ReLU", activation_args={"inplace": True}))
self.h_a.append(SLB_Downsample(conv_name, N, N, hop=2, skip_conn_aggr=skip_conn_aggr,
self.h_a.append(SLB_Downsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation="ReLU", activation_args={"inplace": True},
pool_func=pool_func, pool_size_sqrt=2))
# For the last layer there is no ReLu anymore:
self.h_a.append(SLB_Downsample(conv_name, N, N, hop=2, skip_conn_aggr=skip_conn_aggr,
self.h_a.append(SLB_Downsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
pool_func=pool_func, pool_size_sqrt=2))
####################### h_s #######################
self.h_s = torch.nn.ModuleList()
self.h_s.append(SLB_Upsample(conv_name, N, N, hop=2, skip_conn_aggr=skip_conn_aggr,
self.h_s.append(SLB_Upsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation="ReLU", activation_args={"inplace": True},
unpool_func=unpool_func, unpool_size_sqrt=2))
self.h_s.append(SLB_Upsample(conv_name, N, N, hop=2, skip_conn_aggr=skip_conn_aggr,
self.h_s.append(SLB_Upsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation="ReLU", activation_args={"inplace": True},
unpool_func=unpool_func, unpool_size_sqrt=2))
# effective hop=1 => num_conv = 1, and there is no Upsampling
self.h_s.append(SLB_Upsample(conv_name, N, M, hop=1, skip_conn_aggr=None,
self.h_s.append(SLB_Upsample(conv_name, N, M, hop=1, single_conv=single_conv, skip_conn_aggr=None,
activation="ReLU", activation_args={"inplace": True}))
################## Context Model ##################
self.context_model = context_model
......
......@@ -40,14 +40,20 @@ class SDPAConv (torch.nn.Module):
def forward(self, x, neighbors_indices=None, neighbors_weights=None, valid_index=None):
# in case of 1x1 convolution, the neighbors_indices and neighbors_weights are not needed
assert (self.kernel_size==1) or ((self.kernel_size-1)==neighbors_weights.size(1)), "size does not match"
assert (self.kernel_size==1) or ((self.kernel_size-1)<=neighbors_weights.size(1)), "size does not match"
if self.kernel_size > 1:
assert (neighbors_indices is not None) and (neighbors_weights is not None), "neighbors_indices and _weights must be provided for 2-hop convolution"
neighbors_indices = neighbors_indices[:, :self.kernel_size-1]
neighbors_weights = neighbors_weights[:, :self.kernel_size-1]
if valid_index is not None: valid_index = valid_index[:, :self.kernel_size-1]
if self.mask in ['A', 'B']: # set future neighbors to zero
assert (neighbors_indices is not None) and (neighbors_weights is not None), "neighbors_indices and _weights must be provided for masked convolution"
neighbors_weights = torch.mul(neighbors_weights, (neighbors_indices < torch.arange(x.size(1), device=neighbors_weights.device).view(-1, 1)))
if self.mask != 'A': # current node included in convolution
out = torch.matmul(x, self.weight[0])
else:
out = torch.zeros(x.size(0), x.size(1), self.out_channels, dtype=x.dtype, device=x.device)
# test_out = torch.zeros(x.size(), dtype=x.dtype)
# for k in range(neighbors_weights.size(1)):
......
......@@ -29,6 +29,7 @@ class SLB_Downsample(torch.nn.Module):
pool_size_sqrt:int=1,
attention_pos:int=-1,
n_rbs:int=3,
single_conv:bool=False,
mask:str='full',
conv1x1:bool=False):
if (conv_name != 'SDPAConv') and (mask not in ['A', 'B', 'full'] or conv1x1):
......@@ -38,6 +39,7 @@ class SLB_Downsample(torch.nn.Module):
self.node_dim = 1
self.list_conv = torch.nn.ModuleList()
num_conv = hop if conv_name in ["GraphConv", "SDPAConv"] else 1
if single_conv: num_conv = 1
if skip_conn_aggr=='cat':
out_channels //= num_conv
......@@ -54,8 +56,7 @@ class SLB_Downsample(torch.nn.Module):
elif conv_name == "SDPAConv":
conv = SDPAConv
n_firstHopNeighbors = 8
# n_neighbors = 8
n_neighbors = util_common.sumOfAP(a=n_firstHopNeighbors, d=n_firstHopNeighbors, n=1)
n_neighbors = util_common.sumOfAP(a=n_firstHopNeighbors, d=n_firstHopNeighbors, n=1 if not single_conv else hop)
self.list_conv.append(conv(in_channels=in_channels, out_channels=out_channels, kernel_size=n_neighbors+1 if not conv1x1 else 1, bias=bias, node_dim=self.node_dim, mask=mask))
# mask only for first convolution
self.list_conv.extend(torch.nn.ModuleList(
......@@ -198,7 +199,8 @@ class SLB_Upsample(torch.nn.Module):
unpool_func:str=None,
unpool_size_sqrt:int=1,
attention_pos:int=-1,
n_rbs:int=3,):
n_rbs:int=3,
single_conv:bool=False):
super().__init__()
self.node_dim = 1
......@@ -219,6 +221,7 @@ class SLB_Upsample(torch.nn.Module):
# 2- Setting convolution
self.list_conv = torch.nn.ModuleList()
num_conv = hop if conv_name in ["GraphConv", "SDPAConv"] else 1
if single_conv: num_conv = 1
if skip_conn_aggr == 'cat':
out_channels //= num_conv
......@@ -235,7 +238,7 @@ class SLB_Upsample(torch.nn.Module):
elif conv_name == "SDPAConv":
conv = SDPAConv
n_firstHopNeighbors = 8
n_neighbors = util_common.sumOfAP(a=n_firstHopNeighbors, d=n_firstHopNeighbors, n=1)
n_neighbors = util_common.sumOfAP(a=n_firstHopNeighbors, d=n_firstHopNeighbors, n=1 if not single_conv else hop)
self.list_conv.append(conv(in_channels=in_channels, out_channels=out_channels, kernel_size=n_neighbors + 1, bias=bias, node_dim=self.node_dim))
self.list_conv.extend(torch.nn.ModuleList([conv(in_channels=out_channels, out_channels=out_channels, kernel_size=n_neighbors + 1, bias=bias, node_dim=self.node_dim) for _ in range(num_conv - 1)]))
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment