17
17
"""
18
18
19
19
import plotly .graph_objects as go
20
+ import plotly .express as px
21
+ from datetime import datetime
20
22
21
23
from .utils import plot_with_error , plot_line , finalize_plot , colors
22
24
28
30
import numpy as np
29
31
from scipy .stats import norm
30
32
33
+ from nyx_space .time import Epoch
31
34
32
35
def plot_estimates (
33
36
dfs ,
@@ -93,8 +96,8 @@ def plot_estimates(
93
96
epoch = epoch .replace ("UTC" , "" ).strip ()
94
97
if "." not in epoch :
95
98
epoch += ".0"
96
- pd_ok_epochs += [epoch ]
97
- time_col = pd .to_datetime (pd_ok_epochs )
99
+ pd_ok_epochs += [datetime . fromisoformat ( str ( epoch ). replace ( "UTC" , "" ). strip ()) ]
100
+ time_col = pd .Series (pd_ok_epochs )
98
101
x_title = "Epoch {}" .format (time_col_name [- 3 :])
99
102
100
103
# Check that the requested covariance frame exists
@@ -247,12 +250,12 @@ def plot_estimates(
247
250
248
251
if msr_df is not None :
249
252
# Plot the measurements on both plots
250
- pos_fig = plot_measurements (
251
- msr_df , title , time_col_name , fig = pos_fig , show = False
253
+ pos_fig = overlay_measurements (
254
+ pos_fig , msr_df , title , time_col_name , show = False
252
255
)
253
256
254
- vel_fig = plot_measurements (
255
- msr_df , title , time_col_name , fig = vel_fig , show = False
257
+ vel_fig = overlay_measurements (
258
+ vel_fig , msr_df , title , time_col_name , show = False
256
259
)
257
260
258
261
if html_out :
@@ -333,8 +336,8 @@ def plot_covar(
333
336
epoch = epoch .replace ("UTC" , "" ).strip ()
334
337
if "." not in epoch :
335
338
epoch += ".0"
336
- pd_ok_epochs += [epoch ]
337
- time_col = pd .to_datetime (pd_ok_epochs )
339
+ pd_ok_epochs += [datetime . fromisoformat ( str ( epoch ). replace ( "UTC" , "" ). strip ()) ]
340
+ time_col = pd .Series (pd_ok_epochs )
338
341
x_title = "Epoch {}" .format (time_col_name [- 3 :])
339
342
340
343
# Check that the requested covariance frame exists
@@ -454,12 +457,12 @@ def plot_covar(
454
457
455
458
if msr_df is not None :
456
459
# Plot the measurements on both plots
457
- pos_fig = plot_measurements (
458
- msr_df , title , time_col_name , fig = pos_fig , show = False
460
+ pos_fig = overlay_measurements (
461
+ pos_fig , msr_df , title , time_col_name , show = False
459
462
)
460
463
461
- vel_fig = plot_measurements (
462
- msr_df , title , time_col_name , fig = vel_fig , show = False
464
+ vel_fig = overlay_measurements (
465
+ vel_fig , msr_df , title , time_col_name , show = False
463
466
)
464
467
465
468
if html_out :
@@ -481,21 +484,22 @@ def plot_covar(
481
484
return pos_fig , vel_fig
482
485
483
486
484
- def plot_measurements (
487
+ def overlay_measurements (
488
+ fig ,
485
489
dfs ,
486
490
title ,
487
491
time_col_name = "Epoch:Gregorian UTC" ,
488
492
html_out = None ,
489
493
copyright = None ,
490
- fig = None ,
491
494
show = True ,
492
495
):
496
+ """
497
+ Given a plotly figure, overlay the measurements as shaded regions on top of the existing plot.
498
+ For a plot of measurements only, use `plot_measurements`.
499
+ """
493
500
if not isinstance (dfs , list ):
494
501
dfs = [dfs ]
495
502
496
- if fig is None :
497
- fig = go .Figure ()
498
-
499
503
color_values = list (colors .values ())
500
504
501
505
station_colors = {}
@@ -518,8 +522,8 @@ def plot_measurements(
518
522
epoch = epoch .replace ("UTC" , "" ).strip ()
519
523
if "." not in epoch :
520
524
epoch += ".0"
521
- pd_ok_epochs += [epoch ]
522
- time_col = pd .to_datetime (pd_ok_epochs )
525
+ pd_ok_epochs += [datetime . fromisoformat ( str ( epoch ). replace ( "UTC" , "" ). strip ()) ]
526
+ time_col = pd .Series (pd_ok_epochs )
523
527
x_title = "Epoch {}" .format (time_col_name [- 3 :])
524
528
525
529
# Diff the epochs of the measurements to find when there is a start and end.
@@ -571,7 +575,7 @@ def plot_measurements(
571
575
line_width = 0 ,
572
576
)
573
577
574
- finalize_plot (fig , title , x_title , copyright , show )
578
+ finalize_plot (fig , title , x_title , None , copyright )
575
579
576
580
if html_out :
577
581
with open (html_out , "w" ) as f :
@@ -595,7 +599,7 @@ def plot_residuals(
595
599
show = True ,
596
600
):
597
601
"""
598
- Plot of residuals, with 3-σ lines
602
+ Plot of residuals, with 3-σ lines. Returns a tuple of the plots if show=False.
599
603
"""
600
604
601
605
try :
@@ -615,12 +619,14 @@ def plot_residuals(
615
619
epoch = epoch .replace ("UTC" , "" ).strip ()
616
620
if "." not in epoch :
617
621
epoch += ".0"
618
- pd_ok_epochs += [epoch ]
619
- time_col = pd .to_datetime (pd_ok_epochs )
622
+ pd_ok_epochs += [datetime . fromisoformat ( str ( epoch ). replace ( "UTC" , "" ). strip ()) ]
623
+ time_col = pd .Series (pd_ok_epochs )
620
624
x_title = "Epoch {}" .format (time_col_name [- 3 :])
621
625
622
626
plt_any = False
623
627
628
+ rtn_plots = []
629
+
624
630
for col in df .columns :
625
631
if col .startswith (kind ):
626
632
fig = go .Figure ()
@@ -671,8 +677,8 @@ def plot_residuals(
671
677
672
678
if msr_df is not None :
673
679
# Plot the measurements on both plots
674
- fig = plot_measurements (
675
- msr_df , title , time_col_name , fig = fig , show = False
680
+ fig = overlay_measurements (
681
+ fig , msr_df , title , time_col_name , show = False
676
682
)
677
683
678
684
finalize_plot (
@@ -689,10 +695,15 @@ def plot_residuals(
689
695
690
696
if show :
691
697
fig .show ()
698
+ else :
699
+ rtn_plots += [fig ]
692
700
693
701
if not plt_any :
694
702
raise ValueError (f"No columns ending with { kind } found -- nothing plotted" )
695
703
704
+ if not show :
705
+ return rtn_plots
706
+
696
707
697
708
def plot_residual_histogram (
698
709
df , title , kind = "Prefit" , copyright = None , html_out = None , show = True
@@ -737,3 +748,64 @@ def plot_residual_histogram(
737
748
738
749
if show :
739
750
fig .show ()
751
+
752
+ def plot_measurements (
753
+ df ,
754
+ msr_type = None ,
755
+ title = None ,
756
+ time_col_name = "Epoch:Gregorian UTC" ,
757
+ html_out = None ,
758
+ copyright = None ,
759
+ show = True ,
760
+ ):
761
+ """
762
+ Plot the provided measurement type, fuzzy matching of the column name, or plot all as a strip
763
+ """
764
+
765
+ if title is None :
766
+ # Build a title
767
+ station_names = ", " .join ([name for name in df ["Tracking device" ].unique ()])
768
+ start = Epoch (df ["Epoch:Gregorian UTC" ].iloc [0 ])
769
+ end = Epoch (df ["Epoch:Gregorian UTC" ].iloc [- 1 ])
770
+ arc_duration = end .timedelta (start )
771
+ title = f"Measurements from { station_names } spanning { start } to { end } ({ arc_duration } )"
772
+
773
+ try :
774
+ orig_tim_col = df [time_col_name ]
775
+ except KeyError :
776
+ # Find the time column
777
+ try :
778
+ col_name = [x for x in df .columns if x .startswith ("Epoch" )][0 ]
779
+ except IndexError :
780
+ raise KeyError ("Could not find any Epoch column" )
781
+ print (f"Could not find time column { time_col_name } , using `{ col_name } `" )
782
+ orig_tim_col = df [col_name ]
783
+
784
+ # Build a Python datetime column
785
+ pd_ok_epochs = []
786
+ for epoch in orig_tim_col :
787
+ epoch = epoch .replace ("UTC" , "" ).strip ()
788
+ if "." not in epoch :
789
+ epoch += ".0"
790
+ pd_ok_epochs += [datetime .fromisoformat (str (epoch ).replace ("UTC" , "" ).strip ())]
791
+ df ["time_col" ] = pd .Series (pd_ok_epochs )
792
+ x_title = "Epoch {}" .format (time_col_name [- 3 :])
793
+
794
+ if msr_type is None :
795
+ fig = px .strip (df , x = "time_col" , y = "Tracking device" , color = "Tracking device" )
796
+ finalize_plot (fig , title , x_title , "All tracking data" , copyright )
797
+ else :
798
+ msr_col_name = [col for col in df .columns if msr_type in col .lower ()]
799
+
800
+ fig = px .scatter (df , x = "time_col" , y = msr_col_name , color = "Tracking device" )
801
+ finalize_plot (fig , title , x_title , msr_col_name [0 ], copyright )
802
+
803
+ if html_out :
804
+ with open (html_out , "w" ) as f :
805
+ f .write (fig .to_html ())
806
+ print (f"Saved HTML to { html_out } " )
807
+
808
+ if show :
809
+ fig .show ()
810
+ else :
811
+ return fig
0 commit comments