Fine-tuning Segment Anything Model: Call up a saved model

Hi all,

I have successfully trained my SAM model with the good guidance from the following blog post.

Blog Post: Fine tune Segment Anything (SAM) for images with multiple masks

I saved the weights of the trained model using the following code:

torch.save(model.state_dict(), ‘model_weights.pth’).

However, I am now faced with the question of how to use this saved model for my code. The original version of SAM is loaded via the following code:

from segment_anything import SamAutomaticMaskGenerator, sam_model_registry

sam = sam_model_registry(vit_h)(checkpoint=“/content/drive/MyDrive/…/SAMFineTune/sam_vit_h_4b8939.pth”)

Note: I had to replace the square brackets at vit_h with round ones to post it here.

Can anyone help me how to use the trained model now in my code to create masks? Unfortunately, I haven’t found anything yet to solve the problem.