[HELP]Bart summarization output exactly the same as labels

Hii, I’m trying to finetune BART on summariztion task using Tensorflow TPU. I first tokenized the data, stored them in *.tfrecords using the :hugs: datasets export() function, then, created TF Datasets using them, I have given the preprocessing and finetuning code below.
Problem: I am getting exact copy of labels as outputs. Its like BART is not learning anything.
eg:

['The ICSI Meeting Recorder Dialog Act (MRDA) Corpus\nWe describe a new corpus of over 180,000 handannotated dialog act tags and accompanying adjacency pair annotations for roughly 72 hours of speech from 75 naturally-occurring meetings.\nWe provide a brief summary of the annotation system and labeling procedure, inter-annotator reliability statistics, overall distributional statistics, a description of auxiliary files distributed with the corpus, and information on how to obtain the data.', 'Templates-Based Information Extraction without the Templates\nStandard algorithms for template-based information extraction (IE) require predefined template schemas, and often labeled data, to learn to extract slot fillers (e.g., a template).\nThis paper describes an approach to template- based IE that removes this requirement and performs extraction without knowing the template structure in advance.\nOur algorithm instead learns the template structures automatically from raw text, inducing template Schema schemas as sets of linked events associated with semantic roles.\nWe also solve', 'Get Out The Vote: Determining Support Or Opposition From Congressional Floor-Debate Transcripts\nWe investigate whether one can determine from the transcripts of U.S. Congressional floor debates whether the speeches represent support of or opposition to proposed legislation.\nTo address this problem, we exploit the fact that these speeches occur as part of a discussion; this allows us to use sources of information regarding relationships between discourse segments, such as whether a given utterance indicates agreement with the opinion expressed by another.\nWe find that the incorporation of such information yields substantial improvements over classifying speeches']
['The ICSI Meeting Recorder Dialog Act (MRDA) Corpus\nWe describe a new corpus of over 180,000 hand-annotated dialog act tags and accompanying adjacency pair annotations for roughly 72 hours of speech from 75 naturally-occurring meetings.\nWe provide a brief summary of the annotation system and labeling procedure, inter-annotator reliability statistics, overall distributional statistics, a description of auxiliary files distributed with the corpus, and information on how to obtain the data.', 'Template-Based Information Extraction without the Templates\nStandard algorithms for template-based information extraction (IE) require predefined template schemas, and often labeled data, to learn to extract their slot fillers (e.g., an embassy is the Target of a Bombing template).\nThis paper describes an approach to template-based IE that removes this requirement and performs extraction without knowing the template structure in advance.\nOur algorithm instead learns the template structure automatically from raw text, inducing template schemas as sets of linked events (e.g., bombings include detonate, set off, and destroy events) associated with semantic', 'Get Out The Vote: Determining Support Or Opposition From Congressional Floor-Debate Transcripts\nWe investigate whether one can determine from the transcripts of U.S. Congressional floor debates whether the speeches represent support of or opposition to proposed legislation.\nTo address this problem, we exploit the fact that these speeches occur as part of a discussion; this allows us to use sources of information regarding relationships between discourse segments, such as whether a given utterance indicates agreement with the opinion expressed by another.\nWe find that the incorporation of such information yields substantial improvements over classifying speeches in isolation.\nWe present a method based on support']
Validation results:---
 {'rouge1': 78.3223, 'rouge2': 72.4416, 'rougeL': 76.0222, 'rougeLsum': 77.904}

I am using facebook/bart-large-cnn as my checkpoint. At first, I was getting gibberish output. So, I passed in BartConfig.from_pretrained("facebook/bart-large-cnn") to the model, and it started copying the labels. I searched on the forum and found this thread, @valhalla suggested using prepare_seq2seq_batch() at that time. Now, since it is going to be deprecated in transformers 5.x, and the suggested way is to use tokenizer.as_target_tokenizer(), I did that.

The code:

class Config:
    num_epochs=3
    train_batch_size=2
    val_batch_size=4
    test_batch_size=4
    learning_rate=2e-5
    num_warmup_steps=0
    num_beams=4
    max_input_length=1024
    max_target_length=128
    val_max_target_length=None
    ignore_pad_token_for_loss=True
    padding="max_length"
    train_data_len=None
    valid_data_len=None
    test_data_len=None
    num_val_take=6
    num_test_take=6
    num_val_examples=num_val_take * val_batch_size * REPLICAS
    num_test_examples=num_test_take * test_batch_size * REPLICAS

