Trying to fine tune a wav2vec2 model for classification task on 3 labels, I am running out of memory in colab.
All my initializations are lazy and i cant understand why its taking so much space of RAM. My model is hardly 500 MBs and model training is being done on a batch size of 16.
hindi_ds = load_dataset("SPRINGLab/IndicVoices-R_Hindi",split="train",streaming=True)
tamil_ds=load_dataset("SPRINGLab/IndicVoices-R_Tamil",split="train",streaming=True)
bengali_ds=load_dataset("SPRINGLab/IndicVoices-R_Bengali",split="train",streaming=True)
small_hindi = hindi_ds.take(500)
small_tamil = tamil_ds.take(500)
small_bengali = bengali_ds.take(500)
small_bengali=small_bengali.map(add_lang) # function simply changed lang column dtype to int
small_hindi=small_hindi.map(add_lang) # function simply changed lang column dtype to int
small_tamil=small_tamil.map(add_lang) # function simply changed lang column dtype to int
ds=interleave_datasets([small_tamil,small_hindi,small_bengali])
ds=ds.cast_column("audio",Audio(sampling_rate=16000))
*All Lazy Initialisation Till Here. Shouldnt take much memory*
model_id="facebook/wav2vec2-base-960h"
extractor=Wav2Vec2FeatureExtractor.from_pretrained(model_id)
def feature_extractor(batch):
  audio_arrays=[example['array'] for example in batch["audio"]]
  inputs=extractor(audio_arrays,
                     sampling_rate=extractor.sampling_rate,
                     return_attention_mask=True,
                   return_tensors="pt",
                   padding=True,
                  max_length=160000,
                 truncation=True
                     )
  return inputs
ds_encoded=ds.map(feature_extractor,
                  remove_columns=cols_to_remove,
                  batched=True,
                  batch_size=50,
                  )
*For a batch size of 50, memory requirement would be approx (50X160000X2)bytes=16 MB
model=AutoModelForAudioClassification.from_pretrained(model_id,
                                                      num_labels=3
                                                      )
ds_encoded = ds_encoded.rename_column("lang", "label")
model_name="wave2vec2-base-960h"
training_args=TrainingArguments(
    f"{model_name}-finetuned-springlab-tamil-hindi-bengali",
    save_strategy='epoch',
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    num_train_epochs=10,
    logging_strategy="epoch",
    fp16="True",
    max_steps=500
)
trainer = Trainer(
    model,
    training_args,
    train_dataset=ds_encoded,
    processing_class=feature_extractor,
)
trainer.train()
Can anyone please comment..