Thanks to visit codestin.com
Credit goes to www.scribd.com

0% found this document useful (0 votes)
14 views8 pages

Inference

The document contains a Python script that processes video files to extract and classify frames using deep learning models, specifically a combination of ResNet and a transformer architecture. It includes functions for loading video data, preprocessing frames, performing inference, and generating labels based on the model's predictions. The results are saved in a JSON format, indicating the status of processing and the corresponding labels for different segments of the video.

Uploaded by

周亚
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
14 views8 pages

Inference

The document contains a Python script that processes video files to extract and classify frames using deep learning models, specifically a combination of ResNet and a transformer architecture. It includes functions for loading video data, preprocessing frames, performing inference, and generating labels based on the model's predictions. The results are saved in a JSON format, indicating the status of processing and the corresponding labels for different segments of the video.

Uploaded by

周亚
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
You are on page 1/ 8

import torch

import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.nn import DataParallel
from torch.utils.data import Sampler
from PIL import Image
import numpy as np
from scipy.ndimage import gaussian_filter1d
import os
import cv2
import mstcn
from transformer2_3_1 import Transformer2_3_1
import multiprocessing
import argparse
import json
import sys

sequence_length = 1
val_batch_size = 100
workers = min(4, multiprocessing.cpu_count() - 1) # 保留 1 个核心供其他任务使用

num_gpus = torch.cuda.device_count()
gpus = ",".join(map(str, range(num_gpus)))
os.environ["CUDA_VISIBLE_DEVICES"] = gpus

def pil_loader(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')

def change_size(image):
binary_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
_, binary_image2 = cv2.threshold(binary_image, 15, 255, cv2.THRESH_BINARY)
binary_image2 = cv2.medianBlur(binary_image2, 19)
x = binary_image2.shape[0]
y = binary_image2.shape[1]

edges_x = []
edges_y = []
for i in range(x):
for j in range(10, y-10):
if binary_image2.item(i, j) != 0:
edges_x.append(i)
edges_y.append(j)

if not edges_x:
return image

left = min(edges_x)
right = max(edges_x)
width = right - left
bottom = min(edges_y)
top = max(edges_y)
height = top - bottom

pre1_picture = image[left:left + width, bottom:bottom + height]

return pre1_picture
def get_data_from_video(video_path):
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = frame_count / fps
frames = []
for sec in range(int(duration)):
cap.set(cv2.CAP_PROP_POS_MSEC, sec * 1000)
success, frame = cap.read()
if success:
dim = (int(frame.shape[1] / frame.shape[0] * 300), 300)
frame = cv2.resize(frame, dim)
frame = change_size(frame)
frame = cv2.resize(frame, (250, 250))
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frames.append(frame)
else:
break
cap.release()
return frames

def pil_loader_from_video(video_path):
frames = get_data_from_video(video_path)
return frames

class VideoDataset(Dataset):
def __init__(self, video_path, transform=None):
self.frames = pil_loader_from_video(video_path)
self.transform = transform

def __getitem__(self, index):


img = self.frames[index]
if self.transform is not None:
img = self.transform(img)
return img, index

def __len__(self):
return len(self.frames)

class resnet_lstm(nn.Module):
def __init__(self):
super(resnet_lstm, self).__init__()
resnet = models.resnet50(pretrained=True)
self.share = nn.Sequential(
resnet.conv1,
resnet.bn1,
resnet.relu,
resnet.maxpool,
resnet.layer1,
resnet.layer2,
resnet.layer3,
resnet.layer4,
resnet.avgpool
)
self.fc = nn.Sequential(nn.Linear(2048, 512),
nn.ReLU(),
nn.Linear(512, 7))
def forward(self, x):
x = x.view(-1, 3, 224, 224)
x = self.share(x)
x = x.view(-1, 2048)
return x

def get_useful_start_idx(sequence_length, list_each_length):


count = 0
idx = []
for i in range(len(list_each_length)):
for j in range(count, count + (list_each_length[i] + 1 - sequence_length)):
idx.append(j)
count += list_each_length[i]
return idx

def get_useful_start_idx_LFB(sequence_length, list_each_length):


count = 0
idx = []
for i in range(len(list_each_length)):
for j in range(count, count + (list_each_length[i] + 1 - sequence_length)):
idx.append(j)
count += list_each_length[i]
return idx

def get_data(video_path):
test_transforms = None
test_transforms = transforms.Compose([
transforms.Resize((250, 250)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.41757566, 0.26098573, 0.25888634], [0.21938758,
0.1983, 0.19342837])
])