def read_tfrecord(example, max_input_length, max_target_length):
    feature_description = {
        'input_ids': tf.io.FixedLenFeature([max_input_length], tf.int64, default_value=[0]*max_input_length),
        'attention_mask': tf.io.FixedLenFeature([max_input_length], tf.int64, default_value=[0]*max_input_length),
        'decoder_input_ids': tf.io.FixedLenFeature([max_target_length], tf.int64, default_value=[0]*max_target_length),
        'decoder_attention_mask': tf.io.FixedLenFeature([max_target_length], tf.int64, default_value=[0]*max_target_length),
        'labels': tf.io.FixedLenFeature([max_target_length], tf.int64, default_value=[0]*max_target_length),
    }
    
    example = tf.io.parse_single_example(example, feature_description)
    return example, example["labels"]

def preprocess_function(examples, article_column="article", summary_column="target",
                        max_input_length=1024, max_output_length=128, prefix="summarize: "):
    
    inputs = [prefix + article for article in examples[article_column]]
    
    tokenized_inputs = tokenizer(inputs, max_length=max_input_length, padding="max_length", truncation=True)
    
    with tokenizer.as_target_tokenizer():
        tokenized_outputs = tokenizer(examples[summary_column], max_length=max_output_length, padding="max_length", truncation=True)
        
    return {"input_ids": tokenized_inputs["input_ids"],
            "attention_mask": tokenized_inputs["attention_mask"],
            "decoder_input_ids": tokenized_outputs["input_ids"],
            "decoder_attention_mask": tokenized_outputs["attention_mask"],
            "labels": tokenized_outputs["input_ids"]}

model_config = AutoConfig.from_pretrained("facebook/bart-large-cnn")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
with strategy.scope():
     model = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, config=model_config)
     model.compile(optimizer=optimizer,
              loss={"logits": masked_sparse_categorical_crossentropy}) 

metric = load_metric("rouge")

def postprocess_text(preds, labels):
     preds = [pred.strip() for pred in preds]
     labels = [label.strip() for label in labels]

     # rougeLSum expects newline after each sentence
     preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
     labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

     return preds, labels

def eval_fn(model, tokenizer,
            tokenized_tf_dataset,
            tf_dataset_take_size,
            pre_train_val_check=False):

            if Config.val_max_target_length is None:
                Config.val_max_target_length = Config.max_target_length

            gen_kwargs = {
                "max_length": Config.val_max_target_length,
                "num_beams": Config.num_beams,
            }
             
            if pre_train_val_check: # Checks Validation Loop before starting fine-tuning
                tokenized_tf_dataset = tokenized_tf_dataset.take(2)
                total = 2
            else:
                total = tf_dataset_take_size
            
            decoded_labels = None
            decoded_pred = None
            
            for batch, labels in tqdm(
                    tokenized_tf_dataset,
                    total=total,
                    unit="batchs"
                ):
                    temp_batch = {
                        "input_ids": batch["input_ids"],
                        #"attention_mask": batch["attention_mask"],
                    }
                    temp_batch.update(gen_kwargs)
                    generated_tokens = model.generate(**temp_batch)
                    if isinstance(generated_tokens, tuple):
                        generated_tokens = generated_tokens[0]
                    decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
                    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
                    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
                    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

                    metric.add_batch(predictions=decoded_preds, references=decoded_labels)
                    
            result = metric.compute(use_stemmer=True)
            # Extract a few results from ROUGE
            result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

            result = {k: round(v, 4) for k, v in result.items()}
            
            if pre_train_val_check:
                result_integrity = np.array([False if v == 0 else True for k, v in result.items()])
                if False in result_integrity:
                    print("Result: ", result)
                    print("Result integrity failed")
                    return False
                else:
                    print(result)
                    print("Valiation Epoch working correctly....")
                    return True
            else:
                return result

Thanks

try increasing the epochs maybe that will help I don’t have any experience with this model but every time the model doesn’t predict right increasing the epochs almost always helps

Thanks for the advice.
I increased the epochs but its not learning anything.

@sshleifer Would be grateful, if you can help with this. I found similar post on SO. Is there a problem with facebook/bart-* checkpoints? Can someone confirm this?