Cannot instantiate model under dopamine

Steps to reproduce the behavior:

  1. Clone GitHub - kovkev/dopamine: Dopamine is a research framework for fast prototyping of reinforcement learning algorithms.
  2. Setup dopamine dependencies
  3. python3.9 mytest.py

If I instantiate the model at other places in the script, it’s fine. However, if I instantiate the model in dopamine/discrete_domains/run_experiment.py in that location, I get an error:

I have made it easier to understand the situation. I created a colab - dopamine/mynotebook.ipynb at master · kovkev/dopamine · GitHub . Note that my instantiation of the transformer is at dopmaine/dopamine/discrete_domains/run_experiment.py line 115-117

/usr/lib/python3.9/site-packages/ale_py/roms/__init__.py:94: DeprecationWarning: Automatic importing of atari-py roms won't be supported in future releases of ale-py. Please migrate over to using `ale-import-roms` OR an ALE-supported ROM package. To make this warning disappear you can run `ale-import-roms --import-from-pkg atari_py.atari_roms`.For more information see: https://github.com/mgbellemare/Arcade-Learning-Environment#rom-management
  _RESOLVED_ROMS = _resolve_roms()
2021-12-27 04:44:31.214754: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at t5-small.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.
>>>done0
INFO:absl:Creating TrainRunner ...
WARNING:tensorflow:From /usr/lib/python3.9/site-packages/tensorflow/python/compat/v2_compat.py:111: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.
Instructions for updating:
non-resource variables are not supported in the long term
WARNING:tensorflow:From /usr/lib/python3.9/site-packages/tensorflow/python/compat/v2_compat.py:111: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.
Instructions for updating:
non-resource variables are not supported in the long term
A.L.E: Arcade Learning Environment (version +978d2ce)
[Powered by Stella]
Traceback (most recent call last):
  File "/home/project/dopamine/mytest.py", line 55, in <module>
    dqn_runner = run_experiment.create_runner(DQN_PATH, schedule='continuous_train')
  File "/usr/lib/python3.9/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/usr/lib/python3.9/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/usr/lib/python3.9/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/project/dopamine/dopamine/discrete_domains/run_experiment.py", line 145, in create_runner
    return TrainRunner(base_dir, create_agent)
  File "/usr/lib/python3.9/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/usr/lib/python3.9/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/usr/lib/python3.9/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/project/dopamine/dopamine/discrete_domains/run_experiment.py", line 562, in __init__
    super(TrainRunner, self).__init__(base_dir, create_agent_fn,
  File "/usr/lib/python3.9/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/usr/lib/python3.9/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/usr/lib/python3.9/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/project/dopamine/dopamine/discrete_domains/run_experiment.py", line 230, in __init__
    self._agent = create_agent_fn(self._sess, self._environment,
  File "/usr/lib/python3.9/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/usr/lib/python3.9/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/usr/lib/python3.9/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/project/dopamine/dopamine/discrete_domains/run_experiment.py", line 117, in create_agent
    another_model = TFAutoModelForSeq2SeqLM.from_pretrained("t5-small")
  File "/usr/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 441, in from_pretrained
    return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
  File "/usr/lib/python3.9/site-packages/transformers/modeling_tf_utils.py", line 1595, in from_pretrained
    model(model.dummy_inputs)  # build the network with dummy inputs
  File "/usr/lib/python3.9/site-packages/keras/engine/base_layer_v1.py", line 765, in __call__
    outputs = call_fn(cast_inputs, *args, **kwargs)
  File "/usr/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 699, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    File "/usr/lib/python3.9/site-packages/transformers/models/t5/modeling_tf_t5.py", line 1422, in call  *
        inputs["encoder_outputs"] = self.encoder(
    File "/usr/lib/python3.9/site-packages/transformers/models/t5/modeling_tf_t5.py", line 688, in call  *
        inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
    File "/usr/lib/python3.9/site-packages/transformers/modeling_tf_utils.py", line 2003, in __call__  *
        return self._layer(inputs, mode)
    File "/usr/lib/python3.9/site-packages/keras/engine/base_layer_v1.py", line 745, in __call__  **
        self._maybe_build(inputs)
    File "/usr/lib/python3.9/site-packages/keras/engine/base_layer_v1.py", line 2074, in _maybe_build
        self.build(input_shapes)
    File "/usr/lib/python3.9/site-packages/transformers/modeling_tf_utils.py", line 1760, in build
        self.weight = self.add_weight(
    File "/usr/lib/python3.9/site-packages/keras/engine/base_layer_v1.py", line 423, in add_weight
        variable = self._add_variable_with_custom_getter(
    File "/usr/lib/python3.9/site-packages/keras/engine/base_layer_utils.py", line 117, in make_variable
        return tf.compat.v1.Variable(
    File "/usr/lib/python3.9/site-packages/keras/initializers/initializers_v2.py", line 416, in __call__
        dtype = _assert_float_dtype(_get_dtype(dtype))
    File "/usr/lib/python3.9/site-packages/keras/initializers/initializers_v2.py", line 969, in _assert_float_dtype
        raise ValueError(f'Expected floating point type, got {dtype}.')

    ValueError: Expected floating point type, got <dtype: 'int32'>.

  In call to configurable 'create_agent' (<function create_agent at 0x7fbe9c160430>)
  In call to configurable 'Runner' (<class 'dopamine.discrete_domains.run_experiment.Runner'>)
  In call to configurable 'TrainRunner' (<class 'dopamine.discrete_domains.run_experiment.TrainRunner'>)
  In call to configurable 'create_runner' (<function create_runner at 0x7fbe9c160af0>)