Skip to content

Commit 958206e

Browse files
committed
Initial refactoring to use sane OSITrace reader
Signed-off-by: Pierre R. Mai <[email protected]>
1 parent 7fd9311 commit 958206e

File tree

6 files changed

+47
-518
lines changed

6 files changed

+47
-518
lines changed

.github/workflows/ci.yml

+2-24
Original file line numberDiff line numberDiff line change
@@ -48,42 +48,20 @@ jobs:
4848
- name: Checkout
4949
uses: actions/checkout@v4
5050
with:
51+
submodules: recursive
5152
lfs: true
5253

5354
- name: Set up Python ${{ matrix.python-version }}
5455
uses: actions/setup-python@v5
5556
with:
5657
python-version: ${{ matrix.python-version }}
5758

58-
- name: Cache Dependencies
59-
id: cache-depends
60-
uses: actions/cache@v3
61-
with:
62-
path: protobuf-3.20.1
63-
key: ${{ runner.os }}-v2-depends
64-
65-
- name: Download ProtoBuf
66-
if: steps.cache-depends.outputs.cache-hit != 'true'
67-
run: curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v3.20.1/protobuf-all-3.20.1.tar.gz && tar xzvf protobuf-all-3.20.1.tar.gz
68-
69-
- name: Build ProtoBuf
70-
if: steps.cache-depends.outputs.cache-hit != 'true'
71-
working-directory: protobuf-3.20.1
72-
run: ./configure DIST_LANG=cpp --prefix=/usr && make
73-
74-
- name: Install ProtoBuf
75-
working-directory: protobuf-3.20.1
76-
run: sudo make install && sudo ldconfig
77-
78-
- name: Install Open Simulation Interface
79-
shell: bash
59+
- name: Set up Virtual Environment
8060
run: |
81-
git submodule update --init
8261
python -m venv .venv
8362
source .venv/bin/activate
8463
python -m pip install --upgrade pip
8564
pip install -r requirements_develop.txt
86-
cd open-simulation-interface && pip install . && cd ..
8765
8866
- name: Generate parsed rules
8967
run: |

README.md

+6-6
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ The full documentation on the validator and customization of the rules is availa
1111
## Usage
1212

1313
```bash
14-
usage: osivalidator [-h] [--data DATA] [--rules RULES] [--type {SensorView,GroundTruth,SensorData}] [--output OUTPUT] [--timesteps TIMESTEPS] [--debug] [--verbose] [--parallel] [--format {separated,None}]
14+
usage: osivalidator [-h] --data DATA [--rules RULES] [--type {SensorView,GroundTruth,SensorData}] [--output OUTPUT] [--timesteps TIMESTEPS] [--debug] [--verbose] [--parallel] [--format {None}]
1515
[--blast BLAST] [--buffer BUFFER]
1616

1717
Validate data defined at the input
@@ -29,13 +29,13 @@ optional arguments:
2929
Number of timesteps to analyze. If -1, all.
3030
--debug Set the debug mode to ON.
3131
--verbose, -v Set the verbose mode to ON.
32-
--parallel, -p Set parallel mode to ON.
33-
--format {separated,None}, -f {separated,None}
34-
Set the format type of the trace.
32+
--parallel, -p (Ignored) Set parallel mode to ON.
33+
--format {None}, -f {None}
34+
(Ignored) Set the format type of the trace.
3535
--blast BLAST, -bl BLAST
36-
Set the in-memory storage count of OSI messages during validation.
36+
Set the maximum in-memory storage count of OSI messages during validation.
3737
--buffer BUFFER, -bu BUFFER
38-
Set the buffer size to retrieve OSI messages from trace file. Set it to 0 if you do not want to use buffering at all.
38+
(Ignored) Set the buffer size to retrieve OSI messages from trace file. Set it to 0 if you do not want to use buffering at all.
3939
```
4040
4141
## Installation

osivalidator/osi_general_validator.py

