Hello!
I have a custom model that I train and also would like to test within the HF environment. However it seems that even tho I pass a custom compute_metrics function to my trainer it doesnât call it once.
This is my code:
def plot_covariance_matrix(model_output, config):
print(âHello World!â)
# Calculate covariance matrices
cov_matrix_og = np.corrcoef(model_output.target, rowvar=True)
cov_matrix_reconst = np.corrcoef(model_output.output, rowvar=True)
# Plot covariance matrix of the data
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
sns.heatmap(cov_matrix_og, annot=True, cmap="vlag", ax=axes[0], xticklabels=features, yticklabels=features)
axes[0].set_title('Covariance of the original data.')
sns.heatmap(cov_matrix_reconst, annot=True, cmap="vlag", ax=axes[1], xticklabels=features, yticklabels=features)
axes[1].set_title('Covariance of the reconstructed data.')
fig.suptitle('Covariance matrices.')
plt.tight_layout()
print('Before saving figure path:', os.path.abspath('.'))
fig.savefig(os.path.join(config['savepath'], config['result_filename']), format='png')
def compute_metrics_baseline(model_output):
plot_covariance_matrix(model_output, config={
âresult_filenameâ: âbaseline_result.pngâ,
âsavepathâ: os.path.dirname(os.path.abspath(file))
})
return {}
testing_args_baseline = TrainingArguments(output_dir=âembeddingmodel_test_checkpointsâ, logging_dir=â./baseline_logâ,
remove_unused_columns=False, evaluation_strategy=âepochâ,
per_device_eval_batch_size=BATCH_SIZE)
baseline_tester = Trainer(
model=embeddingModel,
args=testing_args_baseline,
eval_dataset=test_dataset,
data_collator=baseline_collator,
compute_metrics=compute_metrics_baseline
)
print(âTesting baseline model.â)
baseline_tester.evaluate()
As you can see there are some print statements already in there, because first I thought my figure is saved on a different location I want it toâŚ
But not even the first print statement is reached, since I cannot se anything.
I checked this topic: Trainer never invokes compute_metrics and many others but I still canât find reason for this.
At this point Iâm considering ditching evaluate as is, because hand coding this would have taken far less time than trying to debug it, so this post is my last resort