|
| 1 | +import os |
| 2 | +import logging |
| 3 | +import cv2 |
| 4 | +from ultralytics import YOLO |
| 5 | +import streamlit as st |
| 6 | + |
| 7 | +# Set up logging |
| 8 | +logging.basicConfig(level=logging.DEBUG) |
| 9 | + |
| 10 | +# Load YOLO models |
| 11 | +yolo_models = { |
| 12 | + 'regular_deadlift': YOLO("muscleAi_weights/best.pt"), |
| 13 | + 'sumo_deadlift': YOLO("muscleAi_weights/sumo_best.pt"), |
| 14 | + 'squat': YOLO("muscleAi_weights/squats_best.pt"), |
| 15 | + 'romanian_deadlift': YOLO("muscleAi_weights/best_romanian.pt"), |
| 16 | + "zercher_squat": YOLO("muscleAi_weights/zercher_best.pt"), |
| 17 | + "front_squat": YOLO("muscleAi_weights/front_squats_best.pt") |
| 18 | +} |
| 19 | + |
| 20 | +# Function to check for injury risk |
| 21 | +def check_injury_risk(labels, exercise_type): |
| 22 | + if exercise_type in ['regular_deadlift', 'squat']: |
| 23 | + ibw_value = labels.get('ibw', 1.0) |
| 24 | + down_value = labels.get('down', 1.0) |
| 25 | + else: |
| 26 | + ibw_value = labels.get('up', 1.0) |
| 27 | + down_value = labels.get('down', 1.0) |
| 28 | + |
| 29 | + return "stop right now to prevent injury" if ibw_value < 0.80 or down_value < 0.70 else "No significant risk" |
| 30 | + |
| 31 | +# Function to draw keypoints on the frame |
| 32 | +def draw_keypoints(frame, keypoints): |
| 33 | + for point in keypoints: |
| 34 | + x, y = int(point[0]), int(point[1]) |
| 35 | + cv2.circle(frame, (x, y), 5, (0, 255, 0), -1) |
| 36 | + return frame |
| 37 | + |
| 38 | +# Function to process video with YOLO |
| 39 | +def process_video_with_yolo(video_path, exercise_type): |
| 40 | + processed_frames = [] |
| 41 | + |
| 42 | + yolo_model = yolo_models[exercise_type] |
| 43 | + |
| 44 | + cap = cv2.VideoCapture(video_path) |
| 45 | + |
| 46 | + if not cap.isOpened(): |
| 47 | + st.error("Error opening video file") |
| 48 | + return None |
| 49 | + |
| 50 | + last_ibw_label = None |
| 51 | + rep_count = 0 |
| 52 | + rep_started = False |
| 53 | + |
| 54 | + while True: |
| 55 | + ret, frame = cap.read() |
| 56 | + if not ret: |
| 57 | + break |
| 58 | + |
| 59 | + results = yolo_model(source=frame, stream=True, conf=0.3) |
| 60 | + |
| 61 | + for result in results: |
| 62 | + frame = result.orig_img |
| 63 | + |
| 64 | + labels = {result.names[int(box.cls)]: float(box.conf) for box in result.boxes} if result.boxes is not None else {} |
| 65 | + injury_risk = check_injury_risk(labels, exercise_type) |
| 66 | + |
| 67 | + current_ibw_label = labels.get('ibw') if exercise_type in ['regular_deadlift', 'squat'] else labels.get('up') |
| 68 | + |
| 69 | + if last_ibw_label is not None and current_ibw_label is not None: |
| 70 | + if not rep_started: |
| 71 | + if last_ibw_label > 0.89 and current_ibw_label <= 0.89: |
| 72 | + rep_started = True |
| 73 | + else: |
| 74 | + if last_ibw_label <= 0.89 and current_ibw_label > 0.89: |
| 75 | + rep_count += 1 |
| 76 | + rep_started = False |
| 77 | + |
| 78 | + last_ibw_label = current_ibw_label |
| 79 | + |
| 80 | + # Draw keypoints on the frame if available |
| 81 | + if hasattr(result, 'keypoints') and result.keypoints is not None: |
| 82 | + keypoints = result.keypoints.xy[0] |
| 83 | + frame = draw_keypoints(frame, keypoints) |
| 84 | + |
| 85 | + cv2.putText(frame, f"Injury Risk: {injury_risk}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) |
| 86 | + cv2.putText(frame, f"Repetitions: {rep_count}", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) |
| 87 | + |
| 88 | + processed_frames.append(frame) |
| 89 | + |
| 90 | + cap.release() |
| 91 | + |
| 92 | + return processed_frames |
| 93 | + |
| 94 | +# Function to process live video stream |
| 95 | +def process_live_video(exercise_type): |
| 96 | + cap = cv2.VideoCapture(0) |
| 97 | + |
| 98 | + if not cap.isOpened(): |
| 99 | + st.error("Error opening webcam") |
| 100 | + return |
| 101 | + |
| 102 | + stframe = st.empty() |
| 103 | + |
| 104 | + last_ibw_label = None |
| 105 | + rep_count = 0 |
| 106 | + rep_started = False |
| 107 | + |
| 108 | + while True: |
| 109 | + ret, frame = cap.read() |
| 110 | + if not ret: |
| 111 | + break |
| 112 | + |
| 113 | + results = yolo_models[exercise_type](source=frame, stream=True, conf=0.3) |
| 114 | + |
| 115 | + for result in results: |
| 116 | + frame = result.orig_img |
| 117 | + |
| 118 | + labels = {result.names[int(box.cls)]: float(box.conf) for box in result.boxes} if result.boxes is not None else {} |
| 119 | + injury_risk = check_injury_risk(labels, exercise_type) |
| 120 | + |
| 121 | + current_ibw_label = labels.get('ibw') if exercise_type in ['regular_deadlift', 'squat'] else labels.get('up') |
| 122 | + |
| 123 | + if last_ibw_label is not None and current_ibw_label is not None: |
| 124 | + if not rep_started: |
| 125 | + if last_ibw_label > 0.89 and current_ibw_label <= 0.89: |
| 126 | + rep_started = True |
| 127 | + else: |
| 128 | + if last_ibw_label <= 0.89 and current_ibw_label > 0.89: |
| 129 | + rep_count += 1 |
| 130 | + rep_started = False |
| 131 | + |
| 132 | + last_ibw_label = current_ibw_label |
| 133 | + |
| 134 | + # Draw keypoints on the frame if available |
| 135 | + if hasattr(result, 'keypoints') and result.keypoints is not None: |
| 136 | + keypoints = result.keypoints.xy[0] |
| 137 | + frame = draw_keypoints(frame, keypoints) |
| 138 | + |
| 139 | + cv2.putText(frame, f"Injury Risk: {injury_risk}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) |
| 140 | + cv2.putText(frame, f"Repetitions: {rep_count}", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) |
| 141 | + |
| 142 | + # Update the Streamlit frame with the processed frame |
| 143 | + stframe.image(frame, channels="BGR") |
| 144 | + |
| 145 | + cap.release() |
| 146 | + |
| 147 | +# Streamlit UI |
| 148 | +st.title("Exercise Video Analysis") |
| 149 | + |
| 150 | +exercise_type = st.selectbox("Select Exercise Type", list(yolo_models.keys())) |
| 151 | + |
| 152 | +uploaded_file = st.file_uploader("Upload a Video", type=["mp4", "mov"]) |
| 153 | + |
| 154 | +if uploaded_file is not None: |
| 155 | + video_path = os.path.join('./videos', uploaded_file.name) |
| 156 | + |
| 157 | + # Save uploaded video |
| 158 | + with open(video_path, "wb") as f: |
| 159 | + f.write(uploaded_file.getbuffer()) |
| 160 | + |
| 161 | + st.video(uploaded_file) |
| 162 | + |
| 163 | + if st.button("Process Video"): |
| 164 | + processed_frames = process_video_with_yolo(video_path, exercise_type) |
| 165 | + |
| 166 | + if processed_frames is not None: |
| 167 | + output_video_path = os.path.join('./processed_videos', f'processed_{uploaded_file.name}') |
| 168 | + out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'XVID'), 30, |
| 169 | + (processed_frames[0].shape[1], processed_frames[0].shape[0])) |
| 170 | + |
| 171 | + for frame in processed_frames: |
| 172 | + out.write(frame) |
| 173 | + out.release() |
| 174 | + |
| 175 | + st.success(f"Video processed successfully! You can download it [here](./processed_videos/processed_{uploaded_file.name})") |
| 176 | + |
| 177 | +if st.button("Start Live Stream"): |
| 178 | + process_live_video(exercise_type) |
| 179 | + |
0 commit comments