-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RelTR code reading #43
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- bipartite matching
- model I/O
- architecture
위주로 봄
Params: | ||
outputs: This is a dict that contains at least these entries: | ||
"pred_logits": Tensor of dim [batch_size, num_entities, num_entity_classes] with the entity classification logits | ||
"pred_boxes": Tensor of dim [batch_size, num_entities, 4] with the predicted box coordinates | ||
"sub_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the subject classification logits | ||
"sub_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted subject box coordinates | ||
"obj_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the object classification logits | ||
"obj_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted object box coordinates | ||
"rel_logits": Tensor of dim [batch_size, num_triplets, num_predicate_classes] with the predicate classification logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
모델의 output들은 아래와 같음.
entity들에 대한건 logits과 bbox coord
아하! triplet에 대한 건
subject에 대해 [batch_size, num_triplets, num_entity_classes]으로 나오고
[batch_size, num_triplets, 4]으로 나오고 obj에 대해서도 똑같이 나오는군.
n번째 triplet이 주어졌을 때, entity class와 bbox prediction / relation 을 하는구나.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
다시 들어가는 param을 보면
- pred_boxes : OD에서 나오는 boxes. 그래서 1번째 차원이 num_entities
- pred_logits : OD에서 나오는 cls logits
- sub_boxes : triplet decoder에서 나오는 subject에 대한 boxes
- sub_logits : ..
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: | ||
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth | ||
objects in the target) containing the class labels | ||
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates | ||
"image_id": Image index | ||
"orig_size": Tensor of dim [2] with the height and width | ||
"size": Tensor of dim [2] with the height and width after transformation | ||
"rel_annotations": Tensor of dim [num_gt_triplet, 3] with the subject index/object index/predicate class |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
target은
- labels : 타겟 box class labels
- boxes : 타겟 box들의 coord
- rel_annotations : num_gt_triplets, 3(subject, object, predicate class.) (index?)
A list of size batch_size, containing tuples of (index_i, index_j) where: | ||
- index_i is the indices of the selected entity predictions (in order) | ||
- index_j is the indices of the corresponding selected entity targets (in order) | ||
A list of size batch_size, containing tuples of (index_i, index_j) where: | ||
- index_i is the indices of the selected triplet predictions (in order) | ||
- index_j is the indices of the corresponding selected triplet targets (in order) | ||
Subject loss weight (Type: bool) to determine if back propagation should be conducted | ||
Object loss weight (Type: bool) to determine if back propagation should be conducted |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(selected entity prediction, selected entity target), (selected triplet prediction, selected triplet target)
self.so_mask_conv = nn.Sequential(torch.nn.Upsample(size=(28, 28)), | ||
nn.Conv2d(2, 64, kernel_size=3, stride=2, padding=3, bias=True), | ||
nn.ReLU(inplace=True), | ||
nn.BatchNorm2d(64), | ||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), | ||
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, bias=True), | ||
nn.ReLU(inplace=True), | ||
nn.BatchNorm2d(32)) | ||
self.so_mask_fc = nn.Sequential(nn.Linear(2048, 512), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(512, 128)) | ||
|
||
# predicate classification | ||
self.rel_class_embed = MLP(hidden_dim*2+128, hidden_dim, num_rel_classes + 1, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
relation 뽑을 때는 아까 뽑은 obj_maps 사용해서 CNN레이어 통해서 함
|
||
if isinstance(samples, (list, torch.Tensor)): | ||
samples = nested_tensor_from_tensor_list(samples) | ||
features, pos = self.backbone(samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resnet backbone 통과
hs, hs_t, so_masks, _ = self.transformer(self.input_proj(src), mask, self.entity_embed.weight, | ||
self.triplet_embed.weight, pos[-1], self.so_embed.weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
transformer encoder decoder 통과
so_masks = so_masks.detach() | ||
so_masks = self.so_mask_conv(so_masks.view(-1, 2, src.shape[-2],src.shape[-1])).view(hs_t.shape[0], hs_t.shape[1], hs_t.shape[2],-1) | ||
so_masks = self.so_mask_fc(so_masks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
뭘 위한 mask인지 잘 모르겠음
outputs_class = self.entity_class_embed(hs) | ||
outputs_coord = self.entity_bbox_embed(hs).sigmoid() | ||
|
||
outputs_class_sub = self.sub_class_embed(hs_sub) | ||
outputs_coord_sub = self.sub_bbox_embed(hs_sub).sigmoid() | ||
|
||
outputs_class_obj = self.obj_class_embed(hs_obj) | ||
outputs_coord_obj = self.obj_bbox_embed(hs_obj).sigmoid() | ||
|
||
outputs_class_rel = self.rel_class_embed(torch.cat((hs_sub, hs_obj, so_masks), dim=-1)) | ||
|
||
out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], | ||
'sub_logits': outputs_class_sub[-1], 'sub_boxes': outputs_coord_sub[-1], | ||
'obj_logits': outputs_class_obj[-1], 'obj_boxes': outputs_coord_obj[-1], | ||
'rel_logits': outputs_class_rel[-1]} | ||
if self.aux_loss: | ||
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord, outputs_class_sub, outputs_coord_sub, | ||
outputs_class_obj, outputs_coord_obj, outputs_class_rel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prediction heads
outputs = model(img) | ||
|
||
# keep only predictions with 0.+ confidence | ||
probas = outputs['rel_logits'].softmax(-1)[0, :, :-1] | ||
probas_sub = outputs['sub_logits'].softmax(-1)[0, :, :-1] | ||
probas_obj = outputs['obj_logits'].softmax(-1)[0, :, :-1] | ||
keep = torch.logical_and(probas.max(-1).values > 0.3, torch.logical_and(probas_sub.max(-1).values > 0.3, | ||
probas_obj.max(-1).values > 0.3)) | ||
|
||
# convert boxes from [0; 1] to image scales | ||
sub_bboxes_scaled = rescale_bboxes(outputs['sub_boxes'][0, keep], im.size) | ||
obj_bboxes_scaled = rescale_bboxes(outputs['obj_boxes'][0, keep], im.size) | ||
|
||
topk = 10 | ||
keep_queries = torch.nonzero(keep, as_tuple=True)[0] | ||
indices = torch.argsort(-probas[keep_queries].max(-1)[0] * probas_sub[keep_queries].max(-1)[0] * probas_obj[keep_queries].max(-1)[0])[:topk] | ||
keep_queries = keep_queries[indices] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inference는 그냥 prob 기준치로 자르고, sort해서 뽑나보다
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rel cost가 어떻게 정의되는지 다시 읽음
Params: | ||
outputs: This is a dict that contains at least these entries: | ||
"pred_logits": Tensor of dim [batch_size, num_entities, num_entity_classes] with the entity classification logits | ||
"pred_boxes": Tensor of dim [batch_size, num_entities, 4] with the predicted box coordinates | ||
"sub_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the subject classification logits | ||
"sub_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted subject box coordinates | ||
"obj_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the object classification logits | ||
"obj_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted object box coordinates | ||
"rel_logits": Tensor of dim [batch_size, num_triplets, num_predicate_classes] with the predicate classification logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Params: | ||
outputs: This is a dict that contains at least these entries: | ||
"pred_logits": Tensor of dim [batch_size, num_entities, num_entity_classes] with the entity classification logits | ||
"pred_boxes": Tensor of dim [batch_size, num_entities, 4] with the predicted box coordinates | ||
"sub_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the subject classification logits | ||
"sub_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted subject box coordinates | ||
"obj_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the object classification logits | ||
"obj_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted object box coordinates | ||
"rel_logits": Tensor of dim [batch_size, num_triplets, num_predicate_classes] with the predicate classification logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
다시 들어가는 param을 보면
- pred_boxes : OD에서 나오는 boxes. 그래서 1번째 차원이 num_entities
- pred_logits : OD에서 나오는 cls logits
- sub_boxes : triplet decoder에서 나오는 subject에 대한 boxes
- sub_logits : ..
bs, num_queries = outputs["pred_logits"].shape[:2] | ||
num_queries_rel = outputs["rel_logits"].shape[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_queries는 DETR decoder에 들어가는 num_queries
num_queries_rel은 아마 num_sub=num_obj=num_pred
# Concat the subject/object/predicate labels and subject/object boxes | ||
sub_tgt_bbox = torch.cat([v['boxes'][v['rel_annotations'][:, 0]] for v in targets]) | ||
sub_tgt_ids = torch.cat([v['labels'][v['rel_annotations'][:, 0]] for v in targets]) | ||
obj_tgt_bbox = torch.cat([v['boxes'][v['rel_annotations'][:, 1]] for v in targets]) | ||
obj_tgt_ids = torch.cat([v['labels'][v['rel_annotations'][:, 1]] for v in targets]) | ||
rel_tgt_ids = torch.cat([v["rel_annotations"][:, 2] for v in targets]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RelTR의 output은 SOP가 따로 나오기 때문에 rel_annotations에서 나온 idx로 인덱싱해서 tensor로 만들어줌
sub_prob = outputs["sub_logits"].flatten(0, 1).sigmoid() | ||
sub_bbox = outputs["sub_boxes"].flatten(0, 1) | ||
obj_prob = outputs["obj_logits"].flatten(0, 1).sigmoid() | ||
obj_bbox = outputs["obj_boxes"].flatten(0, 1) | ||
rel_prob = outputs["rel_logits"].flatten(0, 1).sigmoid() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output들 0~1차원 flatten 시켜줌.
logits 기준 [batch_size, num_entities, num_entity_classes] -> [batch_size * num_entities, num_entity_classes]
# Compute the subject matching cost based on class and box. | ||
neg_cost_class_sub = (1 - alpha) * (sub_prob ** gamma) * (-(1 - sub_prob + 1e-8).log()) | ||
pos_cost_class_sub = alpha * ((1 - sub_prob) ** gamma) * (-(sub_prob + 1e-8).log()) | ||
cost_sub_class = pos_cost_class_sub[:, sub_tgt_ids] - neg_cost_class_sub[:, sub_tgt_ids] | ||
cost_sub_bbox = torch.cdist(sub_bbox, sub_tgt_bbox, p=1) | ||
cost_sub_giou = -generalized_box_iou(box_cxcywh_to_xyxy(sub_bbox), box_cxcywh_to_xyxy(sub_tgt_bbox)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
neg_cost_class_sub 뭐하는건지 모르겠넹
# Compute the object matching cost only based on class. | ||
neg_cost_class_rel = (1 - alpha) * (rel_prob ** gamma) * (-(1 - rel_prob + 1e-8).log()) | ||
pos_cost_class_rel = alpha * ((1 - rel_prob) ** gamma) * (-(rel_prob + 1e-8).log()) | ||
cost_rel_class = pos_cost_class_rel[:, rel_tgt_ids] - neg_cost_class_rel[:, rel_tgt_ids] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rel_cost 따로 cdist 안하고 OD class_cost하듯이 함. -> 100 x 100 임~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RelTR code reading
https://github.com/yrcong/RelTR
#40