diff --git a/TODO b/TODO
index 505e490301e21e77ee260d3179f3eb97f4ebcb02..494fad71ca41cadf6984d0a73114dad93e2c5356 100644
--- a/TODO
+++ b/TODO
@@ -23,3 +23,6 @@ Pr le path: ds simul_params, il y a que simu_data_path: absolute or relative pat
 [ ] num_* plutôt que *_nb -> en fait ça dépend de si c'est une quantité (num) ou un identifiant (nb)
 [X] rajouter ds scenario_params_preprocessed la colonne num_replicates, et l'utiliser ds le DataLoader -> ds generate_loaders en fait
 [X] dataloader: index = 1 simu pas 1 scenario.
+
+[ ] sumstats : pour LD, faire la moyenne par scenario avant de sauver sur le disque. changer la function de plot en conséquence.
+[ ] sumstats: checker selection sumstats
diff --git a/summary_statistics.py b/summary_statistics.py
index ea170f84f598e323b69308c029cdc1bf50bc3d5e..95d6aedcd398da672db50bb12bd1dccab7d1522e 100644
--- a/summary_statistics.py
+++ b/summary_statistics.py
@@ -150,7 +150,7 @@ def sfs(haplotype, ac, nindiv=None, folded=False):
     df_sfs["freq_indiv"] = df_sfs.N_indiv / nindiv
     return df_sfs
 
-def LD(haplotype, pos_vec, size_chr, circular=True, distance_bins=None):
+def LD(haplotype, pos_vec, size_chr, circular=True, distance_bins=None, gaps_type="short", min_SNP_pairs=300):
     """
     Compute LD for a subset of SNPs drawn with different gap sizes in between them.
     Gap sizes follow power 2 distribution.
@@ -174,6 +174,16 @@ def LD(haplotype, pos_vec, size_chr, circular=True, distance_bins=None):
         If distance_bins is an int, it defines the number of bins of distances for which to compute the LD
             The bins are created in a logspace
         If distance_bins is a list, they will be used instead
+    gaps_type: str
+        Pairs of SNP considered are separated by a given number (gap) of columns. Not all pairs are considered.
+        By defaut (`short`), gaps are power of 2 up to the closest power of 2 of the number of SNP.
+        Meaning that most of the comparisons will be done on close SNPs (short distance).
+        If one wants to sample more at large distance (to test for circularity for instance), use `long` instead of `short`
+        Using `long` will add gaps like: n_SNP - gaps. It will take more time to run.
+    min_SNP_pairs: int
+        Minimum number of pairs of SNP to consider for a given gap size.
+        If the gap size is big enough such that there is less than `min_SNP_pairs` possible pairs,
+        then all pairs are considered.
 
     Returns
     -------
@@ -181,7 +191,7 @@ def LD(haplotype, pos_vec, size_chr, circular=True, distance_bins=None):
         Table with the distance_bins as index, and the mean value of
     """
 
-    if distance_bins == None or isinstance(distance_bins, int):
+    if isinstance(distance_bins, type(None)) or isinstance(distance_bins, int):
         if isinstance(distance_bins, int):
             n_bins = distance_bins - 1
         else:
@@ -194,43 +204,58 @@ def LD(haplotype, pos_vec, size_chr, circular=True, distance_bins=None):
             distance_bins = np.insert(distance_bins, 0, [0])
 
     n_SNP, n_samples = haplotype.shape
-    gaps = (2 ** np.arange(0, np.log2(n_SNP), 1)).astype(int) # log2 scales of intervals
-    #gaps = np.arange(1, n_SNP-1)
-    selected_snps = []
-    for gap in gaps:
-
-        snps = np.arange(0, n_SNP, gap) + np.random.randint(0, (n_SNP - 1) % gap + 1)  # adding a random start (+1, bc 2nd bound in randint is exlusive)
-
-        # non overlapping contiguous pairs
-        # snps=[ 196, 1220, 2244] becomes
-        # snp_pairs=[(196, 1220), (1221, 2245)]
-        snp_pairs = np.unique([((snps[i] + i) % n_SNP, (snps[i + 1] + i) % n_SNP) for i in range(len(snps) - 1)], axis=0)
 
