I’m trying out a method to identify important training samples for a given test-time prediction. What it essentially boils down to is calculating the gradient of a test-time prediction and ordering the training samples by their gradient similarity to the test-time gradient. My interpretation is that it attempts to answer the question of which training samples has nudged/influenced the models parameters as similarly a given test-time prediction would have had it been a training sample. It’s not all too important for the question but I hope it makes sense.
The model I’m using is T5 and here’s where I run into trouble. What I observe is that very similar (input, target)-pairs produce vastly different gradients in terms of cosine similarity.
Let me provide an example starting with a sanity check on a dummy example which should be easily reproducible (helper functions are found below):
MODEL_PATH = "t5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)
tokenizer = T5TokenizerFast.from_pretrained(MODEL_PATH)
sentence_1 = get_grads(model,
tokenizer,
inputs="I like to eat <extra_id_0>",
targets="<extra_id_0> pizza")
sentence_2 = get_grads(model,
tokenizer,
inputs="I like to eat <extra_id_0>",
targets="<extra_id_0> pizza")
cos_sim(sentence_1, sentence_2)
>>> 1.0
which is totally expected as the same sample would affect the model’s parameters exactly the same. Now changing sentence_2
s target slightly to "<extra_id_0> pizza."
, i.e. with a period at the end, I get a cosine similarity of 0.46.
What I don’t quite understand is that the introduction of a seemingly insignificant token can change the gradients that much?
Any help, hints and guidance in understanding this is greatly appreciated!
My helper functions:
def get_grads(model, tokenizer, inputs, targets):
device = "cuda" if torch.cuda.is_available() else "cpu"
outputs = model(**{k: v.to(device) for k, v in tokenizer(text=inputs,
text_target=targets,
truncation=True,
return_tensors="pt").items()})
grads = torch.autograd.grad(outputs.loss, model.parameters())
return torch.cat([grad.flatten() for grad in grads])
def cos_sim(a, b):
return np.dot(a, b)/(np.linalg.norm(a)*np.linalg.norm(b))