diff --git a/generate_test.py b/generate_test.py index 141b11063c611a99668dd3346a5ca4fb693284db..6233b4211cb46bdf3e19974e883a9468765df2a8 100755 --- a/generate_test.py +++ b/generate_test.py @@ -42,8 +42,8 @@ EXPERIMENTS_P800 = [f"P800-{i}" for i in range(1, 10)] EXPERIMENTS_BS1534 = [f"BS1534-{i}{x}" for i in range(1, 8) for x in ["a", "b"]] LAB_IDS = ["a", "b", "c", "d"] IN_FMT_FOR_MASA_EXPS = { - "P800-8": dict(zip([f"cat{i}" for i in range(1, 7)], ["FOA"] * 6)), - "P800-9": dict(zip([f"cat{i}" for i in range(1, 7)], ["FOA"] * 6)), + "P800-8": {"cat1": "FOA", "cat2": "FOA", "cat3": "FOA", "cat4": "FOA", "cat5": "FOA", "cat6": "FOA"}, + "P800-9": {"cat1": "FOA", "cat2": "FOA", "cat3": "FOA", "cat4": "FOA", "cat5": "FOA", "cat6": "FOA"}, "BS1534-7a": {"cat1": "FOA", "cat2": "HOA2"}, "BS1534-7b": {"cat1": "FOA", "cat2": "HOA2"}, } @@ -102,7 +102,7 @@ def create_experiment_setup(experiment, lab) -> list[Path]: input_path = base_path.joinpath("proc_input").joinpath(cat) output_path = base_path.joinpath("proc_output").joinpath(suffix) bg_noise_path = base_path.joinpath("background_noise").joinpath( - f"background_noise_{suffix}.wav" + f"background_noise_{cat}.wav" ) cfg_path = default_cfg_path.parent.joinpath(f"{experiment}{cat}-lab_{lab}.yml") cfgs.append(cfg_path) @@ -113,15 +113,18 @@ def create_experiment_setup(experiment, lab) -> list[Path]: cfg.prerun_seed = seed cfg.input_path = str(input_path) cfg.output_path = str(output_path) + + cat_num = int(cat[-1]) if ( bg_noise_pre_proc_2 := cfg.preprocessing_2.get("background_noise", None) ) is not None: bg_noise_pre_proc_2["background_noise_path"] = str(bg_noise_path) # bg noise SNR only differs from default config for some experiments - cat_num = int(cat[-1]) if experiment in ["P800-5", "P800-9"] and cat_num >= 3: bg_noise_pre_proc_2["snr"] = 15 + if cfg.preprocessing_2.get("concatenate_input", None) is not None: + cfg.preprocessing_2["concatenation_order"] = concatenation_order(lab, experiment, cat_num) # for MASA, the input format can differ between categories if (fmt_for_category := IN_FMT_FOR_MASA_EXPS.get(experiment, None)) is not None: @@ -158,6 +161,11 @@ def exp_lab_pair(arg): return exp, lab +def concatenation_order(lab_id, experiment, category): + exp_id = f"p0{experiment[-1]}" + return [f"{lab_id}{exp_id}a{category}s0{i}.wav" for i in range(1, 8)] + + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Generate config files and process files for selecton experiments. Experiment names and lab ids must be given as comma-separated pairs (e.g. 'P800-5,b BS1534-4a,d ...')" diff --git a/ivas_processing_scripts/processing/processing.py b/ivas_processing_scripts/processing/processing.py index bcc8330fdeb0d2b120748ad612f65290a8eed404..8da8593ee3030e2eb00f6599c515cee672de10e2 100755 --- a/ivas_processing_scripts/processing/processing.py +++ b/ivas_processing_scripts/processing/processing.py @@ -83,6 +83,10 @@ def reorder_items_list(items_list: list, concatenation_order: list) -> list: Re-ordered list of input items """ name_to_full = {Path(full_file).name: full_file for full_file in items_list} + + if set(name_to_full.keys()) != set(concatenation_order): + raise ValueError(f"Items given in concatenation_order {concatenation_order} are not identical to what was found in the input folder {name_to_full.keys()}") + ordered_full_files = [ name_to_full[name] for name in concatenation_order if name in name_to_full ] diff --git a/ivas_processing_scripts/utils.py b/ivas_processing_scripts/utils.py index 1ad9319f7aa3b777cdd119f7ef50ac41cc82f6bd..6d29ba96f1a847b5c1934d5cfca87a07e209e014 100755 --- a/ivas_processing_scripts/utils.py +++ b/ivas_processing_scripts/utils.py @@ -127,6 +127,9 @@ def list_audio(path: str, select_list: list = None) -> list: f for f in audio_list if any([pattern in f.stem for pattern in select_set]) ] + # sort file list alphanumerically by filenames + audio_list = sorted(audio_list, key=lambda p: p.name) + return audio_list diff --git a/tests/test_experiments.py b/tests/test_experiments.py index b8e62114fde2b4867a56dfb7edcc11d6183a9c05..51d50e635c56683043b94f0d6c119eb28b5ca059 100644 --- a/tests/test_experiments.py +++ b/tests/test_experiments.py @@ -77,6 +77,7 @@ def setup_input_files_for_config(config): dummy_md_files = FORMAT_TO_METADATA_FILES.get(input_fmt, list()) # copy input files + files_copied = list() for f in dummy_input_files: f_out = input_path.joinpath(f.name).resolve().absolute() # need at least 2s of input files for gen-patt to be happy (can not keep the tolerance for 50 frames only) @@ -86,6 +87,8 @@ def setup_input_files_for_config(config): md_f_out = ".".join([str(f_out), suffix]) shutil.copy(md_f, md_f_out) + files_copied.append(f_out.name) + # create background noise files with white noise if "background_noise" in config.preprocessing_2: # always set the same seed to have reproducible test noises @@ -98,6 +101,8 @@ def setup_input_files_for_config(config): ).absolute() write(bg_noise_path, noise) + return files_copied + def all_lengths_equal(cfg): output_folder = cfg.output_path @@ -132,7 +137,12 @@ def test_generate_test_items(exp_lab_pair): args = Arguments(str(cfg)) config = TestConfig(cfg) - setup_input_files_for_config(config) + input_filenames = setup_input_files_for_config(config) + # patch concatenation order + if config.preprocessing_2.get("concatenate_input", None) is not None: + config.preprocessing_2["concatenation_order"] = sorted(input_filenames) + config.to_file(cfg) + generate_test(args) if not all_lengths_equal(config):