-        # If we don't have enough pairs (typically when gap is large), we add a random rotation until we have at least 300)
-        #count = 0
+    # gaps are distance between SNPs in term of position in the snp matrix (not in bp)
+    gaps_interval = (2 ** np.arange(0, np.log2(n_SNP), 1)).astype(int) # log2 scales of intervals
+    if gaps_type.lower() == "long":
+        gaps_interval = np.unique(np.concatenate([gaps_interval,
+                                                  np.array(list(n_SNP - gaps_interval)[::-1])])).astype(int)
+    else:
+        if gaps_type.lower() != "short":
+            logging.warning("gaps should be either `short` or `long`. Using short instead of f{gaps_type}")
 
-        if not circular:
-            snp_pairs = snp_pairs[snp_pairs[:, 0] < snp_pairs[:, 1]]
-        last_pair = snp_pairs[-1]
+    selected_snps = []
+    for gap in gaps_interval:
 
         if circular:
-            max_value = n_SNP - 1
+            max_value = n_SNP
         else:
-            max_value = n_SNP - gap - 1
-
-        while len(snp_pairs) <= min(300, max_value):
-            #count += 1
-            #if count % 10 == 0:
-                #print(">>  " + str(gap) + " - " + str(len(np.unique(snp_pairs, axis=0))) + " -- "+ str(len(snps) - 1) + "#" + str(count))
-            #remainder = (n_SNP - 1) % gap if (n_SNP - 1) % gap != 0 else (n_SNP - 1) // gap
-            random_shift =  np.random.randint(1, n_SNP) % n_SNP
-            new_pair = (last_pair + random_shift) % n_SNP
-            snp_pairs = np.unique(np.concatenate([snp_pairs,
-                                                  new_pair.reshape(1, 2) ]), axis=0)
-            last_pair = new_pair
+            max_value = n_SNP - gap
+
+        if max_value < min_SNP_pairs: # min_SNP_pairs : min number of SNP pairs to consider.
+            # if not many possible pairs possible, just take them all directly,
+            # instead of reaching that number after many more random trials
+            snps = np.arange(0, n_SNP, gap)
+            snp_pairs = np.unique([((snps[i] + i) % n_SNP, (snps[i + 1] + i) % n_SNP) for i in range(len(snps) - 1)], axis=0)
+            snp_pairs = np.concatenate([(snp_pairs + i)%n_SNP  for i in range(max_value)], axis=0)
+        else:
+            snps = np.arange(0, n_SNP, gap) + np.random.randint(0, (n_SNP - 1) % gap + 1)  # adding a random start (+1, bc 2nd bound in randint is exlusive)
+            # non overlapping contiguous pairs
+            # snps=[ 196, 1220, 2244] becomes
+            # snp_pairs=[(196, 1220), (1221, 2245)]
+            snp_pairs = np.unique([((snps[i] + i) % n_SNP, (snps[i + 1] + i) % n_SNP) for i in range(len(snps) - 1)], axis=0)
+
+            # If we don't have enough pairs (typically when gap is large), we add a random rotation until we have at least 300)
+            #count = 0
 
             if not circular:
+                # remove pairs that are over the edges
                 snp_pairs = snp_pairs[snp_pairs[:, 0] < snp_pairs[:, 1]]
+            last_pair = snp_pairs[-1]
+
+            while len(snp_pairs) < min(min_SNP_pairs, max_value):
+                #count += 1
+                #if count % 10 == 0:
+                    #print(">>  " + str(gap) + " - " + str(len(np.unique(snp_pairs, axis=0))) + " -- "+ str(len(snps) - 1) + "#" + str(count))
+                #remainder = (n_SNP - 1) % gap if (n_SNP - 1) % gap != 0 else (n_SNP - 1) // gap
+                shift =  np.random.randint(1, n_SNP) % n_SNP
+                new_pair = (last_pair + shift) % n_SNP
+                snp_pairs = np.unique(np.concatenate([snp_pairs,
+                                                      new_pair.reshape(1, 2) ]), axis=0)
+                last_pair = new_pair
+
+                if not circular:
+                    snp_pairs = snp_pairs[snp_pairs[:, 0] < snp_pairs[:, 1]]
 
         selected_snps.append(snp_pairs)
 