video_dataset = VideoDataset(video_path, test_transforms)


return video_dataset

class SeqSampler(Sampler):
def __init__(self, data_source, idx):
super().__init__(data_source)
self.data_source = data_source
self.idx = idx

def __iter__(self):
return iter(self.idx)

def __len__(self):
return len(self.idx)

sig_f = nn.Sigmoid()

g_LFB_test = np.zeros(shape=(0, 2048))

def first_stage_inference(test_dataset):
test_num_each = [len(test_dataset)]
test_useful_start_idx = get_useful_start_idx(sequence_length, test_num_each)
test_useful_start_idx_LFB = get_useful_start_idx_LFB(sequence_length,
test_num_each)
num_test_we_use = len(test_useful_start_idx)
num_test_we_use_LFB = len(test_useful_start_idx_LFB)

test_we_use_start_idx = test_useful_start_idx
test_we_use_start_idx_LFB = test_useful_start_idx_LFB

test_idx = []
for i in range(num_test_we_use):
for j in range(sequence_length):
test_idx.append(test_we_use_start_idx[i] + j)

test_idx_LFB = []
for i in range(num_test_we_use_LFB):
for j in range(sequence_length):
test_idx_LFB.append(test_we_use_start_idx_LFB[i] + j)

global g_LFB_test
test_feature_loader = DataLoader(
test_dataset,
batch_size=val_batch_size,
sampler=SeqSampler(test_dataset, test_idx_LFB),
num_workers=workers,
pin_memory=False
)

model_LFB = resnet_lstm()
model_LFB = DataParallel(model_LFB)
model_LFB = torch.load("resnet.pth")

def get_parameter_number(net):
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
return trainable_num

total_papa_num = 0
total_papa_num += get_parameter_number(model_LFB)
model_LFB.cuda()

for params in model_LFB.parameters():


params.requires_grad = False

model_LFB.eval()

with torch.no_grad():
for data in test_feature_loader:
inputs, _ = data[0].cuda(), data[1].cuda()
inputs = inputs.view(-1, sequence_length, 3, 224, 224)
outputs_feature = model_LFB.forward(inputs).data.cpu().numpy()
g_LFB_test = np.concatenate((g_LFB_test, outputs_feature), axis=0)

g_LFB_test = np.array(g_LFB_test)

class Transformer(nn.Module):
def __init__(self, mstcn_f_maps, mstcn_f_dim, out_features, len_q):
super(Transformer, self).__init__()
self.num_f_maps = mstcn_f_maps
self.dim = mstcn_f_dim
self.num_classes = out_features
self.len_q = len_q
self.transformer = Transformer2_3_1(d_model=out_features,
d_ff=mstcn_f_maps, d_k=mstcn_f_maps,
d_v=mstcn_f_maps, n_layers=1,
n_heads=8, len_q=sequence_length)
self.fc = nn.Linear(mstcn_f_dim, out_features, bias=False)

def forward(self, x, long_feature):


out_features = x.transpose(1, 2)
inputs = []
for i in range(out_features.size(1)):
if i < self.len_q - 1:
input = torch.zeros((1, self.len_q - 1 - i,
self.num_classes)).cuda()
input = torch.cat([input, out_features[:, 0:i + 1]], dim=1)
else:
input = out_features[:, i - self.len_q + 1:i + 1]
inputs.append(input)
inputs = torch.stack(inputs, dim=0).squeeze(1)
feas = torch.tanh(self.fc(long_feature).transpose(0, 1))
output = self.transformer(inputs, feas)
return output

def smooth_labels(labels, sigma=2):


return gaussian_filter1d(labels, sigma=sigma, mode='nearest')

def second_stage_inference():
def get_long_feature(start_index, lfb, LFB_length):
long_feature = []
long_feature_each = []
for k in range(LFB_length):
LFB_index = (start_index + k)
LFB_index = int(LFB_index)
long_feature_each.append(lfb[LFB_index])
long_feature.append(long_feature_each)
return long_feature

