Overwrite attention heads in BartForConditionalGeneration


I am looking to overwrite the attention heads in the Bart model, following the below process:

  1. Run the model on an article with the keyword parameter: “Covid”
  2. Save the encoder/decoder heads for this article
  3. Run the model on another article, also with the keyword parameter: “Covid”
  4. As a proxy for making this model ‘topic-aware’, I will insert the “Covid” attention heads generated in step 2 and insert the attention heads for the model run in step 3
  5. Model will generate a new ‘topic-aware’ summary for the article as the attention heads are ‘trained’ on the topic key-word ‘covid’

Note: The above is extremely preliminary, we will be looking to train the attention heads & model on more data for each key-word in the future.

article = """Covid-19 is a global pandemic"
model_name = "facebook/bart-large-cnn"
config = BartConfig.from_pretrained(model_name, output_hidden_states=True, output_attention=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer(article, padding=True, truncation=True, return_tensors="pt")
model = AutoModel.from_pretrained(model_name)
model.config.output_attentions = True
outputs = model(**inputs)
summary = tokenizer.decode(outputs)

covid_encoder_attention = outputs.encoder_attentions
covid_decoder_attention = outputs.decoder_attentions

# Repeat model run with new article and insert covid_encoder_attention and/or covid_decoder_attention for new run

I’m curious to know how this is possible, also. I’ve found no methods in transformers to allow this.