extract dataset with:
import os
import json
from PIL import Image
def save_dataset_to_folder(dataset, folder_path):
max_samples = 10
if not os.path.exists(folder_path):
os.makedirs(folder_path)
metadata = []
for i, item in enumerate(dataset):
if i >= max_samples:
break
input_image_path = os.path.join(folder_path, f"{i}_input_image.jpg")
edited_image_path = os.path.join(folder_path, f"{i}_edited_image.jpg")
item['input_image'].save(input_image_path)
item['edited_image'].save(edited_image_path)
metadata.append({
'input_image': f"{i}_input_image.jpg",
'edited_image': f"{i}_edited_image.jpg",
'edit_prompt': item['edit_prompt']
})
with open(os.path.join(folder_path, 'metadata.json'), 'w') as f:
json.dump(metadata, f)
dataset = load_dataset("fusing/instructpix2pix-1000-samples", split="train")
save_dataset_to_folder(dataset, "extracted_dataset")
in the trainer change
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
to
dataset = load_dataset('json', data_files='extracted_dataset/train/metadata.json')
also in preprocess_images I had to add the paths to the images:
[convert_to_np(Image.open('extracted_dataset/train/'+image).convert('RGB'), args.resolution) for image in examples[original_image_column]]
If you need to train from pretrained timbrooks/instruct-pix2pix, outcomment adding additional channels line 514:
with torch.no_grad():
new_conv_in = nn.Conv2d(
in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
unet.conv_in = new_conv_in