1
1
import os
2
2
import os .path
3
- from typing import List , Tuple
3
+ from typing import List , Optional , Tuple
4
4
5
5
import matplotlib
6
6
from matplotlib import pyplot as plt
@@ -23,6 +23,7 @@ def generate_simple_trajectories(
23
23
traj_path : str ,
24
24
colors : List [str ],
25
25
size : Tuple [int , int ],
26
+ weights : Optional [List [int ]] = None ,
26
27
):
27
28
28
29
# Set the seed
@@ -32,7 +33,7 @@ def generate_simple_trajectories(
32
33
avg_distances = []
33
34
for i in range (num_transitions ):
34
35
data , avg_distance , distances = generate_transition (
35
- number_pairs , circle_radius , size , colors
36
+ number_pairs , circle_radius , size , colors , weights
36
37
)
37
38
obs_list .append (data )
38
39
infos_list .append ({"distances" : distances })
@@ -68,7 +69,13 @@ def generate_transition(
68
69
circle_radius : float ,
69
70
size : Tuple [float , float ],
70
71
colors : List [str ],
72
+ weights : Optional [List [int ]],
71
73
):
74
+ if weights is not None :
75
+ if len (weights ) < number_pairs :
76
+ raise ValueError ("Not every pair has a weight" )
77
+ norm = sum (weights [:number_pairs ])
78
+
72
79
if number_pairs > len (colors ):
73
80
raise ValueError ("Not enough colors for the number of pairs" )
74
81
@@ -112,7 +119,13 @@ def random_coordinate():
112
119
data = data .reshape (fig .canvas .get_width_height ()[::- 1 ] + (3 ,))
113
120
plt .close ()
114
121
115
- avg_distance = np .mean (distances )
122
+ if weights is None :
123
+ avg_distance = np .mean (distances )
124
+ else :
125
+ weighted_distances = [
126
+ dist * weight for (dist , weight ) in zip (distances , weights )
127
+ ]
128
+ avg_distance = sum (weighted_distances ) / norm
116
129
return data , avg_distance , distances
117
130
118
131
0 commit comments