-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcontrol_drone.py
158 lines (130 loc) · 4.99 KB
/
control_drone.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import argparse
import time
from collections import deque
import torch
from classifier import EMG_Inference
from pymavlink import mavutil
def move_drone(x, y, z):
master.mav.set_position_target_local_ned_send(
0, master.target_system, master.target_component,
mavutil.mavlink.MAV_FRAME_BODY_OFFSET_NED,
int(0b100111111000), # Type mask to ignore yaw and yaw rate, focus on position only
x, y, z, # Position in x, y, z
0, 0, 0, # Velocity in m/s
0, 0, 0, # Acceleration
0, 0, # Yaw, Yaw rate
)
def turn_drone(degree):
master.mav.command_long_send(
master.target_system, master.target_component,
mavutil.mavlink.MAV_CMD_CONDITION_YAW, 0,
degree, # Yaw angle (set to 360 for continuous rotation)
50,
1, # Direction (1: CW, -1: CCW)
1, # Relative (1 for relative yaw change)
0, 0, 0
)
def get_current_altitude():
msg = master.recv_match(type='GLOBAL_POSITION_INT', blocking=True)
if msg:
return msg.relative_alt / 1000.0 # Altitude in meters
return None
def choose_movement_from_gesture(predictions):
armed = master.motors_armed() != 0
counts = [predictions.count(gesture) for gesture in gestures]
if not PRED_LEN in counts:
return False
selected_gesture = gestures[counts.index(PRED_LEN)]
if selected_gesture == 'fist':
if not armed:
print("ARMING")
master.set_mode('GUIDED')
time.sleep(0.5)
master.arducopter_arm()
time.sleep(0.5)
return True
if not armed:
print("ARM WITH FIST FIRST")
return True
if selected_gesture == 'up':
if get_current_altitude() < ALT_THRESHOLD:
master.mav.command_long_send(
master.target_system, master.target_component,
mavutil.mavlink.MAV_CMD_NAV_TAKEOFF, 0,
0, 0, 0, 0, 0, 0, 2
)
else:
move_drone(0, 0, -0.5)
elif selected_gesture == 'lift':
move_drone(2, 0, 0)
time.sleep(2)
elif selected_gesture == 'peace':
turn_drone(50)
time.sleep(2)
elif selected_gesture == 'down':
if get_current_altitude() >= ALT_THRESHOLD:
move_drone(0, 0, 0.5)
else:
land_drone()
return True
def land_drone():
master.mav.param_set_send(
master.target_system,
master.target_component,
b'PLND_ENABLED',
0, # Disable precision landing
mavutil.mavlink.MAV_PARAM_TYPE_INT8
)
time.sleep(2)
master.mav.command_long_send(
master.target_system, master.target_component,
mavutil.mavlink.MAV_CMD_NAV_LAND, 0,
0, 0, 0, 0, 0, 0, 0
)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Drone control via EMG gesture commands")
parser.add_argument(
"--model_type",
choices=["sklearn", "torch", "lstm", "tf"],
default="sklearn",
help="Type of model to use for classification",
)
parser.add_argument(
"--model_path",
type=str,
default="../training/resources/custom_classifier_gen2.pkl",
help="Path to classifier",
)
parser.add_argument(
"--scaler_path",
type=str,
default="../training/resources/custom_scaler.pkl",
help="Path to scaler",
)
parser.add_argument("--window_size", type=int, default=200, help="Size of the data window for predictions")
parser.add_argument("--prediction_delay", type=int, default=1, help="Delay between predictions")
parser.add_argument("-p", "--port", type=str, default="/dev/ttyUSB0", help="USB receiving station port")
master = mavutil.mavlink_connection('udp:127.0.0.1:14550')
time.sleep(5)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inference = EMG_Inference(port=args.port, model_path=args.model_path, model_type=args.model_type, scaler_path=args.scaler_path)
PRED_LEN = 3
gestures = ("baseline", "fist", "peace", "up", "down", "lift")
ALT_THRESHOLD = 0.5
predictions = deque([], maxlen=PRED_LEN)
try:
while True:
master.wait_heartbeat(timeout=0.1) # Do heartbeat to keep connection alive (probably)
prediction = inference.classification()
if prediction is not None:
gesture = gestures[prediction]
predictions.append(gesture)
print(f"Prediction: {gesture}")
if len(predictions) == PRED_LEN:
found_gesture = choose_movement_from_gesture(predictions)
if found_gesture:
predictions = deque([], maxlen=PRED_LEN)
time.sleep(args.prediction_delay)
except KeyboardInterrupt:
land_drone()