Commit 5c2da613 authored by VIOLANTE Nicolas's avatar VIOLANTE Nicolas
Browse files

add basic background remover

parent 25e80580
from torchvision.models.detection import maskrcnn_resnet50_fpn
from dataset import PathDataset, EgoNetTransform
from torch.utils.data import DataLoader
import torch
import os
from PIL import Image
from torchvision.ops import box_area
import numpy as np
if __name__ == "__main__":
data="/data/nviolant/data_eg3d/lsun-cars-100k-256x256"
dest="/data/nviolant/data_eg3d/lsun-cars-100k-256x256-white-bg"
os.makedirs(dest, exist_ok=True)
batch_size=32
model = maskrcnn_resnet50_fpn(pretrained=True).eval().cuda()
dataloader = DataLoader(PathDataset(data, EgoNetTransform), batch_size)
imgs = []
imgs2 = []
bgs = []
total_images = 0
for paths, images in dataloader:
images = images.cuda()
with torch.no_grad():
detections = model(images)
for img_id, detection in enumerate(detections):
areas = box_area(detection["boxes"])
idx = 0
bg_mask = 1.0 - detection["masks"][idx].cpu().detach().numpy()[0]
bg_mask = bg_mask[:,:,None].repeat(3, -1)
img = images[img_id].cpu().numpy().transpose(1, 2, 0)
img2 = (img * (1 - bg_mask) + bg_mask)
folder = paths[img_id].split("/")[-2]
os.makedirs(os.path.join(dest, folder), exist_ok=True)
filename = os.path.join(dest, folder, os.path.basename(paths[img_id]))
img2 = Image.fromarray((255 * img2).astype(np.uint8))
img2.save(filename)
total_images += 1
print(f"Processing {total_images} images")
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment