diff --git a/fog_x/trajectory.py b/fog_x/trajectory.py index 6545f99..e30d8ce 100644 --- a/fog_x/trajectory.py +++ b/fog_x/trajectory.py @@ -10,6 +10,7 @@ import h5py import asyncio from concurrent.futures import ThreadPoolExecutor +import sys logger = logging.getLogger(__name__) @@ -361,7 +362,7 @@ def _load_from_cache(self): def _load_from_container(self): """ - Load the container file with the entire VLA trajectory. + Load the container file with the entire VLA trajectory using multi-processing for image streams. args: save_to_cache: save the decoded data to the cache file @@ -372,8 +373,11 @@ def _load_from_container(self): Workflow: - Get schema of the container file. - Preallocate decoded streams. - - Decode frame by frame and store in the preallocated memory. + - Use multi-processing to decode image streams separately. + - Decode non-image streams in the main process. + - Combine results from all processes. """ + import multiprocessing as mp def _get_length_of_stream(container, stream): """ @@ -385,6 +389,25 @@ def _get_length_of_stream(container, stream): length += 1 return length + def process_image_stream(stream, feature_name, feature_type, length, path, result_queue): + container = av.open(path, mode="r", format="matroska") + np_cache = np.empty((length,) + feature_type.shape, dtype=feature_type.dtype) + feature_length = 0 + + for packet in container.demux([stream]): + frames = packet.decode() + for frame in frames: + if feature_type.dtype == "float32": + data = frame.to_ndarray(format="gray").reshape(feature_type.shape) + else: + data = frame.to_ndarray(format="rgb24").reshape(feature_type.shape) + np_cache[feature_length] = data + feature_length += 1 + + container.close() + result_queue.put((feature_name, np_cache[:feature_length])) + os._exit(0) + try: container_to_get_length = av.open(self.path, mode="r", format="matroska") except Exception as e: @@ -398,11 +421,12 @@ def _get_length_of_stream(container, stream): container = av.open(self.path, mode="r", format="matroska") streams = container.streams - # Dictionary to store preallocated numpy arrays np_cache = {} - # Preallocate memory for the streams in numpy arrays + # Prepare for multi-processing + image_streams = [] + other_streams = [] for stream in streams: feature_name = stream.metadata.get("FEATURE_NAME") if feature_name is None: @@ -412,54 +436,50 @@ def _get_length_of_stream(container, stream): self.feature_name_to_stream[feature_name] = stream self.feature_name_to_feature_type[feature_name] = feature_type - logger.debug( - f"Creating a cache for {feature_name} with shape {feature_type.shape}" - ) - - # Allocate numpy array with shape [None, X, Y, Z] where X, Y, Z are feature dimensions - if feature_type.dtype == "string": - np_cache[feature_name] = np.empty((length,) + feature_type.shape, dtype=object) + if stream.codec_context.codec.name == "h264": + image_streams.append((stream, feature_name, feature_type)) else: - np_cache[feature_name] = np.empty((length,) + feature_type.shape, dtype=feature_type.dtype) + other_streams.append((stream, feature_name, feature_type)) + if feature_type.dtype == "string": + np_cache[feature_name] = np.empty((length,) + feature_type.shape, dtype=object) + else: + np_cache[feature_name] = np.empty((length,) + feature_type.shape, dtype=feature_type.dtype) + + # Process image streams with multi-processing + result_queue = mp.Queue() + processes = [] + for stream, feature_name, feature_type in image_streams: + p = mp.Process(target=process_image_stream, args=(stream, feature_name, feature_type, length, self.path, result_queue)) + processes.append(p) + p.start() + - # Decode the frames and store them in the preallocated numpy memory - d_feature_length = {feature: 0 for feature in self.feature_name_to_stream} - for packet in container.demux(list(streams)): + # Process other streams in the main process + d_feature_length = {feature: 0 for feature, _, _ in other_streams} + for packet in container.demux([stream for stream, _, _ in other_streams]): feature_name = packet.stream.metadata.get("FEATURE_NAME") if feature_name is None: logger.debug(f"Skipping stream without FEATURE_NAME: {packet.stream}") continue feature_type = FeatureType.from_str(packet.stream.metadata.get("FEATURE_TYPE")) - logger.debug( - f"Decoding {feature_name} with shape {feature_type.shape} and dtype {feature_type.dtype} with time {packet.dts}" - ) - - feature_codec = packet.stream.codec_context.codec.name - if feature_codec == "h264": - frames = packet.decode() - for frame in frames: - if feature_type.dtype == "float32": - data = frame.to_ndarray(format="gray").reshape(feature_type.shape) - else: - data = frame.to_ndarray(format="rgb24").reshape(feature_type.shape) - - # Append data to the numpy array - np_cache[feature_name][d_feature_length[feature_name]] = data - d_feature_length[feature_name] += 1 + packet_in_bytes = bytes(packet) + if packet_in_bytes: + data = pickle.loads(packet_in_bytes) + np_cache[feature_name][d_feature_length[packet.stream]] = data + d_feature_length[packet.stream] += 1 else: - packet_in_bytes = bytes(packet) - if packet_in_bytes: - # Decode the packet - data = pickle.loads(packet_in_bytes) - - # Append data to the numpy array - np_cache[feature_name][d_feature_length[feature_name]] = data - d_feature_length[feature_name] += 1 - else: - logger.debug(f"Skipping empty packet: {packet} for {feature_name}") - logger.debug(f"Length of the stream {feature_name} is {d_feature_length[feature_name]}") + logger.debug(f"Skipping empty packet: {packet} for {feature_name}") container.close() + # Wait for all image processing to complete + # busy join here + for p in processes: + p.join() + + # Collect results from image processing + while not result_queue.empty(): + feature_name, data = result_queue.get() + np_cache[feature_name] = data return np_cache