Skip to content

Commit 7f047da

Browse files
committed
Initial Commit.
1 parent b8e4b19 commit 7f047da

File tree

5 files changed

+285
-0
lines changed

5 files changed

+285
-0
lines changed

Diff for: evaluate.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import numpy as np
2+
3+
def test_generated_gpx(generated_track, mean_coords, std_coords):
4+
if generated_track is None or len(generated_track) == 0:
5+
print("Generated track is empty.")
6+
return
7+
8+
# 将生成的轨迹数据逆归一化回原始的坐标值
9+
generated_track_unnormalized = generated_track * std_coords + mean_coords
10+
11+
# 计算生成轨迹的均值和标准差
12+
mean_generated = np.mean(generated_track_unnormalized, axis=0)
13+
std_generated = np.std(generated_track_unnormalized, axis=0)
14+
15+
print("Generated GPX track statistics:")
16+
print(f"Mean coordinates: {mean_generated}")
17+
print(f"Standard deviation of coordinates: {std_generated}")
18+
19+
# 这里可以添加更多的评估方法,例如与实际轨迹的对比,或者定量评估生成轨迹的质量

Diff for: main.py

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# main.py
2+
import argparse
3+
import glob
4+
import os
5+
import sys
6+
import time
7+
import numpy as np
8+
from preprocess import preprocess_gpx
9+
from model import build_model, train_model, generate_gpx
10+
11+
try:
12+
import select
13+
except ImportError:
14+
select = None
15+
16+
try:
17+
import msvcrt
18+
except ImportError:
19+
msvcrt = None
20+
21+
# 反标准化函数,将数据恢复到原始的经纬度范围
22+
def denormalize(lat, lon, mean_coords, std_coords):
23+
denorm_lat = (lat * std_coords[0]) + mean_coords[0]
24+
denorm_lon = (lon * std_coords[1]) + mean_coords[1]
25+
return denorm_lat, denorm_lon
26+
27+
# 保存为 GPX 路线格式的函数
28+
def save_to_gpx_route(filename, data, mean_coords, std_coords):
29+
with open(filename, 'w') as file:
30+
# 写入 GPX 文件头
31+
file.write('<?xml version="1.0" encoding="UTF-8"?>\n')
32+
file.write('<gpx version="1.1" creator="StackAll">\n')
33+
file.write(' <rte>\n')
34+
file.write(' <name>Generated Route</name>\n')
35+
file.write(' <number>1</number>\n')
36+
37+
# 写入路线点
38+
for point in data:
39+
lat, lon = denormalize(point[0], point[1], mean_coords, std_coords)
40+
file.write(f' <rtept lat="{lat}" lon="{lon}"></rtept>\n')
41+
42+
# 写入 GPX 文件尾
43+
file.write(' </rte>\n')
44+
file.write('</gpx>\n')
45+
46+
# 增加了命令行参数解析
47+
def parse_args():
48+
parser = argparse.ArgumentParser(description='GPX Route Generation')
49+
parser.add_argument('-ni', '--no-interaction', action='store_true',
50+
help='Run without interactive input, using default values')
51+
parser.add_argument('-e', '--epochs', type=int, default=100,
52+
help='Number of epochs to train the model')
53+
parser.add_argument('-p', '--num-points', type=int, default=100,
54+
help='Number of points to generate in each GPX route')
55+
parser.add_argument('-n', '--num-files', type=int, default=10,
56+
help='Number of GPX files to generate')
57+
return parser.parse_args()
58+
59+
def cross_platform_input_with_timeout(prompt, timeout, default):
60+
print(prompt, end='', flush=True)
61+
input_str = ''
62+
start_time = time.time()
63+
64+
while True:
65+
if msvcrt:
66+
while (time.time() - start_time) < timeout:
67+
if msvcrt.kbhit():
68+
char = msvcrt.getche()
69+
if ord(char) == 13:
70+
return input_str or default
71+
if char.decode().isdigit():
72+
input_str += char.decode()
73+
else:
74+
print("\n请输入数字。")
75+
return cross_platform_input_with_timeout(prompt, timeout, default)
76+
time.sleep(0.05)
77+
else:
78+
if (time.time() - start_time) < timeout:
79+
rlist, _, _ = select.select([sys.stdin], [], [], timeout)
80+
if rlist:
81+
input_str = sys.stdin.readline().rstrip('\n')
82+
if input_str.isdigit() or input_str == '':
83+
return input_str or default
84+
else:
85+
print("请输入数字。")
86+
return cross_platform_input_with_timeout(prompt, timeout, default)
87+
timeout -= (time.time() - start_time)
88+
else:
89+
print()
90+
return default
91+
92+
def ensure_directories():
93+
if not os.path.exists('./input'):
94+
os.makedirs('./input')
95+
print("创建了 './input' 文件夹。请将您的 GPX 文件导入到这个文件夹。")
96+
if not os.path.exists('./output'):
97+
os.makedirs('./output')
98+
print("创建了 './output' 文件夹。")
99+
100+
def ensure_numeric_args(args):
101+
if not str(args.epochs).isdigit() or not str(args.num_points).isdigit() or not str(args.num_files).isdigit():
102+
print("参数必须为数字。现在进入交互式。")
103+
return False
104+
return True
105+
106+
def main():
107+
ensure_directories()
108+
args = parse_args()
109+
110+
if not ensure_numeric_args(args):
111+
args.no_interaction = False
112+
113+
# 默认使用命令行参数设置变量
114+
epochs = args.epochs
115+
num_points = args.num_points
116+
num_files = args.num_files
117+
118+
# 如果启用了交互模式,则使用用户输入覆盖
119+
if not args.no_interaction:
120+
epochs_str = cross_platform_input_with_timeout(
121+
"请输入训练模型的迭代次数(默认为 100,10s 后采用默认值):",
122+
10,
123+
str(args.epochs)
124+
)
125+
epochs = int(epochs_str) if epochs_str.isdigit() else args.epochs
126+
127+
num_points_str = cross_platform_input_with_timeout(
128+
"请输入每个 GPX 路线要生成的点数(默认为 100,10s 后采用默认值):",
129+
10,
130+
str(args.num_points)
131+
)
132+
num_points = int(num_points_str) if num_points_str.isdigit() else args.num_points
133+
134+
if not args.no_interaction:
135+
num_files_str = cross_platform_input_with_timeout(
136+
"请输入要生成的 GPX 数量(默认为 10,10s 后采用默认值):",
137+
10,
138+
str(args.num_files)
139+
)
140+
num_files = int(num_files_str) if num_files_str.isdigit() else args.num_files
141+
142+
# 在input文件夹中查找所有.gpx文件
143+
gpx_files = glob.glob('./input/*.gpx')
144+
145+
# 数据预处理
146+
sequences, next_points, mean_coords, std_coords = preprocess_gpx(gpx_files)
147+
148+
if sequences is not None:
149+
# 构建模型
150+
input_shape = (sequences.shape[1], sequences.shape[2]) # 序列长度和特征数
151+
model = build_model(input_shape)
152+
153+
# 训练模型
154+
train_model(model, sequences, next_points, epochs=epochs)
155+
156+
# 生成并保存新的GPX路线
157+
seed = sequences[:1] # 使用现有序列作为种子
158+
for i in range(num_files): # 使用 num_files 控制生成的 GPX 文件数量
159+
generated_route = generate_gpx(model, seed, num_points=num_points)
160+
# 反标准化并将生成的路径保存到 GPX 文件
161+
filename = f'./output/generated_route_{i+1}.gpx'
162+
save_to_gpx_route(filename, generated_route, mean_coords, std_coords)
163+
print(f'生成的GPX路线已保存至 {filename}')
164+
165+
if __name__ == '__main__':
166+
main()

Diff for: model.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# model.py
2+
import numpy as np
3+
from tensorflow.keras.models import Sequential
4+
from tensorflow.keras.layers import LSTM, Dense
5+
from tensorflow.keras.callbacks import ModelCheckpoint
6+
7+
# 设计 LSTM 模型
8+
def build_model(input_shape, units=64, dropout=0.3):
9+
model = Sequential()
10+
model.add(LSTM(units, input_shape=input_shape, return_sequences=True))
11+
model.add(LSTM(units, return_sequences=False))
12+
model.add(Dense(units, activation='relu'))
13+
model.add(Dense(2)) # 预测纬度和经度
14+
model.compile(optimizer='adam', loss='mse')
15+
return model
16+
17+
# 训练模型并保存
18+
# epochs 代表训练模型时整个数据集将被遍历迭代的次数。每一次遍历完整个数据集并进行一次前向传播和后向传播的过程被称为一个epoch。
19+
def train_model(model, sequences, next_points, epochs=100, batch_size=64, model_path='model.h5'):
20+
checkpoint = ModelCheckpoint(model_path, save_best_only=True, monitor='val_loss', mode='min')
21+
model.fit(sequences, next_points, batch_size=batch_size, epochs=epochs, callbacks=[checkpoint], validation_split=0.2)
22+
model.save(model_path)
23+
return model
24+
25+
# 使用模型生成轨迹,num_points 代表生成的轨迹的点的数量
26+
def generate_gpx(model, seed, num_points=100):
27+
generated_points = []
28+
current_seq = seed
29+
for _ in range(num_points):
30+
predicted_point = model.predict(current_seq)
31+
generated_points.append(predicted_point[0])
32+
current_seq = np.concatenate((current_seq[0][1:], predicted_point.reshape(1, -1)), axis=0).reshape(1, current_seq.shape[1], current_seq.shape[2])
33+
return np.array(generated_points)

