|
| 1 | +# main.py |
| 2 | +import argparse |
| 3 | +import glob |
| 4 | +import os |
| 5 | +import sys |
| 6 | +import time |
| 7 | +import numpy as np |
| 8 | +from preprocess import preprocess_gpx |
| 9 | +from model import build_model, train_model, generate_gpx |
| 10 | + |
| 11 | +try: |
| 12 | + import select |
| 13 | +except ImportError: |
| 14 | + select = None |
| 15 | + |
| 16 | +try: |
| 17 | + import msvcrt |
| 18 | +except ImportError: |
| 19 | + msvcrt = None |
| 20 | + |
| 21 | +# 反标准化函数,将数据恢复到原始的经纬度范围 |
| 22 | +def denormalize(lat, lon, mean_coords, std_coords): |
| 23 | + denorm_lat = (lat * std_coords[0]) + mean_coords[0] |
| 24 | + denorm_lon = (lon * std_coords[1]) + mean_coords[1] |
| 25 | + return denorm_lat, denorm_lon |
| 26 | + |
| 27 | +# 保存为 GPX 路线格式的函数 |
| 28 | +def save_to_gpx_route(filename, data, mean_coords, std_coords): |
| 29 | + with open(filename, 'w') as file: |
| 30 | + # 写入 GPX 文件头 |
| 31 | + file.write('<?xml version="1.0" encoding="UTF-8"?>\n') |
| 32 | + file.write('<gpx version="1.1" creator="StackAll">\n') |
| 33 | + file.write(' <rte>\n') |
| 34 | + file.write(' <name>Generated Route</name>\n') |
| 35 | + file.write(' <number>1</number>\n') |
| 36 | + |
| 37 | + # 写入路线点 |
| 38 | + for point in data: |
| 39 | + lat, lon = denormalize(point[0], point[1], mean_coords, std_coords) |
| 40 | + file.write(f' <rtept lat="{lat}" lon="{lon}"></rtept>\n') |
| 41 | + |
| 42 | + # 写入 GPX 文件尾 |
| 43 | + file.write(' </rte>\n') |
| 44 | + file.write('</gpx>\n') |
| 45 | + |
| 46 | +# 增加了命令行参数解析 |
| 47 | +def parse_args(): |
| 48 | + parser = argparse.ArgumentParser(description='GPX Route Generation') |
| 49 | + parser.add_argument('-ni', '--no-interaction', action='store_true', |
| 50 | + help='Run without interactive input, using default values') |
| 51 | + parser.add_argument('-e', '--epochs', type=int, default=100, |
| 52 | + help='Number of epochs to train the model') |
| 53 | + parser.add_argument('-p', '--num-points', type=int, default=100, |
| 54 | + help='Number of points to generate in each GPX route') |
| 55 | + parser.add_argument('-n', '--num-files', type=int, default=10, |
| 56 | + help='Number of GPX files to generate') |
| 57 | + return parser.parse_args() |
| 58 | + |
| 59 | +def cross_platform_input_with_timeout(prompt, timeout, default): |
| 60 | + print(prompt, end='', flush=True) |
| 61 | + input_str = '' |
| 62 | + start_time = time.time() |
| 63 | + |
| 64 | + while True: |
| 65 | + if msvcrt: |
| 66 | + while (time.time() - start_time) < timeout: |
| 67 | + if msvcrt.kbhit(): |
| 68 | + char = msvcrt.getche() |
| 69 | + if ord(char) == 13: |
| 70 | + return input_str or default |
| 71 | + if char.decode().isdigit(): |
| 72 | + input_str += char.decode() |
| 73 | + else: |
| 74 | + print("\n请输入数字。") |
| 75 | + return cross_platform_input_with_timeout(prompt, timeout, default) |
| 76 | + time.sleep(0.05) |
| 77 | + else: |
| 78 | + if (time.time() - start_time) < timeout: |
| 79 | + rlist, _, _ = select.select([sys.stdin], [], [], timeout) |
| 80 | + if rlist: |
| 81 | + input_str = sys.stdin.readline().rstrip('\n') |
| 82 | + if input_str.isdigit() or input_str == '': |
| 83 | + return input_str or default |
| 84 | + else: |
| 85 | + print("请输入数字。") |
| 86 | + return cross_platform_input_with_timeout(prompt, timeout, default) |
| 87 | + timeout -= (time.time() - start_time) |
| 88 | + else: |
| 89 | + print() |
| 90 | + return default |
| 91 | + |
| 92 | +def ensure_directories(): |
| 93 | + if not os.path.exists('./input'): |
| 94 | + os.makedirs('./input') |
| 95 | + print("创建了 './input' 文件夹。请将您的 GPX 文件导入到这个文件夹。") |
| 96 | + if not os.path.exists('./output'): |
| 97 | + os.makedirs('./output') |
| 98 | + print("创建了 './output' 文件夹。") |
| 99 | + |
| 100 | +def ensure_numeric_args(args): |
| 101 | + if not str(args.epochs).isdigit() or not str(args.num_points).isdigit() or not str(args.num_files).isdigit(): |
| 102 | + print("参数必须为数字。现在进入交互式。") |
| 103 | + return False |
| 104 | + return True |
| 105 | + |
| 106 | +def main(): |
| 107 | + ensure_directories() |
| 108 | + args = parse_args() |
| 109 | + |
| 110 | + if not ensure_numeric_args(args): |
| 111 | + args.no_interaction = False |
| 112 | + |
| 113 | + # 默认使用命令行参数设置变量 |
| 114 | + epochs = args.epochs |
| 115 | + num_points = args.num_points |
| 116 | + num_files = args.num_files |
| 117 | + |
| 118 | + # 如果启用了交互模式,则使用用户输入覆盖 |
| 119 | + if not args.no_interaction: |
| 120 | + epochs_str = cross_platform_input_with_timeout( |
| 121 | + "请输入训练模型的迭代次数(默认为 100,10s 后采用默认值):", |
| 122 | + 10, |
| 123 | + str(args.epochs) |
| 124 | + ) |
| 125 | + epochs = int(epochs_str) if epochs_str.isdigit() else args.epochs |
| 126 | + |
| 127 | + num_points_str = cross_platform_input_with_timeout( |
| 128 | + "请输入每个 GPX 路线要生成的点数(默认为 100,10s 后采用默认值):", |
| 129 | + 10, |
| 130 | + str(args.num_points) |
| 131 | + ) |
| 132 | + num_points = int(num_points_str) if num_points_str.isdigit() else args.num_points |
| 133 | + |
| 134 | + if not args.no_interaction: |
| 135 | + num_files_str = cross_platform_input_with_timeout( |
| 136 | + "请输入要生成的 GPX 数量(默认为 10,10s 后采用默认值):", |
| 137 | + 10, |
| 138 | + str(args.num_files) |
| 139 | + ) |
| 140 | + num_files = int(num_files_str) if num_files_str.isdigit() else args.num_files |
| 141 | + |
| 142 | + # 在input文件夹中查找所有.gpx文件 |
| 143 | + gpx_files = glob.glob('./input/*.gpx') |
| 144 | + |
| 145 | + # 数据预处理 |
| 146 | + sequences, next_points, mean_coords, std_coords = preprocess_gpx(gpx_files) |
| 147 | + |
| 148 | + if sequences is not None: |
| 149 | + # 构建模型 |
| 150 | + input_shape = (sequences.shape[1], sequences.shape[2]) # 序列长度和特征数 |
| 151 | + model = build_model(input_shape) |
| 152 | + |
| 153 | + # 训练模型 |
| 154 | + train_model(model, sequences, next_points, epochs=epochs) |
| 155 | + |
| 156 | + # 生成并保存新的GPX路线 |
| 157 | + seed = sequences[:1] # 使用现有序列作为种子 |
| 158 | + for i in range(num_files): # 使用 num_files 控制生成的 GPX 文件数量 |
| 159 | + generated_route = generate_gpx(model, seed, num_points=num_points) |
| 160 | + # 反标准化并将生成的路径保存到 GPX 文件 |
| 161 | + filename = f'./output/generated_route_{i+1}.gpx' |
| 162 | + save_to_gpx_route(filename, generated_route, mean_coords, std_coords) |
| 163 | + print(f'生成的GPX路线已保存至 {filename}') |
| 164 | + |
| 165 | +if __name__ == '__main__': |
| 166 | + main() |
0 commit comments