Whisper is not learning a new tokenizer, even when i make test and train dataset the same

System Info

  • transformers version: 4.35.2
  • Platform: Linux-5.15.120±x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.19.3
  • Safetensors version: 0.4.0
  • Accelerate version: 0.24.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.0+cu118 (False)
  • Tensorflow version (GPU?): 2.14.0 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.7.5 (cpu)
  • Jax version: 0.4.20
  • JaxLib version: 0.4.20
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@sanchit-gandhi

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

Hello, I want to take the audio at my workplace and transform it into a transcription; however, with base whisper, it seems as though it isn’t that good. So, I have been wanting to create my own tokenizer that can understand jargon and output that jargon better. Stuff similar to acronyms. Below I have shown my steps in

  1. Creating Tokenizer
  2. Preprocessing data pipeline
  3. Model init, and configuration
  4. Model outputs

I run this using huggingface trainer, with the generate option. Is it my data size? i have scoured online to try and find some sort of solution but they all just say it works. I am at my wits end and would appreciate any help on getting this tokenizer to learn my jargon.

Thank you in advance :slight_smile:

Creating the tokenizer

from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers
from transformers import WhisperTokenizer

# Initialize a tokenizer
tokenizer = Tokenizer(models.BPE())

# Pre-tokenizer responsible for converting the text to a stream of characters
tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()#ByteLevel(add_prefix_space=False)

# Decoder responsible for converting the tokens back to a string
tokenizer.decoder = decoders.ByteLevel()

# Trainer responsible for training the BPE model
tokenizer.trainers = trainers.BpeTrainer(vocab_size=1000, min_frequency=2 , special_tokens=spec_tok)

# Training the tokenizer
tokenizer.train(["file.txt"])

# Save the tokenizer
tokenizer.save("NewWhisperTokenizer.json")

f = open('NewWhisperTokenizer.json')

# returns JSON object as
# a dictionary
data = json.load(f)
with open("vocab.json", "w") as outfile:
    json.dump(data['model']['vocab'], outfile)
with open("merges.txt", "w") as outfile:
    json.dump(data['model']['merges'], outfile)


tokenizer = WhisperTokenizer("vocab.json", "merges.txt" , errors = "replace", unk_token = "<|endoftext|>", bos_token = "<|endoftext|>", eos_token = "<|endoftext|>", pad_token = "<|endoftext|>")
tokenizer.add_special_tokens(WhisperTokenizer.from_pretrained("openai/whisper-tiny").special_tokens_map_extended)
tokenizer.save_pretrained("new_tok")

len(tokenizer) == 193

Preprocessing steps

def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    temp_labels = tokenizer(batch["phonetic_detail"]["utterance"]).input_ids
    batch["label"] = [label for sentence_labels in temp_labels for label in sentence_labels]
    return batch

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    tokenizer: Any
    feature_extractor: Any
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.feature_extractor.pad(input_features, return_tensors="pt")
        label_features = [{"input_ids": feature["label"]} for feature in features]
        labels_batch = self.tokenizer.pad(label_features, return_tensors="pt")


        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        if (labels[:, 0] == self.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(tokenizer , feature_extractor)

len(train_dataset) == 4000
len(test_dataset) == 1000

Model Config

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

voc = tokenizer.get_vocab()

model_Gen = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model_Gen = model_Gen.to(device)

model_Gen.resize_token_embeddings(len(tokenizer))

model_Gen.config.pad_token_id = tokenizer.pad_token_id
model_Gen.config.decoder_start_token_id = voc['<|startoftranscript|>']
model_Gen.config.eos_token_id = tokenizer.eos_token_id
model_Gen.config.bos_token_id = tokenizer.bos_token_id
model_Gen.config.suppress_tokens = []
model_Gen.config.forced_decoder_ids = None
model_Gen.config.begin_suppress_tokens = [
    tokenizer.pad_token_id
  ]

model_Gen.generation_config.pad_token_id = tokenizer.pad_token_id
model_Gen.generation_config.decoder_start_token_id = voc['<|startoftranscript|>']
model_Gen.generation_config.eos_token_id = tokenizer.eos_token_id
model_Gen.generation_config.bos_token_id = tokenizer.bos_token_id
model_Gen.generation_config.suppress_tokens = []
model_Gen.generation_config.forced_decoder_ids = None
model_Gen.generation_config.begin_suppress_tokens = [
    tokenizer.pad_token_id
  ]

model_Gen.generation_config.no_timestamps_token_id = voc['<|notimestamps|>']

Huggingface Trainer

Here I have made the dataset the same 30 examples to see if it would give me complete overprediction, but even with setting train and test to be the same, it is not overfitting at all.

training_args = Seq2SeqTrainingArguments(
  output_dir='training_output',
  logging_dir='./logs',
  group_by_length=True,
  per_device_train_batch_size=1,
  gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
  per_device_eval_batch_size=1,
  num_train_epochs=8,
  gradient_checkpointing=True,
  lr_scheduler_type = "cosine_with_restarts",
  save_strategy='epoch',
  evaluation_strategy='epoch',
  logging_strategy='epoch',
  learning_rate=1e-2,
  weight_decay=0.005,
  # warmup_steps=36,
  save_total_limit=4,
  push_to_hub=False,
  predict_with_generate=True,
  generation_max_length=225,
  load_best_model_at_end=True,
  greater_is_better=False,
  generation_num_beams = 4,
  # fp16 = True,

  report_to="wandb", # Turn this off for pdb debug

)

trainer = CustomTrainer(
    compute_metrics=compute_metrics,
    args=training_args,
    model=model_Gen,
    data_collator=data_collator,
    tokenizer=processor.feature_extractor,
    train_dataset=new_test['test'] ,
    eval_dataset=new_test['test'],
)

trainer.evaluate()

Outputs after second epoch

tokenizer.batch_decode(pred.predictions , skip_special_tokens = True)
['', '', 'uwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuw', 'k', '', 'k', 'kkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk', 
'awawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawaw', 'awawawaw', '', '', '', 'jjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjj', '', 'jjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjj', 'uweuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuw', '', 
'axaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxaxax', '', 
'kuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhkuhk', 
'eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 
'eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
 'awawawaw', 
'eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 
'awawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawawaw',
 '', 
'jjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjj']

Expected behavior

More understandable text descriptions