Diff for: preprocess.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import gpxpy
2+
import numpy as np
3+
4+
# Loads GPX data
5+
def load_gpx_data(gpx_files):
6+
tracks = []
7+
for file in gpx_files:
8+
print(f"Loading GPX file: {file}")
9+
try:
10+
with open(file, 'r') as f:
11+
gpx = gpxpy.parse(f)
12+
# Check for tracks
13+
for track in gpx.tracks:
14+
for segment in track.segments:
15+
if segment.points:
16+
tracks.append(np.array([(point.latitude, point.longitude) for point in segment.points]))
17+
else:
18+
print("No points in segment.")
19+
# Check for routes
20+
for route in gpx.routes:
21+
if route.points:
22+
tracks.append(np.array([(point.latitude, point.longitude) for point in route.points]))
23+
else:
24+
print("No points in route.")
25+
except Exception as e:
26+
print(f"Error loading GPX file {file}: {e}")
27+
return tracks
28+
29+
# Normalizes the tracks
30+
def normalize_tracks(tracks):
31+
all_points = np.concatenate(tracks, axis=0)
32+
mean_lat, mean_lon = np.mean(all_points, axis=0)
33+
std_lat, std_lon = np.std(all_points, axis=0)
34+
35+
normalized_tracks = []
36+
for track in tracks:
37+
normalized_track = (track - [mean_lat, mean_lon]) / [std_lat, std_lon]
38+
normalized_tracks.append(normalized_track)
39+
return normalized_tracks, (mean_lat, mean_lon), (std_lat, std_lon)
40+
41+
# Prepares sequences
42+
def prepare_sequences(tracks, sequence_length):
43+
sequences = []
44+
next_points = []
45+
for track in tracks:
46+
for i in range(len(track) - sequence_length):
47+
sequences.append(track[i:i + sequence_length])
48+
next_points.append(track[i + sequence_length])
49+
return np.array(sequences), np.array(next_points)
50+
51+
# Main function to preprocess GPX files
52+
def preprocess_gpx(files, sequence_length=5):
53+
tracks = load_gpx_data(files)
54+
if not tracks:
55+
print("No tracks to normalize. Exiting preprocessing.")
56+
return None, None, None, None
57+
58+
print(f"Normalizing {len(tracks)} tracks.")
59+
normalized_tracks, mean_coords, std_coords = normalize_tracks(tracks)
60+
61+
print(f"Preparing sequences from normalized tracks.")
62+
sequences, next_points = prepare_sequences(normalized_tracks, sequence_length)
63+
return sequences, next_points, mean_coords, std_coords

Diff for: requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
numpy
2+
gpxpy
3+
tensorflow
4+
scikit-learn

0 commit comments

Comments
 (0)