Skip to content

Commit 1de178f

Browse files
committed
Weighted!
1 parent 9520859 commit 1de178f

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

src/reward_preprocessing/scripts/gen_dots_and_dists.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import os.path
3-
from typing import List, Tuple
3+
from typing import List, Optional, Tuple
44

55
import matplotlib
66
from matplotlib import pyplot as plt
@@ -23,6 +23,7 @@ def generate_simple_trajectories(
2323
traj_path: str,
2424
colors: List[str],
2525
size: Tuple[int, int],
26+
weights: Optional[List[int]] = None,
2627
):
2728

2829
# Set the seed
@@ -32,7 +33,7 @@ def generate_simple_trajectories(
3233
avg_distances = []
3334
for i in range(num_transitions):
3435
data, avg_distance, distances = generate_transition(
35-
number_pairs, circle_radius, size, colors
36+
number_pairs, circle_radius, size, colors, weights
3637
)
3738
obs_list.append(data)
3839
infos_list.append({"distances": distances})
@@ -68,7 +69,13 @@ def generate_transition(
6869
circle_radius: float,
6970
size: Tuple[float, float],
7071
colors: List[str],
72+
weights: Optional[List[int]],
7173
):
74+
if weights is not None:
75+
if len(weights) < number_pairs:
76+
raise ValueError("Not every pair has a weight")
77+
norm = sum(weights[:number_pairs])
78+
7279
if number_pairs > len(colors):
7380
raise ValueError("Not enough colors for the number of pairs")
7481

@@ -112,7 +119,13 @@ def random_coordinate():
112119
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
113120
plt.close()
114121

115-
avg_distance = np.mean(distances)
122+
if weights is None:
123+
avg_distance = np.mean(distances)
124+
else:
125+
weighted_distances = [
126+
dist * weight for (dist, weight) in zip(distances, weights)
127+
]
128+
avg_distance = sum(weighted_distances) / norm
116129
return data, avg_distance, distances
117130

118131

0 commit comments

Comments
 (0)