test_num_each_80, test_start_vidx = [len(g_LFB_test)], [0]


out_features = 7
mstcn_causal_conv = True
mstcn_layers = 8
mstcn_f_maps = 32
mstcn_f_dim= 2048
mstcn_stages = 2
sequence_length = 30

model = mstcn.MultiStageModel(mstcn_stages, mstcn_layers, mstcn_f_maps,


mstcn_f_dim, out_features, mstcn_causal_conv)
model = DataParallel(model)
model = torch.load('tcn.pth')
model.cuda()
model.eval()

model1 = Transformer(mstcn_f_maps, mstcn_f_dim, out_features, sequence_length)


model1 = DataParallel(model1)
model1 = torch.load('transformer.pth')
model1.cuda()
model1.eval()
torch.cuda.empty_cache()
with torch.no_grad():
long_feature = get_long_feature(start_index=test_start_vidx[0],
lfb=g_LFB_test, LFB_length=test_num_each_80[0])
long_feature = (torch.Tensor(long_feature)).cuda()
video_fe = long_feature.transpose(2, 1)

out_features = model.forward(video_fe)[-1]
out_features = out_features.squeeze(1)
p_classes1 = model1(out_features, long_feature)

p_classes = p_classes1.squeeze()
_, preds_phase = torch.max(p_classes.data, 1)
preds_phase = preds_phase.cpu().numpy()
results = smooth_labels(preds_phase, sigma=2)
return results

def is_mp4_file(file_path):
return file_path.lower().endswith('.mp4')

def check_mp4_file(file_path):
if not is_mp4_file(file_path):
print(f"{file_path} is not an MP4 file.")
return False

try:
cap = cv2.VideoCapture(file_path)
if not cap.isOpened():
print(f"{file_path} might be corrupted.")
return False

# 尝试读取第一帧
ret, frame = cap.read()
if not ret:
print(f"{file_path} might be corrupted.")
return False

print(f"{file_path} is a valid MP4 file.")


return True
except Exception as e:
print(f"Error checking {file_path}: {str(e)}")
return False
finally:
cap.release()

def convert_seconds_to_timecode(seconds):
# 将秒数转换为时:分:秒格式,确保每部分为两位数
hours = seconds // 3600
minutes = (seconds % 3600) // 60
seconds = seconds % 60
return f"{hours:02}:{minutes:02}:{seconds:02}"

def process_labels(labels):
label_names = {
0: "术前准备",
1: "Calot 三角解剖",
2: "剪切与夹闭",
3: "胆囊解剖",
4: "胆囊打包",
5: "清理与凝血",
6: "胆囊牵引"
}

processed_labels = []
if not labels.any():
return processed_labels

current_label = labels[0]
start_time = 0

for i in range(1, len(labels)):


if labels[i] != current_label:
end_time = i
processed_labels.append({
"startTime": convert_seconds_to_timecode(start_time),
"labelName": label_names[current_label],
"endTime": convert_seconds_to_timecode(end_time)
})
current_label = labels[i]
start_time = i

# 最后一个阶段
processed_labels.append({
"startTime": convert_seconds_to_timecode(start_time),
"labelName": label_names[current_label],
"endTime": convert_seconds_to_timecode(len(labels))
})

return processed_labels

def main(video_path):
video_dataset = get_data(video_path)
first_stage_inference(video_dataset)
results = second_stage_inference()
return results

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('video_path', type=str, help="Path to the video file")
args = parser.parse_args()

status = 0
b_mp4 = True # 假设判断是否是 mp4 的结果
b_valid = True # 假设判断文件是否有效的结果

video_path = args.video_path

b_mp4 = is_mp4_file(video_path)
b_valid = check_mp4_file(video_path)

if not b_mp4:
status = 1
elif not b_valid:
status = 2

if status == 0:
labels = main(video_path)
processed_labels = process_labels(labels)
else:
processed_labels = []

output = {
"status": status,
"labels": processed_labels
}

# 输出 JSON 结果
with open('result.json', 'w', encoding='utf-8') as f:
json.dump(output, f, ensure_ascii=False, indent=4)
print("status:", status)
# 退出并返回 status
sys.exit(status)

You might also like