Skip to content

Commit

Permalink
adapt buffer to dataclass DTFactorsNSeq
Browse files Browse the repository at this point in the history
  • Loading branch information
TheoMF committed Feb 10, 2025
1 parent 54006b5 commit 474d267
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions agimus_controller/agimus_controller/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import numpy.typing as npt
from pinocchio import SE3, Force
from agimus_controller.ocp_param_base import DTFactorsNSeq


@dataclass
Expand Down Expand Up @@ -157,10 +158,10 @@ def __eq__(self, other):
class TrajectoryBuffer(object):
"""List of variable size in which the HPP trajectory nodes will be."""

def __init__(self, dt_factor_n_seq: list[tuple[int, int]]):
def __init__(self, dt_factor_n_seq: DTFactorsNSeq):
self._buffer = []
self.dt_factor_n_seq = deepcopy(dt_factor_n_seq)
self.horizon_indexes = self.compute_horizon_indexes(self.dt_factor_n_seq)
self.horizon_indexes = self.compute_horizon_indexes()

def append(self, item):
self._buffer.append(item)
Expand All @@ -172,14 +173,14 @@ def clear_past(self):
if self._buffer:
self._buffer.pop(0)

def compute_horizon_indexes(self, dt_factor_n_seq: list[tuple[int, int]]):
indexes = [0] * sum(sn for _, sn in dt_factor_n_seq)
def compute_horizon_indexes(self):
indexes = [0] * (sum(sn for sn in self.dt_factor_n_seq.dts) + 1)
i = 0
for factor, sn in dt_factor_n_seq:
for factor, sn in zip(self.dt_factor_n_seq.factors, self.dt_factor_n_seq.dts):
for _ in range(sn):
indexes[i] = 0 if i == 0 else factor + indexes[i - 1]
i += 1

indexes[-1] = indexes[-2] + self.dt_factor_n_seq.factors[-1]
assert indexes[0] == 0, "First time step must be 0"
assert all(t0 <= t1 for t0, t1 in zip(indexes[:-1], indexes[1:])), (
"Time steps must be increasing"
Expand Down

0 comments on commit 474d267

Please sign in to comment.