Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Visualbert VQA model inference lower accuracy in validation around 40% by huggingface framework #45

@guanhdrmq

Description

@guanhdrmq
class VQADataset(torch.utils.data.Dataset):
"""VQA (v2) dataset."""
def __init__(self, questions, annotations, tokenizer, image_preprocess, frcnn, frcnn_cfg):
self.questions = questions
self.annotations = annotations
self.tokenizer = tokenizer
self.image_preprocess = image_preprocess
self.frcnn = frcnn
self.frcnn_cfg = frcnn_cfg

def __len__(self):
return len(self.annotations)

def __getitem__(self, idx):
 # answer
annotation = self.annotations[idx]
#  question
questions = self.questions[idx]
image_path = id_to_filename[annotation["image_id"]]
image_path = image_path.replace("./multimodal_data/vqa2/val2014/.", "", 1)
text = questions['question']

inputs = self.tokenizer(
     text,
     padding="max_length",
     max_length=25,
     truncation=True,
     return_token_type_ids=True,
     return_attention_mask=True,
     add_special_tokens=True,
     return_tensors="pt")


images, sizes, scales_yx = self.image_preprocess(image_path)
output_dict = self.frcnn(
                     images,
                     sizes,
                     scales_yx=scales_yx,
                     padding="max_detections",
                     max_detections=self.frcnn_cfg.max_detections,
                     return_tensors="pt")

# Very important that the boxes are normalized
feature = output_dict.get("roi_features")
normalized_boxes = output_dict.get("normalized_boxes")

inputs.update(
    {
     "visual_embeds": feature,
     "visual_attention_mask": torch.ones(feature.shape[:-1], dtype=torch.float),
     # "visual_token_type_ids": torch.ones(feature.shape[:-1], dtype=torch.long),
     "output_attentions": False
     }
)

# remove batch dimension
for k, v in inputs.items():
     if isinstance(v, torch.Tensor):
        inputs[k] = v.squeeze()

# add labels
labels = annotation['labels']
# print("label candidate:", labels)
scores = annotation["scores"]

targets = torch.zeros(len(config.id2label), dtype=torch.float)
for label, score in zip(labels, scores):
    # print(f"Setting target at index {label} to {score}")
    targets[label] = score
inputs["labels"] = targets
inputs["text"] = text

print(text)
return inputs

from visualbert.processing_image import Preprocess
from visualbert.visualizing_image import SingleImageViz
from visualbert.modeling_frcnn import GeneralizedRCNN
from visualbert.utils import Config

frcnn_cfg = Config.from_pretrained("unc-nlp/frcnn-vg-finetuned")
frcnn = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=frcnn_cfg)
image_preprocess = Preprocess(frcnn_cfg)

from transformers import VisualBertForQuestionAnswering, AutoTokenizer, BertTokenizerFast
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

model = VisualBertForQuestionAnswering.from_pretrained("uclanlp/visualbert-vqa",
num_labels=len(config.id2label),
id2label=config.id2label,
label2id=config.label2id,
output_hidden_states=True)

model.to(device)
model.eval()

dataset = VQADataset(questions=questions[:100],
annotations=annotations[:100],
tokenizer=tokenizer,
image_preprocess=image_preprocess,
frcnn=frcnn,
frcnn_cfg=frcnn_cfg)

test_dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
correct = 0.0
total = 0

for batch in tqdm(test_dataloader):
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
logits = outputs.logits # [batch_size, 3129]
_, pre = torch.max(logits, 1)
_, target = torch.max(batch["labels"], 1)
print("prediction:", pre)
print("target:", target)
print("Predicted answer:", model.config.id2label[pre.item()])
print("Target answer:", model.config.id2label[target.item()])
correct += (pre == target).sum()
total = total + 1
print(total)

final_acc = correct / float(len(test_dataloader.dataset))
print('Accuracy of test: %f %%' % (100 * float(final_acc)))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions