I am looking to fine-tune a BART model to essentially do a numerical-to-text task.
When you look at the model:
BartForConditionalGeneration(
(model): BartModel(
(shared): Embedding(50265, 768, padding_idx=1)
(encoder): BartEncoder(
(embed_tokens): Embedding(50265, 768, padding_idx=1)
(embed_positions): BartLearnedPositionalEmbedding(1026, 768)
(layers): ModuleList(
(0-5): 6 x BartEncoderLayer(
(self_attn): BartAttention(
(k_proj): Linear(in_features=768, out_features=768, bias=True)
(v_proj): Linear(in_features=768, out_features=768, bias=True)
(q_proj): Linear(in_features=768, out_features=768, bias=True)
(out_proj): Linear(in_features=768, out_features=768, bias=True)
)
(self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(activation_fn): GELUActivation()
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
(layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(decoder): BartDecoder(
(embed_tokens): Embedding(50265, 768, padding_idx=1)
(embed_positions): BartLearnedPositionalEmbedding(1026, 768)
(layers): ModuleList(
(0-5): 6 x BartDecoderLayer(
(self_attn): BartAttention(
(k_proj): Linear(in_features=768, out_features=768, bias=True)
(v_proj): Linear(in_features=768, out_features=768, bias=True)
(q_proj): Linear(in_features=768, out_features=768, bias=True)
(out_proj): Linear(in_features=768, out_features=768, bias=True)
)
(activation_fn): GELUActivation()
(self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(encoder_attn): BartAttention(
(k_proj): Linear(in_features=768, out_features=768, bias=True)
(v_proj): Linear(in_features=768, out_features=768, bias=True)
(q_proj): Linear(in_features=768, out_features=768, bias=True)
(out_proj): Linear(in_features=768, out_features=768, bias=True)
)
(encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
(layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
(lm_head): Linear(in_features=768, out_features=50265, bias=False)
)
it seems like I can just swap out the embedding layers of BART for two linear layers. So far, I have perhaps naively tried:
model.model.encoder.embed_tokens = nn.Linear(6, 768)
model.model.encoder.embed_positions = nn.Linear(6, 768)
model.model.decoder.embed_tokens = nn.Linear(6, 768)
model.model.decoder.embed_positions = nn.Linear(6, 768)
model.model.shared = nn.Linear(6, 768)
(where 6 is just a number I picked to test it out). And then when I stuff a tensor into it:
out = model.generate( torch.rand((1,1,6)) )
This doesn’t work. As some point in the process one of the linear layers is trying to multiply ints with floats and the whole thing just grinds to a halt.
Is what I’m trying to do even possible? And if so, is there more to it than simply swapping out the embedding layers?