Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/push_model_to_hub.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ python push_model_to_hf_hub.py \
--repo-name ivila-block-layoutlm-finetuned-docbank \
--agg_level "block" \
--group_bbox_agg "first" \
--added_special_sepration_token "[BLK]"
--added_special_separation_token "[BLK]"

# DocBank HVILA Block Finetuned
python push_model_to_hf_hub.py \
Expand All @@ -34,7 +34,7 @@ python push_model_to_hf_hub.py \
--repo-name ivila-block-layoutlm-finetuned-grotoap2 \
--agg_level "block" \
--group_bbox_agg "first" \
--added_special_sepration_token "[BLK]"
--added_special_separation_token "[BLK]"

# GROTOAP2 HVILA Block Finetuned
python push_model_to_hf_hub.py \
Expand Down
2 changes: 1 addition & 1 deletion scripts/train_ivila.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,6 @@ python train-ivila.py \
--per_device_eval_batch_size 40 \
--warmup_steps 2000 \
--load_best_model_at_end \
--added_special_sepration_token $used_token \
--added_special_separation_token $used_token \
--agg_level $agg_level \
--fp16
7 changes: 5 additions & 2 deletions src/vila/dataset/preprocessors/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class VILAPreprocessorConfig:
agg_level: str = "row" #"block", "sentence"
label_all_tokens: bool = False
group_bbox_agg: str = "first"
added_special_sepration_token: str = "[BLK]"
added_special_separation_token: str = "[BLK]"

def to_json(self, path: str):
with open(path, "w") as fp:
Expand All @@ -25,7 +25,10 @@ def from_pretrained(cls, model_path: str, **kwargs):
config = AutoConfig.from_pretrained(model_path)

if hasattr(config, "vila_preprocessor_config"):
data_json = config.vila_preprocessor_config
data_json = config.vila_preprocessor_config.copy()
if "added_special_sepration_token" in data_json:
data_json["added_special_separation_token"] = data_json.pop("added_special_sepration_token")
# Fix an old typo in the config
data_json.update(kwargs)
return cls(**data_json)
# We store the vila-preprocessor configs inside
Expand Down
14 changes: 7 additions & 7 deletions src/vila/dataset/preprocessors/layout_indicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def __init__(

super().__init__(tokenizer, config, text_column_name, label_column_name)

self.added_special_sepration_token = config.added_special_sepration_token
if self.added_special_sepration_token == "default":
self.added_special_sepration_token = tokenizer.special_tokens_map[
self.added_special_separation_token = config.added_special_separation_token
if self.added_special_separation_token == "default":
self.added_special_separation_token = tokenizer.special_tokens_map[
"sep_token"
]

Expand Down Expand Up @@ -127,7 +127,7 @@ def preprocess_sample(self, example: Dict, padding="max_length") -> Dict:
self.special_tokens_map[
self.tokenizer.special_tokens_map["sep_token"]
],
self.special_tokens_map[self.added_special_sepration_token],
self.special_tokens_map[self.added_special_separation_token],
]:
# Because we could possibly insert [SEP] or [BLK] tokens in
# this process.
Expand Down Expand Up @@ -180,7 +180,7 @@ def insert_layout_indicator(self, example: Dict) -> Tuple[Dict, Dict]:
)
processed_words.extend(
words[pre_index : pre_index + cur_len]
+ [self.added_special_sepration_token]
+ [self.added_special_separation_token]
)
processed_bbox.extend(
bbox[pre_index : pre_index + cur_len]
Expand Down Expand Up @@ -226,7 +226,7 @@ def insert_layout_indicator(self, example: Dict) -> Tuple[Dict, Dict]:
)
processed_words.extend(
words[pre_index : pre_index + cur_len]
+ [self.added_special_sepration_token]
+ [self.added_special_separation_token]
)
processed_bbox.extend(
bbox[pre_index : pre_index + cur_len]
Expand Down Expand Up @@ -271,7 +271,7 @@ def insert_layout_indicator(self, example: Dict) -> Tuple[Dict, Dict]:
range(new_sequence_len, new_sequence_len + end - start)
)
processed_words.extend(
words[start:end] + [self.added_special_sepration_token]
words[start:end] + [self.added_special_separation_token]
)
processed_bbox.extend(bbox[start:end] + [union_box(bbox[start:end])])
processed_labels.extend(labels[start:end] + [-100])
Expand Down
15 changes: 6 additions & 9 deletions tests/test_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from vila.constants import *
from vila.dataset.preprocessors.base import SimplePDFDataPreprocessor
from vila.dataset.preprocessors.config import VILAPreprocessorConfig
from vila.dataset.preprocessors.layout_indicator import (
BlockLayoutIndicatorPDFDataPreprocessor,
RowLayoutIndicatorPDFDataPreprocessor,
Expand Down Expand Up @@ -48,15 +49,11 @@
use_auth_token=None,
)


class Config:
pass


config = Config()
config.label_all_tokens = False
config.added_special_sepration_token = "[SEP]"
config.group_bbox_agg = "union"
config = VILAPreprocessorConfig(
label_all_tokens = False,
added_special_separation_token = "[SEP]",
group_bbox_agg = "union"
)