@@ -323,16 +348,56 @@ def nsl(haplotype, pos_vec=None, window=None):
 def worker_do_sum_stats(param):
     do_sum_stats(**param)
 
-def do_sum_stats(scenario_dir, name_id, size_chr=2e6, circular=True, label="", nrep="all"):
+def do_sum_stats(scenario_dir, name_id, size_chr=2e6,
+                 ld_kws=None, sfs_kws=None,
+                 label="", nrep="all", overwrite=False):
+    """Compute sfs and LD for a set of replicates.
+
+    Parameters
+    ----------
+    scenario_dir : str
+        path to the directory where the outputs (npz file) of the replicates for a given scenario are.
+    name_id : str
+        identifier for the scenario.
+    size_chr : int
+        Size of the chromosome.
+    ld_kws : dict
+        Keywords arguments to pass to the ld function.
+        Available kws are: circular[True], distance_bins.
+    sfs_kws : dict
+        Keywords arguments to pass to the sfs function.
+        Available kws are: folded[False].
+    label : str
+        Give a label for the scenario.
+    nrep : int
+        Whether to use all replicates (default) or just a subset (for testing purpose).
+    overwrite : bool
+        If False (default), the output is appended at the end of the existing file.
+        Otherwise, it overwrites over it.
+
+    Returns
+    -------
+    None
+        Nothing is returned. Files are written on disk:
+        name_id.mut1
+        name_id.sfs
+        name_id.ld
+        name_id.sel (contains data for Tajima's D, ihs and nsl)
+    """
     """
     For a file in npz format (from a ms file format), compute different summary statistics
     and output them in different output files:
-    name_id.mut1
     name_id.sfs
     name_id.ld
-    name_id.sel
-    The latter one contains data for Tajima's D, ihs and nsl
+    name_id.sel (off) The latter one contains data for Tajima's D, ihs and nsl
     """
+
+    if ld_kws == None:
+        ld_kws = {}
+    ld_kws.update({"size_chr":size_chr})
+    if sfs_kws == None:
+        sfs_kws = {}
+
     all_sfs = pd.DataFrame()
     all_ld = pd.DataFrame()
     npzfiles = [i for i in os.listdir(scenario_dir) if i.endswith("npz")]
@@ -353,13 +418,13 @@ def do_sum_stats(scenario_dir, name_id, size_chr=2e6, circular=True, label="", n
             model, scenario, replicat, run_id = split_simid(sim_id)
 
             try:
-                df_sfs = sfs(haplotype, allel_count)
+                df_sfs = sfs(haplotype, allel_count, **sfs_kws)
                 all_sfs = pd.concat([all_sfs, df_sfs])
                 # df_sfs.to_csv(os.path.join(name_id + ".sfs"), sep="\t", index=False, mode="a", header=False)
             except Exception as e:
                 logging.error("While computing SFS for {}\n>>> Error: {}".format(sim_id, e))
             try:
-                ld = LD(haplotype, pos_vec, size_chr=size_chr, circular=circular)
+                ld = LD(haplotype, pos_vec, **ld_kws)
                 ld["sim_id"] = sim_id
                 ld["scenario"] = scenario
                 ld["run_id"] = run_id
@@ -382,7 +447,7 @@ def do_sum_stats(scenario_dir, name_id, size_chr=2e6, circular=True, label="", n
         all_sfs2.to_csv(sfsfile,
                         sep="\t",
                         index=False,
-                        header=False if (sfsfile.tell()==0 and not overwrite) else True)
+                        header=False if (sfsfile.tell() and not overwrite) else True)
 
     all_ld2 = all_ld.groupby("dist_group").mean()
     all_ld2["sim_id"] = sim_id
@@ -393,7 +458,7 @@ def do_sum_stats(scenario_dir, name_id, size_chr=2e6, circular=True, label="", n
         all_ld2.to_csv(ldfile,
                        sep="\t",
                        index=False,
-                       header=False if (ldfile.tell()==0 and not overwrite) else True)
+                       header=False if (ldfile.tell() and not overwrite) else True)
 
 
     #