-
Notifications
You must be signed in to change notification settings - Fork 41
Description
import json
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import XLMRobertaTokenizer, LiltForTokenClassification, Trainer, TrainingArguments
from sklearn.model_split import train_test_split
from sklearn.metrics import precision_recall_fscore_support
Constants
IMAGE_FOLDER = "images"
TRAIN_FILE = "train.txt"
BATCH_SIZE = 16
NUM_EPOCHS = 3
LEARNING_RATE = 2e-5
MAX_LENGTH = 512
def parse_dataset(train_file):
data = []
with open(train_file, 'r') as f:
for line in f:
parts = line.strip().split('\t')
image_name = parts[0]
annotations_str = parts[1]
annotations = json.loads(annotations_str)
data.append({
'image_path': os.path.join(IMAGE_FOLDER, image_name),
'annotations': annotations
})
return data
def split_data(data, test_size=0.2):
train_data, val_data = train_test_split(data, test_size=test_size, random_state=42)
return train_data, val_data
def create_label_mappings(data):
unique_labels = set()
for sample in data:
for ann in sample['annotations']:
unique_labels.add(ann['label'])
label2id = {label: idx for idx, label in enumerate(sorted(unique_labels))}
id2label = {idx: label for label, idx in label2id.items()}
return label2id, id2label
class LiLTDataset(Dataset):
def init(self, data, tokenizer, label2id, max_length=MAX_LENGTH):
self.data = data
self.tokenizer = tokenizer
self.label2id = label2id
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
image_path = sample['image_path']
annotations = sample['annotations']
with Image.open(image_path) as img:
width, height = img.size
annotations.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
full_text = " ".join([ann['transcription'] for ann in annotations])
word_spans = []
current_pos = 0
for ann in annotations:
start = full_text.find(ann['transcription'], current_pos)
end = start + len(ann['transcription'])
word_spans.append((start, end))
current_pos = end
encoding = self.tokenizer(
full_text,
return_offsets_mapping=True,
truncation=True,
max_length=self.max_length,
padding='max_length'
)
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']
offset_mapping = encoding['offset_mapping']
labels = [-100] * len(input_ids)
bbox = [[0, 0, 0, 0]] * len(input_ids)
for word_idx, (word_start, word_end) in enumerate(word_spans):
word_label = self.label2id[annotations[word_idx]['label']]
word_bbox = annotations[word_idx]['bbox']
x1, y1, x2, y2 = word_bbox
x1_norm = int((x1 / width) * 1000)
y1_norm = int((y1 / height) * 1000)
x2_norm = int((x2 / width) * 1000)
y2_norm = int((y2 / height) * 1000)
for token_idx, (token_start, token_end) in enumerate(offset_mapping):
if token_start >= word_end or token_end <= word_start:
continue
if token_start >= word_start and token_end <= word_end:
labels[token_idx] = word_label
bbox[token_idx] = [x1_norm, y1_norm, x2_norm, y2_norm]
return {
'input_ids': torch.tensor(input_ids),
'attention_mask': torch.tensor(attention_mask),
'bbox': torch.tensor(bbox),
'labels': torch.tensor(labels)
}
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
true_labels = [label for label_seq in labels for label in label_seq if label != -100]
true_preds = [pred for pred_seq, label_seq in zip(preds, labels) for pred, label in zip(pred_seq, label_seq) if label != -100]
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, true_preds, average='weighted')
return {'precision': precision, 'recall': recall, 'f1': f1}
def train_lilt():
data = parse_dataset(TRAIN_FILE)
train_data, val_data = split_data(data)
label2id, id2label = create_label_mappings(data)
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
train_dataset = LiLTDataset(train_data, tokenizer, label2id)
val_dataset = LiLTDataset(val_data, tokenizer, label2id)
model = LiltForTokenClassification.from_pretrained(
"nielsr/lilt-xlm-roberta-base",
num_labels=len(label2id),
id2label=id2label,
label2id=label2id
)
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=LEARNING_RATE,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
num_train_epochs=NUM_EPOCHS,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1",
fp16=True
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics
)
trainer.train()
trainer.save_model("./lilt_model")
if name == "main":
train_lilt()