-
Paul Wawerek-López authoredPaul Wawerek-López authored
__main__.py 6.92 KiB
"""Benchmark evaluations from compressai.utils.bench with S-PSNR and WS-PSNR evaluation."""
import argparse
import json
import multiprocessing as mp
import sys
from collections import defaultdict
from itertools import starmap
from pathlib import Path
from typing import List
from math import nan
from .codecs import AV1, BPG, HM, JPEG, JPEG2000, TFCI, VTM, Codec, WebP
# from torchvision.datasets.folder
IMG_EXTENSIONS = (
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
)
codecs = [JPEG, WebP, JPEG2000, BPG, TFCI, VTM, HM, AV1]
# we need the quality index (not value) to compute the stats later
def func(codec, i, *args):
rv = codec.run(*args)
return i, rv
def collect(
codec: Codec,
dataset: str,
qps: List[int],
metrics: List[str],
num_jobs: int = 1,
out_dir: str = None,
):
filepaths = []
if Path(dataset).is_dir(): # take all images in the directory
for ext in IMG_EXTENSIONS:
filepaths.extend(Path(dataset).rglob(f"*{ext}"))
elif Path(dataset).is_file():
with open(dataset, "r") as f: # take images given in file
lines = f.read().splitlines()
root_dir = lines[1]
root_dir = Path('').joinpath(*root_dir.split('/'))
img_names = lines[2:]
for img_name in img_names:
if img_name.endswith(IMG_EXTENSIONS):
filepaths.append(root_dir.joinpath(img_name))
else: # neither file nor directory
raise OSError(f"No such directory: {dataset}")
pool = mp.Pool(num_jobs) if num_jobs > 1 else None
if len(filepaths) == 0:
print("No images found in the dataset directory")
sys.exit(1)
args = [
(codec, i, f, q, metrics) for i, q in enumerate(qps) for f in sorted(filepaths)
]
if pool:
rv = pool.starmap(func, args)
else:
rv = list(starmap(func, args))
results = [defaultdict(float) for _ in range(len(qps))]
if out_dir:
out_dir_codec = Path(out_dir).joinpath(codec.name)
if not out_dir_codec.is_dir():
out_dir_codec.mkdir(parents=True, exist_ok=True)
# list of shape (len(qp), len(images))
rv_resort = [[] for _ in range(len(qps))]
for i, metrics in rv:
rv_resort[i].append(metrics)
psnr_list = [
"s-psnr-rgb",
"ws-psnr-rgb",
"psnr-rgb",
]
nameWidth, metricWidth, rateWidth = len(filepaths[0].stem), 22, 30
for i, (q, img_metrics) in enumerate(zip(qps, rv_resort)):
fp = out_dir_codec.joinpath(f"rate_metrics_rec_q_{q:02d}.txt")
if not fp.is_file(): # create file with header
with open(fp, "w") as f:
fmt_metric = '{:^' + str(metricWidth) + 's}'
fmt_rate = '{:<' + str(rateWidth) + 's}'
header = "#" + "name".center(nameWidth - 1) + \
(fmt_metric * 4).format("SPSNR Color", "WSPSNR Color", "PSNR Color", "MSSSIM Color") +\
(fmt_rate).format("Rate")
f.write(header+"\n")
header_units = "#" + " " * (nameWidth - 1) + \
(fmt_metric * 4).format("(dB)", "(dB)", "(dB)", "") + \
(fmt_rate).format("(bpp)")
f.write(header_units+"\n")
fmt_psnr = '{:^' + str(metricWidth) + '.2f}'
fmt_ssim = '{:^' + str(metricWidth) + '.4f}'
fmt_rate = '{:<' + str(rateWidth) + '}'
# check existing files
imgs_tested = []
with open(fp, "r") as f:
for line in f:
if line.startswith("#"): continue
img_name = line.split()[0] # caution: does not check if all metrics were calculated
imgs_tested.append(img_name)
with open(fp, "a") as f:
for img_fp, metrics in zip(filepaths, img_metrics):
img_name = img_fp.stem
if img_name in imgs_tested: continue
f.write(
img_name.center(nameWidth) + (fmt_psnr * len(psnr_list)).format(*[metrics.get(m, nan) for m in psnr_list]) + \
fmt_ssim.format(metrics.get("ms-ssim-rgb", nan)) + fmt_rate.format(metrics["bpp"]) + "\n"
)
# aggregate results for all images
for i, metrics in rv:
for k, v in metrics.items():
if isinstance(v, float):
results[i][k] += v
for i, _ in enumerate(results):
for k, v in results[i].items():
results[i][k] = v / len(filepaths)
# list of dict -> dict of list
out = defaultdict(list)
for r in results:
for k, v in r.items():
out[k].append(v)
return out
def setup_args():
description = "Collect codec metrics."
parser = argparse.ArgumentParser(description=description)
subparsers = parser.add_subparsers(dest="codec", help="Select codec")
subparsers.required = True
return parser, subparsers
def setup_common_args(parser):
parser.add_argument(
"-d",
"--dataset",
type=str,
help="Path to the dataset directory or text file containing the directory path and image file names",
)
parser.add_argument(
"-o",
"--out-dir",
dest="out_dir",
type=str,
help="Path to directory where to store rates and metrics for each image (default: %(default)s)",
)
parser.add_argument(
"-j",
"--num-jobs",
type=int,
metavar="N",
default=1,
help="number of parallel jobs (default: %(default)s)",
)
parser.add_argument(
"-q",
"--qps",
dest="qps",
type=str,
default="75",
help="list of quality/quantization parameter (default: %(default)s)",
)
parser.add_argument(
"--metrics",
dest="metrics",
default=["psnr-rgb", "ms-ssim-rgb", "s-psnr-rgb", "ws-psnr-rgb"],
nargs="+",
help="choose metrics from [psnr-rgb, ms-ssim-rgb, s-psnr-rgb, ws-psnr-rgb] (use for very small images)",
)
def main(argv):
parser, subparsers = setup_args()
for c in codecs:
cparser = subparsers.add_parser(c.__name__.lower(), help=f"{c.__name__}")
setup_common_args(cparser)
c.setup_args(cparser)
args = parser.parse_args(argv)
codec_cls = next(c for c in codecs if c.__name__.lower() == args.codec)
codec = codec_cls(args)
qps = [int(q) for q in args.qps.split(",") if q]
results = collect(
codec,
args.dataset,
sorted(qps),
args.metrics,
args.num_jobs,
args.out_dir,
)
output = {
"name": codec.name,
"description": codec.description,
"results": results,
}
print(json.dumps(output, indent=2))
if __name__ == "__main__":
main(sys.argv[1:])