+35-134
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
"""
44

55
import argparse
6-
from multiprocessing import Pool, Manager
76
from tqdm import tqdm
7+
from osi3trace.osi_trace import OSITrace
88
import os, sys
99

1010
sys.path.append(os.path.join(os.path.dirname(__file__), "."))
@@ -14,7 +14,7 @@
1414
import osi_rules
1515
import osi_validator_logger
1616
import osi_rules_checker
17-
import osi_trace
17+
import linked_proto_field
1818
except Exception as e:
1919
print(
2020
"Make sure you have installed the requirements with 'pip install -r requirements.txt'!"
@@ -39,9 +39,9 @@ def command_line_arguments():
3939
)
4040
parser.add_argument(
4141
"--data",
42-
default="",
4342
help="Path to the file with OSI-serialized data.",
4443
type=str,
44+
required=True,
4545
)
4646
parser.add_argument(
4747
"--rules",
@@ -83,48 +83,43 @@ def command_line_arguments():
8383
parser.add_argument(
8484
"--parallel",
8585
"-p",
86-
help="Set parallel mode to ON.",
86+
help="(Ignored) Set parallel mode to ON.",
8787
default=False,
8888
required=False,
8989
action="store_true",
9090
)
9191
parser.add_argument(
9292
"--format",
9393
"-f",
94-
help="Set the format type of the trace.",
95-
choices=["separated", None],
94+
help="(Ignored) Set the format type of the trace.",
95+
choices=[None],
9696
default=None,
9797
type=str,
9898
required=False,
9999
)
100100
parser.add_argument(
101101
"--blast",
102102
"-bl",
103-
help="Set the in-memory storage count of OSI messages during validation.",
103+
help="Set the maximum in-memory storage count of OSI messages during validation.",
104104
default=500,
105105
type=check_positive_int,
106106
required=False,
107107
)
108108
parser.add_argument(
109109
"--buffer",
110110
"-bu",
111-
help="Set the buffer size to retrieve OSI messages from trace file. Set it to 0 if you do not want to use buffering at all.",
112-
default=1000000,
111+
help="(Ignored) Set the buffer size to retrieve OSI messages from trace file. Set it to 0 if you do not want to use buffering at all.",
112+
default=0,
113113
type=check_positive_int,
114114
required=False,
115115
)
116116

117117
return parser.parse_args()
118118

119119

120-
MANAGER = Manager()
121-
LOGS = MANAGER.list()
122-
TIMESTAMP_ANALYZED = MANAGER.list()
120+
LOGS = []
123121
LOGGER = osi_validator_logger.OSIValidatorLogger()
124122
VALIDATION_RULES = osi_rules.OSIRules()
125-
ID_TO_TS = {}
126-
BAR_SUFFIX = "%(index)d/%(max)d [%(elapsed_td)s]"
127-
MESSAGE_CACHE = {}
128123

129124

130125
def main():
@@ -143,11 +138,7 @@ def main():
143138

144139
# Read data
145140
print("Reading data ...")
146-
DATA = osi_trace.OSITrace(buffer_size=args.buffer)
147-
DATA.from_file(path=args.data, type_name=args.type, max_index=args.timesteps)
148-
149-
if DATA.timestep_count < args.timesteps:
150-
args.timesteps = -1
141+
trace = OSITrace(path=args.data, type_name=args.type)
151142

152143
# Collect Validation Rules
153144
print("Collect validation rules ...")
@@ -159,140 +150,50 @@ def main():
159150
LOGGER.info(None, f"Pass the {max_timestep} first timesteps")
160151
else:
161152
LOGGER.info(None, "Pass all timesteps")
162-
max_timestep = DATA.timestep_count
163-
164-
# Dividing in several blast to not overload the memory
165-
max_timestep_blast = 0
166-
167-
while max_timestep_blast < max_timestep:
168-
# Clear log queue
169-
LOGS = MANAGER.list()
170-
171-
# Increment the max-timestep to analyze
172-
max_timestep_blast += args.blast
173-
first_of_blast = max_timestep_blast - args.blast
174-
last_of_blast = min(max_timestep_blast, max_timestep)
175-
176-
# Cache messages
177-
DATA.cache_messages_in_index_range(first_of_blast, last_of_blast)
178-
MESSAGE_CACHE.update(DATA.message_cache)
179-
180-
if args.parallel:
181-
# Launch parallel computation
182-
# Recreate the pool
153+
max_timestep = None
154+
155+
index = 0
156+
total_length = os.path.getsize(args.data)
157+
current_pos = 0
158+
159+
with tqdm(total=total_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
160+
for message in trace:
161+
if (index % args.blast == 0):
162+
LOGS = []
163+
if max_timestep and index >= max_timestep:
164+
pbar.update(total_length - current_pos)
165+
break
183166
try:
184-
argument_list = [
185-
(i, args.type) for i in tqdm(range(first_of_blast, last_of_blast))
186-
]
187-
with Pool() as pool:
188-
pool.starmap(process_timestep, argument_list)
189-
167+
process_message(message, index, args.type)
190168
except Exception as e:
191169
print(str(e))
170+
new_pos = trace.file.tell()
171+
pbar.update(new_pos - current_pos)
172+
current_pos = new_pos
173+
index += 1
192174

193-
finally:
194-
close_pool(pool)
195-
print("\nClosed pool!")
196-
else:
197-
# Launch sequential computation
198-
try:
199-
for i in tqdm(range(first_of_blast, last_of_blast)):
200-
process_timestep(i, args.type)
201-
202-
except Exception as e:
203-
print(str(e))
204-
205-
MESSAGE_CACHE.clear()
206-
207-
DATA.trace_file.close()
175+
trace.close()
208176
display_results()
209177

210178

211-
def close_pool(pool):
212-
"""Cleanly close a pool to free the memory"""
213-
pool.close()
214-
pool.terminate()
215-
pool.join()
216-
217-
218-
def process_timestep(timestep, data_type):
219-
"""Process one timestep"""
220-
message = MESSAGE_CACHE[timestep]
179+
def process_message(message, timestep, data_type):
180+
"""Process one message"""
221181
rule_checker = osi_rules_checker.OSIRulesChecker(LOGGER)
222-
timestamp = rule_checker.set_timestamp(message.value.timestamp, timestep)
223-
ID_TO_TS[timestep] = timestamp
182+
timestamp = rule_checker.set_timestamp(message.timestamp, timestep)
224183

225184
LOGGER.log_messages[timestep] = []
226185
LOGGER.debug_messages[timestep] = []
227186
LOGGER.info(None, f"Analyze message of timestamp {timestamp}", False)
228187

229-
with MANAGER.Lock():
230-
if timestamp in TIMESTAMP_ANALYZED:
231-
LOGGER.error(timestep, f"Timestamp already exists")
232-
TIMESTAMP_ANALYZED.append(timestamp)
233-
234188
# Check common rules
235189
getattr(rule_checker, "is_valid")(
236-
message, VALIDATION_RULES.get_rules().get_type(data_type)
190+
linked_proto_field.LinkedProtoField(message, name=data_type),
191+
VALIDATION_RULES.get_rules().get_type(data_type),
237192
)
238193

239194
LOGS.extend(LOGGER.log_messages[timestep])
240195

241196

242-
def get_message_count(data, data_type="SensorView", from_message=0, to_message=None):
243-
# Wrapper function for external use in combination with process_timestep
244-
timesteps = None
245-
246-
if from_message != 0:
247-
print("Currently only validation from the first frame (0) is supported!")
248-
249-
if to_message is not None:
250-
timesteps = int(to_message)
251-
252-
# Read data
253-
print("Reading data ...")
254-
DATA = osi_trace.OSITrace(buffer_size=1000000)
255-
DATA.from_file(path=data, type_name=data_type, max_index=timesteps)
256-
257-
if DATA.timestep_count < timesteps:
258-
timesteps = -1
259-
260-
# Collect Validation Rules
261-
print("Collect validation rules ...")
262-
try:
263-
VALIDATION_RULES.from_yaml_directory("osi-validation/rules/")
264-
except Exception as e:
265-
print("Error collecting validation rules:", e)
266-
267-
# Pass all timesteps or the number specified
268-
if timesteps != -1:
269-
max_timestep = timesteps
270-
LOGGER.info(None, f"Pass the {max_timestep} first timesteps")
271-
else:
272-
LOGGER.info(None, "Pass all timesteps")
273-
max_timestep = DATA.timestep_count
274-
275-
# Dividing in several blast to not overload the memory
276-
max_timestep_blast = 0
277-
278-
while max_timestep_blast < max_timestep:
279-
# Clear log queue
280-
LOGS[:] = []
281-
282-
# Increment the max-timestep to analyze
283-
max_timestep_blast += 500
284-
first_of_blast = max_timestep_blast - 500
285-
last_of_blast = min(max_timestep_blast, max_timestep)
286-
287-
# Cache messages
288-
DATA.cache_messages_in_index_range(first_of_blast, last_of_blast)
289-
MESSAGE_CACHE.update(DATA.message_cache)
290-
291-
DATA.trace_file.close()
292-
293-
return len(MESSAGE_CACHE)
294-
295-
296197
# Synthetize Logs
297198
def display_results():
298199
return LOGGER.synthetize_results(LOGS)

0 commit comments

Comments
 (0)