def test_sentence_indicator_processor():
Expand Down
31 changes: 31 additions & 0 deletions tests/test_vila_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import layoutparser as lp # For visualization

from vila.pdftools.pdf_extractor import PDFExtractor
from vila.predictors import HierarchicalPDFPredictor, LayoutIndicatorPDFPredictor

def test_hvila_run():

pdf_extractor = PDFExtractor("pdfplumber")
page_tokens, page_images = pdf_extractor.load_tokens_and_image(f"tests/fixtures/large.pdf")

vision_model = lp.EfficientDetLayoutModel("lp://PubLayNet")
pdf_predictor = HierarchicalPDFPredictor.from_pretrained("allenai/hvila-row-layoutlm-finetuned-docbank")

for idx, page_token in enumerate(page_tokens):
blocks = vision_model.detect(page_images[idx])
page_token.annotate(blocks=blocks)
pdf_data = page_token.to_pagedata().to_dict()
predicted_tokens = pdf_predictor.predict(pdf_data, page_token.page_size)

def test_ivila_run():
pdf_extractor = PDFExtractor("pdfplumber")
page_tokens, page_images = pdf_extractor.load_tokens_and_image(f"tests/fixtures/large.pdf")

vision_model = lp.EfficientDetLayoutModel("lp://PubLayNet")
pdf_predictor = LayoutIndicatorPDFPredictor.from_pretrained("allenai/ivila-block-layoutlm-finetuned-docbank")

for idx, page_token in enumerate(page_tokens):
blocks = vision_model.detect(page_images[idx])
page_token.annotate(blocks=blocks)
pdf_data = page_token.to_pagedata().to_dict()
predicted_tokens = pdf_predictor.predict(pdf_data, page_token.page_size)
12 changes: 6 additions & 6 deletions tools/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class ModelArguments:
######### VILA Settings #########
#################################

added_special_sepration_token: str = field(
added_special_separation_token: str = field(
default="SEP",
metadata={
"help": "The added special token for I-VILA models for separating the blocks/sentences/rows. Can be one of {SEP, BLK}. Default to `SEP`."
Expand All @@ -72,13 +72,13 @@ class ModelArguments:

def __post_init__(self):

assert self.added_special_sepration_token in ["BLK", "SEP"]
assert self.added_special_separation_token in ["BLK", "SEP"]

if self.added_special_sepration_token == "BLK":
self.added_special_sepration_token = "[BLK]"
if self.added_special_separation_token == "BLK":
self.added_special_separation_token = "[BLK]"

if self.added_special_sepration_token == "SEP":
self.added_special_sepration_token = "[SEP]"
if self.added_special_separation_token == "SEP":
self.added_special_separation_token = "[SEP]"


@dataclass
Expand Down
6 changes: 3 additions & 3 deletions tools/push_model_to_hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def write_json(data, filename):
parser.add_argument("--agg_level", type=str, default=None, help="desc")
parser.add_argument("--label_all_tokens", type=str, default=None, help="desc")
parser.add_argument("--group_bbox_agg", type=str, default=None, help="desc")
parser.add_argument("--added_special_sepration_token", type=str, default=None, help="desc")
parser.add_argument("--added_special_separation_token", type=str, default=None, help="desc")
args = parser.parse_args()

print(f"Loading Models from {args.model_path}")
Expand Down Expand Up @@ -58,8 +58,8 @@ def write_json(data, filename):
vila_preprocessor_config['label_all_tokens'] = args.label_all_tokens
if args.group_bbox_agg is not None:
vila_preprocessor_config['group_bbox_agg'] = args.group_bbox_agg
if args.added_special_sepration_token is not None:
vila_preprocessor_config['added_special_sepration_token'] = args.added_special_sepration_token
if args.added_special_separation_token is not None:
vila_preprocessor_config['added_special_separation_token'] = args.added_special_separation_token

model_config.vila_preprocessor_config = vila_preprocessor_config

Expand Down
6 changes: 3 additions & 3 deletions tools/train-ivila.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ def get_label_list(labels):
use_auth_token=True if model_args.use_auth_token else None,
)

if model_args.added_special_sepration_token not in tokenizer.special_tokens_map.values():
tokenizer.add_special_tokens({"additional_special_tokens": [model_args.added_special_sepration_token]})
if model_args.added_special_separation_token not in tokenizer.special_tokens_map.values():
tokenizer.add_special_tokens({"additional_special_tokens": [model_args.added_special_separation_token]})
model.resize_token_embeddings(len(tokenizer))
# In a previous version, we try to avoid resizing the token embeddings by directly
# modifying the unused tokens in vocab. However, this is not possible as not all tokenizer
Expand All @@ -233,7 +233,7 @@ def get_label_list(labels):
)

logger.info(f"The used agg level is {data_args.agg_level}")
data_args.added_special_sepration_token = model_args.added_special_sepration_token
data_args.added_special_separation_token = model_args.added_special_separation_token
preprocessor = instantiate_dataset_preprocessor(
"layout_indicator", tokenizer, data_args
)
Expand Down