Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RelTR code reading #43

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

RelTR code reading #43

wants to merge 1 commit into from

Conversation

long8v
Copy link
Owner

@long8v long8v commented Jul 21, 2022

RelTR code reading

https://github.com/yrcong/RelTR
#40

  • bipartite matching이 어떻게 되는지 알겠음
    • 최종적인 output은 (batch_size, num_of_triplets, -1) 처럼 됨.
    • cost function(=gt와의 classification 에러 + bbox 에러)를 각각의 차원(bbox, classification output은 따로 나올 것이므로)에서 정의하여 summation하면 (batch_size, num_of_triplets, cost)처럼 될 것이고 이를 scipy 패키지 사용하면 cost를 최소화하며 모든 gt와 매칭되는 index가 나옴!
    • inference는 그럼 어떻게 되는거지??? -> 이건 DETR 다시 읽자

Copy link
Owner Author

@long8v long8v left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • bipartite matching
  • model I/O
  • architecture
    위주로 봄

RelTR/models/matcher.py Show resolved Hide resolved
RelTR/models/matcher.py Show resolved Hide resolved
Comment on lines +35 to +43
Params:
outputs: This is a dict that contains at least these entries:
"pred_logits": Tensor of dim [batch_size, num_entities, num_entity_classes] with the entity classification logits
"pred_boxes": Tensor of dim [batch_size, num_entities, 4] with the predicted box coordinates
"sub_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the subject classification logits
"sub_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted subject box coordinates
"obj_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the object classification logits
"obj_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted object box coordinates
"rel_logits": Tensor of dim [batch_size, num_triplets, num_predicate_classes] with the predicate classification logits
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

모델의 output들은 아래와 같음.
entity들에 대한건 logits과 bbox coord
아하! triplet에 대한 건
subject에 대해 [batch_size, num_triplets, num_entity_classes]으로 나오고
[batch_size, num_triplets, 4]으로 나오고 obj에 대해서도 똑같이 나오는군.
n번째 triplet이 주어졌을 때, entity class와 bbox prediction / relation 을 하는구나.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

다시 들어가는 param을 보면

  • pred_boxes : OD에서 나오는 boxes. 그래서 1번째 차원이 num_entities
  • pred_logits : OD에서 나오는 cls logits
  • sub_boxes : triplet decoder에서 나오는 subject에 대한 boxes
  • sub_logits : ..

Comment on lines +45 to +52
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
objects in the target) containing the class labels
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
"image_id": Image index
"orig_size": Tensor of dim [2] with the height and width
"size": Tensor of dim [2] with the height and width after transformation
"rel_annotations": Tensor of dim [num_gt_triplet, 3] with the subject index/object index/predicate class
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

target은

  • labels : 타겟 box class labels
  • boxes : 타겟 box들의 coord
  • rel_annotations : num_gt_triplets, 3(subject, object, predicate class.) (index?)

Comment on lines +54 to +61
A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected entity predictions (in order)
- index_j is the indices of the corresponding selected entity targets (in order)
A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected triplet predictions (in order)
- index_j is the indices of the corresponding selected triplet targets (in order)
Subject loss weight (Type: bool) to determine if back propagation should be conducted
Object loss weight (Type: bool) to determine if back propagation should be conducted
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(selected entity prediction, selected entity target), (selected triplet prediction, selected triplet target)

Comment on lines +46 to +59
self.so_mask_conv = nn.Sequential(torch.nn.Upsample(size=(28, 28)),
nn.Conv2d(2, 64, kernel_size=3, stride=2, padding=3, bias=True),
nn.ReLU(inplace=True),
nn.BatchNorm2d(64),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.BatchNorm2d(32))
self.so_mask_fc = nn.Sequential(nn.Linear(2048, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 128))

# predicate classification
self.rel_class_embed = MLP(hidden_dim*2+128, hidden_dim, num_rel_classes + 1, 2)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

relation 뽑을 때는 아까 뽑은 obj_maps 사용해서 CNN레이어 통해서 함


if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
features, pos = self.backbone(samples)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resnet backbone 통과

Comment on lines +94 to +95
hs, hs_t, so_masks, _ = self.transformer(self.input_proj(src), mask, self.entity_embed.weight,
self.triplet_embed.weight, pos[-1], self.so_embed.weight)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transformer encoder decoder 통과

Comment on lines +96 to +98
so_masks = so_masks.detach()
so_masks = self.so_mask_conv(so_masks.view(-1, 2, src.shape[-2],src.shape[-1])).view(hs_t.shape[0], hs_t.shape[1], hs_t.shape[2],-1)
so_masks = self.so_mask_fc(so_masks)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

뭘 위한 mask인지 잘 모르겠음

Comment on lines +102 to +119
outputs_class = self.entity_class_embed(hs)
outputs_coord = self.entity_bbox_embed(hs).sigmoid()

outputs_class_sub = self.sub_class_embed(hs_sub)
outputs_coord_sub = self.sub_bbox_embed(hs_sub).sigmoid()

outputs_class_obj = self.obj_class_embed(hs_obj)
outputs_coord_obj = self.obj_bbox_embed(hs_obj).sigmoid()

outputs_class_rel = self.rel_class_embed(torch.cat((hs_sub, hs_obj, so_masks), dim=-1))

out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1],
'sub_logits': outputs_class_sub[-1], 'sub_boxes': outputs_coord_sub[-1],
'obj_logits': outputs_class_obj[-1], 'obj_boxes': outputs_coord_obj[-1],
'rel_logits': outputs_class_rel[-1]}
if self.aux_loss:
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord, outputs_class_sub, outputs_coord_sub,
outputs_class_obj, outputs_coord_obj, outputs_class_rel)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prediction heads

Comment on lines +131 to +147
outputs = model(img)

