Will it be learned properly if tokens listed in one dimension are reshaped in the form of (batch, seq_len) and inputted into the transformer xl model?

I made the code by referring to GitHub - kimiyoung/transformer-xl .

I have a question, so I leave a question

I’d appreciate it if you could answer when you have time

Prior knowledge: Text data is tokenized and all sentences are separated only by ‘eos’, listed in one-dimensional Numphi and converted into the form (batch size, length) when entering the model

mems = tf.cast(memes, dtype=w.dtype) (400, length, embedding dimension) initially filled to zero
W = Data mentioned above
cat = tf.concat([mems, w]

w_heads = tf.keras.layers.Dense(cat)
w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1)
w_head_q = w_head_q[-qlen:]
These values are output throughout the Transformer encoder process, one by one for each encoder

Combine OUTPUT with AXIS=0 at the end of the mems

Can transfor xl model learn by putting these mems back in and learning?

As I put in the long sentence itself, one sentence goes into the encoder with a different batch when it goes into the data (for example, 108 length is (3,36),
And I wonder if the transformer xl model will be effective when using mems like above.

this colab