17
17
"""
18
18
19
19
import plotly .graph_objects as go
20
+ import plotly .express as px
21
+ from datetime import datetime
20
22
21
23
from .utils import plot_with_error , plot_line , finalize_plot , colors
22
24
28
30
import numpy as np
29
31
from scipy .stats import norm
30
32
33
+ from nyx_space .time import Epoch
31
34
32
35
def plot_estimates (
33
36
dfs ,
@@ -247,11 +250,11 @@ def plot_estimates(
247
250
248
251
if msr_df is not None :
249
252
# Plot the measurements on both plots
250
- pos_fig = plot_measurements (
253
+ pos_fig = overlay_measurements (
251
254
msr_df , title , time_col_name , fig = pos_fig , show = False
252
255
)
253
256
254
- vel_fig = plot_measurements (
257
+ vel_fig = overlay_measurements (
255
258
msr_df , title , time_col_name , fig = vel_fig , show = False
256
259
)
257
260
@@ -454,11 +457,11 @@ def plot_covar(
454
457
455
458
if msr_df is not None :
456
459
# Plot the measurements on both plots
457
- pos_fig = plot_measurements (
460
+ pos_fig = overlay_measurements (
458
461
msr_df , title , time_col_name , fig = pos_fig , show = False
459
462
)
460
463
461
- vel_fig = plot_measurements (
464
+ vel_fig = overlay_measurements (
462
465
msr_df , title , time_col_name , fig = vel_fig , show = False
463
466
)
464
467
@@ -481,15 +484,19 @@ def plot_covar(
481
484
return pos_fig , vel_fig
482
485
483
486
484
- def plot_measurements (
487
+ def overlay_measurements (
488
+ fig ,
485
489
dfs ,
486
490
title ,
487
491
time_col_name = "Epoch:Gregorian UTC" ,
488
492
html_out = None ,
489
493
copyright = None ,
490
- fig = None ,
491
494
show = True ,
492
495
):
496
+ """
497
+ Given a plotly figure, overlay the measurements as shaded regions on top of the existing plot.
498
+ For a plot of measurements only, use `plot_measurements`.
499
+ """
493
500
if not isinstance (dfs , list ):
494
501
dfs = [dfs ]
495
502
@@ -571,7 +578,7 @@ def plot_measurements(
571
578
line_width = 0 ,
572
579
)
573
580
574
- finalize_plot (fig , title , x_title , copyright , show )
581
+ finalize_plot (fig , title , x_title , None , copyright )
575
582
576
583
if html_out :
577
584
with open (html_out , "w" ) as f :
@@ -673,7 +680,7 @@ def plot_residuals(
673
680
674
681
if msr_df is not None :
675
682
# Plot the measurements on both plots
676
- fig = plot_measurements (
683
+ fig = overlay_measurements (
677
684
msr_df , title , time_col_name , fig = fig , show = False
678
685
)
679
686
@@ -744,3 +751,65 @@ def plot_residual_histogram(
744
751
745
752
if show :
746
753
fig .show ()
754
+
755
+ def plot_measurements (
756
+ df ,
757
+ msr_type ,
758
+ title = None ,
759
+ time_col_name = "Epoch:Gregorian UTC" ,
760
+ html_out = None ,
761
+ copyright = None ,
762
+ show = True ,
763
+ ):
764
+ """
765
+ Plot the provided measurement type, fuzzy matching of the column name
766
+ """
767
+
768
+ msr_col_name = [col for col in df .columns if msr_type in col .lower ()]
769
+
770
+ if title is None :
771
+ # Build a title
772
+ station_names = ", " .join ([name for name in df ["Tracking device" ].unique ()])
773
+ start = Epoch (df ["Epoch:Gregorian UTC" ].iloc [0 ])
774
+ end = Epoch (df ["Epoch:Gregorian UTC" ].iloc [- 1 ])
775
+ arc_duration = end .timedelta (start )
776
+ title = f"Measurements from { station_names } spanning { start } to { end } ({ arc_duration } )"
777
+
778
+ try :
779
+ orig_tim_col = df [time_col_name ]
780
+ except KeyError :
781
+ # Find the time column
782
+ try :
783
+ col_name = [x for x in df .columns if x .startswith ("Epoch" )][0 ]
784
+ except IndexError :
785
+ raise KeyError ("Could not find any Epoch column" )
786
+ print (f"Could not find time column { time_col_name } , using `{ col_name } `" )
787
+ orig_tim_col = df [col_name ]
788
+
789
+ # Build a Python datetime column
790
+ pd_ok_epochs = []
791
+ for epoch in orig_tim_col :
792
+ epoch = epoch .replace ("UTC" , "" ).strip ()
793
+ if "." not in epoch :
794
+ epoch += ".0"
795
+ pd_ok_epochs += [datetime .fromisoformat (str (epoch ).replace ("UTC" , "" ).strip ())]
796
+ df ["time_col" ] = pd .Series (pd_ok_epochs )
797
+ x_title = "Epoch {}" .format (time_col_name [- 3 :])
798
+
799
+ fig = px .scatter (df , x = "time_col" , y = msr_col_name , color = "Tracking device" )
800
+
801
+ finalize_plot (fig , title , x_title , msr_col_name [0 ], copyright )
802
+
803
+ if html_out :
804
+ with open (html_out , "w" ) as f :
805
+ f .write (fig .to_html ())
806
+ print (f"Saved HTML to { html_out } " )
807
+
808
+ if show :
809
+ fig .show ()
810
+ else :
811
+ return fig
812
+
813
+ if __name__ == "__main__" :
814
+ df = pd .read_parquet ("output_data/msr-2023-11-25T06-14-01.parquet" )
815
+ plot_measurements (df , "range" )
0 commit comments