# keep only predictions with 0.+ confidence
probas = outputs['rel_logits'].softmax(-1)[0, :, :-1]
probas_sub = outputs['sub_logits'].softmax(-1)[0, :, :-1]
probas_obj = outputs['obj_logits'].softmax(-1)[0, :, :-1]
keep = torch.logical_and(probas.max(-1).values > 0.3, torch.logical_and(probas_sub.max(-1).values > 0.3,
probas_obj.max(-1).values > 0.3))

# convert boxes from [0; 1] to image scales
sub_bboxes_scaled = rescale_bboxes(outputs['sub_boxes'][0, keep], im.size)
obj_bboxes_scaled = rescale_bboxes(outputs['obj_boxes'][0, keep], im.size)

topk = 10
keep_queries = torch.nonzero(keep, as_tuple=True)[0]
indices = torch.argsort(-probas[keep_queries].max(-1)[0] * probas_sub[keep_queries].max(-1)[0] * probas_obj[keep_queries].max(-1)[0])[:topk]
keep_queries = keep_queries[indices]
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inference는 그냥 prob 기준치로 자르고, sort해서 뽑나보다

Copy link
Owner Author

@long8v long8v left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rel cost가 어떻게 정의되는지 다시 읽음

Comment on lines +35 to +43
Params:
outputs: This is a dict that contains at least these entries:
"pred_logits": Tensor of dim [batch_size, num_entities, num_entity_classes] with the entity classification logits
"pred_boxes": Tensor of dim [batch_size, num_entities, 4] with the predicted box coordinates
"sub_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the subject classification logits
"sub_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted subject box coordinates
"obj_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the object classification logits
"obj_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted object box coordinates
"rel_logits": Tensor of dim [batch_size, num_triplets, num_predicate_classes] with the predicate classification logits
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Comment on lines +35 to +43
Params:
outputs: This is a dict that contains at least these entries:
"pred_logits": Tensor of dim [batch_size, num_entities, num_entity_classes] with the entity classification logits
"pred_boxes": Tensor of dim [batch_size, num_entities, 4] with the predicted box coordinates
"sub_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the subject classification logits
"sub_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted subject box coordinates
"obj_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the object classification logits
"obj_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted object box coordinates
"rel_logits": Tensor of dim [batch_size, num_triplets, num_predicate_classes] with the predicate classification logits
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

다시 들어가는 param을 보면

  • pred_boxes : OD에서 나오는 boxes. 그래서 1번째 차원이 num_entities
  • pred_logits : OD에서 나오는 cls logits
  • sub_boxes : triplet decoder에서 나오는 subject에 대한 boxes
  • sub_logits : ..

Comment on lines +63 to +64
bs, num_queries = outputs["pred_logits"].shape[:2]
num_queries_rel = outputs["rel_logits"].shape[1]
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_queries는 DETR decoder에 들어가는 num_queries
num_queries_rel은 아마 num_sub=num_obj=num_pred

Comment on lines +94 to +99
# Concat the subject/object/predicate labels and subject/object boxes
sub_tgt_bbox = torch.cat([v['boxes'][v['rel_annotations'][:, 0]] for v in targets])
sub_tgt_ids = torch.cat([v['labels'][v['rel_annotations'][:, 0]] for v in targets])
obj_tgt_bbox = torch.cat([v['boxes'][v['rel_annotations'][:, 1]] for v in targets])
obj_tgt_ids = torch.cat([v['labels'][v['rel_annotations'][:, 1]] for v in targets])
rel_tgt_ids = torch.cat([v["rel_annotations"][:, 2] for v in targets])
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RelTR의 output은 SOP가 따로 나오기 때문에 rel_annotations에서 나온 idx로 인덱싱해서 tensor로 만들어줌

Comment on lines +101 to +105
sub_prob = outputs["sub_logits"].flatten(0, 1).sigmoid()
sub_bbox = outputs["sub_boxes"].flatten(0, 1)
obj_prob = outputs["obj_logits"].flatten(0, 1).sigmoid()
obj_bbox = outputs["obj_boxes"].flatten(0, 1)
rel_prob = outputs["rel_logits"].flatten(0, 1).sigmoid()
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output들 0~1차원 flatten 시켜줌.
logits 기준 [batch_size, num_entities, num_entity_classes] -> [batch_size * num_entities, num_entity_classes]

Comment on lines +107 to +112
# Compute the subject matching cost based on class and box.
neg_cost_class_sub = (1 - alpha) * (sub_prob ** gamma) * (-(1 - sub_prob + 1e-8).log())
pos_cost_class_sub = alpha * ((1 - sub_prob) ** gamma) * (-(sub_prob + 1e-8).log())
cost_sub_class = pos_cost_class_sub[:, sub_tgt_ids] - neg_cost_class_sub[:, sub_tgt_ids]
cost_sub_bbox = torch.cdist(sub_bbox, sub_tgt_bbox, p=1)
cost_sub_giou = -generalized_box_iou(box_cxcywh_to_xyxy(sub_bbox), box_cxcywh_to_xyxy(sub_tgt_bbox))
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neg_cost_class_sub 뭐하는건지 모르겠넹

Comment on lines +121 to +124
# Compute the object matching cost only based on class.
neg_cost_class_rel = (1 - alpha) * (rel_prob ** gamma) * (-(1 - rel_prob + 1e-8).log())
pos_cost_class_rel = alpha * ((1 - rel_prob) ** gamma) * (-(rel_prob + 1e-8).log())
cost_rel_class = pos_cost_class_rel[:, rel_tgt_ids] - neg_cost_class_rel[:, rel_tgt_ids]
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rel_cost 따로 cdist 안하고 OD class_cost하듯이 함. -> 100 x 100 임~

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant