[Help appreciated] Modifying load_tf_weights_in_albert for transforming ALBERT tensorflow checkpoint to pytorch model

Hello everyone,
I hope you’re feeling great today !

I open this conversation following the issue reported here.

To summarize :

  • Using google-research official ALBERT github repo to pretrain an ALBERT model from scratch, a tensorflow checkpoint is obtained
  • Tried to use transformer-cli convert, convert_albert_original_tf_checkpoint_to_pytorch.py and AlbertForPreTraining.from_pretrained("tf_checkpoint_folder", from_tf=True) to transform the tf checkpoint to a pytorch model. But an errors occurs, always in the same function : load_tf_weights_in_albert
  • In this function, a new AlbertForPreTraining pytorch model is instantiated and it’s tensor progressively filled with corresponding tensorflow variables. A pointer is moved onto the pytorch model according to the tensorflow variable’s name currently read. The error seems to occurs because when tensorflow variable name either contains gamma or beta, the script try to move the pointer to an non-existing part of the pytorch model with getattr.

With the previous context, I tried to modify load_tf_weights_in_albert to retrieve correct objects of pytorch model (If you also possess a similar google-research Albert checkpoint, you can try it with options --checkpoint_path --albert_config_file and --pytorch_dump_path).
Here’s my script :

import os
import tensorflow as tf
from tensorflow import keras
from transformers import TFAlbertModel
import re
import argparse
from transformers import AlbertConfig, AlbertForPreTraining
import torch

# A twisted version of HuggingFace's convert_albert_original_tf_checkpoint_to_pytorch.py file
# and load_tf_weights_in_albert function from modeling_albert.py

