Commit d40fd015 authored by Trái Vú Sữa's avatar Trái Vú Sữa 😄

fix #1

parent f350b742
import cv2
import numpy as np
import random
# from utils.box_utils import matrix_iof
def matrix_iof(a, b):
"""
return iof of a and b, numpy version for data augenmentation
"""
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
return area_i / np.maximum(area_a[:, np.newaxis], 1)
def _crop(image, boxes, labels, landm, img_dim):
height, width, _ = image.shape
pad_image_flag = True
for _ in range(250):
if random.uniform(0, 1) <= 0.3:
scale = 1.0
else:
scale = random.uniform(0.3, 1.0)
# PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0]
# scale = random.choice(PRE_SCALES)
short_side = min(width, height)
w = int(scale * short_side)
h = w
if width == w:
l = 0
else:
l = random.randrange(width - w)
if height == h:
t = 0
else:
t = random.randrange(height - h)
roi = np.array((l, t, l + w, t + h))
value = matrix_iof(boxes, roi[np.newaxis])
flag = (value >= 1)
if not flag.any():
continue
centers = (boxes[:, :2] + boxes[:, 2:]) / 2
mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1)
boxes_t = boxes[mask_a].copy()
labels_t = labels[mask_a].copy()
landms_t = landm[mask_a].copy()
landms_t = landms_t.reshape([-1, 5, 2])
if boxes_t.shape[0] == 0:
continue
image_t = image[roi[1]:roi[3], roi[0]:roi[2]]
boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2])
boxes_t[:, :2] -= roi[:2]
boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:])
boxes_t[:, 2:] -= roi[:2]
# landm
landms_t[:, :, :2] = landms_t[:, :, :2] - roi[:2]
landms_t[:, :, :2] = np.maximum(landms_t[:, :, :2], np.array([0, 0]))
landms_t[:, :, :2] = np.minimum(landms_t[:, :, :2], roi[2:] - roi[:2])
landms_t = landms_t.reshape([-1, 10])
# make sure that the cropped image contains at least one face > 16 pixel at training image scale
b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim
b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim
mask_b = np.minimum(b_w_t, b_h_t) > 0.0
boxes_t = boxes_t[mask_b]
labels_t = labels_t[mask_b]
landms_t = landms_t[mask_b]
if boxes_t.shape[0] == 0:
continue
pad_image_flag = False
return image_t, boxes_t, labels_t, landms_t, pad_image_flag
return image, boxes, labels, landm, pad_image_flag
def _distort(image):
def _convert(image, alpha=1, beta=0):
tmp = image.astype(float) * alpha + beta
tmp[tmp < 0] = 0
tmp[tmp > 255] = 255
image[:] = tmp
image = image.copy()
if random.randrange(2):
# brightness distortion
if random.randrange(2):
_convert(image, beta=random.uniform(-32, 32))
# contrast distortion
if random.randrange(2):
_convert(image, alpha=random.uniform(0.5, 1.5))
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# saturation distortion
if random.randrange(2):
_convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
# hue distortion
if random.randrange(2):
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
tmp %= 180
image[:, :, 0] = tmp
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
else:
# brightness distortion
if random.randrange(2):
_convert(image, beta=random.uniform(-32, 32))
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# saturation distortion
if random.randrange(2):
_convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
# hue distortion
if random.randrange(2):
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
tmp %= 180
image[:, :, 0] = tmp
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
# contrast distortion
if random.randrange(2):
_convert(image, alpha=random.uniform(0.5, 1.5))
return image
def _expand(image, boxes, fill, p):
if random.randrange(2):
return image, boxes
height, width, depth = image.shape
scale = random.uniform(1, p)
w = int(scale * width)
h = int(scale * height)
left = random.randint(0, w - width)
top = random.randint(0, h - height)
boxes_t = boxes.copy()
boxes_t[:, :2] += (left, top)
boxes_t[:, 2:] += (left, top)
expand_image = np.empty(
(h, w, depth),
dtype=image.dtype)
expand_image[:, :] = fill
expand_image[top:top + height, left:left + width] = image
image = expand_image
return image, boxes_t
def _mirror(image, boxes, landms):
_, width, _ = image.shape
if random.randrange(2):
image = image[:, ::-1]
boxes = boxes.copy()
boxes[:, 0::2] = width - boxes[:, 2::-2]
# landm
landms = landms.copy()
landms = landms.reshape([-1, 5, 2])
landms[:, :, 0] = width - landms[:, :, 0]
tmp = landms[:, 1, :].copy()
landms[:, 1, :] = landms[:, 0, :]
landms[:, 0, :] = tmp
tmp1 = landms[:, 4, :].copy()
landms[:, 4, :] = landms[:, 3, :]
landms[:, 3, :] = tmp1
landms = landms.reshape([-1, 10])
return image, boxes, landms
def _pad_to_square(image, rgb_mean, pad_image_flag):
if not pad_image_flag:
return image
height, width, _ = image.shape
long_side = max(width, height)
image_t = np.empty((long_side, long_side, 3), dtype=image.dtype)
image_t[:, :] = rgb_mean
image_t[0:0 + height, 0:0 + width] = image
return image_t
def _resize_subtract_mean(image, insize, rgb_mean):
interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4]
interp_method = interp_methods[random.randrange(5)]
image = cv2.resize(image, (insize, insize), interpolation=interp_method)
image = image.astype(np.float32)
image -= rgb_mean
return image.transpose(2, 0, 1)
class preproc(object):
def __init__(self, img_dim, rgb_means):
self.img_dim = img_dim
self.rgb_means = rgb_means
def __call__(self, image, targets, debug=False):
assert targets.shape[0] > 0, "this image does not have gt"
boxes = targets[:, :4].copy()
labels = targets[:, -1].copy()
landm = targets[:, 4:-1].copy()
image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim)
image_t = _distort(image_t)
image_t = _pad_to_square(image_t, self.rgb_means, pad_image_flag)
image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t)
height, width, _ = image_t.shape
image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means)
scale = image_t.shape[1] / height
boxes_t *= scale
landm_t *= scale
if debug:
# Debug:
img_debug = image_t.copy()
img_debug = img_debug.transpose(1, 2, 0)
img_debug += self.rgb_means
img_debug = np.uint8(img_debug)
cv2.imwrite("test_temp.jpg", img_debug)
img_debug = cv2.imread('test_temp.jpg')
for index, b in enumerate(boxes_t):
b = [int(x) for x in b.tolist()]
b += [1]
b += [int(x) for x in landm_t[index].tolist()]
cv2.rectangle(img_debug, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
cx = b[0]
cy = b[1] + 12
# landms
cv2.circle(img_debug, (b[5], b[6]), 1, (0, 0, 255), 4)
cv2.circle(img_debug, (b[7], b[8]), 1, (0, 255, 255), 4)
cv2.circle(img_debug, (b[9], b[10]), 1, (255, 0, 255), 4)
cv2.circle(img_debug, (b[11], b[12]), 1, (0, 255, 0), 4)
cv2.circle(img_debug, (b[13], b[14]), 1, (255, 0, 0), 4)
name = "test_augmentation_1.jpg"
cv2.imshow("test", img_debug)
cv2.waitKey()
cv2.imwrite(name, img_debug)
_, height, width = image_t.shape
boxes_t[:, 0::2] /= width
boxes_t[:, 1::2] /= height
landm_t[:, 0::2] /= width
landm_t[:, 1::2] /= height
labels_t = np.expand_dims(labels_t, 1)
targets_t = np.hstack((boxes_t, landm_t, labels_t))
return image_t, targets_t
......@@ -6,6 +6,7 @@ import cv2
import os
import json
class _DataLoader:
def __init__(self, root, transform=None, target_transform=None):
......@@ -18,21 +19,21 @@ class _DataLoader:
self.ids = []
self.class_names = ('BACKGROUND', 'person')
self.class_dict = {class_name: i for i, class_name in enumerate(self.class_names)}
self._annopath = os.path.join('%s', 'json_annotations', '%s.json')
for file in os.listdir(self.anno_path):
with open(os.path.join(self.anno_path,file), 'r') as f:
with open(os.path.join(self.anno_path, file), 'r') as f:
data = json.load(f)
objects = data["objects"]
for sub_object in data["objects"]:
if sub_object["label"]=="person":
if sub_object["label"] == "person":
self.ids.append(file.split(".json")[0])
break
break
def __getitem__(self, index):
image_id = self.ids[index]
boxes, labels= self._get_annotation(image_id)
boxes, labels = self._get_annotation(image_id)
image = self._read_image(image_id)
if self.transform:
image, boxes, labels = self.transform(image, boxes, labels)
......@@ -45,7 +46,7 @@ class _DataLoader:
return len(self.ids)
def _get_annotation(self, image_id):
annotation_file = os.path.join(self.anno_path,image_id+".json")
annotation_file = os.path.join(self.anno_path, image_id + ".json")
# print(annotation_file)
with open(annotation_file, 'r') as f:
data = json.load(f)
......@@ -53,7 +54,7 @@ class _DataLoader:
boxes = []
labels = []
for sub_object in objects:
class_name = sub_object["label"]
class_name = sub_object["label"]
if class_name in self.class_dict:
bbox = sub_object["bbox"]
x1 = float(bbox["x_topleft"])
......@@ -67,12 +68,12 @@ class _DataLoader:
np.array(labels, dtype=np.int64))
def _read_image(self, image_id):
if os.path.isfile(os.path.join(self.img_path,image_id+".jpg")):
image_file = os.path.join(self.img_path,image_id+".jpg")
elif os.path.isfile(os.path.join(self.img_path,image_id+".jpeg")):
image_file = os.path.join(self.img_path,image_id+".jpeg")
else :
image_file = os.path.join(self.img_path,image_id+".png")
if os.path.isfile(os.path.join(self.img_path, image_id + ".jpg")):
image_file = os.path.join(self.img_path, image_id + ".jpg")
elif os.path.isfile(os.path.join(self.img_path, image_id + ".jpeg")):
image_file = os.path.join(self.img_path, image_id + ".jpeg")
else:
image_file = os.path.join(self.img_path, image_id + ".png")
image = cv2.imread(str(image_file))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
\ No newline at end of file
return image
import os
import os.path
import sys
import torch
import torch.utils.data as data
import cv2
import numpy as np
import torch.nn.functional as F
import skimage.transform
import torchvision.transforms as transforms
class FaceDataset(data.Dataset):
def __init__(self, root_path, file_name, preproc, target_transform=None):
super(FaceDataset, self).__init__()
self.path_images, self.labels = self.read_file(root_path, file_name)
self.preproc = preproc
self.target_transform = target_transform
def __len__(self):
return len(self.path_images)
def __getitem__(self, idx):
img = cv2.imread(self.path_images[idx])
labels = self.labels[idx]
annotations = np.zeros((0, 15))
if len(labels) == 0:
return annotations
for idx, label in enumerate(labels):
annotation = np.zeros((1, 15))
# bbox
annotation[0, 0] = label[0] # x1
annotation[0, 1] = label[1] # y1
annotation[0, 2] = label[0] + label[2] # x2
annotation[0, 3] = label[1] + label[3] # y2
# landmarks
annotation[0, 4] = label[4] # l0_x
annotation[0, 5] = label[5] # l0_y
annotation[0, 6] = label[7] # l1_x
annotation[0, 7] = label[8] # l1_y
annotation[0, 8] = label[10] # l2_x
annotation[0, 9] = label[11] # l2_y
annotation[0, 10] = label[13] # l3_x
annotation[0, 11] = label[14] # l3_y
annotation[0, 12] = label[16] # l4_x
annotation[0, 13] = label[17] # l4_y
if (annotation[0, 4] < 0):
annotation[0, 14] = -1
else:
annotation[0, 14] = 1
annotations = np.append(annotations, annotation, axis=0)
target = np.array(annotations)
debug = False
if debug:
img_debug = img.copy()
for index, b in enumerate(annotations):
b = [int(x) for x in b.tolist()]
cv2.rectangle(img_debug, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
# landms
cv2.circle(img_debug, (b[4], b[5]), 1, (0, 0, 255), 4)
cv2.circle(img_debug, (b[6], b[7]), 1, (0, 255, 255), 4)
cv2.circle(img_debug, (b[8], b[9]), 1, (255, 0, 255), 4)
cv2.circle(img_debug, (b[10], b[11]), 1, (0, 255, 0), 4)
cv2.circle(img_debug, (b[12], b[13]), 1, (255, 0, 0), 4)
name = "test_data.jpg"
cv2.imwrite(name, img_debug)
if self.preproc is not None:
img, target = self.preproc(img, target)
truths = target[:, :4]
labels = target[:, -1]
landms = target[:, 4:14]
# TODO write landms to target_transforms
if self.target_transform:
boxes, labels = self.target_transform(boxes, labels)
return torch.from_numpy(img), target
@staticmethod
def read_file(root_path, file_name):
path_images = []
words = []
file_name = os.path.join('/'.join(root_path.split('/')[:-1]), file_name)
lines = list(open(file_name, 'r').readlines())
labels = []
flag = False
for line in lines:
line = line.rstrip()
if line.startswith('#') or line.startswith('/'):
if flag == False:
flag = True
else:
words.append(labels)
labels = []
image_name = line[2:]
path_images.append(os.path.join(root_path, image_name))
else:
label = [float(x) for x in line.split(' ')]
labels.append(label)
words.append(labels)
return path_images, words
import sys
sys.path.append('/media/ducanh/DATA/tienln/ai_camera/ai_camera_detector/')
from utils.misc import str2bool, Timer, freeze_net_layers, store_labels
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR
......@@ -23,27 +24,29 @@ from model.config import mb_ssd_lite_f19_config
from model.rfb_tiny_mb_ssd import create_rfb_tiny_mb_ssd
from model.config import rfb_tiny_mb_ssd_config
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
class Train():
'''
The class to training
'''
def __init__(self):
self.args = _argument()
self.device = torch.device("cuda:0" if torch.cuda.is_available() and self.args.use_cuda else "cpu")
self.net, self.criterion, self.optimizer, self.scheduler, self.train_loader, self.val_loader = self.get_model()
self.dir_path = os.path.join(self.args.checkpoint_folder,self.args.net)
self.dir_path = os.path.join(self.args.checkpoint_folder, self.args.net)
if not os.path.exists(self.dir_path):
os.makedirs(self.dir_path)
def get_model(self):
timer = Timer()
logging.info(self.args)
if self.args.net == 'mb2-ssd-lite_f19':
create_net = create_mb_ssd_lite_f19
config = mb_ssd_lite_f19_config
config = mb_ssd_lite_f19_config
elif self.args.net == 'mb2-ssd-lite_f38':
create_net = create_mb_ssd_lite_f38
config = mb_ssd_lite_f38_config
......@@ -58,29 +61,34 @@ class Train():
parser.print_help(sys.stderr)
sys.exit(1)
train_loader,val_loader, num_classes = data_loader(config)
train_loader, val_loader, num_classes = data_loader(config)
net, criterion, optimizer, scheduler = create_network(create_net, num_classes, self.device)
return net, criterion, optimizer, scheduler, train_loader, val_loader
def training (self):
def training(self):
print(self.dir_path)
for epoch in range(0, self.args.num_epochs):
self.scheduler.step()
training_loss = train(self.train_loader, self.net, self.criterion, self.optimizer, device=self.device, debug_steps=self.args.debug_steps, epoch=epoch)
training_loss = train(self.train_loader, self.net, self.criterion, self.optimizer, device=self.device,
debug_steps=self.args.debug_steps, epoch=epoch)
if epoch % self.args.validation_epochs == 0 or epoch == self.args.num_epochs - 1:
if self.args.valid:
val_running_loss, val_running_regression_loss, val_running_classification_loss = test(self.val_loader,self.net,self.criterion,device=self.device)
val_running_loss, val_running_regression_loss, val_running_classification_loss = test(
self.val_loader, self.net, self.criterion, device=self.device)
logging.info(
f"Epoch: {epoch}, " +
f"val_avg_loss: {val_running_loss:.4f}, " +
f"val_reg_loss {val_running_regression_loss:.4f}, " +
f"val_cls_loss: {val_running_classification_loss:.4f}")
model_path = os.path.join(self.dir_path, f"{self.args.net}-epoch-{epoch}-train_loss-{round(training_loss,2)}-val_loss-{round(val_running_loss,2)}.pth")
else :
model_path = os.path.join(self.dir_path, f"{self.args.net}-epoch-{epoch}-train_loss-{round(training_loss,2)}.pth")
model_path = os.path.join(self.dir_path,
f"{self.args.net}-epoch-{epoch}-train_loss-{round(training_loss, 2)}-val_loss-{round(val_running_loss, 2)}.pth")
else:
model_path = os.path.join(self.dir_path,
f"{self.args.net}-epoch-{epoch}-train_loss-{round(training_loss, 2)}.pth")
self.net.save(model_path)
logging.info(f"Saved model {self.dir_path}")
if __name__ == '__main__':
train = Train().training()
\ No newline at end of file
train = Train().training()
from utils.argument import _argument
import logging
import sys
......@@ -14,12 +13,15 @@ from torchsummary import summary
import torch
from torchscope import scope
import sys
sys.path.append('/media/ducanh/DATA/tienln/ai_camera/detector/')
from utils.misc import str2bool, Timer, freeze_net_layers, store_labels
timer = Timer()
args = _argument()
def train(loader, net, criterion, optimizer, device, debug_steps=100, epoch=-1):
net.train(True)
running_loss = 0.0
......@@ -60,6 +62,7 @@ def train(loader, net, criterion, optimizer, device, debug_steps=100, epoch=-1):
return training_loss
def test(loader, net, criterion, device):
net.eval()
running_loss = 0.0
......@@ -82,9 +85,10 @@ def test(loader, net, criterion, device):
running_classification_loss += classification_loss.item()
return running_loss / num, running_regression_loss / num, running_classification_loss / num
def data_loader(config):
train_transform = TrainAugmentation(config.image_size, config.image_mean, config.image_std)
target_transform = MatchPrior(config.priors, config.center_variance,config.size_variance, config.iou_threshold)
target_transform = MatchPrior(config.priors, config.center_variance, config.size_variance, config.iou_threshold)
test_transform = TestTransform(config.image_size, config.image_mean, config.image_std)
logging.info("Prepare training datasets.")
......@@ -95,38 +99,39 @@ def data_loader(config):
path_dataset = open("/media/ducanh/DATA/tienln/ai_camera/ai_camera_detector/datasets/train_dataset.txt", "r")
for line in path_dataset:
data = line.split('+')
Data_Train.append([data[0],data[1][:-1]])
Data_Train.append([data[0], data[1][:-1]])
# training datasets
# dataset_paths = [Data_Train[0],Data_Train[1],Data_Train[2],Data_Train[3],Data_Train[4],Data_Train[5]]
dataset_paths = [Data_Train[3]]
for dataset_path in dataset_paths:
print(dataset_path)
dataset = _DataLoader(dataset_path, transform=train_transform,target_transform=target_transform)
dataset = _DataLoader(dataset_path, transform=train_transform, target_transform=target_transform)
print(len(dataset.ids))
datasets.append(dataset)
num_classes = len(dataset.class_names)
train_dataset = ConcatDataset(datasets)
logging.info("Train dataset size: {}".format(len(train_dataset)))
train_loader = DataLoader(train_dataset, args.batch_size,num_workers=args.num_workers,shuffle=True)
train_loader = DataLoader(train_dataset, args.batch_size, num_workers=args.num_workers, shuffle=True)
if args.valid:
# Validation datasets
path_dataset = open("/media/ducanh/DATA/tienln/ai_camera/ai_camera_detector/datasets/valid_dataset.txt", "r")
for line in path_dataset:
data = line.split('+')
Data_Valid.append([data[0],data[1][:-1]])
Data_Valid.append([data[0], data[1][:-1]])
# print(Data_Valid)
logging.info("Prepare Validation datasets.")
valid_dataset_paths = [Data_Valid[0]]
for dataset_path in valid_dataset_paths:
val_dataset = _DataLoader(dataset_path, transform=test_transform,target_transform=target_transform)
val_loader = DataLoader(val_dataset, args.batch_size,num_workers=args.num_workers,shuffle=True)
val_dataset = _DataLoader(dataset_path, transform=test_transform, target_transform=target_transform)
val_loader = DataLoader(val_dataset, args.batch_size, num_workers=args.num_workers, shuffle=True)
return train_loader, val_loader, num_classes
else:
return train_loader, num_classes
def create_network(create_net,num_classes, DEVICE ):
def create_network(create_net, num_classes, DEVICE):
logging.info("Build network.")
net = create_net(num_classes)
# print(net)
......@@ -195,7 +200,7 @@ def create_network(create_net,num_classes, DEVICE ):
logging.info("Uses MultiStepLR scheduler.")
milestones = [int(v.strip()) for v in args.milestones.split(",")]
scheduler = MultiStepLR(optimizer, milestones=milestones,
gamma=0.1, last_epoch=last_epoch)
gamma=0.1, last_epoch=last_epoch)
elif args.scheduler == 'cosine':
logging.info("Uses CosineAnnealingLR scheduler.")
scheduler = CosineAnnealingLR(optimizer, args.t_max, last_epoch=last_epoch)
......@@ -203,5 +208,5 @@ def create_network(create_net,num_classes, DEVICE ):
logging.fatal(f"Unsupported Scheduler: {args.scheduler}.")
parser.print_help(sys.stderr)
sys.exit(1)
return net, criterion, optimizer, scheduler
......@@ -4,6 +4,7 @@ import torch
import numpy as np
from utils import box_processing as box_utils
class MultiboxLoss(nn.Module):
def __init__(self, priors, iou_threshold, neg_pos_ratio,
center_variance, size_variance, device):
......@@ -42,9 +43,11 @@ class MultiboxLoss(nn.Module):
gt_locations = gt_locations[pos_mask, :].reshape(-1, 4)
smooth_l1_loss = F.smooth_l1_loss(predicted_locations, gt_locations, size_average=False)
num_pos = gt_locations.size(0)
return smooth_l1_loss/num_pos, classification_loss/num_pos
return smooth_l1_loss / num_pos, classification_loss / num_pos
class FocalLoss(nn.Module):
def __init__(self, gamma = 2, alpha = 0.25):
def __init__(self, gamma=2, alpha=0.25):
"""
focusing is parameter that can adjust the rate at which easy
examples are down-weighted.
......@@ -87,31 +90,32 @@ class FocalLoss(nn.Module):
conf_loss = -self.alpha * ((1 - p_t)**self.gamma * p_t_log)
conf_loss = conf_loss.sum()
"""
# focal loss implementation(2)
pos_cls = conf_targets >-1
pos_cls = conf_targets > -1
mask = pos_cls.unsqueeze(2).expand_as(conf_preds)
conf_p = conf_preds[mask].view(-1, conf_preds.size(2)).clone()
p_t_log = -F.cross_entropy(conf_p, conf_targets[pos_cls], reduction='sum')
p_t = torch.exp(p_t_log)
# This is focal loss presented in the paper eq(5)
conf_loss = -self.alpha * ((1 - p_t)**self.gamma * p_t_log)
conf_loss = -self.alpha * ((1 - p_t) ** self.gamma * p_t_log)
############# Localization Loss part ##############
pos = conf_targets > 0 # ignore background
pos = conf_targets > 0 # ignore background
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_preds)
loc_p = loc_preds[pos_idx].view(-1, 4)
loc_t = loc_targets[pos_idx].view(-1, 4)
loc_loss = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
num_pos = pos.long().sum(1, keepdim = True)
N = max(num_pos.data.sum(), 1) # to avoid divide by 0. It is caused by data augmentation when crop the images. The cropping can distort the boxes
conf_loss /= N # exclude number of background?
num_pos = pos.long().sum(1, keepdim=True)
N = max(num_pos.data.sum(),
1) # to avoid divide by 0. It is caused by data augmentation when crop the images. The cropping can distort the boxes
conf_loss /= N # exclude number of background?
loc_loss /= N
return loc_loss, conf_loss
def one_hot(self, x, n):
y = torch.eye(n)
return y[x]
\ No newline at end of file
return y[x]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment