Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 4d804d2e authored by ZHANG Haoying's avatar ZHANG Haoying
Browse files

add auxilieary knowledge into csp

parent c1dfe8cf
Branches
No related tags found
No related merge requests found
# GLOBAL IMPORTS
import os, sys, glob, json, argparse, time
import argparse
import time
from typing import Tuple
import numpy as np
# from pynput.keyboard import GlobalHotKeys
# LOCAL IMPORTS
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
) # Add root directory to path
from src.rmpi.constants import (
default_m,
n_min,
ermpi_log_dir,
ermpi_log_name,
ermpi_log_suffix,
pearson_threshold,
m_min,
remarkable_rmpi_dir,
ermpi_results_dir,
results_prefix,
)
from src.toolbox import (
save_res,
pearson_correlation,
find_two_closest_elements,
nb_znormalized_euclidian_distance,
nb_euclidian_distance,
nb_manhattan_distance,
)
from src.rmpi.rmpi_ts_partial import solve_rmpi, DEFAULT_TIME_LIMIT
from src.rmpi.ermpi_data_exploiter import process_ermpi_results, store_rmpi
from src.impi.usecase.real_ts_exploiter import get_energy_ts
from src.impi.usecase.find_valid_dates import load_dates
from joblib import Parallel, delayed
# ----------------------------------------------------------------------------------------------------
"""
"""
# ----------------------------------------------------------------------------------------------------
INTERRUPTED = False # Global variable to handle keyboard interruption
MAX_DISTRIB_SIZE = 20
# ----------------------------------------------------------------------------------------------------
# def interrupt_hotkey():
# global INTERRUPTED
# print("\nInterrupting...", flush=True)
# INTERRUPTED = True
# hotkey.stop() # Stop the keyboard listener
#
#
def info_hotkey():
if os.path.exists(f"{ermpi_log_dir}/{ermpi_log_name}{ermpi_log_suffix}"):
with open(f"{ermpi_log_dir}/{ermpi_log_name}{ermpi_log_suffix}", "r") as f:
print(f.read(), flush=True)
# ----------------------------------------------------------------------------------------------------
def theoretical_ts_generator(n, n_ts, random_generator):
""" """
metadata = dict()
for i_ts in range(n_ts):
# metadata["enum_index"] = i_ts
time_series = random_generator.random(n)
yield metadata, time_series
def theoretical_initializer(n: int, m: int, average: int, random_seed: int):
rng = np.random.default_rng(random_seed)
n_ts = average
ts_generator = theoretical_ts_generator(n, n_ts, rng)
# if remarkable_pc[0] == -1.0:
# remarkable_pc = np.array([0.1, 0.9], dtype=np.float64)
# else:
# remarkable_pc = np.asarray(remarkable_pc, dtype=np.float64)
return ts_generator, m, n_ts, random_seed
def theoretical_metadata_logger(f, metadata):
# f.write(f"Enum index: {metadata['enum_index']}, ")
return
# ----------------------------------------------------------------------------------------------------
def energy_ts_generator(n, n_ts, dates, ts_dict):
metadata = dict()
i = 0
while i < n_ts:
date = dates.pop()
print(date)
time_series = ts_dict[date]
if len(time_series) >= n:
metadata["date"] = date
i += 1
time_series = time_series[600 : 600 + n]
time_series = (time_series - time_series.min()) / (
time_series.max() - time_series.min()
)
yield metadata, time_series
def energy_initializer(n: int, m: int, average: int, random_seed: int):
rng = np.random.default_rng(random_seed)
n_ts = average
dates = load_dates()
ts_dict = get_energy_ts(dates, verbose=False)
# Check that they are enough time series of length at least n
if len([ts for ts in ts_dict.values() if len(ts) >= n]) < n_ts:
raise ValueError(
f"Not enough time series of length at least {n} in the dataset. Please choose a smaller value for n or average."
)
rng.shuffle(dates)
ts_generator = energy_ts_generator(n, n_ts, dates, ts_dict)
# if remarkable_pc[0] == -1.0:
# remarkable_pc = np.array([0.1, 0.9], dtype=np.float64)
# else:
# remarkable_pc = np.asarray(remarkable_pc, dtype=np.float64)
return ts_generator, m, n_ts, random_seed
def energy_metadata_logger(f, metadata):
f.write(f"Date: {metadata['date']}, ")
return
# ----------------------------------------------------------------------------------------------------
ECG_TS_LEN = 5000
ACM_TS_LEN = 80
def ecg_ts_generator(n, n_ts, ecg_path_list):
metadata = dict()
for i_ts in range(n_ts):
ecg_path = ecg_path_list.pop()
ecg = np.load(ecg_path)
i_ecg = int(ecg_path.split("_")[-1].split(".")[0])
metadata["i_ecg"] = i_ecg
metadata["i_person"] = i_ecg // 10
metadata["i_ecg_in_person"] = i_ecg % 10
yield metadata, ecg[:n]
def acm_ts_generator(n, n_ts, acm_path_list):
metadata = dict()
for i_ts in range(n_ts):
acm_path = acm_path_list.pop()
acm = np.load(acm_path)
acm = (acm - acm.min()) / (acm.max() - acm.min())
i_acm = int(acm_path.split("_")[-1].split(".")[0])
metadata["i_acm"] = i_acm
metadata["i_person"] = i_acm // 20
metadata["i_acm_in_person"] = i_acm % 20
yield metadata, acm[:n]
def ecg_initializer(n: int, m: int, average: int, random_seed: int):
rng = np.random.default_rng(random_seed)
n_ts = average
ecg_path_list = glob.glob("data/ecg_test/*.npy")
if n > ECG_TS_LEN:
raise ValueError(
f"ECG time series have a length of {ECG_TS_LEN}. Please choose a smaller value for n."
)
if len(ecg_path_list) < n_ts:
raise ValueError(
f"Not enough ECG time series in the dataset. Please choose a smaller value for average."
)
rng.shuffle(ecg_path_list)
ts_generator = ecg_ts_generator(n, n_ts, ecg_path_list)
return ts_generator, m, n_ts, random_seed
def acm_initializer(n: int, m: int, average: int, random_seed: int):
rng = np.random.default_rng(random_seed)
n_ts = average
acm_path_list = glob.glob("data/acm/*.npy")
if n > ACM_TS_LEN:
raise ValueError(
f"ACM time series have a length of {ACM_TS_LEN}. Please choose a smaller value for n."
)
if len(acm_path_list) < n_ts:
raise ValueError(
f"Not enough ACM time series in the dataset. Please choose a smaller value for average."
)
rng.shuffle(acm_path_list)
ts_generator = acm_ts_generator(n, n_ts, acm_path_list)
return ts_generator, m, n_ts, random_seed
def ecg_metadata_logger(f, metadata):
f.write(
f"Person index: {metadata['i_person']}, ECG index in person: {metadata['i_ecg_in_person']}, "
)
return
def acm_metadata_logger(f, metadata):
f.write(
f"Person index: {metadata['i_person']}, ACM index in person: {metadata['i_acm_in_person']}, "
)
return
# ----------------------------------------------------------------------------------------------------
def compute_ermpi(
initializer: callable,
metadata_logger: callable,
distance_fun: callable,
time_limit: float,
genetic: bool,
save_folder: str = None,
verbose: bool = False,
auxiliary_indices=[],
):
global INTERRUPTED
start_time = time.perf_counter()
# Create the log file
os.makedirs(ermpi_log_dir, exist_ok=True)
with open(f"{ermpi_log_dir}/{ermpi_log_name}{ermpi_log_suffix}", "w") as f:
if not verbose:
f.write("no verbose mode")
else:
f.write(
"Index: 0/0 (0.0%), Loss: 0.0, Time to compute: 0.00 sec, Time left: 0h 00min 00sec"
)
# Display the computing message
if verbose:
print(
f"Computing... (Press Enter to show info. Press Ctrl+Alt+I to safely interrupt)",
flush=True,
)
# Specific initialization
ts_generator, m, n_ts, random_seed = initializer()
# This dictionary stores, for each number of solutions, the number of RMPIs leading to this number of solutions
loss_distribution = dict()
# This dictionary stores, for each pearson correlation, the number of RMPIs leading to this correlation
pc_distribution = dict()
# Average time to compute the IMPI for a given time series
average_ttc = 0
# Main loop
for i_ts in range(n_ts):
# iteration start time
i_start_time = time.perf_counter()
metadata, time_series = next(ts_generator)
rmpi_instance, elapsed_time = solve_rmpi(
time_series,
m=m,
distance_fun=distance_fun,
time_limit=time_limit,
genetic=genetic,
seed=random_seed,
verbose=False,
timer=False,
auxiliary_indices=auxiliary_indices,
)
# Get statistics on the RMPI
loss = rmpi_instance.objective_values[0]
pearson_corr = round(
pearson_correlation(time_series, rmpi_instance.solutions[0]), 1
)
if loss not in loss_distribution:
loss_distribution[loss] = 0
loss_distribution[loss] += 1
if pearson_corr not in pc_distribution:
pc_distribution[pearson_corr] = 0
pc_distribution[pearson_corr] += 1
if save_folder is not None:
id_key = ""
if "i_ecg" in metadata:
id_key = "i_ecg"
elif "i_acm" in metadata:
id_key = "i_acm"
else:
id_key = "date"
store_rmpi(
time_series,
rmpi_instance,
elapsed_time,
loss,
pearson_corr,
os.path.join(save_folder, remarkable_rmpi_dir),
metadata[id_key],
)
# Clear the memory
del time_series
del rmpi_instance
if verbose:
ttc = time.perf_counter() - i_start_time
average_ttc = (average_ttc * i_ts + ttc) / (i_ts + 1)
time_left = (n_ts - i_ts - 1) * average_ttc
text_time_left = f"{int(time_left // 3600)}h {int((time_left % 3600) // 60):02}min {int(time_left % 60):02}sec"
with open(f"{ermpi_log_dir}/{ermpi_log_name}{ermpi_log_suffix}", "wb") as f:
f.write(b"\r")
with open(f"{ermpi_log_dir}/{ermpi_log_name}{ermpi_log_suffix}", "w") as f:
metadata_logger(f, metadata)
f.write(
f"Index: {i_ts + 1}/{n_ts} ({round(100 * (i_ts + 1) / n_ts, 1)}%), Loss: {loss:.2f}, Pearson correlation: {pearson_corr}, Time to compute: {ttc:.2f} sec, "
f"Time left: {text_time_left}"
)
# Check for interruption
if INTERRUPTED:
break
# To autoverbose each each iteration, to remove and reactivate hotkeys
info_hotkey()
# Synthesize the loss distribution
if verbose:
print("Synthesizing the loss distribution...", flush=True)
while len(loss_distribution) > MAX_DISTRIB_SIZE:
key1, key2 = find_two_closest_elements(
np.asarray(list(loss_distribution.keys()), dtype=np.float64)
)
avg_key = (key1 * loss_distribution[key1] + key2 * loss_distribution[key2]) / (
loss_distribution[key1] + loss_distribution[key2]
)
loss_distribution[avg_key] = loss_distribution.pop(
key1
) + loss_distribution.pop(key2)
# Sort both distributions by key
loss_distribution = dict(sorted(loss_distribution.items()))
pc_distribution = dict(sorted(pc_distribution.items()))
return (
loss_distribution,
pc_distribution,
time.perf_counter() - start_time,
random_seed,
)
def compute_ermpi_parallel(
initializer: callable,
metadata_logger: callable,
distance_fun: callable,
time_limit: float,
genetic: bool,
save_folder: str = None,
verbose: bool = False,
nb_jobs=os.cpu_count(),
nb_solutions_per_ts=1,
auxiliary_indices=[],
):
global INTERRUPTED
start_time = time.perf_counter()
# Create the log file
os.makedirs(ermpi_log_dir, exist_ok=True)
with open(f"{ermpi_log_dir}/{ermpi_log_name}{ermpi_log_suffix}", "w") as f:
if not verbose:
f.write("no verbose mode")
else:
f.write(
"Index: 0/0 (0.0%), Loss: 0.0, Time to compute: 0.00 sec, Time left: 0h 00min 00sec"
)
# Display the computing message
if verbose:
print(
f"Computing... (Press Enter to show info. Press Ctrl+Alt+I to safely interrupt)",
flush=True,
)
# Specific initialization
ts_generator, m, n_ts, random_seed = initializer()
# This dictionary stores, for each number of solutions, the number of RMPIs leading to this number of solutions
loss_distribution = dict()
# This dictionary stores, for each pearson correlation, the number of RMPIs leading to this correlation
pc_distribution = dict()
# Average time to compute the IMPI for a given time series
average_ttc = 0
# Iteration counter
itr = 0
# Divide the indices of time series into groupes of nb_jobs
nb_jobs = 0 if nb_jobs is None else nb_jobs
i_ts_list = [i for i in range(n_ts)]
i_ts_list_grouped = [i_ts_list[i : i + nb_jobs] for i in range(0, n_ts, nb_jobs)]
print(i_ts_list_grouped)
# Main loop to parallel execution
for i_ts_l in i_ts_list_grouped:
# iteration start time
i_start_time = time.perf_counter()
time_series_list = []
for _ in range(len(i_ts_l)):
elem_1, elem_2 = next(ts_generator)
print(elem_1.copy())
time_series_list.append((elem_1.copy(), elem_2.copy()))
# res = metadatas, losses, pear_corres
res = Parallel(n_jobs=len(i_ts_l), backend="loky")(
delayed(solve_and_save_rmpi_thread)(
time_series_list[i][0],
time_series_list[i][1],
m,
distance_fun,
time_limit,
genetic,
random_seed,
save_folder,
remarkable_rmpi_dir,
nb_solutions_per_ts,
auxiliary_indices,
)
for i in range(len(i_ts_l))
)
for j in range(len(i_ts_l)):
loss = res[j]["loss"]
pearson_corr = res[j]["pearson_correlation"]
if loss not in loss_distribution:
loss_distribution[loss] = 0
loss_distribution[loss] += 1
if pearson_corr not in pc_distribution:
pc_distribution[pearson_corr] = 0
pc_distribution[pearson_corr] += 1
if verbose:
ttc = time.perf_counter() - i_start_time
average_ttc = (average_ttc * itr + ttc) / (itr + 1)
time_left = (len(i_ts_list_grouped) - itr - 1) * average_ttc
text_time_left = f"{int(time_left // 3600)}h {int((time_left % 3600) // 60):02}min {int(time_left % 60):02}sec"
with open(f"{ermpi_log_dir}/{ermpi_log_name}{ermpi_log_suffix}", "wb") as f:
f.write(b"\r")
for i_ts in range(len(res)):
metadata = res[i_ts]["metadata"]
with open(
f"{ermpi_log_dir}/{ermpi_log_name}{ermpi_log_suffix}", "a"
) as f:
metadata_logger(f, metadata)
loss = res[i_ts]["loss"]
pc = res[i_ts]["pearson_correlation"]
f.write(
f"Index: {itr*len(i_ts_l) + i_ts+1}/{n_ts} ({round(100 * (i_ts + 1) / n_ts, 1)}%), Loss: {loss:.2f}, Pearson correlation: {pc}, Time to compute: {ttc:.2f} sec, "
f"Time left: {text_time_left}\n"
)
# Clear the memory
del time_series_list, res
# Check for interruption
if INTERRUPTED:
break
# To autoverbose each each iteration, to remove and reactivate hotkeys
info_hotkey()
itr += 1
# Synthesize the loss distribution
if verbose:
print("Synthesizing the loss distribution...", flush=True)
while len(loss_distribution) > MAX_DISTRIB_SIZE:
key1, key2 = find_two_closest_elements(
np.asarray(list(loss_distribution.keys()), dtype=np.float64)
)
avg_key = (key1 * loss_distribution[key1] + key2 * loss_distribution[key2]) / (
loss_distribution[key1] + loss_distribution[key2]
)
loss_distribution[avg_key] = loss_distribution.pop(
key1
) + loss_distribution.pop(key2)
# Sort both distributions by key
loss_distribution = dict(sorted(loss_distribution.items()))
pc_distribution = dict(sorted(pc_distribution.items()))
return (
loss_distribution,
pc_distribution,
time.perf_counter() - start_time,
random_seed,
)
def solve_and_save_rmpi_thread(
metadata,
time_series,
m,
distance_fun,
time_limit,
genetic,
random_seed,
save_folder,
remarkable_rmpi_dir,
n_pop,
auxiliary,
):
rmpi_instance, elapsed_time = solve_rmpi(
time_series,
m=m,
distance_fun=distance_fun,
time_limit=time_limit,
genetic=genetic,
seed=random_seed,
verbose=False,
timer=False,
single_solution=True,
nb_pop=n_pop,
auxiliary_indices=auxiliary,
)
# Get statistics on the RMPI
loss = rmpi_instance.objective_values[0]
pearson_corr = round(
pearson_correlation(time_series, rmpi_instance.solutions[0]), 1
)
if save_folder is not None:
id_key = ""
if "i_ecg" in metadata:
id_key = "i_ecg"
elif "i_acm" in metadata:
id_key = "i_acm"
else:
id_key = "date"
store_rmpi(
time_series,
rmpi_instance,
elapsed_time,
loss,
pearson_corr,
os.path.join(save_folder, remarkable_rmpi_dir),
metadata[id_key],
)
return {"metadata": metadata, "loss": loss, "pearson_correlation": pearson_corr}
# ----------------------------------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-n", type=int, required=True, help="Length of the considered time series"
)
parser.add_argument(
"-m", type=int, required=True, help="Subsequence length for the Matrix Profile"
)
parser.add_argument(
"-l",
"--time_limit",
type=float,
default=DEFAULT_TIME_LIMIT,
help=f"Time limit for the optimization problem (default: {DEFAULT_TIME_LIMIT} sec)",
)
parser.add_argument(
"-g",
"--genetic",
action="store_true",
help="Solve the problem on a population, which size is equal to the number of cores (default: False)",
)
parser.add_argument(
"-z",
"--znormalized",
action="store_true",
help="Use the z-normalized Euclidian distance instead of the Euclidian distance (default: False)",
)
parser.add_argument(
"-man", "--manhattan", action="store_true", help="Use the Manhattan distance"
)
parser.add_argument(
"-r",
"--random",
type=int,
default=None,
help="Random seed for the random number generator (default: None)",
)
parser.add_argument(
"-a",
"--average",
type=int,
default=100,
help="Number of time series to average the results (default: 100)",
)
parser.add_argument(
"-p",
"--percentage",
type=int,
default=95,
help="Percentage for data exploitation (default: 95)",
)
parser.add_argument(
"-pop",
"--nb_pop",
type=int,
default=1,
help="Number of solutions per time series",
)
parser.add_argument(
"-parallel",
"--nb_parallel",
type=int,
default=12,
help="Number of parallel tasks of solving rmpi",
)
parser.add_argument(
"-aux",
"--auxiliary",
type=int,
default=0,
help="Number of points in the auxiliary knowledge",
)
parser.add_argument(
"-c",
"--category",
type=str,
default="theoretical",
help="Category of the time series: 'theoretical', 'energy', or 'ecg' (default: 'theoretical')",
)
parser.add_argument(
"-s",
"--save",
action="store_true",
help="Save the results to a file, including th plot if --plot is specified (default: False)",
)
parser.add_argument(
"-i",
"--important",
action="store_true",
help="Save the results in the 'important' directory (default: False)",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Display information during the process (default: False)",
)
args = parser.parse_args()
if args.n < n_min:
parser.error(f"n must be at least {n_min}")
if args.m < m_min:
parser.error(f"m must be at least {m_min}")
if args.n - args.m < 0:
parser.error("n must be greater or equal to m")
if args.time_limit <= 0:
parser.error("time_limit must be greater than 0")
if args.category not in ["theoretical", "energy", "ecg", "acm"]:
parser.error("category must be 'theoretical', 'energy', 'acm' or 'ecg'")
if args.important and not args.save:
parser.error("--important is useless if --save is not specified")
if args.average <= 0:
parser.error("average must be greater than 0")
if args.nb_pop * args.nb_parallel > os.cpu_count():
parser.error(f"The performance will be reduced because of not enough cpus")
##############################################################
# # Keyboard listener to handle interruption
# hotkey = GlobalHotKeys({
# '<ctrl>+<alt>+i': interrupt_hotkey,
# '<enter>': info_hotkey
# })
# hotkey.start()
if args.manhattan:
distance_fun = nb_manhattan_distance
else:
distance_fun = (
nb_znormalized_euclidian_distance
if args.znormalized
else nb_euclidian_distance
)
# Precompile the Numba functions by running them once on a small time series
if args.verbose:
print("Precompiling Numba functions...", flush=True)
solve_rmpi(
np.random.rand(10),
m=3,
distance_fun=distance_fun,
time_limit=1,
genetic=args.genetic,
seed=args.random,
verbose=False,
timer=False,
)
if args.save:
save_folder = os.path.join("results", ermpi_results_dir, args.category)
if args.important:
save_folder = os.path.join(save_folder, "important")
current_time = time.strftime("%Y%m%d_%H%M%S")
save_folder = os.path.join(
save_folder,
f"ermpi_n_{args.n}_m_{args.m}_a_{args.average}_t_{args.time_limit}__{current_time}",
)
os.makedirs(os.path.join(save_folder, remarkable_rmpi_dir), exist_ok=True)
else:
save_folder = None
if args.category == "theoretical":
(
main_loss_distribution,
main_pc_distribution,
main_elapsed_time,
main_random_seed,
) = compute_ermpi_parallel(
initializer=lambda: theoretical_initializer(
args.n, args.m, args.average, args.random
),
metadata_logger=theoretical_metadata_logger,
distance_fun=distance_fun,
time_limit=args.time_limit,
genetic=args.genetic,
save_folder=save_folder,
verbose=args.verbose,
nb_jobs=args.nb_parallel,
nb_solutions_per_ts=args.nb_pop,
auxiliary_indices=[i for i in range(args.auxiliary)],
)
elif args.category == "energy":
(
main_loss_distribution,
main_pc_distribution,
main_elapsed_time,
main_random_seed,
) = compute_ermpi_parallel(
initializer=lambda: energy_initializer(
args.n, args.m, args.average, args.random
),
metadata_logger=energy_metadata_logger,
distance_fun=distance_fun,
time_limit=args.time_limit,
genetic=args.genetic,
save_folder=save_folder,
verbose=args.verbose,
nb_jobs=args.nb_parallel,
nb_solutions_per_ts=args.nb_pop,
auxiliary_indices=[i for i in range(args.auxiliary)],
)
elif args.category == "ecg":
## Si Ying lit ceci, c'est qu'elle procrastine sur son papier (:D)
(
main_loss_distribution,
main_pc_distribution,
main_elapsed_time,
main_random_seed,
) = compute_ermpi_parallel(
initializer=lambda: ecg_initializer(
args.n, args.m, args.average, args.random
),
metadata_logger=ecg_metadata_logger,
distance_fun=distance_fun,
time_limit=args.time_limit,
genetic=args.genetic,
save_folder=save_folder,
verbose=args.verbose,
nb_jobs=args.nb_parallel,
nb_solutions_per_ts=args.nb_pop,
auxiliary_indices=[i for i in range(args.auxiliary)],
)
elif args.category == "acm":
(
main_loss_distribution,
main_pc_distribution,
main_elapsed_time,
main_random_seed,
) = compute_ermpi_parallel(
initializer=lambda: acm_initializer(
args.n, args.m, args.average, args.random
),
metadata_logger=acm_metadata_logger,
distance_fun=distance_fun,
time_limit=args.time_limit,
genetic=args.genetic,
save_folder=save_folder,
verbose=args.verbose,
nb_jobs=args.nb_parallel,
nb_solutions_per_ts=args.nb_pop,
auxiliary_indices=[i for i in range(args.auxiliary)],
)
if args.verbose:
print("Loss distribution:", main_loss_distribution)
print("Pearson correlation distribution:", main_pc_distribution)
# Save the results to a file
if args.save:
file_name = "results"
data = {
"args": vars(args),
"random_seed": main_random_seed,
"n_ts": sum(main_loss_distribution.values()),
"interrupted": INTERRUPTED,
"n_processes": os.cpu_count(),
"elapsed_time": int(main_elapsed_time),
"category": args.category,
"distance_fun": distance_fun.__name__,
"loss_distribution": main_loss_distribution,
"pc_distribution": main_pc_distribution,
}
file_path = os.path.join(save_folder, f"{results_prefix}.json")
with open(file_path, "w") as f:
data["date"] = current_time
json.dump(data, f, indent=4)
# Free the memory
del data
del main_loss_distribution
del main_pc_distribution
if args.verbose:
print(f"Results saved to {file_path}", flush=True)
process_ermpi_results(
folder=save_folder,
prudent=False,
threshold=args.percentage,
verbose=args.verbose,
)
# Clear the logs
if args.verbose:
os.remove(f"{ermpi_log_dir}/{ermpi_log_name}{ermpi_log_suffix}")
print("Done.", flush=True)
# ----------------------------------------------------------------------------------------------------
...@@ -28,12 +28,17 @@ def precompute_distances(x, m, n_mp, distance_function): ...@@ -28,12 +28,17 @@ def precompute_distances(x, m, n_mp, distance_function):
return distance_cache return distance_cache
def objective_function(xdict, mp, m, distance_function, coeff_dist, coeff_identity): def objective_function(xdict, mp, m, distance_function, coeff_dist, coeff_identity, auxiliary=[]):
x = xdict['x'] x = xdict['x']
n_mp = len(mp) n_mp = len(mp)
distance_cache = precompute_distances(x, m, n_mp, distance_function) distance_cache = precompute_distances(x, m, n_mp, distance_function)
auxiliary_loss = 0
if len(auxiliary)>0:
for index,value in auxiliary:
auxiliary_loss += (x[index]-value) ** 2
# Compute distance_loss using vectorized operations # Compute distance_loss using vectorized operations
distance_indices = mp[:, 1].astype(int) distance_indices = mp[:, 1].astype(int)
distance_losses = mp[:, 0] - distance_cache[np.arange(n_mp), distance_indices] distance_losses = mp[:, 0] - distance_cache[np.arange(n_mp), distance_indices]
...@@ -45,7 +50,7 @@ def objective_function(xdict, mp, m, distance_function, coeff_dist, coeff_identi ...@@ -45,7 +50,7 @@ def objective_function(xdict, mp, m, distance_function, coeff_dist, coeff_identi
identity_loss = np.sum(identity_diff) identity_loss = np.sum(identity_diff)
# print(f"distance_loss: {distance_loss}, identity_loss: {identity_loss}") # print(f"distance_loss: {distance_loss}, identity_loss: {identity_loss}")
fail = False # Set this to True if there's an issue in computation fail = False # Set this to True if there's an issue in computation
return {'f': coeff_dist*distance_loss + coeff_identity*identity_loss}, fail return {'f': coeff_dist*distance_loss + coeff_identity*identity_loss + 100 * auxiliary_loss}, fail
def objective_function_bis(xdict, mp, m, distance_function): def objective_function_bis(xdict, mp, m, distance_function):
x = xdict['x'] x = xdict['x']
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment