diff --git a/StreamServer/src/analytic/action/action_model.py b/StreamServer/src/analytic/action/action_model.py index 0e8945a..c4928fc 100644 --- a/StreamServer/src/analytic/action/action_model.py +++ b/StreamServer/src/analytic/action/action_model.py @@ -50,9 +50,10 @@ def kpt2bbox(kpt, ex=20): class ActionModel: - def __init__(self) -> None: - self.ACTION_LIST = [] + ACTION_LIST = [] + IS_FALL_DOWN = False + def __init__(self) -> None: # Model initialization self.detect_model = TinyYOLOv3_onecls(INP_DETS, device=DEVICE, config_file=CONFIG_FILE, weight_file=YOLO_WEIGHT_FILE) @@ -64,6 +65,53 @@ class ActionModel: self.tracker = Tracker(max_age=self.max_age, n_init=3) + def run_action_model(self, source): + cam = CamLoader(int(source) if source.isdigit() else source, + preprocess=preproc).start() + + print("STARTING ACTION MODEL") + + while cam.grabbed(): + frame = cam.getitem() + detected = self.detect_model.detect(frame, need_resize=False, expand_bb=10) + self.tracker.predict() + # Merge two source of predicted bbox together. + for track in self.tracker.tracks: + det = torch.tensor([track.to_tlbr().tolist() + [0.5, 1.0, 0.0]], dtype=torch.float32) + detected = torch.cat([detected, det], dim=0) if detected is not None else det + + detections = [] # List of Detections object for tracking. + if detected is not None: + #detected = non_max_suppression(detected[None, :], 0.45, 0.2)[0] + # Predict skeleton pose of each bboxs. + poses = self.pose_model.predict(frame, detected[:, 0:4], detected[:, 4]) + + # Create Detections object. + detections = [Detection(kpt2bbox(ps['keypoints'].numpy()), + np.concatenate((ps['keypoints'].numpy(), + ps['kp_score'].numpy()), axis=1), + ps['kp_score'].mean().numpy()) for ps in poses] + + self.tracker.update(detections) + + # Predict Actions of each track. + for i, track in enumerate(self.tracker.tracks): + if not track.is_confirmed(): + continue + + action = 'pending' + # Use 30 frames time-steps to prediction. + if len(track.keypoints_list) == 30: + pts = np.array(track.keypoints_list, dtype=np.float32) + out = self.action_model.predict(pts, frame.shape[:2]) + action_name = self.action_model.class_names[out[0].argmax()] + action = '{}: {:.2f}%'.format(action_name, out[0].max() * 100) + # Add action to action list. + ActionModel.ACTION_LIST.append(action) + if action_name == 'Fall Down': + ActionModel.IS_FALL_DOWN = True + + def generate_action_model_frame(self, source): CAM_SOURCE = source detect_model = self.detect_model @@ -129,14 +177,14 @@ class ActionModel: out = action_model.predict(pts, frame.shape[:2]) action_name = action_model.class_names[out[0].argmax()] action = '{}: {:.2f}%'.format(action_name, out[0].max() * 100) + # Add action to action list. + ActionModel.ACTION_LIST.append(action) if action_name == 'Fall Down': + ActionModel.IS_FALL_DOWN = True clr = (255, 0, 0) elif action_name == 'Lying Down': clr = (255, 200, 0) - # Add action to action list. - self.ACTION_LIST.append(action) - # VISUALIZE. if track.time_since_update == 0: if SHOW_SKELETON: @@ -160,4 +208,4 @@ class ActionModel: # If encoding fails, raise an error to stop the streaming raise HTTPException(status_code=500, detail="Frame encoding failed") yield (b'--frame\r\n' - b'Content-Type: image/jpeg\r\n\r\n' + buffer.tobytes() + b'\r\n') + b'Content-Type: image/jpeg\r\n\r\n' + buffer.tobytes() + b'\r\n') \ No newline at end of file