diff --git a/ivas_processing_scripts/__init__.py b/ivas_processing_scripts/__init__.py index 418ccfcca42e4b43cc6a618c45e5c1340cdeb554..360148700d3af844bb571bd70715ff151945e6f3 100755 --- a/ivas_processing_scripts/__init__.py +++ b/ivas_processing_scripts/__init__.py @@ -34,8 +34,6 @@ import logging import sys from itertools import product from multiprocessing import Pool -from pathlib import Path -from shutil import rmtree from time import sleep from ivas_processing_scripts.audiotools.metadata import ( @@ -100,33 +98,13 @@ def main(args): # set up processing chains chains.init_processing_chains(cfg) - # set up logging - logger = logging_init(args, cfg) - - if cfg.delete_output: - deletion_list = [d for d in [*cfg.out_dirs, *cfg.tmp_dirs] if Path(d).exists()] - if deletion_list: - logger.warning( - "\nWARNING! The configuration key to delete output directories was specified!" - ) - logger.warning( - f"The following directories will be REMOVED from {cfg.output_path}:\n {', '.join([d.name for d in deletion_list])}\n" - ) - confirm = input( - "Are you sure you want to delete these? Type 'YES' in capitals to confirm deletion: " - ) - if confirm == "YES": - for dir in deletion_list: - rmtree(dir) - else: - logger.warning( - "Deletion was canceled. Please remove the output directories manually." - ) - # context manager to create output directories and clean up temporary directories with DirManager( cfg.out_dirs + cfg.tmp_dirs, cfg.tmp_dirs if cfg.delete_tmp else [] ): + # set up logging + logger = logging_init(args, cfg) + # Re-ordering items based on concatenation order if hasattr(cfg, "preprocessing_2"): if ( diff --git a/ivas_processing_scripts/audiotools/__init__.py b/ivas_processing_scripts/audiotools/__init__.py index c8f6b1703b2b259f1b4bfb7b1df194ac7888ef4a..f3b599d19f836086ae2674d02198017d8850fc67 100755 --- a/ivas_processing_scripts/audiotools/__init__.py +++ b/ivas_processing_scripts/audiotools/__init__.py @@ -39,7 +39,7 @@ from ivas_processing_scripts.audiotools.constants import ( BINAURAL_LFE_GAIN, ) from ivas_processing_scripts.audiotools.convert import convert_file -from ivas_processing_scripts.utils import apply_func_parallel +from ivas_processing_scripts.utils import apply_func_parallel, parse_gain def add_processing_args(group, input=True): @@ -51,23 +51,6 @@ def add_processing_args(group, input=True): p = "out" ps = "o" - # validation function(s) - def parse_gain(g: str) -> float: - g = g.strip() - try: - if g.lower().endswith("db"): - g = float(g[:-2].strip()) - g = 10 ** (g / 20) - else: - g = float(g) - - except ValueError: - raise argparse.ArgumentTypeError( - f"Invalid gain value '{g}' specified. Must be a number or a number suffixed with dB" - ) - - return g - group.add_argument( f"-{ps}", f"--{p}", diff --git a/ivas_processing_scripts/processing/chains.py b/ivas_processing_scripts/processing/chains.py index 111c920475a53082765291fd3a388d089c1d3774..21b20a0e4dbdcf33fccc738d159290f0dd6a1605 100755 --- a/ivas_processing_scripts/processing/chains.py +++ b/ivas_processing_scripts/processing/chains.py @@ -30,7 +30,8 @@ # the United Nations Convention on Contracts on the International Sales of Goods. # -from shutil import copyfile +from pathlib import Path +from shutil import copyfile, rmtree from typing import Optional from warnings import warn @@ -46,7 +47,7 @@ from ivas_processing_scripts.processing.preprocessing_2 import Preprocessing2 from ivas_processing_scripts.processing.processing_splitting_scaling import ( Processing_splitting_scaling, ) -from ivas_processing_scripts.utils import get_abs_path, list_audio +from ivas_processing_scripts.utils import get_abs_path, list_audio, parse_gain def init_processing_chains(cfg: TestConfig) -> None: @@ -104,14 +105,17 @@ def init_processing_chains(cfg: TestConfig) -> None: f"Directory {cfg.input_path} does not exist, contains no audio files or all files were filtered out." ) - # validate input files for correct format and sampling rate - validate_input_files(cfg) - # assemble a list of output and temporary directories to create for chain in cfg.proc_chains: cfg.out_dirs.append(cfg.output_path.joinpath(chain["name"])) cfg.tmp_dirs.append(cfg.output_path.joinpath(f"tmp_{chain['name']}")) + # delete output files if requested + clean_outputs(cfg) + + # validate input files for correct format and sampling rate + validate_input_files(cfg) + def get_preprocessing(cfg: TestConfig) -> dict: """Mapping from test configuration to preprocessing keyword arguments""" @@ -140,8 +144,8 @@ def get_preprocessing(cfg: TestConfig) -> dict: "in_loudness": pre_cfg.get("loudness"), "in_loudness_fmt": pre_cfg.get("loudness_fmt", post_fmt), "in_mask": pre_cfg.get("mask", None), - "in_gain_pre": pre_cfg.get("gain_pre"), - "out_gain_post": pre_cfg.get("gain_post"), + "in_gain_pre": parse_gain(pre_cfg.get("gain_pre")), + "out_gain_post": parse_gain(pre_cfg.get("gain_post")), "multiprocessing": cfg.multiprocessing, } ) @@ -567,8 +571,8 @@ def get_processing_chain( { "in_fs": tmp_in_fs, "in_fmt": tmp_in_fmt, - "in_gain_pre": post_cfg.get("gain_pre"), - "out_gain_post": post_cfg.get("gain_post"), + "in_gain_pre": parse_gain(post_cfg.get("gain_pre")), + "out_gain_post": parse_gain(post_cfg.get("gain_post")), "out_fs": post_cfg.get("fs"), "out_fmt": post_fmt, "out_cutoff": tmp_lp_cutoff, @@ -622,7 +626,7 @@ def validate_input_files(cfg: TestConfig): if input_format.startswith("ISM") or input_format.startswith("MASA"): frame_alignment = "error" - if cfg.input["frame_alignment"] == "padding": + if frame_alignment == "padding": # Create new input directory for padded files output_dir = cfg.output_path / "20ms_aligned_files" try: @@ -703,3 +707,33 @@ def validate_input_files(cfg: TestConfig): if frame_alignment == "padding": # Make the output path as the new input path cfg.input_path = output_dir + + +def clean_outputs(cfg: TestConfig) -> None: + if cfg.delete_output: + deletion_list = [ + d + for d in [ + *cfg.out_dirs, + *cfg.tmp_dirs, + cfg.output_path.joinpath("20ms_aligned_files"), + ] + if Path(d).exists() + ] + if deletion_list: + warn( + "\nWARNING! The configuration key to delete output directories was specified!" + ) + warn( + f"The following directories will be REMOVED from {cfg.output_path}:\n {', '.join([d.name for d in deletion_list])}\n" + ) + confirm = input( + "Are you sure you want to delete these? Type 'YES' in capitals to confirm deletion: " + ) + if confirm == "YES": + for dir in deletion_list: + rmtree(dir) + else: + print( + "Deletion was canceled. Please remove the output directories manually." + ) diff --git a/ivas_processing_scripts/utils.py b/ivas_processing_scripts/utils.py index b1104096a1605e01e7ca2c6dbd09ee89f03ebb79..d9af9e4641be413a13e692cc958d2e6633613f24 100755 --- a/ivas_processing_scripts/utils.py +++ b/ivas_processing_scripts/utils.py @@ -320,3 +320,23 @@ def get_abs_path(rel_path): else: abs_path = None return abs_path + + +def parse_gain(g: str) -> float: + if g is None: + return None + + g = g.strip() + try: + if g.lower().endswith("db"): + g = float(g[:-2].strip()) + g = 10 ** (g / 20) + else: + g = float(g) + + except ValueError: + raise ValueError( + f"Invalid gain value '{g}' specified. Must be a number or a number suffixed with dB" + ) + + return g