diff --git a/datasets/README.md b/datasets/README.md new file mode 100644 index 0000000..eeef942 --- /dev/null +++ b/datasets/README.md @@ -0,0 +1,176 @@ +# S2-VLUE + +- [S2-VLUE](#s2-vlue) + - [Overview](#overview) + - [Download & Usage](#download--usage) + - [Download the exported JSON (for training language models)](#download-the-exported-json-for-training-language-models) + - [Download the source PDFs or screenshots](#download-the-source-pdfs-or-screenshots) + - [Datasets Details](#datasets-details) + - [The S2-VL dataset](#the-s2-vl-dataset) + - [Recreating the dataset from PDFs and annotations](#recreating-the-dataset-from-pdfs-and-annotations) + - [Dataset Curation Details](#dataset-curation-details) + - [The VILA-enhanced DocBank Dataset](#the-vila-enhanced-docbank-dataset) + - [Dataset Details](#dataset-details) + - [Statistics of the Datasets](#statistics-of-the-datasets) + - [File Structures](#file-structures) + - [Reference](#reference) + - [Citation](#citation) + +## Overview + +The S2-VLUE, Semantic Scholar **V**isual **L**ayout-enhanced Scientific Text **U**nderstanding **E**valuation (S2-VLUE) Benchmark Suite, is created to evaluate the scientific document understanding and parsing with visual layout information. + +It consists of three datasets, i.e., GROTOAP2, DocBank, and, S2-VL. We modify the existing dataset GROTOAP2[1] and DocBank[2], adding visual layout information and converting them to a format that is compatible with [HuggingFace Datasets](https://huggingface.co/docs/datasets/). +The S2-VL dataset is a newly curated dataset that addresses three major drawbacks in existing work: 1) annotation quality, 2) VILA creation, and 3) domain coverage. +It contains human annotations for papers from 19 scientific disciplines. +We provide scripts for downloading the source PDF files as well as converting them to a similar HuggingFace Datasets format. + +## Download & Usage + +### Download the exported JSON (for training language models) + +```bash +cd /datasets +bash ./download.sh #grotoap2, docbank, s2-vl or all +``` + +### Download the source PDFs or screenshots + +- GROTOAP2 (downloading paper PDFs) + - Please follow the instructions from the [GROTOAP2 Project README](http://cermine.ceon.pl/grotoap2/README). +- DocBank (downloading paper page screenshots) + - Please follow the instructions from the [home page of the DocBank Project](https://doc-analysis.github.io/docbank-page/index.html). +- S2-VL (downloading paper PDFs) + - Please check the instructions in [s2-vl-utils/README.md](s2-vl-utils/README.md). + +## Datasets Details + +### The S2-VL dataset + +During the data release process, we unfortunately found that a small portion of PDFs in our dataset (22 out of 87) had additional copyright constraints of which we had been unaware. This meant that we could not directly release the data corresponding to these papers. As such, in the downloaded version, it contains only paper data created from the 65 papers. + +If you are interested in the version of the dataset used for training and evaluation in our paper, please fill out this [Google Form](https://forms.gle/M1g9tQLrUtKSsDYA7) to request access. + +#### Recreating the dataset from PDFs and annotations + +We also provide the full code to help you recreate the dataset from PDFs and annotation files to the JSON files for training models. Please check the instructions in [s2-vl-utils/README.md](s2-vl-utils/README.md). + +#### Dataset Curation Details + +Please find a detailed description of the labeling schemas and categories in the following documents: +- [Labeling Instruction](https://docs.google.com/document/d/1DsIDKNEi8GBxrqQuYRy86lCKhksgvyRaGhXPCheGgG0/edit?usp=sharing) +- [S2-VL Category Definition](https://docs.google.com/document/d/1frGmzYOHnVRWAwTOuuPfc3KVAwu-XKdkFSbpLfy78RI/edit?usp=sharing) + - We labeled both layout and semantic categories in S2-VL (see the document above), but only the 16 Layout categories will be used in this evaluation benchmark. +- [The 19 Scientific Disciplines](https://docs.google.com/document/d/1ytJkYhswp4Wlx8tT1iRe-jdjx5A-nqisvUikgmqSQKc/edit?usp=sharing) + +### The VILA-enhanced DocBank Dataset + +## Dataset Details + +### Statistics of the Datasets + +| | GROTOAP2 | DocBank | S2-VL-ver1 | +| ----------------- | ------------ | --------------- | ------------------------------ | +| Train Test Split | 83k/18k/18k | 398k/50k/50k | * | +| Annotation Method | Automatic | Automatic | Human Annotation | +| Paper Domain | Life Science | Math/Physics/CS | 19 Disciplines | +| VILA Structure | PDF parsing | Vision model | Gold Label / Detection methods | +| # of Categories | 22 | 12 | 15 | + +| | GROTOAP2 | DocBank | S2-VL-ver1* | +| ------------------------- | -------- | ------- | --------- | +| **Tokens per Page** | +| Average | 1203 | 838 | 790 | +| Std | 591 | 503 | 453 | +| 95th Percentile | 2307 | 1553 | 1591 | +| **Text Lines per Page** | +| Average | 90 | 60 | 64 | +| Std | 51 | 34 | 54 | +| 95th Percentile | 171 | 125 | 154 | +| **Text Blocks per Page** | +| Average | 12 | 15 | 22 | +| Std | 16 | 8 | 36 | +| 95th Percentile | 37 | 30 | 68 | +| **Tokens per Text Line** | +| Average | 17 | 16 | 14 | +| Std | 12 | 43 | 10 | +| 95th Percentile | 38 | 38 | 30 | +| **Tokens per Text Block** | +| Average | 90 | 57 | 48 | +| Std | 184 | 138 | 121 | +| 95th Percentile | 431 | 210 | 249 | + +* This is calculated based on the S2-VL-ver1 with all 87 papers. + +### File Structures + +1. The organization of the dataset files : + ```bash + grotoap2 # Docbank is similar + ├─ labels.json + ├─ train-token.json + ├─ dev-token.json + ├─ test-token.json + └─ train-test-split.json + ``` +2. What's in each file? + 1. `labels.json` + ```json + {"0": "Title", + "1": "Author", + ... + } + ``` + 2. `train-test-split.json` + ```json + { + "train": [ + "pdf-file-name", ... + ], + "test": ["pdf-file-name", ...] + } + ``` + 3. `train-token.json`, `dev-token.json` or `test-token.json` + Please see detailed schema explanation in the [schema-token.json](schema-token.json) file. +3. Special notes on the folder structure for S2-VL: since the dataset size is small, we use 5-fold cross validation in the paper. The released version has a similar structure: + ```bash + s2-vl-ver1 + ├─ 0 # 5-fold Cross validation + │ ├─ labels.json + │ ├─ test-token.json + │ ├─ train-test-split.json + │ └─ train-token.json + ├─ 1 # fold-1, have the same files as other folds + │ ├─ labels.json + │ ├─ test-token.json + │ ├─ train-test-split.json + │ └─ train-token.json + ├─ 2 + ├─ 3 + └─ 4 + ``` + +## Reference + +1. The GROTOAP2 Dataset: + - Paper: https://www.dlib.org/dlib/november14/tkaczyk/11tkaczyk.html + - Original download link: http://cermine.ceon.pl/grotoap2/ + - Licence: Open Access license + +2. The Original DocBank Dataset: + - Paper: https://arxiv.org/pdf/2006.01038.pdf + - Original download link: https://github.com/doc-analysis/DocBank + - Licence: Apache-2.0 + +## Citation + +``` +@article{Shen2021IncorporatingVL, + title={Incorporating Visual Layout Structures for Scientific Text Classification}, + author={Zejiang Shen and Kyle Lo and Lucy Lu Wang and Bailey Kuehl and Daniel S. Weld and Doug Downey}, + journal={ArXiv}, + year={2021}, + volume={abs/2106.00676}, + url={https://arxiv.org/abs/2106.00676} +} +``` diff --git a/datasets/download.sh b/datasets/download.sh new file mode 100644 index 0000000..8a041a7 --- /dev/null +++ b/datasets/download.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +dataset_name="$1" +base_save_path="../data" +mkdir -p $base_save_path + +S3_BASE_LINK="https://ai2-s2-research.s3.us-west-2.amazonaws.com/s2-vlue" +GROTOAP2_S3_NAME="grotoap2.zip" +DOCBANK_S3_NAME="docbank.zip" +S2_VL_VER1_S3_NAME="s2-vl-ver1-public.zip" + +download_complied_dataset () { + target_path="$1" + s3_name="$2" + wget $S3_BASE_LINK/$s3_name -O $base_save_path/$s3_name + unzip $base_save_path/$s3_name -d $base_save_path/$target_path + rm $base_save_path/$s3_name +} + +case $dataset_name in + + grotoap2) + download_complied_dataset "grotoap2" $GROTOAP2_S3_NAME + ;; + + docbank) + download_complied_dataset "docbank" $DOCBANK_S3_NAME + ;; + + s2-vl) + download_complied_dataset "s2-vl-ver1" $S2_VL_VER1_S3_NAME + ;; + + all) + download_complied_dataset "grotoap2" $GROTOAP2_S3_NAME + download_complied_dataset "docbank" $DOCBANK_S3_NAME + download_complied_dataset "s2-vl-ver1" $S2_VL_VER1_S3_NAME + ;; + + *) + echo -n "Unkown Dataset" + exit + ;; +esac \ No newline at end of file diff --git a/datasets/s2-vl-utils/README.md b/datasets/s2-vl-utils/README.md new file mode 100644 index 0000000..c2e4592 --- /dev/null +++ b/datasets/s2-vl-utils/README.md @@ -0,0 +1,69 @@ +# Recreating the S2-VL Dataset + +- [Recreating the S2-VL Dataset](#recreating-the-s2-vl-dataset) + - [STEP0: Install extra dependencies for creating the dataset](#step0-install-extra-dependencies-for-creating-the-dataset) + - [STEP1: Download the papers & annotations](#step1-download-the-papers--annotations) + - [STEP2: Parse token data using CERMINE](#step2-parse-token-data-using-cermine) + - [STEP3: Run visual layout detectors for getting the text block and line blocks](#step3-run-visual-layout-detectors-for-getting-the-text-block-and-line-blocks) + - [STEP4: Assemble the annotations and export the dataset](#step4-assemble-the-annotations-and-export-the-dataset) + +## STEP0: Install extra dependencies for creating the dataset + +```bash +cd /datasets/s2-vl-utils +# activate the corresponding environment +pip install -r requirements +``` + +## STEP1: Download the papers & annotations + +```bash +python download.py --base-path sources/s2-vl-ver1 +``` +This will download the pdf files to `sources/s2-vl-ver1/pdfs` and annotation files to `sources/s2-vl-ver1/annotations`. +We'll check and report PDFs that don't have the compatible SHA1 code or cannot be downloaded. +Note: when you find incompatible SHAs for one PDF, it doesn't necessarily mean the PDFs are different. + +## STEP2: Parse token data using CERMINE + +1. Download JAVA and CERMINE following instructions in [this repo](https://github.com/CeON/CERMINE#using-cermine) (PS: The easiest approach would be just downloading CERMINE v1.13 from [JFrog](http://maven.ceon.pl/artifactory/webapp/#/artifacts/browse/simple/General/kdd-releases/pl/edu/icm/cermine/cermine-impl). + + +2. Run CERMINE on the set of papers and parse the token data, and convert the source CERMINE data to the csv format: + ```bash + python cermine_loader.py \ + --base-path sources/s2-vl-ver1 \ + --cermine-path /path/to/cermine-impl-1.13-jar-with-dependencies.jar + ``` + It will create the token table for each `sha-pid.csv` in the `sources/tokens` folder. + +## STEP3: Run visual layout detectors for getting the text block and line blocks + +```bash +python vision_model_loader.py --base-path sources +``` +It will: +1. run visual layout detection for both text blocks and lines, and save them in the `-.csv` files in the `sources/blocks` and `sources/lines` folder. +2. combine the text block, line, and token information, create a refined version of visual layout detection, and save them in the `-.csv` files in the `sources/condensed` folder. + +## STEP4: Assemble the annotations and export the dataset + +```bash +python condense_dataset.py \ + --annotation-folder 'sources/s2-vl-ver1/annotations' \ + --annotation-table 'sources/s2-vl-ver1/annotation_table.csv' \ + --cermine-pdf-dir 'sources/s2-vl-ver1/pdfs' \ + --cermine-csv-dir 'sources/s2-vl-ver1/tokens' \ + --vision-csv-dir 'sources/s2-vl-ver1/condensed' \ + --export-folder 'export/s2-vl-ver1' \ + --config './config.json' +``` + +It will convert all the source data in the source folder to a format that can be directly used for training the language models. By default, it will split the dataset into 5-folds for cross validation. The save folder will be specified in `--export-folder` configuration. There are several configurable options during the creation of the training dataset, perhaps the most important one is to specify what notion of blocks and lines to be used when constructing the dataset. Here are some available options: + +| Source of blocks | Sources of lines | Option | +| ---------------- | ---------------- | -------------------- | +| CERMINE | CERMINE | - (default behavior) | +| Vision Model | CERMINE | `--use-vision-box` | +| Vision Model | Vision Model | `--use-vision-line` | +| Ground-Truth | Vision Model | `--use-gt-box` | \ No newline at end of file diff --git a/datasets/s2-vl-utils/cermine_loader.py b/datasets/s2-vl-utils/cermine_loader.py new file mode 100644 index 0000000..1f65ab4 --- /dev/null +++ b/datasets/s2-vl-utils/cermine_loader.py @@ -0,0 +1,452 @@ +from typing import List, Union, Dict, Any, Tuple +from dataclasses import dataclass +from glob import glob +import os +import subprocess + + +from tqdm import tqdm +from bs4 import BeautifulSoup +import layoutparser as lp +import pandas as pd +from joblib import Parallel, delayed +import numpy as np +from scipy.spatial.distance import cdist + + +@dataclass +class PageData: + blocks: List[lp.TextBlock] + lines: List[lp.TextBlock] + words: List[lp.TextBlock] + + def to_dataframe( + self, + keep_token_index=True, + export_font=False, + normalize_coordinates=False, + canvas_width=None, + canvas_height=None, + ) -> pd.DataFrame: + + if not export_font: + blocks_to_save = [ + [ + ele.id, + *ele.coordinates, + ele.text, + ele.type, + -1, + -1, + True, + False, + ] + for ele in self.blocks + ] + lines_to_save = [ + [ + ele.id, + *ele.coordinates, + ele.text, + ele.type, + ele.parent, + -1, + False, + True, + ] + for ele in self.lines + ] + parent_block_id_for_line_id = {ele.id: ele.parent for ele in self.lines} + tokens_to_save = [ + [ + ele.id if keep_token_index else idx, + *ele.coordinates, + ele.text, + ele.type, + parent_block_id_for_line_id[ele.parent], # Cvt to block-level id + ele.parent, + False, + False, + ] + for idx, ele in enumerate(self.words, start=len(blocks_to_save)) + ] + df = pd.DataFrame( + blocks_to_save + lines_to_save + tokens_to_save, + columns=[ + "id", + "x_1", + "y_1", + "x_2", + "y_2", + "text", + "category", + "block_id", + "line_id", + "is_block", + "is_line", + ], + ) + else: + blocks_to_save = [ + [ + ele.id, + *ele.coordinates, + ele.text, + None, + ele.type, + -1, + -1, + True, + False, + ] + for ele in self.blocks + ] + lines_to_save = [ + [ + ele.id, + *ele.coordinates, + ele.text, + None, + ele.type, + ele.parent, + -1, + False, + True, + ] + for ele in self.lines + ] + parent_block_id_for_line_id = {ele.id: ele.parent for ele in self.lines} + tokens_to_save = [ + [ + ele.id if keep_token_index else idx, + *ele.coordinates, + ele.text, + ele.font, + ele.type, + parent_block_id_for_line_id[ele.parent], # Cvt to block-level id + ele.parent, + False, + False, + ] + for idx, ele in enumerate(self.words, start=len(blocks_to_save)) + ] + df = pd.DataFrame( + blocks_to_save + lines_to_save + tokens_to_save, + columns=[ + "id", + "x_1", + "y_1", + "x_2", + "y_2", + "text", + "font", + "category", + "block_id", + "line_id", + "is_block", + "is_line", + ], + ) + + if normalize_coordinates: + assert canvas_width is not None + assert canvas_height is not None + df[["x_1", "x_2"]] = (df[["x_1", "x_2"]] / canvas_width * 1000).astype( + "int" + ) + df[["y_1", "y_2"]] = (df[["y_1", "y_2"]] / canvas_height * 1000).astype( + "int" + ) + + return df + + +class GrotoapDataset: + def __init__(self, base_dir: str, dataset_folder_name: str = "dataset"): + + self.base_dir = base_dir + self.dataset_folder_name = dataset_folder_name + self.all_xml_files = glob( + f"{self.base_dir}/{self.dataset_folder_name}/*/*.cxml" + ) + + def load_xml(self, xml_filename: str): + with open(xml_filename, "r") as fp: + soup = BeautifulSoup(fp, "lxml") + + pages = soup.find_all("page") + + parsed_page_data = { + idx: self.parse_page_xml(page) for idx, page in enumerate(pages) + } + + return parsed_page_data + + def parse_page_xml(self, page: "bs4.element.Tag") -> PageData: + + blocks = [] + lines = [] + words = [] + + word_id = 0 + line_id = 0 + all_zones = page.find_all("zone") + if all_zones is None: + return PageData() + + for zone_id, zone in enumerate(all_zones): + + words_in_this_block = [] + # Fetch the zone + v1, v2 = zone.find("zonecorners").find_all("vertex") + block_type = zone.find("classification").find("category")["value"] + block = lp.TextBlock( + lp.Rectangle( + float(v1["x"]), float(v1["y"]), float(v2["x"]), float(v2["y"]) + ), + type=block_type, + id=zone_id, + ) + + # Fetch lines + all_lines = zone.find_all("line") + if all_lines is None: + continue + + for line in all_lines: + + words_in_this_line = [] + + v1, v2 = line.find("linecorners").find_all("vertex") + current_line = lp.TextBlock( + lp.Rectangle( + float(v1["x"]), + float(v1["y"]), + float(v2["x"]), + float(v2["y"]), + ), + type=block_type, + parent=zone_id, + id=line_id, + ) + + # Fetch words + all_words = line.find_all("word") + if all_words is None: + continue + + for word in line.find_all("word"): + v1, v2 = word.find("wordcorners").find_all("vertex") + words_in_this_line.append( + lp.TextBlock( + lp.Rectangle( + float(v1["x"]), + float(v1["y"]), + float(v2["x"]), + float(v2["y"]), + ), + type=block_type, + text="".join( + [ele["value"] for ele in word.find_all("gt_text")] + ), + id=word_id, + parent=line_id, + ) + ) + word_id += 1 + + current_line.text = " ".join(ele.text for ele in words_in_this_line) + line_id += 1 + words_in_this_block.extend(words_in_this_line) + lines.append(current_line) + + block.text = " ".join(ele.text for ele in words_in_this_block) + blocks.append(block) + words.extend(words_in_this_block) + + return PageData(blocks, lines, words) + + def convert_xml_to_page_token(self, xml_filename, export_path): + + savename = "-".join(xml_filename.split("/")[-2:]).rstrip(".cxml") + parsed_page_data = self.load_xml(xml_filename) + print(f"Processing {savename}") + for page_id, page_data in parsed_page_data.items(): + + if os.path.exists(f"{export_path}/{savename}-{page_id}.csv"): + continue + + df = page_data.to_dataframe() + df.to_csv(f"{export_path}/{savename}-{page_id}.csv", index=None) + + def convert_to_page_token_table(self, export_path: str, n_jobs=20): + + if not os.path.exists(export_path): + os.makedirs(export_path) + print(f"Creating the export directory {export_path}") + else: + print(f"Overwriting existing exports in {export_path}") + + Parallel(n_jobs=n_jobs)( + delayed(self.convert_xml_to_page_token)(xml_filename, export_path) + for xml_filename in tqdm(self.all_xml_files) + ) + + +class CERMINELoader(GrotoapDataset): + def __init__(self): + pass + + @staticmethod + def corner_to_rectangle(corners): + corners = corners.find_all("vertex") + corners = np.array([(float(ele["x"]), float(ele["y"])) for ele in corners]) + x1, y1 = corners.min(axis=0) + x2, y2 = corners.max(axis=0) + return lp.Rectangle(x1, y1, x2, y2) + + def parse_page_xml(self, page: "bs4.element.Tag") -> PageData: + + blocks = [] + lines = [] + words = [] + + word_id = 0 + line_id = 0 + all_zones = page.find_all("zone") + if all_zones is None: + return PageData() + + for zone_id, zone in enumerate(all_zones): + + words_in_this_block = [] + # Fetch the zone + rect = self.corner_to_rectangle(zone.find("zonecorners")) + block_type = zone.find("classification").find("category")["value"] + block = lp.TextBlock( + rect, + type=block_type, + id=zone_id, + ) + + # Fetch lines + all_lines = zone.find_all("line") + if all_lines is None: + continue + + for line in all_lines: + + words_in_this_line = [] + + rect = self.corner_to_rectangle(line.find("linecorners")) + current_line = lp.TextBlock( + rect, + type=block_type, + parent=zone_id, + id=line_id, + ) + + # Fetch words + all_words = line.find_all("word") + if all_words is None: + continue + + for word in line.find_all("word"): + rect = self.corner_to_rectangle(word.find("wordcorners")) + words_in_this_line.append( + lp.TextBlock( + rect, + type=block_type, + text="".join( + [ele["value"] for ele in word.find_all("gt_text")] + ), + id=word_id, + parent=line_id, + ) + ) + word_id += 1 + + current_line.text = " ".join(ele.text for ele in words_in_this_line) + line_id += 1 + words_in_this_block.extend(words_in_this_line) + lines.append(current_line) + + block.text = " ".join(ele.text for ele in words_in_this_block) + blocks.append(block) + words.extend(words_in_this_block) + + return PageData(blocks, lines, words) + + +CERMINE_LOADER = CERMINELoader() + + +def process_cermine_annotation(sha, pdf_path, token_path): + filename = f"{pdf_path}/{sha}.cxml" + + try: + xml_data = CERMINE_LOADER.load_xml(filename) + except: + print("error CERMINE parsing for ", sha) + return None + + # _, pdf_images = pdf_extractor.load_tokens_and_image(filename.replace('.cxml', '.pdf'), resize_image=True) + + # if len(xml_data) != len(pdf_images): + # print("error CERMINE parsing for ", sha) + # return None + + if ( + len(xml_data) == 1 and len(sha.split("-")) == 2 + ): # it is a single page pdf for an individual page + xml_data[0].to_dataframe().to_csv(f"{token_path}/{sha}.csv", index=None) + else: + for page_id in range(len(xml_data)): + xml_data[page_id].to_dataframe().to_csv( + f"{token_path}/{sha}-{page_id:02d}.csv", index=None + ) + +def get_file_sha(filename): + return filename.split("/")[-1].split(".")[0] + +# parse the arguments of base_path and run process_cermine_annotation +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--base-path", type=str, help="The path to the source files of a dataset, e.g., sources/s2-vl-ver1") + parser.add_argument("--cermine-path", type=str) + parser.add_argument("--njobs", type=int, default=2) + args = parser.parse_args() + + # folder structure + base_path = args.base_path + pdf_path = f"{base_path}/pdfs" + token_path = f"{base_path}/tokens" + + # verify the existence of the files + if not os.path.exists(pdf_path) or len(glob(f"{pdf_path}/*.pdf")) == 0: + print(f"The PDF path {pdf_path} does not exist! Please try download the dataset first.") + exit() + + # Run cermine parsing + if len(glob(f"{pdf_path}/*.cxml")) == 0: + CERMINE_IMP_NAME = "cermine-impl-1.13-jar-with-dependencies.jar" + cermine_imp_name = args.cermine_path if os.path.exists(args.cermine_path) else CERMINE_IMP_NAME + cermine_prog_name = "pl.edu.icm.cermine.PdfBxStructureExtractor" + subprocess.call( + ["java", "-cp", cermine_imp_name, cermine_prog_name, "-path", pdf_path] + ) + print("Finish processing all the PDFs using CERMINE") + else: + print("CERMINE XML files already exist") + + # process the cermine files + if not os.path.exists(token_path): + os.makedirs(token_path) + + Parallel(n_jobs=args.njobs)( + delayed(process_cermine_annotation)(get_file_sha(filename), pdf_path, token_path) + for filename in tqdm(glob(f"{pdf_path}/*.pdf")) + ) + print("Finish converting all the CERMINE XMLS to csv") \ No newline at end of file diff --git a/datasets/s2-vl-utils/condense_dataset.py b/datasets/s2-vl-utils/condense_dataset.py new file mode 100644 index 0000000..81ee3aa --- /dev/null +++ b/datasets/s2-vl-utils/condense_dataset.py @@ -0,0 +1,720 @@ +from typing import List, Union, Dict, Any, Tuple +import random +import argparse +import json +import itertools +import os +from glob import glob +from dataclasses import dataclass +from collections import defaultdict + +from sklearn.model_selection import KFold, train_test_split +from tqdm import tqdm +import pandas as pd +import numpy as np +import layoutparser as lp + +PADDING_CONSTANT = 10000 + +np.random.seed(42) +random.seed(42) + + +def load_json(filename): + with open(filename, "r") as fp: + return json.load(fp) + + +def write_json(data, filename): + with open(filename, "w") as fp: + json.dump(data, fp) + + +def cvt_df_to_layout(row): + + return lp.TextBlock( + lp.Rectangle( + row["x_1"], + row["y_1"], + row["x_2"], + row["y_2"], + ), + id=row["id"], + type=row["category"], + text=row["text"], + ) + + +class RawAnnotation: + def __init__(self, annotation_table, annotation_dir): + + self.annotation_table = pd.read_csv(annotation_table).set_index("sha") + self.annotation_dir = annotation_dir + + def load_annotation_for_sha(self, sha): + + all_page_annotations = {} + + if len(glob(f"{self.annotation_dir}/{sha}-*.json"))>0: + # Load annotation for sha-pageid.json like files + for filename in glob(f"{self.annotation_dir}/{sha}-*.json"): + page_id = int(filename.replace(f"{self.annotation_dir}/{sha}-", "").replace(".json","")) + res = self.load_page_data_from_json(filename) + if res is not None: + all_page_annotations[page_id] = res + else: + # load annotations for sha.json like files + for filename in glob(f"{self.annotation_dir}/{sha}.json"): + all_page_annotations = self.load_all_page_data_from_json(filename) + + return all_page_annotations + + def load_all_page_data_from_json(self, filename): + raw = load_json(filename) + results_by_page = defaultdict(list) + for ele in raw["annotations"]: + + results_by_page[ele["page"]].append( + lp.TextBlock( + lp.Rectangle( + ele["bounds"]["left"], + ele["bounds"]["top"], + ele["bounds"]["right"], + ele["bounds"]["bottom"], + ), + type=ele["label"]["text"], + ) + ) + return results_by_page + + def load_page_data_from_json(self, filename): + raw = load_json(filename) + page_annotation = [] + for ele in raw["annotations"]: + page_annotation.append( + lp.TextBlock( + lp.Rectangle( + ele["bounds"]["left"], + ele["bounds"]["top"], + ele["bounds"]["right"], + ele["bounds"]["bottom"], + ), + type=ele["label"]["text"], + ) + ) + return page_annotation + + +@dataclass +class PageData: + blocks: List[lp.TextBlock] + lines: List[lp.TextBlock] + words: List[lp.TextBlock] + + +class CERMINEAnnotation: + def __init__( + self, + pdf_directory, + csv_directory, + ): + self.csv_dir = csv_directory + self.pdf_dir = pdf_directory + + @staticmethod + def load_page_data_from_csv(filename): + df = pd.read_csv(filename) + if len(df) == 0: + return None + + df = df[~df.text.isna()] + if len(df) == 0: + return None + + blocks_df = df[df.is_block] + lines_df = df[df.is_line] + tokens_df = df[~df.is_line & ~df.is_block] + + return PageData( + blocks=lp.Layout(blocks_df.apply(cvt_df_to_layout, axis=1).tolist()), + lines=lp.Layout(lines_df.apply(cvt_df_to_layout, axis=1).tolist()), + words=lp.Layout(tokens_df.apply(cvt_df_to_layout, axis=1).tolist()), + ) + + def load_annotations_for_sha(self, sha): + + xml_data = {} + for filename in glob(f"{self.csv_dir}/{sha}-*.csv"): + page_id = int(filename.replace(f"{self.csv_dir}/{sha}-", "").replace(".csv","")) + res = self.load_page_data_from_csv(filename) + if res is not None: + xml_data[page_id] = res + + return xml_data + + +class VISIONAnnotation(CERMINEAnnotation): + @staticmethod + def load_page_data_from_csv(filename): + df = pd.read_csv(filename) + if len(df) == 0: + return None + # Not dropping empty tokens + + blocks_df = df[df.is_block] + lines_df = df[df.is_line] + tokens_df = df[~df.is_line & ~df.is_block] + + return PageData( + blocks=lp.Layout(blocks_df.apply(cvt_df_to_layout, axis=1).tolist()), + lines=lp.Layout(lines_df.apply(cvt_df_to_layout, axis=1).tolist()), + words=lp.Layout(tokens_df.apply(cvt_df_to_layout, axis=1).tolist()), + ) + +class S2VLAnnotationGenerator: + def __init__( + self, + annotation_table, + raw_annotation, + cermine_annotation, + selected_categories, + default_category, + vision_annotation=None, + ): + self.annotation_table = pd.read_csv(annotation_table) + self.raw_annotation = raw_annotation + self.cermine_annotation = cermine_annotation + self.vision_annotation = vision_annotation + + self.selected_categories = selected_categories + self.default_category = default_category + self.cat2id = {cat: idx for idx, cat in enumerate(selected_categories)} + self.id2cat = {idx: cat for idx, cat in enumerate(selected_categories)} + + def get_unique_shas(self): + return self.annotation_table.sha.unique() + + def convert_token_data_to_json(self, tokens): + token_df = pd.DataFrame( + [ + [ + e.id, + str(e.text), + [int(_) for _ in e.coordinates], + e.type, + e.block_id, + e.line_id, + ] + for e in tokens + ], + columns=["id", "text", "bbox", "category", "block_id", "line_id"], + ) + token_df = token_df[ + ~token_df.text.isnull() + & ~token_df.text.isna() + & ~token_df.text.str.isspace() + ] + row_item = { + "words": token_df["text"].tolist(), + "bbox": token_df["bbox"].tolist(), + "labels": token_df["category"].map(self.cat2id).tolist(), + "block_ids": token_df["block_id"].astype("int").tolist(), + "line_ids": token_df["line_id"].astype("int").tolist(), + } + + return row_item + + def create_annotation_for_sha(self, sha): + + all_token_data = [] + all_files = [] + + raw_blocks = self.raw_annotation.load_annotation_for_sha(sha) + cermine_data = self.cermine_annotation.load_annotations_for_sha(sha) + + for page_id in cermine_data.keys(): + blocks = [ + b for b in raw_blocks[page_id] if b.type in self.selected_categories + ] + + # Pass 1: O(n) Initialize ids and categories + for word in cermine_data[page_id].words: + word.line_id = -1 + word.block_id = -1 + word.type = self.default_category + + # Pass 2: O(mn) Assign token categories + for word in cermine_data[page_id].words: + for block in blocks: + if word.is_in(block, center=True): + word.type = block.type + + # Pass 3: O(mn) Assign token block-category ids + used_lines_for_assign_line_ids = cermine_data[page_id].lines + used_blocks_for_assign_block_ids = cermine_data[page_id].blocks + for word in cermine_data[page_id].words: + for _l in used_lines_for_assign_line_ids: + if word.is_in(_l, center=True): + word.line_id = _l.id + + for _b in used_blocks_for_assign_block_ids: + if word.is_in(_b, center=True): + word.block_id = _b.id + + # Pass 4: O(n) In case some blocks are not assigned with the + # appropriate block indices, we assign the line ids + for word in cermine_data[page_id].words: + if word.block_id == -1: + word.block_id = word.line_id + PADDING_CONSTANT + + row_item = self.convert_token_data_to_json(cermine_data[page_id].words) + + if len(row_item["words"]) > 0: + + all_token_data.append(row_item) + all_files.append(f"{sha}-{page_id}") + + return all_token_data, all_files + + def create_annotation_for_shas(self, shas): + all_token_data = [] + all_files = [] + pbar = tqdm(shas) + for sha in pbar: + pbar.set_description(sha) + token_data, files = self.create_annotation_for_sha(sha) + all_token_data.extend(token_data) + all_files.extend(files) + return all_token_data, all_files + + def create_annotations(self): + shas = self.get_unique_shas() + all_token_data, all_files = self.create_annotation_for_shas(shas) + all_valid_shas = list(set([ele.split("-")[0] for ele in all_files])) + + self.all_token_data = all_token_data + self.all_files = all_files + self.all_valid_shas = all_valid_shas + self.sha_to_sample_mapping = { + sha: [idx for idx, file in enumerate(all_files) if file[:40] == sha] + for sha in all_valid_shas + } + + def save_annotation_cv(self, export_folder, n_fold=5): + + kf = KFold(n_splits=n_fold, shuffle=True, random_state=42) + + for idx, (train_idx, test_idx) in enumerate( + tqdm(kf.split(self.all_valid_shas), total=n_fold) + ): + annotation_data = {} + train_test_split = {} + + for name, indices in [("train", train_idx), ("test", test_idx)]: + cur_shas = [self.all_valid_shas[i] for i in indices] + selected_data_item_indices = list( + itertools.chain.from_iterable( + [self.sha_to_sample_mapping[sha] for sha in cur_shas] + ) + ) + + annotation_data[name] = ( + [self.all_token_data[i] for i in selected_data_item_indices], + [self.all_files[i] for i in selected_data_item_indices], + ) + train_test_split[name] = annotation_data[name][1] + + cur_export_folder = f"{export_folder}/{idx}" + self.save_json(annotation_data, train_test_split, cur_export_folder) + + def save_annotation_few_shot(self, export_folder, sample_sizes=[5, 10, 15]): + + for sample_size in tqdm(sample_sizes): + + train_sha, test_sha = train_test_split( + self.all_valid_shas, train_size=sample_size, random_state=42 + ) + + annotation_data = {} + train_test_files = {} + + for name, cur_shas in [("train", train_sha), ("test", test_sha)]: + selected_data_item_indices = list( + itertools.chain.from_iterable( + [self.sha_to_sample_mapping[sha] for sha in cur_shas] + ) + ) + + annotation_data[name] = ( + [self.all_token_data[i] for i in selected_data_item_indices], + [self.all_files[i] for i in selected_data_item_indices], + ) + train_test_files[name] = annotation_data[name][1] + + cur_export_folder = f"{export_folder}/{sample_size}" + self.save_json(annotation_data, train_test_files, cur_export_folder) + + def save_annotation_few_shot_with_mutual_test_set( + self, export_folder, sample_sizes=[5, 10, 15] + ): + + maximum_training_samples = max(sample_sizes) + maximum_remaining_test_samples = ( + len(self.all_valid_shas) - maximum_training_samples + ) + + all_train_sha, test_sha = train_test_split( + self.all_valid_shas, + test_size=maximum_remaining_test_samples, + random_state=42, + ) + + for sample_size in tqdm(sample_sizes): + + train_sha = random.sample(all_train_sha, sample_size) + + annotation_data = {} + train_test_files = {} + + for name, cur_shas in [("train", train_sha), ("test", test_sha)]: + selected_data_item_indices = list( + itertools.chain.from_iterable( + [self.sha_to_sample_mapping[sha] for sha in cur_shas] + ) + ) + + annotation_data[name] = ( + [self.all_token_data[i] for i in selected_data_item_indices], + [self.all_files[i] for i in selected_data_item_indices], + ) + train_test_files[name] = annotation_data[name][1] + + cur_export_folder = f"{export_folder}/{sample_size}" + self.save_json(annotation_data, train_test_files, cur_export_folder) + + def save_annotation_few_shot_and_cv( + self, export_folder, train_test_shas, sample_sizes=[5, 10, 15, 25, 45, 70] + ): + + for cv_index, _shas in enumerate(tqdm(train_test_shas)): + all_train_sha, test_sha = _shas["train"], _shas["test"] + for sample_size in sample_sizes: + train_sha = all_train_sha[:sample_size] + + annotation_data = {} + train_test_files = {} + + for name, cur_shas in [("train", train_sha), ("test", test_sha)]: + selected_data_item_indices = list( + itertools.chain.from_iterable( + [self.sha_to_sample_mapping[sha] for sha in cur_shas] + ) + ) + + annotation_data[name] = ( + [self.all_token_data[i] for i in selected_data_item_indices], + [self.all_files[i] for i in selected_data_item_indices], + ) + train_test_files[name] = annotation_data[name][1] + + cur_export_folder = f"{export_folder}/{sample_size}/{cv_index}" + self.save_json(annotation_data, train_test_files, cur_export_folder) + + def save_json(self, annotation_data, train_test_split, export_folder): + + os.makedirs(export_folder, exist_ok=True) + + for subset, (all_token_data, all_files) in annotation_data.items(): + + write_json( + {"data": all_token_data, "labels": self.cat2id, "files": all_files}, + f"{export_folder}/{subset}-token.json", + ) + + write_json(train_test_split, f"{export_folder}/train-test-split.json") + write_json(self.id2cat, f"{export_folder}/labels.json") + + +class S2VLAnnotationGeneratorWithGTBox(S2VLAnnotationGenerator): + @staticmethod + def order_blocks_based_on_token_ids(blocks, tokens): + + token_ids_in_blocks = [] + + for block in blocks: + + token_ids_in_this_block = [] + + for token in tokens: + if token.is_in(block, center=True): + token_ids_in_this_block.append(token.id) + + if len(token_ids_in_this_block) == 0: + token_ids_in_blocks.append(float("inf")) + else: + token_ids_in_blocks.append(min(token_ids_in_this_block)) + + sorted_blocks = [ + x.set(id=idx) + for idx, (_, x) in enumerate( + sorted(zip(token_ids_in_blocks, blocks), key=lambda pair: pair[0]) + ) + ] + + return sorted_blocks + + def create_annotation_for_sha(self, sha): + + all_token_data = [] + all_files = [] + + raw_blocks = self.raw_annotation.load_annotation_for_sha(sha) + cermine_data = self.cermine_annotation.load_annotations_for_sha(sha) + + for page_id in cermine_data.keys(): + blocks = [ + b for b in raw_blocks[page_id] if b.type in self.selected_categories + ] + blocks = self.order_blocks_based_on_token_ids( + blocks, cermine_data[page_id].words + ) + + # Pass 1: O(n) Initialize ids and categories + for word in cermine_data[page_id].words: + word.line_id = -1 + word.block_id = -1 + word.type = self.default_category + + # Pass 2: O(mn) Assign token categories + for word in cermine_data[page_id].words: + for block in blocks: + if word.is_in(block, center=True): + word.type = block.type + + # Pass 3: O(mn) Assign token block-category ids + used_lines_for_assign_line_ids = cermine_data[page_id].lines + used_blocks_for_assign_block_ids = blocks + + for word in cermine_data[page_id].words: + for _l in used_lines_for_assign_line_ids: + if word.is_in(_l, center=True): + word.line_id = _l.id + + for _b in used_blocks_for_assign_block_ids: + if word.is_in(_b, center=True): + word.block_id = _b.id + + # Pass 4: O(n) In case some blocks are not assigned with the + # appropriate block indices, we assign the line ids + for word in cermine_data[page_id].words: + if word.block_id == -1: + word.block_id = word.line_id + PADDING_CONSTANT + + row_item = self.convert_token_data_to_json(cermine_data[page_id].words) + + if len(row_item["words"]) > 0: + + all_token_data.append(row_item) + all_files.append(f"{sha}-{page_id}") + + return all_token_data, all_files + +class S2VLAnnotationGeneratorWithVisionBox(S2VLAnnotationGenerator): + + def create_annotation_for_sha(self, sha): + + all_token_data = [] + all_files = [] + + raw_blocks = self.raw_annotation.load_annotation_for_sha(sha) + cermine_data = self.cermine_annotation.load_annotations_for_sha(sha) + vision_data = self.vision_annotation.load_annotations_for_sha(sha) + + for page_id in cermine_data.keys(): + blocks = [ + b for b in raw_blocks[page_id] if b.type in self.selected_categories + ] + + # Pass 1: O(n) Initialize ids and categories + for word in cermine_data[page_id].words: + word.line_id = -1 + word.block_id = -1 + word.type = self.default_category + + # Pass 2: O(mn) Assign token categories + for word in cermine_data[page_id].words: + for block in blocks: + if word.is_in(block, center=True): + word.type = block.type + + # Pass 3: O(mn) Assign token block-category ids + used_lines_for_assign_line_ids = cermine_data[page_id].lines + used_blocks_for_assign_block_ids = vision_data[page_id].blocks + for word in cermine_data[page_id].words: + for _l in used_lines_for_assign_line_ids: + if word.is_in(_l, center=True): + word.line_id = _l.id + + for _b in used_blocks_for_assign_block_ids: + if word.is_in(_b, center=True): + word.block_id = _b.id + + # Pass 4: O(n) In case some blocks are not assigned with the + # appropriate block indices, we assign the line ids + for word in cermine_data[page_id].words: + if word.block_id == -1: + word.block_id = word.line_id + PADDING_CONSTANT + + row_item = self.convert_token_data_to_json(cermine_data[page_id].words) + + if len(row_item["words"]) > 0: + + all_token_data.append(row_item) + all_files.append(f"{sha}-{page_id}") + + return all_token_data, all_files + +class S2VLAnnotationGeneratorWithVisionLine(S2VLAnnotationGeneratorWithGTBox): + + def create_annotation_for_sha(self, sha): + + all_token_data = [] + all_files = [] + + raw_blocks = self.raw_annotation.load_annotation_for_sha(sha) + cermine_data = self.cermine_annotation.load_annotations_for_sha(sha) + vision_data = self.vision_annotation.load_annotations_for_sha(sha) + + for page_id in cermine_data.keys(): + blocks = [ + b for b in raw_blocks[page_id] if b.type in self.selected_categories + ] + blocks = self.order_blocks_based_on_token_ids( + blocks, cermine_data[page_id].words + ) + + # Pass 1: O(n) Initialize ids and categories + for word in cermine_data[page_id].words: + word.line_id = -1 + word.block_id = -1 + word.type = self.default_category + + # Pass 2: O(mn) Assign token categories + for word in cermine_data[page_id].words: + for block in blocks: + if word.is_in(block, center=True): + word.type = block.type + + # Pass 3: O(mn) Assign token block-category ids + used_lines_for_assign_line_ids = vision_data[page_id].lines + used_blocks_for_assign_block_ids = vision_data[page_id].blocks + + for word in cermine_data[page_id].words: + for _l in used_lines_for_assign_line_ids: + if word.is_in(_l, center=True): + word.line_id = _l.id + + for _b in used_blocks_for_assign_block_ids: + if word.is_in(_b, center=True): + word.block_id = _b.id + + # Pass 4: O(n) In case some blocks are not assigned with the + # appropriate block indices, we assign the line ids + for word in cermine_data[page_id].words: + if word.block_id == -1: + word.block_id = word.line_id + PADDING_CONSTANT + + row_item = self.convert_token_data_to_json(cermine_data[page_id].words) + + if len(row_item["words"]) > 0: + + all_token_data.append(row_item) + all_files.append(f"{sha}-{page_id}") + + return all_token_data, all_files + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--annotation-folder", type=str, help="The path to the annotation folder" + ) + parser.add_argument( + "--annotation-table", type=str, help="The table with sha-annotator name" + ) + parser.add_argument( + "--cermine-pdf-dir", + type=str, + help="The path to the folder containing the PDF and CERMINED results", + ) + parser.add_argument( + "--cermine-csv-dir", + type=str, + help="The path to the folder with CERMINED results stored in csv", + ) + parser.add_argument( + "--vision-csv-dir", + type=str, + help="The path to the folder with VISION Model results stored in csv", + ) + parser.add_argument( + "--export-folder", type=str, help="The folder for storing the data" + ) + parser.add_argument("--config", type=str, help="The path to the config file") + + parser.add_argument("--use-gt-block", action="store_true") + parser.add_argument("--use-vision-box", action="store_true") + parser.add_argument("--use-vision-line", action="store_true") + parser.add_argument("--few-shot-mutual-test-set", action="store_true") + parser.add_argument("--few-shot-cv", action="store_true") + + args = parser.parse_args() + + raw_annotation = RawAnnotation(args.annotation_table, args.annotation_folder) + cermine_annotation = CERMINEAnnotation(args.cermine_pdf_dir, args.cermine_csv_dir) + vision_annotation = VISIONAnnotation(None, args.vision_csv_dir) + + config = load_json(args.config) + + if args.use_gt_block: + s2vl = S2VLAnnotationGeneratorWithGTBox( + args.annotation_table, + raw_annotation, + cermine_annotation, + config["selected_categories"], + config["default_category"], + ) + save_folder = f"{args.export_folder}-gtbox" + elif args.use_vision_box: + s2vl = S2VLAnnotationGeneratorWithVisionBox( + args.annotation_table, + raw_annotation, + cermine_annotation, + config["selected_categories"], + config["default_category"], + vision_annotation=vision_annotation, + ) + save_folder = f"{args.export_folder}-visionbox" + elif args.use_vision_line: + s2vl = S2VLAnnotationGeneratorWithVisionLine( + args.annotation_table, + raw_annotation, + cermine_annotation, + config["selected_categories"], + config["default_category"], + vision_annotation=vision_annotation, + ) + save_folder = f"{args.export_folder}-visionline-v2" + else: + s2vl = S2VLAnnotationGenerator( + args.annotation_table, + raw_annotation, + cermine_annotation, + config["selected_categories"], + config["default_category"], + ) + save_folder = args.export_folder + + s2vl.create_annotations() + s2vl.save_annotation_cv(f"{save_folder}-cv", 5) + # if args.few_shot_mutual_test_set: diff --git a/datasets/s2-vl-utils/config.json b/datasets/s2-vl-utils/config.json new file mode 100644 index 0000000..eb9d2cc --- /dev/null +++ b/datasets/s2-vl-utils/config.json @@ -0,0 +1,21 @@ +{ + "selected_categories": [ + "Title", + "Author", + "Abstract", + "Keywords", + "Section", + "Paragraph", + "List", + "Bibliography", + "Equation", + "Algorithm", + "Figure", + "Table", + "Caption", + "Header", + "Footer", + "Footnote" + ], + "default_category": "Paragraph" +} \ No newline at end of file diff --git a/datasets/s2-vl-utils/download.py b/datasets/s2-vl-utils/download.py new file mode 100644 index 0000000..f5ff88e --- /dev/null +++ b/datasets/s2-vl-utils/download.py @@ -0,0 +1,257 @@ +from typing import List, Union, Dict, Any, Tuple +import sys +import zipfile +import io +import os +import hashlib +import logging +import tempfile +import shutil +from glob import glob + +import requests +import pandas as pd +import requests +import layoutparser as lp +import pandas as pd +from tqdm import tqdm +from PyPDF2 import PdfFileReader, PdfFileWriter + + +logger = logging.getLogger(__name__) +sha1 = hashlib.sha1() +headers = {"User-Agent": "Mozilla/5.0"} + +ANNOTATION_FILE = { + "s2-vl-ver1": "https://ai2-s2-research.s3.us-west-2.amazonaws.com/s2-vlue/s2-vl-ver1-annotations.zip" +} + +def bulk_fetch_pdf_for_urls( + paper_table: pd.DataFrame, + target_dir: str, +) -> List[List[str]]: + + os.makedirs(target_dir, exist_ok=True) + paper_download_status = [] + + paper_table = paper_table.groupby("sha").first().reset_index() # Remove duplicates + pbar = tqdm(paper_table.iterrows(), total=len(paper_table)) + + for _, row in pbar: + + sha_in_table = row["sha"] + download_link = row["url"] + + pbar.set_description(desc=download_link) + + try: + pdf_path = os.path.join(target_dir, str(sha_in_table) + ".pdf") + + if os.path.exists(pdf_path): + continue + + r = requests.get(download_link, headers=headers) + + if r.status_code == 200: + sha1.update(r.content) + downloaded_sha = sha1.hexdigest() + + + with open(pdf_path, "wb") as fh: + fh.write(r.content) + + paper_download_status.append([sha_in_table, downloaded_sha, "success"]) + else: + print(f"Fail to download due to HTTP error {r.status_code} for {download_link}") + paper_download_status.append([sha_in_table, None, "download_error"]) + except: + print(f"Fail to download due to HTTP error {r.status_code} for {download_link}") + paper_download_status.append([sha_in_table, None, "download_error"]) + + return paper_download_status + + +def split_pdf_to_each_page_and_check(pdf_file, target_folder, remove_problematic=False): + """Split a pdf file into separate pages. + + Args: + pdf_file (str): The name of the PDF file to be split. + target_folder (str): The target folder to save the splitted pages. + """ + try: + pdf = PdfFileReader(pdf_file) + # Sometimes the downloaded PDF is corrupted. + total_pages = pdf.getNumPages() + # Sometimes some strange errors would occur if the pdf engine + # thinks the pdf is corrupted. + except: + return False + + is_page_successfully_saved = [] + + # Try to save individual pages + for i in range(total_pages): + pdf_writer = PdfFileWriter() + pdf_writer.addPage(pdf.getPage(i)) + + filename = os.path.splitext(os.path.basename(pdf_file))[0] + save_name = os.path.join(target_folder, f"{filename}-{i:02d}.pdf") + + try: + with open(save_name, "wb") as outputStream: + pdf_writer.write(outputStream) + is_page_successfully_saved.append(i) + except KeyboardInterrupt: + exit() + except: + print(f"Failed to save {save_name}") + + del pdf_writer + del pdf + + # If individual pages + if len(is_page_successfully_saved) != total_pages: + is_pdf_successfully_saved = False + + else: + ok_files = [] + saved_pdf_files = glob(f"{target_folder}/{filename}*.pdf") + for saved_pdf_file in saved_pdf_files: + try: + lp.load_pdf(saved_pdf_file) + ok_files.append(saved_pdf_file) + except KeyboardInterrupt: + exit() + except: + pass + if len(ok_files) != total_pages: + is_pdf_successfully_saved = False + else: + is_pdf_successfully_saved = True + + if not is_pdf_successfully_saved and remove_problematic: + print( + f"The current PDF {pdf_file} cannot be appropriately parsed. Removing the saved folders" + ) + shutil.rmtree(target_folder) + + return is_pdf_successfully_saved + + +def _generalized_paper_downloading_and_processing_protocol( + paper_table, + target_folder, + download_func, +): + + with tempfile.TemporaryDirectory() as combined_pdf_save_path: + + if not os.path.exists(target_folder): + os.makedirs(target_folder) + + print("Downloading the Papers") + paper_download_status = download_func(paper_table, combined_pdf_save_path) + paper_download_status = pd.DataFrame( + paper_download_status, columns=["sha_in_table", "downloaded_sha", "status"] + ) + + if "page" not in paper_table.columns: + create_download_report(paper_download_status) + for file in glob(os.path.join(combined_pdf_save_path, "*")): + shutil.move(file, target_folder) + return paper_download_status + + pbar = tqdm(paper_download_status.iterrows(), total=len(paper_download_status)) + updated_paper_download_status = paper_download_status.copy() + + for idx, row in pbar: + if row["status"] == "success": + sha = row["sha_in_table"] + pbar.set_description(f"Processing {sha}") + + if glob(f"{target_folder}/{sha}-*"): + continue # Skip already processed + + with tempfile.TemporaryDirectory() as tempdir: + is_pdf_successfully_saved = split_pdf_to_each_page_and_check( + os.path.join(combined_pdf_save_path, sha + ".pdf"), tempdir + ) + + # In this command, it will save all the processed files in a tmp folder + # As such, when the PDFs are successfully downloaded, we need to move them + # to the actual target folder + if is_pdf_successfully_saved: + # The tempdir contains all the pages, but we only want to move the target pages + all_pages = paper_table.loc[paper_table["sha"] == sha, "page"].tolist() + for page in all_pages: + shutil.move(os.path.join(tempdir, f"{sha}-{page:02d}.pdf"), target_folder) + else: + updated_paper_download_status.iloc[idx, -1] = "pdf_parsing_failure" + + create_download_report(paper_download_status) + return updated_paper_download_status + + +def create_download_report(paper_download_status): + """Create a report of the downloaded papers. + + Args: + paper_download_status (pd.DataFrame): The status of the downloaded papers. + """ + print("PDF Download Report") + incompatible_papers = paper_download_status[paper_download_status["sha_in_table"] != paper_download_status["downloaded_sha"]] + + print(f"Total mismatch: {len(incompatible_papers)}/{len(paper_download_status)}") + print("Note: The mismatch between SHA doesn't necessarily mean\n" + "the PDF files have different contents.") + for _, row in incompatible_papers.iterrows(): + print( + f"Original SHA: {row['sha_in_table']} -> Actual SHA: {row['downloaded_sha']}" + ) + + unsuccessful_papers = paper_download_status[ + paper_download_status["status"] != "success" + ] + for error_name, gp in unsuccessful_papers.groupby("status"): + print(f"Total {error_name}: {len(gp)}/{len(paper_download_status)}") + for _, row in gp.iterrows(): + print(f"Fail to download SHA: {row['sha_in_table']}") + + +def fetch_and_process_papers_based_on_urls( + paper_table, target_folder +): + return _generalized_paper_downloading_and_processing_protocol( + paper_table, target_folder, bulk_fetch_pdf_for_urls + ) + +if __name__ == "__main__": + + import argparse + + parser = argparse.ArgumentParser(description="Download S2-VL paper data") + parser.add_argument( + "--base-path", + type=str, + help="The path to the source files of a dataset, e.g., sources/s2-vl-ver1", + ) + parser.add_argument("--annotation-table", type=str, default="annotation_table.csv") + + args = parser.parse_args() + + pdf_save_path = f"{args.base_path}/pdfs" + if not os.path.exists(pdf_save_path): + os.makedirs(pdf_save_path) + + paper_table = pd.read_csv(f"{args.base_path}/{args.annotation_table}") + fetch_and_process_papers_based_on_urls(paper_table, pdf_save_path) + + print("Downloading the annotation") + # hacky code to get the dataset name + dataset_name = os.path.basename(args.base_path.strip("/")) + annotation_file_url = ANNOTATION_FILE[dataset_name] + + r = requests.get(annotation_file_url) + + with zipfile.ZipFile(file=io.BytesIO(r.content)) as zip_ref: + zip_ref.extractall(f"{args.base_path}/annotations") \ No newline at end of file diff --git a/datasets/s2-vl-utils/requirements.txt b/datasets/s2-vl-utils/requirements.txt new file mode 100644 index 0000000..b394d4b --- /dev/null +++ b/datasets/s2-vl-utils/requirements.txt @@ -0,0 +1,5 @@ +bs4 +lxml +PyPDF2 +tqdm +requests \ No newline at end of file diff --git a/datasets/s2-vl-utils/vision_model_loader.py b/datasets/s2-vl-utils/vision_model_loader.py new file mode 100644 index 0000000..ac48d00 --- /dev/null +++ b/datasets/s2-vl-utils/vision_model_loader.py @@ -0,0 +1,211 @@ +from genericpath import exists +from glob import glob +import os +import argparse + +from tqdm import tqdm +import pandas as pd +import numpy as np +import layoutparser as lp +from pdf2image import convert_from_path + +from vision_postprocessor import pipeline, create_structure_df + +class S2VLLoader: + def __init__(self, pdf_path): + + self.pdf_path = pdf_path + self.all_pdfs = glob(f"{self.pdf_path}/*.pdf") + + def __getitem__(self, idx): + assert idx < len(self.all_pdfs) + return self.load_sample(self.all_pdfs[idx]) + + def load_sample(self, pdf_path): + return pdf_path, convert_from_path(pdf_path, dpi=72) + + def __len__(self): + return len(self.all_pdfs) + + +def calculate_overlapping_coefficient(box1, box2): + x1, y1, x2, y2 = box1.coordinates + a1, b1, a2, b2 = box2.coordinates + + if x2 < a1 or x1 > a2 or y1 > b2 or y2 < b1: # Bottom or top + return 0 + + else: + intersection = lp.Rectangle( + x_1=max(x1, a1), y_1=max(y1, b1), x_2=min(x2, a2), y_2=min(y2, b2) + ) + return intersection.area / min(box1.area, box2.area) + + +THRESHOLD_FOR_OVERLAPPING_BLOCKS = 0.5 + + +def filter_out_non_overlapping_block(blocks): + + boxes_to_remove = np.zeros(len(blocks)) + for box2 in blocks: + for box1 in blocks: + + if box1.id == box2.id: + continue + + if boxes_to_remove[box1.id] or boxes_to_remove[box2.id]: + continue + + if ( + calculate_overlapping_coefficient(box1, box2) + > THRESHOLD_FOR_OVERLAPPING_BLOCKS + ): + if box1.area >= box2.area: + boxes_to_remove[box2.id] = 1 + else: + boxes_to_remove[box1.id] = 1 + + return [b for b in blocks if not boxes_to_remove[b.id]] + + +def convert_blocks_to_df(blocks_line): + blocks_to_save = [[ele.id, *ele.coordinates, ele.type, ele.score] for ele in blocks_line] + + df = pd.DataFrame( + blocks_to_save, + columns=[ + "id", + "x_1", + "y_1", + "x_2", + "y_2", + "category", + "confidence", + ], + ) + + df[["x_1", "y_1", "x_2", "y_2"]] = df[["x_1", "y_1", "x_2", "y_2"]].astype("int") + return df + + +def textline_detection(base_path): + def detect_lines_for_image(pdf_image): + blocks_line = model_line.detect(pdf_image) + blocks_line = [token.set(id=idx) for idx, token in enumerate(blocks_line)] + blocks_line = filter_out_non_overlapping_block(blocks_line) + return blocks_line + + model_line = lp.Detectron2LayoutModel( + config_path=f"https://www.dropbox.com/s/hd21tarnhbj1p1o/config.yaml?dl=1", # This is the line detection model trained using the GROTOAP2 dataset + extra_config=[ + "MODEL.ROI_HEADS.SCORE_THRESH_TEST", + 0.35, + "MODEL.ROI_HEADS.NMS_THRESH_TEST", + 0.8, + ], + label_map={0: "line"}, + ) + + loader = S2VLLoader(f"{base_path}/pdfs") + + for pdf_path in tqdm(loader.all_pdfs): + pdf_name = pdf_path.split("/")[-1].replace(".pdf", "") + pdf_images = convert_from_path(pdf_path, dpi=72) + for pid, pdf_image in enumerate(pdf_images): + + blocks_line = detect_lines_for_image(pdf_image) + + df = convert_blocks_to_df(blocks_line) + + if len(pdf_images) == 1 and len(pdf_name.split("-")) == 2: + df.to_csv(f"{base_path}/lines/{pdf_name}.csv", index=None) + else: + df.to_csv(f"{base_path}/lines/{pdf_name}-{pid:02d}.csv", index=None) + + +def textblock_detection(base_path): + def detect_blocks_for_image(pdf_image): + blocks1 = block_predictorA.detect(pdf_image) + blocks2 = block_predictorB.detect(pdf_image) + + blocks = blocks1 + blocks2 + blocks = sorted(blocks, key=lambda ele: ele.coordinates[1]) + blocks = [token.set(id=idx) for idx, token in enumerate(blocks)] + blocks = filter_out_non_overlapping_block(blocks) + return blocks + + block_predictorA = lp.Detectron2LayoutModel( + config_path="lp://PubLayNet/mask_rcnn_R_50_FPN_3x/config", + extra_config=[ + "MODEL.ROI_HEADS.SCORE_THRESH_TEST", + 0.50, + "MODEL.ROI_HEADS.NMS_THRESH_TEST", + 0.4, + ], + label_map={0: "text", 1: "title", 2: "list", 3: "table", 4: "figure"}, + ) + + block_predictorB = lp.Detectron2LayoutModel( + config_path="lp://MFD/faster_rcnn_R_50_FPN_3x/config", + extra_config=[ + "MODEL.ROI_HEADS.SCORE_THRESH_TEST", + 0.6, + "MODEL.ROI_HEADS.NMS_THRESH_TEST", + 0.2, + ], + label_map={1: "equation"}, + ) + + loader = S2VLLoader(f"{base_path}/pdfs") + + for pdf_path in tqdm(loader.all_pdfs): + pdf_name = pdf_path.split("/")[-1].replace(".pdf", "") + pdf_images = convert_from_path(pdf_path, dpi=72) + + for pid, pdf_image in enumerate(pdf_images): + blocks = detect_blocks_for_image(pdf_image) + df = convert_blocks_to_df(blocks) + + if len(pdf_images) == 1 and len(pdf_name.split("-")) == 2: + df.to_csv(f"{base_path}/blocks/{pdf_name}.csv", index=None) + else: + df.to_csv(f"{base_path}/blocks/{pdf_name}-{pid:02d}.csv", index=None) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-path", type=str, help="The path to the source files of a dataset, e.g., sources/s2-vl-ver1") + args = parser.parse_args() + + os.makedirs(f"{args.base_path}/blocks", exist_ok=True) + + if len(glob(f"{args.base_path}/blocks/*.csv")) > 0: + print("Text Blocks already detected") + else: + print("Running Text Block Detection") + textblock_detection(args.base_path) + + os.makedirs(f"{args.base_path}/lines", exist_ok=True) + if len(glob(f"{args.base_path}/lines/*.csv")) > 0: + print("Text Lines already detected") + else: + print("Running Text Line Detection") + textline_detection(args.base_path) + + target_dir = f"{args.base_path}/condensed" + os.makedirs(target_dir, exist_ok=True) + for filename in tqdm(glob(f"{args.base_path}/pdfs/*.pdf")): + res = os.path.basename(filename).split(".")[0].split("-") + if len(res)==2: + pdf_sha, pid = res + blocks, lines, tokens, additional_blocks, additional_lines = pipeline(args.base_path, pdf_sha, pid) + df = create_structure_df(tokens, blocks, lines) + df.to_csv(f"{target_dir}/{pdf_sha}-{pid}.csv", index=None) + else: + pdf_sha = res[0] + pids = len(glob(f"{args.base_path}/pdfs/{pdf_sha}-*.csv")) + for pid in range(pids): + blocks, lines, tokens, additional_blocks, additional_lines = pipeline(args.base_path, pdf_sha, f"{pid:02d}") + df = create_structure_df(tokens, blocks, lines) + df.to_csv(f"{target_dir}/{pdf_sha}-{pid:02d}.csv", index=None) \ No newline at end of file diff --git a/datasets/s2-vl-utils/vision_postprocessor.py b/datasets/s2-vl-utils/vision_postprocessor.py new file mode 100644 index 0000000..e801125 --- /dev/null +++ b/datasets/s2-vl-utils/vision_postprocessor.py @@ -0,0 +1,749 @@ +import math +from functools import partial +import json +import re +import random +from itertools import groupby +from collections import Counter, defaultdict +from copy import copy + +import os +from PIL import Image +import numpy as np +import layoutparser as lp +from tqdm import tqdm +from glob import glob +import pandas as pd +from scipy.sparse.csgraph import connected_components + +NON_TEXTUAL_TYPES = ["table", "figure", "equation"] + +def argsort(seq): + # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python + return sorted(range(len(seq)), key=seq.__getitem__) + +def get_most_common_element(lst): + return Counter(lst).most_common(1)[0][0] + +def get_most_common_token_type(tokens): + return get_most_common_element([ele.type for ele in tokens]) + +def union_box(blocks): + if len(blocks) == 0: + # print("Warning: The length of blocks is 0!") + rect = lp.Rectangle(0, 0, 0, 0) + return lp.TextBlock(rect) + else: + x1, y1, x2, y2 = float("inf"), float("inf"), float("-inf"), float("-inf") + for block in blocks: + bbox = block.coordinates + x1 = min(x1, bbox[0]) + y1 = min(y1, bbox[1]) + x2 = max(x2, bbox[2]) + y2 = max(y2, bbox[3]) + rect = lp.Rectangle(int(x1), int(y1), int(x2), int(y2)) + return lp.TextBlock(rect, type=blocks[0].type) + +def is_in(block_a, block_b, metric="center"): + """A rewrite of the lp.LayoutElement.is_in function. + We will use a soft_margin and center function by default. + """ + if metric == "center": + return block_a.is_in( + block_b, + soft_margin={"top": 1, "bottom": 1, "left": 1, "right": 1}, + center=True, + ) + elif metric == "coef": + return ( + calculate_overlapping_coefficient(block_a, block_b) + > MIN_OVERLAPPING_THRESHOLD + ) + elif metric == "any": + return is_in(block_a, block_b, metric="center") or is_in( + block_a, block_b, metric="coef" + ) + +def is_non_textual_type(block): + if isinstance(block.type, str): + return block.type in NON_TEXTUAL_TYPES + else: + raise ValueError(f"strange block type data type {type(block.type)}") + +def cvt_cermine_df_to_layout(row): + + return lp.TextBlock( + lp.Rectangle( + row["x_1"], + row["y_1"], + row["x_2"], + row["y_2"], + ), + id=row["id"], + type=row["category"], + text=row["text"], + ) + +def cvt_line_df_to_layout(row): + + return lp.TextBlock( + lp.Rectangle( + row["x_1"], + row["y_1"], + row["x_2"], + row["y_2"], + ), + id=row["id"], + ) + +def cvt_block_df_to_layout(row): + + return lp.TextBlock( + lp.Rectangle( + row["x_1"], + row["y_1"], + row["x_2"], + row["y_2"], + ), + id=row["id"], + type=row["category"], + ) + +def load_cermine_data_from_csv(filename): + df = pd.read_csv(filename) + if len(df) == 0: + return None + + df = df[~df.text.isna()] + if len(df) == 0: + return None + + tokens_df = df[~df.is_line & ~df.is_block] + + return lp.Layout(tokens_df.apply(cvt_cermine_df_to_layout, axis=1).tolist()) + +def load_line_data_from_csv(filename): + df = pd.read_csv(filename) + return lp.Layout( + df.apply(cvt_line_df_to_layout, axis=1).tolist() + ) + +def load_block_data_from_csv(filename): + df = pd.read_csv(filename) + return lp.Layout( + df.apply(cvt_block_df_to_layout, axis=1).tolist() + ) + +def calculate_overlapping_coefficient(box1, box2): + x1, y1, x2, y2 = box1.coordinates + a1, b1, a2, b2 = box2.coordinates + + if x2 < a1 or x1 > a2 or y1 > b2 or y2 < b1: # Bottom or top + return 0 + + min_area = min(box1.area, box2.area) + if min_area == 0: + return 0 + else: + intersection = lp.Rectangle( + x_1=max(x1, a1), y_1=max(y1, b1), x_2=min(x2, a2), y_2=min(y2, b2) + ) + return intersection.area / min_area + +def calculate_pairwise_overlapping_coefficient(blocks_A, blocks_B=None): + if blocks_B is not None: + return np.array( + [ + [calculate_overlapping_coefficient(box1, box2) for box2 in blocks_B] + for box1 in blocks_A + ] + ) + else: + n = len(blocks_A) + overlapping = np.zeros((n, n)) + for row in range(n): + for col in range(row + 1, n): + overlapping[row, col] = calculate_overlapping_coefficient( + blocks_A[row], blocks_A[col] + ) + + i_lower = np.tril_indices(n, k=-1) + overlapping[i_lower] = overlapping.T[i_lower] + # A trick learned from https://stackoverflow.com/a/42209263 + return overlapping + +MIN_OVERLAPPING_THRESHOLD = 0.65 + +def remove_overlapping_textual_blocks_for_non_textual_blocks(blocks): + # Firstly checking paragraph and non-paragraph blocks + + textual_blocks = [b for b in blocks if not is_non_textual_type(b)] + non_textual_blocks = [b for b in blocks if is_non_textual_type(b)] + + if len(textual_blocks) == 0 or len(non_textual_blocks) == 0: + return textual_blocks + non_textual_blocks + + overlapping = calculate_pairwise_overlapping_coefficient( + non_textual_blocks, textual_blocks + ) + overlapping = overlapping > 0.8 + + if not overlapping.any(): + return textual_blocks + non_textual_blocks + + nids, tids = np.where(overlapping) + + return [ + b for idx, b in enumerate(textual_blocks) if idx not in np.unique(tids) + ] + non_textual_blocks + +def find_parent_for_elements(block_layout, token_layout, target_attr="parent"): + + for block in block_layout: + remaining_tokens = [] + for token in token_layout: + if is_in(token, block): + setattr(token, target_attr, int(block.id)) + else: + remaining_tokens.append(token) + + token_layout = remaining_tokens + + for token in token_layout: + setattr(token, target_attr, None) + +def block_snapping(blocks, tokens): + + for block in blocks: + if is_non_textual_type(block): + continue + tokens_in_this_group = [] + for token in tokens: + if token.parent == block.id: + tokens_in_this_group.append(token) + block.block = union_box(tokens_in_this_group).block + +def filter_out_overlapping_block(blocks): + + boxes_to_remove = {b.id: 0 for b in blocks} + for box2 in blocks: + for box1 in blocks: + + if box1.id == box2.id: + continue + + if boxes_to_remove[box1.id] or boxes_to_remove[box2.id]: + continue + + if ( + calculate_overlapping_coefficient(box1, box2) + > MIN_OVERLAPPING_THRESHOLD + ): + + if box1.area >= box2.area: + boxes_to_remove[box2.id] = 1 + else: + boxes_to_remove[box1.id] = 1 + + return [b for b in blocks if not boxes_to_remove[b.id]] + +def filter_out_overlapping_block_and_union(blocks): + + overlapping = calculate_pairwise_overlapping_coefficient(blocks) + n_components, labels = connected_components( + csgraph=overlapping > MIN_OVERLAPPING_THRESHOLD, + directed=False, + return_labels=True, + ) + + new_blocks = [] + prev_len = 0 + for idx, gp in groupby(labels): + cur_len = len(list(gp)) + cur_blocks = blocks[prev_len : prev_len + cur_len] + new_blocks.append( + union_box(sorted(cur_blocks, key=lambda b: b.area, reverse=True)).set( + id=idx + ) + ) + prev_len += cur_len + + return new_blocks + +def is_close(block1, block2, x_tolerance=15, y_tolerance=16): + # horizontal difference + bbox0 = block1.coordinates + bbox1 = block2.coordinates + if bbox1[0] - bbox0[2] > x_tolerance or bbox0[0] - bbox1[2] > x_tolerance: + return False + + # line difference + _, y1 = block1.block.center + _, y2 = block2.block.center + if abs(y1 - y2) > y_tolerance: + return False + return True + +def group_contents(tokens, x_tolerance=15, y_tolerance=16): + + selected_mask = {b.id: 0 for b in tokens} + + grouped_tokens = [] + + cur_tokens = tokens + + while cur_tokens: + + current_group = [] + + # start from a random sample to improve robustness + start_token = random.choice(cur_tokens) + + queue = [start_token] + selected_mask[start_token.id] = 1 + + while queue: + cur_token = queue[0] + for candidate_token in cur_tokens: + if not selected_mask[candidate_token.id] and is_close( + cur_token, candidate_token, x_tolerance, y_tolerance + ): + queue.append(candidate_token) + selected_mask[candidate_token.id] = 1 + + current_group.append(queue.pop(0)) + + grouped_tokens.append(current_group) + cur_tokens = [token for token in cur_tokens if not selected_mask[token.id]] + + return grouped_tokens + +def group_ungrouped_elements( + tokens, attr_name="parent", x_tolerance=15, y_tolerance=16 +): + selected_tokens = [ + b + for b in tokens + if getattr(b, attr_name, None) is None + ] + + results = group_contents(selected_tokens, x_tolerance, y_tolerance) + return [union_box(ele).set(text=" ".join(i.text for i in ele)) for ele in results] + + +group_token_to_blocks = group_ungrouped_elements +group_token_to_lines = partial( + group_ungrouped_elements, attr_name="line_id", y_tolerance=5 +) + +MIN_OVERLAPPING_THRESHOLD = 0.65 + +def absorb_additional_blocks_into_existing_blocks( + blocks, additional_blocks, threshold=MIN_OVERLAPPING_THRESHOLD +): + + if len(blocks) == 0 or len(additional_blocks) == 0: + return blocks, additional_blocks + + overlapping = calculate_pairwise_overlapping_coefficient(additional_blocks, blocks) + block_indices, add_block_indices = np.where(overlapping.T >= threshold) + # Ensure the block_indices are appropriately ordered + + if len(add_block_indices) == 0: + return blocks, additional_blocks + else: + block_ids_to_remove = [] + additional_block_ids_to_remove = [] + newly_added_blocks = [] + + prev_len = 0 + for orig_idx, gp in groupby(block_indices): + cur_len = len(list(gp)) + additional_block_indices_in_this_group = add_block_indices[ + prev_len : prev_len + cur_len + ] + block_ids_to_remove.append(orig_idx) + additional_block_ids_to_remove.extend( + additional_block_indices_in_this_group + ) + + newly_added_blocks.append( + union_box( + [blocks[orig_idx]] + + [ + additional_blocks[ad_idx] + for ad_idx in additional_block_indices_in_this_group + ] + ) + # it will keep the category from the original block + ) + prev_len += cur_len + + # for ad_idx, orig_idx in zip(add_block_indices, block_indices): + # if overlapping[ad_idx, orig_idx] < MIN_OVERLAPPING_THRESHOLD: + # continue + + # block_ids_to_remove.append(orig_idx) + # additional_block_ids_to_remove.append(orig_idx) + + # newly_added_blocks.append( + # union_box([blocks[orig_idx], additional_blocks[ad_idx]]) + # # it will keep the category from the original block + # ) + + # for ad_idx, orig_idx in zip(add_block_indices, block_indices): + # if overlapping[ad_idx, orig_idx] < MIN_OVERLAPPING_THRESHOLD: + # continue + + # block_ids_to_remove.append(orig_idx) + # additional_block_ids_to_remove.append(orig_idx) + + # newly_added_blocks.append( + # union_box([blocks[orig_idx], additional_blocks[ad_idx]]) + # # it will keep the category from the original block + # ) + return ( + [b for idx, b in enumerate(blocks) if idx not in block_ids_to_remove], + [ + b + for idx, b in enumerate(additional_blocks) + if idx not in additional_block_ids_to_remove + ] + + newly_added_blocks, + ) + +def find_parent_for_all_elements_and_reassign_block_category( + block_layout, token_layout, target_attr="parent", block_ordering_method="token_id" +): + + assert len(block_layout) > 0 + + block_min_token_ids = [] + + iterating_tokens = token_layout + for idx, block in enumerate(block_layout): + block.id = idx + remaining_tokens = [] + tokens_in_this_block = [] + block_min_token_id = float("inf") + for token in iterating_tokens: + if is_in(token, block): + setattr(token, target_attr, idx) + tokens_in_this_block.append(token) + block_min_token_id = min(token.id, block_min_token_id) + else: + remaining_tokens.append(token) + + iterating_tokens = remaining_tokens + # if is_non_textual_type(block): + # for token in tokens_in_this_block: + # token.type = block.type + # else: + # token_types_in_this_block = [b.type for b in tokens_in_this_block] + # block.type = get_most_common_element(token_types_in_this_block) + + block_min_token_ids.append(block_min_token_id) + + if block_ordering_method == "token_id": + sorted_block_token_indices = { + orig_id: new_id + for new_id, orig_id in enumerate(argsort(block_min_token_ids)) + } + + for token in token_layout: + setattr( + token, + target_attr, + sorted_block_token_indices.get(getattr(token, target_attr, None), None), + ) + for block in block_layout: + block.id = sorted_block_token_indices[block.id] + + # print( + # f"Searching the closet blocks for the remaining {len(remaining_tokens)} tokens" + # ) + for token in remaining_tokens: + setattr( + token, target_attr, int(find_closet_block_for_token(token, block_layout).id) + ) + +def find_minimum_gap(block_A, block_B): + # just the manhattan distance + + center_A = block_A.block.center + center_B = block_B.block.center + return sum(abs(a - b) for a, b in zip(center_A, center_B)) + +def find_closet_block_for_token(token, blocks): + gap = float("inf") + target_block = None + for block in blocks: + cur_gap = find_minimum_gap(token, block) + if cur_gap < gap: + gap = cur_gap + target_block = block + assert target_block is not None + return target_block + +def intersect(self, other): + return lp.Rectangle( + max(self.x_1, other.x_1), + max(self.y_1, other.y_1), + min(self.x_2, other.x_2), + min(self.y_2, other.y_2), + ) + +def trim_elements_based_on_parents(block_layout, token_layout): + + block_layout = {b.id: b for b in block_layout} + + for token in token_layout: + block = block_layout.get(token.parent, None) + if block is not None: + token.block = intersect(token.block, block.block) + +def get_tokens_in_block(tokens, block, metric="center"): + + return [tok for tok in tokens if is_in(tok, block, metric=metric)] + +def reorder_lines(lines, blocks, tokens): + # We firstly group lines by blocks, then order + # lines within each group using the token indices + + ordered_blocks = sorted(blocks, key=lambda b: b.id) + + tokens_groupby_blocks = { + block.id: [tok for tok in tokens if tok.parent == block.id] for block in blocks + } + + for token in tokens: + token.line_id = None + + line_id = 0 + iter_lines = lines + + for block in ordered_blocks: + + lines_in_current_block = [] + remaining_lines = [] + for line in iter_lines: + if is_in(line, block): + lines_in_current_block.append(line) + else: + remaining_lines.append(line) + iter_lines = remaining_lines + + tokens_in_current_block = tokens_groupby_blocks[block.id] + + tokens_in_each_line = [ + get_tokens_in_block(tokens_in_current_block, line, metric="any") + for line in lines_in_current_block + ] + + min_token_indices_in_each_line = [ + (min(tok.id for tok in tokens) if len(tokens) > 0 else float("inf")) + for tokens in tokens_in_each_line + ] + + # print(min_token_indices_in_each_line) + + sorted_line_token_indices = { + orig_id: new_id + for new_id, orig_id in enumerate(argsort(min_token_indices_in_each_line)) + } + + used_line_id = 0 + for idx, line in enumerate(lines_in_current_block): + + tokens_in_this_line = tokens_in_each_line[idx] + + if len(tokens_in_this_line) == 0: + line.id = None + continue + + line.id = cur_line_id = sorted_line_token_indices[idx] + line_id + line.parent = block.id + line.type = get_most_common_token_type(tokens_in_this_line) + used_line_id += 1 + + for token in tokens_in_this_line: + token.line_id = cur_line_id + + line_id += used_line_id + + # print( + # "Searching the closet blocks for the remaining", len(remaining_lines), "lines" + # ) + + for line in remaining_lines: + tokens_in_this_line = get_tokens_in_block(tokens, line, metric="coef") + if len(tokens_in_this_line) == 0: + line.id = None + else: + block = find_closet_block_for_token(line, blocks) + line.id = cur_line_id = line_id + line.parent = block.id + line.type = get_most_common_token_type(tokens_in_this_line) + for token in tokens_in_this_line: + token.line_id = cur_line_id + + line_id += 1 + + tokens_without_line_ids = [token for token in tokens if token.line_id is None] + # print( + # "Searching the closet lines for the remaining", len(tokens_without_line_ids), "tokens" + # ) + for token in tokens_without_line_ids: + line = find_closet_block_for_token(token, lines) + token.line_id = line.id + +def replace_non_text_lines_with_block(lines, blocks, tokens): + + blocks = {b.id: b for b in blocks} + + new_line_id = 0 + synthesized_lines = [] + line_id_conversion = {} + + for bid, gp in groupby(lines, key=lambda ele: ele.parent): + lines_in_this_block = list(gp) + cur_block = blocks[bid] + if is_non_textual_type(cur_block): + cur_block = copy(cur_block) + + cur_block.parent = cur_block.id + cur_block.id = new_line_id + synthesized_lines.append(cur_block) + + for line in lines_in_this_block: + line_id_conversion[line.id] = new_line_id + + new_line_id += 1 + + else: + for idx, line in enumerate(lines_in_this_block, start=new_line_id): + line_id_conversion[line.id] = idx + line.id = idx + synthesized_lines.extend(lines_in_this_block) + new_line_id = idx + 1 + + for token in tokens: + if token.line_id not in line_id_conversion: + for line in synthesized_lines: + if is_in(token, line, metric="coef"): + token.line_id = line.id + continue + if token.line_id is None: + line = find_closet_block_for_token(token, synthesized_lines) + token.line_id = line.id + else: + token.line_id = line_id_conversion[token.line_id] + + return synthesized_lines + + +def pipeline(base_path, pdf_sha, pid): + + csv_name = f"{pdf_sha}-{pid}.csv" + tokens = load_cermine_data_from_csv(f'{base_path}/tokens/{csv_name}') + + if tokens is None or len(tokens) == 0: + # Nothing for empty tokens + return [[]]*5 + + blocks = load_block_data_from_csv(f'{base_path}/blocks/{csv_name}') + lines = load_line_data_from_csv(f'{base_path}/lines/{csv_name}') + + blocks = remove_overlapping_textual_blocks_for_non_textual_blocks(blocks) + find_parent_for_elements(blocks, tokens) + block_snapping(blocks, tokens) + blocks = filter_out_overlapping_block_and_union(blocks) + + find_parent_for_elements(blocks, tokens) + + additional_blocks = group_token_to_blocks(tokens) + blocks, additional_blocks = absorb_additional_blocks_into_existing_blocks( + blocks, additional_blocks + ) + + find_parent_for_all_elements_and_reassign_block_category( + blocks + additional_blocks, tokens + ) + + lines = filter_out_overlapping_block(lines) + find_parent_for_elements(blocks, lines) + trim_elements_based_on_parents(blocks, lines) + + find_parent_for_elements(lines, tokens, target_attr="line_id") + additional_lines = group_token_to_lines(tokens) + lines, additional_lines = absorb_additional_blocks_into_existing_blocks( + lines, additional_lines, threshold=0.3 + ) + reorder_lines(lines + additional_lines, blocks + additional_blocks, tokens) + + blocks = sorted([ele for ele in blocks + additional_blocks], key=lambda ele: ele.id) + lines = sorted( + [ele for ele in lines + additional_lines if ele.id is not None], + key=lambda ele: ele.id, + ) + lines = replace_non_text_lines_with_block(lines, blocks, tokens) + + return (blocks, lines, tokens, additional_blocks, additional_lines) + +def create_structure_df(tokens, blocks, lines): + blocks_to_save = [ + [ + ele.id, + *ele.coordinates, + ele.text, + ele.type, + -1, + -1, + True, + False, + ] + for ele in blocks + ] + lines_to_save = [ + [ + ele.id, + *ele.coordinates, + ele.text, + ele.type, + ele.parent, + -1, + False, + True, + ] + for ele in lines + ] + tokens_to_save = [ + [ + ele.id, + *ele.coordinates, + ele.text, + ele.type, + ele.parent, + ele.line_id, + False, + False, + ] + for ele in tokens + ] + df = pd.DataFrame( + blocks_to_save + lines_to_save + tokens_to_save, + columns=[ + "id", + "x_1", + "y_1", + "x_2", + "y_2", + "text", + "category", + "block_id", + "line_id", + "is_block", + "is_line", + ], + ) + return df \ No newline at end of file diff --git a/datasets/schema-token.json b/datasets/schema-token.json new file mode 100644 index 0000000..9d128e4 --- /dev/null +++ b/datasets/schema-token.json @@ -0,0 +1,58 @@ +{ + "$schema": "", + "description": "A condense format for includin all token-level samples in the training dataset", + "type": "object", + "properties": { + "data": { + "type": "array", + "description": "Each element represents one sample of the model input", + "items": { + "$ref": "#/$defs/data" + } + }, + "labels": { + "type": "object", + "description": "A mapping from label_id to label_name" + }, + "files": { + "type": "array", + "description": "A list of strings representing the file names for the i-th object in data", + "items": { + "type": "string" + } + }, + }, + "$defs": { + "data": { + "type": "object", + "properties": { + "words":{ + "type": "array", + "description": "A list of individual string words in this page. Empty token (\"\") should be removed." + }, + "bbox": { + "type": "array", + "description": "A list of bounding boxes ([x1, y1, x2, y2]) for each token in the same order as words" + }, + "labels": { + "type": "array", + "description": "A list of label_id for each token" + }, + "block_ids": { + "type": "array", + "description": "[Optional] A list of ids for the corresponding block for the each token. Used for constructing block embeddings, etc", + "items": { + "type": "number" + } + }, + "line_ids": { + "type": "array", + "description": "[Optional] A list of ids for the corresponding line for the each token. Used for constructing line embeddings, etc", + "items": { + "type": "number" + } + }, + } + } + } +} \ No newline at end of file