I am working on dividing the encoder and decoderpart of t5 model, here is my progress but i am not sure whether i am doing it right or not can any one help me out with this
i have given three files one for testing my encoder and decoder and other two files are encoder and decoder
import os
import torch
from transformers import T5Tokenizer, T5Config
from nlg.src.goal_bot.goal_bot_decoder import CustomDecoder
from nlg.src.goal_bot.goal_bot_encoder import CustomEncoder
if __name__ == '__main__':
tokenizer_path = 't5-base'
tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)
t5_config = T5Config.from_pretrained(tokenizer_path)
# Define paths to the model components
base_dir = os.path.abspath('../../../path_to_saved_model/t5_base')
encoder_path = os.path.join(base_dir, 't5_base_encoder.pt')
decoder_path = os.path.join(base_dir, 't5_base_decoder.pt')
lm_head_path = os.path.join(base_dir, 't5_base_lm_head.pt')
if not os.path.exists(encoder_path):
raise FileNotFoundError(f"Encoder file not found: {encoder_path}")
if not os.path.exists(decoder_path):
raise FileNotFoundError(f"Decoder file not found: {decoder_path}")
if not os.path.exists(lm_head_path):
raise FileNotFoundError(f"LM Head file not found: {lm_head_path}")
# Create instances of the encoder and decoder
customEncoder = CustomEncoder(t5_config)
customDecoder = CustomDecoder(t5_config)
# Tokenize input text
input_text = "translate to german: what is your name"
tokenized_input = tokenizer(input_text, return_tensors="pt")
# Forward pass through the encoder
encoder_output = customEncoder(input_ids=tokenized_input.input_ids, attention_mask=tokenized_input.attention_mask)
# Print the keys of the encoder output to inspect its structure
print("Encoder output keys:", encoder_output.keys())
# Initialize decoder input ids
decoder_input_ids = torch.tensor([[tokenizer.pad_token_id]])
# Generate the decoded text
for _ in range(50): # Limit the length of the generated text
decoder_output = customDecoder(
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=encoder_output["attention_mask"],
decoder_inputs_embeds=None,
past_key_values=None,
hidden_states=encoder_output["hidden_states"],
attention_mask=encoder_output["attention_mask"],
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
labels=None,
encoder_outputs=(encoder_output,),
)
# Get the token id of the most likely next token
next_token_id = torch.argmax(decoder_output.logits[:, -1, :], dim=-1).unsqueeze(0)
# Append the token id to the decoder input ids
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id], dim=-1)
# Stop if the end of sequence token is generated
if next_token_id.item() == tokenizer.eos_token_id:
break
# Decode the generated tokens into text
decoded_text = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
print("Decoded text:", decoded_text)
import torch
from torch.nn import CrossEntropyLoss
from transformers import T5Tokenizer, T5Config
import copy
from torch import nn
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers.models.t5.modeling_t5 import T5Stack
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from typing import Optional, Tuple, Union
from transformers import T5PreTrainedModel
class CustomDecoder(T5PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [
"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
def __init__(self, config):
super().__init__(config)
self.shared = nn.Embedding(config.vocab_size, config.d_model)
self.decoder = T5Stack(config, self.shared)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.is_decoder = True
config.is_decoder = True
config.use_cache = False
# Model parallel
self.model_parallel = False
self.device_map = None
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.decoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.decoder.block))
self.decoder.parallelize(self.device_map)
self.lm_head = self.lm_head.to(self.decoder.first_device)
self.model_parallel = True
# @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
self.decoder.deparallelize()
self.decoder = self.decoder.to("cpu")
self.lm_head = self.lm_head.to("cpu")
self.model_parallel = False
self.device_map = None
torch.cuda.empty_cache()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
hidden_states=None):
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_values=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = decoder_outputs[0]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
self.lm_head = self.lm_head.to(self.decoder.first_device)
sequence_output = sequence_output.to(self.lm_head.weight.device)
lm_logits = self.lm_head(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
if not return_dict:
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
decoder_attention_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"decoder_attention_mask": decoder_attention_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self._shift_right(labels)
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_output_embeddings(self):
return self.lm_head
def get_decoder(self):
return self.decoder
import copy
import torch
from torch import nn
from transformers import T5PreTrainedModel
from transformers.models.t5.modeling_t5 import T5Stack
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from transformers.models.t5.tokenization_t5 import T5Tokenizer
from transformers.utils import logging
from typing import Optional, Tuple, Union
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
logger = logging.get_logger(__name__)
class CustomEncoder(T5PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [
"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
def __init__(self, config):
super().__init__(config)
self.config = config
self.model_dim = config.d_model
self.shared = nn.Embedding(config.vocab_size, config.d_model)
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(encoder_config, self.shared)
# Initialize weights and apply final processing
self.post_init()
# Model parallel
self.model_parallel = False
self.device_map = None
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.encoder.block))
self.encoder.parallelize(self.device_map)
self.model_parallel = True
# @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
self.encoder.deparallelize()
self.encoder = self.encoder.to("cpu")
self.model_parallel = False
self.device_map = None
torch.cuda.empty_cache()
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
if self.config.num_layers == self.config.num_decoder_layers:
# warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
hidden_states = encoder_outputs[0]
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
hidden_states = hidden_states.to(self.decoder.first_device)
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.decoder.first_device)
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
return dict(decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask,
decoder_inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values,
hidden_states=hidden_states, attention_mask=attention_mask,
decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, use_cache=use_cache,
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
return_dict=return_dict, labels=labels, encoder_outputs=encoder_outputs)
Thanks in advance