Calling Silero VAD model from dataset.map

Hello,

Iā€™m attempting to utilize a PyTorch model, specifically the Silero VAD pre-trained enterprise-grade Voice Activity Detector, from my primary dataset using the map function and multiple workers for parallel processing. Hereā€™s the relevant code snippet:

import torch
torch.set_num_threads(1)

def process_row(self, row, index, model, get_speech_timestamps, collect_chunks):

    # do some stuff .....
    # read audio file (wav) saved in row["ref_wave"],
    # apply VAD
     speech_timestamps = get_speech_timestamps(wav, model, sampling_rate=16000)
     # do some stuff .....
    return {'speech_timestamps ': speech_timestamps}

model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
                              model='silero_vad',
                              force_reload=True,
                              onnx=USE_ONNX)

(get_speech_timestamps,
 save_audio,
 read_audio,
 VADIterator,
 collect_chunks) = utils

# load data
dataset = load_dataset("csv", data_files='path_to_csv', split="train")

dataset = dataset.map(
            self.process_row,
            num_proc=8,
            with_indices=True,
            batched=False,
            remove_columns=dataset.column_names,
            fn_kwargs={
                "model_vad": model,
                "get_speech_timestamps": get_speech_timestamps,
                "collect_chunks": collect_chunks,
            },
            desc="Processing Data.....",
        )

However, Iā€™m encountering an issue that generates the following error message:

...transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. 
.....
RuntimeError: Tried to serialize object __torch__.vad.model.vad_annotator.VADRNNJITMerge which does not have a __getstate__ method defined!

Iā€™m uncertain about how to properly serialize the PyTorch object for this purpose. Any assistance you can provide would be greatly appreciated. Thank you for your time. Best regards.

You should be able to avoid this issue by defining the model serializer (before running the map) as follows:

import copyreg
import os

def pickle_model(model):
  if not os.path.exists("model_scripted.pt"):
    model.save("model_scripted.pt")
  return torch.jit.load, ("model_scripted.pt",)

copyreg.pickle(type(model), pickle_model)

Thank you very much for your response.
I used the pickle_model function but encountered an error:

ā€¦\lib\site-packages\dill_dill.py", line 432, in find_class return StockUnpickler.find_class(self, module, name)
ā€œModuleNotFoundError: No module named ā€˜utilsā€™ā€.

This is weird, as I already defined ā€˜utilsā€™ before calling the copyreg.pickle(type(model), pickle_model).