Skip to content

Overrides with !new that also require arguments are not working #28

@egaznep

Description

@egaznep

Context: I want to instantiate a SpeechBrain pretrained model derivative. This should be possible according to

Overrides
In order to run experiments with various values for a hyperparameter, we have a system for overriding the values that are listed in the yaml file.

overrides = {"foo": 7}
fake_file = """
foo: 2
bar: 5
"""
load_hyperpyyaml(fake_file, overrides)
As shown in this example, overrides can take an ordinary python dictionary. However, this form does not support python objects. To override a python object, overrides can also take a yaml-formatted string with the HyperPyYAML syntax.

load_hyperpyyaml(fake_file, "foo: !new:collections.Counter")


Minimal example:

device = "cuda" if torch.cuda.is_available() else "cpu"
classifier: EncoderClassifier = EncoderClassifier.from_hparams( # type: ignore
    source="speechbrain/spkrec-ecapa-voxceleb",
    run_opts={"device":device},
    overrides=
'''
embedding_model: !new:ecapa.ECAPA_TDNN
    input_size: !ref <n_mels>
    channels: [1024, 1024, 1024, 1024, 3072]
    kernel_sizes: [5, 3, 3, 3, 1]
    dilations: [1, 2, 3, 4, 1]
    attention_channels: 128
    lin_neurons: 192
'''
)

This codeblock should instantiate an ECAPA_TDNN instance that's defined inside some local ecapa.py, or it should fail if this is not the case. However, regardless of whether a local ECAPA_TDNN definition exists or not, this silently fails and returns an identical outcome to the following:

device = "cuda" if torch.cuda.is_available() else "cpu"
classifier: EncoderClassifier = EncoderClassifier.from_hparams( # type: ignore
    source="speechbrain/spkrec-ecapa-voxceleb",
    run_opts={"device":device},
)

Tracking the issue, I think this is the problem: the method recursive_update will recursively update the entries of hparams['embedding_model'] but it doesn't copy the new tag, if there's one.

for k, v in u.items():
if isinstance(v, collections.abc.Mapping) and k in d:
recursive_update(d.get(k, {}), v)
elif must_match and k not in d:
raise KeyError(f"Override '{k}' not found in: {[key for key in d.keys()]}")
else:
d[k] = v

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions