Is There a Way to Improve Memory Usage When Using Identical `past_key_values` for All Samples in a Batch?

I think you can use broadcasting, which effectively duplicates a tensor along certain dimensions, without increasing its memory footprint:

def duplicate_pkv(pkv, num_repeats):
  return tuple(tuple(tensor.expand(num_repeats,-1,-1,-1) for tensor in layer) for layer in pkv)

where -1 indicates the dimensions that stay unchanged.