-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #112 from ArthurH91/ahaffemayer/feature/create-ocp…
…-param-class OCP Croco Class
- Loading branch information
Showing
9 changed files
with
577 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,90 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
|
||
from agimus_controller.mpc_data import OCPResults, OCPDebugData | ||
from agimus_controller.trajectory import WeightedTrajectoryPoint | ||
|
||
|
||
class OCPBase(ABC): | ||
"""Base class for the Optimal Control Problem (OCP) solver. | ||
This class defines the interface for the OCP solver.""" | ||
|
||
def __init__(self) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def set_reference_horizon( | ||
self, reference_trajectory: list[WeightedTrajectoryPoint] | ||
) -> None: | ||
... | ||
"""Set the reference trajectory and the weights of the costs for the OCP solver. This method should be implemented by the derived class.""" | ||
pass | ||
|
||
@abstractmethod | ||
@property | ||
@abstractmethod | ||
def horizon_size() -> int: | ||
... | ||
"""Returns the horizon size of the OCP. | ||
Returns: | ||
int: size of the horizon. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
@property | ||
def dt() -> int: | ||
... | ||
@abstractmethod | ||
def dt() -> float: | ||
"""Returns the time step of the OCP in seconds. | ||
Returns: | ||
int: time step of the OCP. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def solve( | ||
self, x0: np.ndarray, x_init: list[np.ndarray], u_init: list[np.ndarray] | ||
self, | ||
x0: npt.NDArray[np.float64], | ||
x_warmstart: list[npt.NDArray[np.float64]], | ||
u_warmstart: list[npt.NDArray[np.float64]], | ||
) -> None: | ||
... | ||
"""Solver for the OCP. This method should be implemented by the derived class. | ||
The method should solve the OCP for the given initial state and warmstart values. | ||
Args: | ||
x0 (npt.NDArray[np.float64]): current state of the robot. | ||
x_warmstart (list[npt.NDArray[np.float64]]): Warmstart values for the state. This doesn't include the current state. | ||
u_warmstart (list[npt.NDArray[np.float64]]): Warmstart values for the control inputs. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
@property | ||
@abstractmethod | ||
def ocp_results(self) -> OCPResults: | ||
... | ||
"""Returns the results of the OCP solver. | ||
The solve method should be called before calling this method. | ||
Returns: | ||
OCPResults: Class containing the results of the OCP solver. | ||
""" | ||
pass | ||
|
||
@ocp_results.setter | ||
def ocp_results(self, value: OCPResults) -> None: | ||
"""Set the output data structure of the OCP. | ||
Args: | ||
value (OCPResults): New output data structure of the OCP. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
@property | ||
@abstractmethod | ||
def debug_data(self) -> OCPDebugData: | ||
... | ||
"""Returns the debug data of the OCP solver. | ||
Returns: | ||
OCPDebugData: Class containing the debug data of the OCP solver. | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
from abc import abstractmethod | ||
|
||
import crocoddyl | ||
import mim_solvers | ||
import numpy as np | ||
import numpy.typing as npt | ||
|
||
from agimus_controller.factory.robot_model import RobotModels | ||
from agimus_controller.mpc_data import OCPResults, OCPDebugData | ||
from agimus_controller.ocp_base import OCPBase | ||
from agimus_controller.ocp_param_base import OCPParamsBaseCroco | ||
from agimus_controller.trajectory import TrajectoryPointWeights | ||
|
||
|
||
class OCPBaseCroco(OCPBase): | ||
def __init__( | ||
self, | ||
robot_models: RobotModels, | ||
ocp_params: OCPParamsBaseCroco, | ||
) -> None: | ||
"""Defines common behavior for all OCP using croccodyl. This is an abstract class with some helpers to design OCPs in a more friendly way. | ||
Args: | ||
robot_models (RobotModels): All models of the robot. | ||
ocp_params (OCPParamsBaseCroco): Input data structure of the OCP. | ||
""" | ||
# Setting the robot model | ||
self._robot_models = robot_models | ||
self._collision_model = self._robot_models.collision_model | ||
self._armature = self._robot_models._params.armature | ||
|
||
# Stat and actuation model | ||
self._state = crocoddyl.StateMultibody(self._robot_models.robot_model) | ||
self._actuation = crocoddyl.ActuationModelFull(self._state) | ||
|
||
# Setting the OCP parameters | ||
self._ocp_params = ocp_params | ||
self._ocp = None | ||
self._ocp_results = None | ||
self._running_model_list = [] | ||
self._terminal_model = [] | ||
self._problem = None | ||
|
||
@property | ||
def horizon_size(self) -> int: | ||
"""Number of time steps in the horizon.""" | ||
return self._ocp_params.horizon_size | ||
|
||
@property | ||
def dt(self) -> float: | ||
"""Integration step of the OCP.""" | ||
return self._ocp_params.dt | ||
|
||
@abstractmethod | ||
def create_running_model_list(self) -> list[crocoddyl.ActionModelAbstract]: | ||
"""Create the list of running models.""" | ||
pass | ||
|
||
@abstractmethod | ||
def create_terminal_model(self) -> crocoddyl.ActionModelAbstract: | ||
"""Create the terminal model.""" | ||
pass | ||
|
||
@abstractmethod | ||
def update_crocoddyl_problem( | ||
self, | ||
x0: npt.NDArray[np.float64], | ||
weighted_trajectory_points: list[TrajectoryPointWeights], | ||
) -> None: | ||
"""Update the Crocoddyl problem's references, weights and x0.""" | ||
pass | ||
|
||
def solve( | ||
self, | ||
x0: npt.NDArray[np.float64], | ||
x_warmstart: list[npt.NDArray[np.float64]], | ||
u_warmstart: list[npt.NDArray[np.float64]], | ||
) -> None: | ||
"""Solves the OCP. | ||
The results can be accessed through the ocp_results property. | ||
Args: | ||
x0 (npt.NDArray[np.float64]): Current state of the robot. | ||
x_warmstart (list[npt.NDArray[np.float64]]): Predicted states for the OCP. | ||
u_warmstart (list[npt.NDArray[np.float64]]): Predicted control inputs for the OCP. | ||
""" | ||
### Creation of the state and actuation models | ||
|
||
if self._ocp is None: | ||
# Create the running models | ||
self._running_model_list = self.create_running_model_list() | ||
# Create the terminal model | ||
self._terminal_model = self.create_terminal_model() | ||
# Create the shooting problem | ||
self._problem = crocoddyl.ShootingProblem( | ||
x0, self._running_model_list, self._terminal_model | ||
) | ||
# Create solver + callbacks | ||
self._ocp = mim_solvers.SolverCSQP(self._problem) | ||
|
||
# Merit function | ||
self._ocp.use_filter_line_search = self._ocp_params.use_filter_line_search | ||
|
||
# Parameters of the solver | ||
self._ocp.termination_tolerance = self._ocp_params.termination_tolerance | ||
self._ocp.max_qp_iters = self._ocp_params.qp_iters | ||
self._ocp.eps_abs = self._ocp_params.eps_abs | ||
self._ocp.eps_rel = self._ocp_params.eps_rel | ||
self._ocp.with_callbacks = self._ocp_params.callbacks | ||
|
||
# Creating the warmstart lists for the solver | ||
# Solve the OCP | ||
self._ocp.solve([x0] + x_warmstart, u_warmstart, self._ocp_params.solver_iters) | ||
|
||
# Store the results | ||
self._ocp_results = OCPResults( | ||
states=self._ocp.xs, | ||
ricatti_gains=self._ocp.K, | ||
feed_forward_terms=self._ocp.us, | ||
) | ||
|
||
@property | ||
def ocp_results(self) -> OCPResults: | ||
"""Output data structure of the OCP. | ||
Returns: | ||
OCPResults: Output data structure of the OCP. It contains the states, Ricatti gains, and feed-forward terms. | ||
""" | ||
return self._ocp_results | ||
|
||
@ocp_results.setter | ||
def ocp_results(self, value: OCPResults) -> None: | ||
"""Set the output data structure of the OCP. | ||
Args: | ||
value (OCPResults): New output data structure of the OCP. | ||
""" | ||
self._ocp_results = value | ||
|
||
@property | ||
def debug_data(self) -> OCPDebugData: | ||
pass | ||
|
||
@debug_data.setter | ||
def debug_data(self, value: OCPDebugData) -> None: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from dataclasses import dataclass | ||
|
||
import numpy as np | ||
|
||
|
||
@dataclass | ||
class OCPParamsBaseCroco: | ||
"""Input data structure of the OCP.""" | ||
|
||
# Data relevant to solve the OCP | ||
dt: float # Integration step of the OCP | ||
horizon_size: int # Number of time steps in the horizon | ||
solver_iters: int # Number of solver iterations | ||
qp_iters: int = 200 # Number of QP iterations (must be a multiple of 25). | ||
termination_tolerance: float = ( | ||
1e-3 # Termination tolerance (norm of the KKT conditions) | ||
) | ||
eps_abs: float = 1e-6 # Absolute tolerance of the solver | ||
eps_rel: float = 0.0 # Relative tolerance of the solver | ||
callbacks: bool = False # Flag to enable/disable callbacks | ||
use_filter_line_search = False # Flag to enable/disable the filter line searchs |
Binary file not shown.
Oops, something went wrong.