def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
    """Load tf checkpoints in a pytorch model."""
    try:
        import re

        import numpy as np
        import tensorflow as tf
    except ImportError:
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
        raise
    tf_path = os.path.abspath(tf_checkpoint_path)
    print(f"Converting TensorFlow checkpoint from {tf_path}")
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
        print(f"Loading TF weight {name} with shape {shape}")
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

    for name, array in zip(names, arrays):
        print(name)

    for name, array in zip(names, arrays):
        original_name = name

        # If saved from the TF HUB module
        name = name.replace("module/", "")

        # Renaming and simplifying
        name = name.replace("ffn_1", "ffn")
        name = name.replace("bert/", "albert/")
        name = name.replace("attention_1", "attention")
        name = name.replace("transform/", "")
        name = name.replace("LayerNorm_1", "full_layer_layer_norm")
        name = name.replace("LayerNorm", "attention/LayerNorm")
        name = name.replace("transformer/", "")

        # The feed forward layer had an 'intermediate' step which has been abstracted away
        name = name.replace("intermediate/dense/", "")
        name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")

        # ALBERT attention was split between self and output which have been abstracted away
        name = name.replace("/output/", "/")
        name = name.replace("/self/", "/")

        # The pooler is a linear layer
        name = name.replace("pooler/dense", "pooler")

        # The classifier was simplified to predictions from cls/predictions
        name = name.replace("cls/predictions", "predictions")
        name = name.replace("predictions/attention", "predictions")

        # Naming was changed to be more explicit
        name = name.replace("embeddings/attention", "embeddings")
        name = name.replace("inner_group_", "albert_layers/")
        name = name.replace("group_", "albert_layer_groups/")

        # Classifier
        if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
            name = "classifier/" + name

        # No ALBERT model currently handles the next sentence prediction task
        if "seq_relationship" in name:
            name = name.replace("seq_relationship/output_", "sop_classifier/classifier/")
            name = name.replace("weights", "weight")

        name = name.split("/")

        # Ignore the gradients applied by the LAMB/ADAM optimizers.
        if (
            "adam_m" in name
            or "adam_v" in name
            or "AdamWeightDecayOptimizer" in name
            or "AdamWeightDecayOptimizer_1" in name
            or "global_step" in name
        ):
            print(f"Skipping {'/'.join(name)} (optimizer)")
            continue

        pointer = model
        for m_name in name:
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
                scope_names = re.split(r"_(\d+)", m_name)
            else:
                scope_names = [m_name]

            layer_norm_names = ["LayerNorm", "full_layer_layer_norm"] # twisted version : Name of LayerNorm objects across AlbertModel architecture are stored inside a list
            if scope_names[0] == "kernel":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "gamma": # twisted version : Retrieve gamma (weight) of LayerNorm object of current object
                for layer_norm_name in layer_norm_names:
                    if hasattr(pointer, layer_norm_name):
                        pointer = getattr(getattr(pointer, layer_norm_name), "weight")
                        break
                if hasattr(pointer, "albert_layers"): # twisted version : go retrieve LayerNorm object inside a AlbertLayer object inside a ModuleList
                    sub_pointer = getattr(pointer, "albert_layers")
                    for sub_module in sub_pointer:
                        pointer = getattr(getattr(sub_module, "full_layer_layer_norm"), "weight")
                        break
            elif scope_names[0] == "output_bias":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "beta": # twisted version : Retrieve beta (bias) of LayerNorm object of current object
                for layer_norm_name in layer_norm_names:
                    if hasattr(pointer, layer_norm_name):
                        pointer = getattr(getattr(pointer, layer_norm_name), "bias")
                        break
                if hasattr(pointer, "albert_layers"): # twisted version : go retrieve LayerNorm object inside a AlbertLayer object inside a ModuleList
                    sub_pointer = getattr(pointer, "albert_layers")
                    for sub_module in sub_pointer:
                        pointer = getattr(getattr(sub_module, "full_layer_layer_norm"), "bias")
                        break
            elif scope_names[0] == "output_weights":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "squad":
                pointer = getattr(pointer, "classifier")
            else:
                try:
                    pointer = getattr(pointer, scope_names[0])
                except AttributeError:
                    print(f"Skipping {'/'.join(name)} on {scope_names[0]}")
                    continue
            if len(scope_names) >= 2:
                num = int(scope_names[1])
                pointer = pointer[num]

        if m_name[-11:] == "_embeddings":
            pointer = getattr(pointer, "weight")
        elif scope_names[0] == "kernel":
           array = np.transpose(array)
        try:
            if pointer.shape != array.shape:
                raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        print(f"Initialize PyTorch weight {name} from {original_name}")
        pointer.data = torch.from_numpy(array)

    return model


def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = AlbertConfig.from_json_file(albert_config_file)
    print(f"Building PyTorch model from configuration: {config}")
    model = AlbertForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_albert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print(f"Save PyTorch model to {pytorch_dump_path}")
    torch.save(model.state_dict(), pytorch_dump_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
    )
    parser.add_argument(
        "--albert_config_file",
        default=None,
        type=str,
        required=True,
        help=(
            "The config json file corresponding to the pre-trained ALBERT model. \n"
            "This specifies the model architecture."
        ),
    )
    parser.add_argument(
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )
    args = parser.parse_args()
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)

With this script, I seems to correctly copy my tf variables inside corresponding pytorch tensors.
Exceptions are :

  • variables corresponding to optimizer (no need to copy)
  • pytorch tensors’ albert.encoder.albert_layer_groups.0.albert_layers.0.attention.LayerNorm.weight and albert.encoder.albert_layer_groups.0.albert_layers.0.attention.LayerNorm.bias because my tf checkpoint doesn’t have equivalents variables
    Those exceptions doesn’t seems to be a problem because when nothing is copied to the newly instantiated model, it can still learn when fine-tuned on a simple task.

For some reason, though, the pytorch model obtained after copying variables doesn’t seems to be “right”: When fine-tuned on a simple task, it doesn’t seems to be able to learn (loss starts high and stays high).
I suspect some modifications may be needed like the transposition at line 180.

If you have some experience with those manipulations, or if you have any idea, help would be appreciated.

Thanks for your attention and any idea you can share :slight_smile:

1 Like