You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

57 lines
1.8 KiB
Python

import torch
import time
import os
import cv2
import numpy as np
from PIL import Image
from .model import BiSeNet
import torchvision.transforms as transforms
class FaceParsing():
def __init__(self, resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
model_pth='./models/face-parse-bisent/79999_iter.pth'):
self.net = self.model_init(resnet_path,model_pth)
self.preprocess = self.image_preprocess()
def model_init(self,resnet_path, model_pth):
net = BiSeNet(resnet_path)
if torch.cuda.is_available():
net.cuda()
net.load_state_dict(torch.load(model_pth))
else:
net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
net.eval()
return net
def image_preprocess(self):
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
def __call__(self, image, size=(512, 512)):
if isinstance(image, str):
image = Image.open(image)
width, height = image.size
with torch.no_grad():
image = image.resize(size, Image.BILINEAR)
img = self.preprocess(image)
if torch.cuda.is_available():
img = torch.unsqueeze(img, 0).cuda()
else:
img = torch.unsqueeze(img, 0)
out = self.net(img)[0]
parsing = out.squeeze(0).cpu().numpy().argmax(0)
parsing[np.where(parsing > 13)] = 0
parsing[np.where(parsing >= 1)] = 255
parsing = Image.fromarray(parsing.astype(np.uint8))
return parsing
if __name__ == "__main__":
fp = FaceParsing()
segmap = fp('154_small.png')
segmap.save('res.png')