From 392e8e856401836334a670a11a320abbbd94ef57 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 10 Oct 2024 12:58:20 +0200 Subject: [PATCH 01/10] Allow passing additional kwargs when contact models are used --- src/jaxsim/api/contact.py | 10 +++++++++- src/jaxsim/api/model.py | 3 +++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 95655be0f..1e8651814 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -98,6 +98,7 @@ def collidable_point_forces( data: js.data.JaxSimModelData, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, + **kwargs, ) -> jtp.Matrix: """ Compute the 6D forces applied to each collidable point. @@ -110,6 +111,7 @@ def collidable_point_forces( representation of data. joint_force_references: The joint force references to apply to the joints. + kwargs: Additional keyword arguments to pass to the active contact model. Returns: The 6D forces applied to each collidable point expressed in the frame @@ -121,6 +123,7 @@ def collidable_point_forces( data=data, link_forces=link_forces, joint_force_references=joint_force_references, + **kwargs, ) return f_Ci @@ -132,6 +135,7 @@ def collidable_point_dynamics( data: js.data.JaxSimModelData, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, + **kwargs, ) -> tuple[jtp.Matrix, dict[str, jtp.Array]]: r""" Compute the 6D force applied to each collidable point. @@ -144,6 +148,7 @@ def collidable_point_dynamics( representation of data. joint_force_references: The joint force references to apply to the joints. + kwargs: Additional keyword arguments to pass to the active contact model. Returns: The 6D force applied to each collidable point and additional data based @@ -169,7 +174,7 @@ def collidable_point_dynamics( # Note that the material deformation rate is always returned in the mixed frame # C[W] = (W_p_C, [W]). This is convenient for integration purpose. W_f_Ci, (CW_ṁ,) = model.contact_model.compute_contact_forces( - model=model, data=data + model=model, data=data, **kwargs ) # Create the dictionary of auxiliary data. @@ -187,6 +192,7 @@ def collidable_point_dynamics( data=data, link_forces=link_forces, joint_force_references=joint_force_references, + **kwargs, ) aux_data = dict() @@ -201,6 +207,7 @@ def collidable_point_dynamics( data=data, link_forces=link_forces, joint_force_references=joint_force_references, + **kwargs, ) aux_data = dict() @@ -226,6 +233,7 @@ def collidable_point_dynamics( dt=None, # TODO link_forces=link_forces, joint_force_references=joint_force_references, + **kwargs, ) aux_data = dict(W_f_avg2_C=W_f̿_Ci, m_tf=m_tf) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 0c9e63f62..414d7aabf 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1772,6 +1772,7 @@ def link_contact_forces( data: js.data.JaxSimModelData, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, + **kwargs, ) -> jtp.Matrix: """ Compute the 6D contact forces of all links of the model. @@ -1784,6 +1785,7 @@ def link_contact_forces( representation of data. joint_force_references: The joint force references to apply to the joints. + kwargs: Additional keyword arguments to pass to the active contact model.. Returns: A `(nL, 6)` array containing the stacked 6D contact forces of the links, @@ -1831,6 +1833,7 @@ def link_contact_forces( data=data, link_forces=input_references.link_forces(), joint_force_references=input_references.joint_force_references(), + **kwargs, ) # Construct the vector defining the parent link index of each collidable point. From f7aa5426225fee371fe02ea9d06d693ff0e75957 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 14 Oct 2024 15:17:57 +0200 Subject: [PATCH 02/10] Update how contact forces are converted to equivalent link forces --- src/jaxsim/api/contact.py | 93 +++++--------- src/jaxsim/api/model.py | 54 +++------ src/jaxsim/api/ode.py | 51 ++++---- src/jaxsim/rbda/contacts/common.py | 140 +++++++++++++++++++++- src/jaxsim/rbda/contacts/relaxed_rigid.py | 4 +- src/jaxsim/rbda/contacts/rigid.py | 4 +- src/jaxsim/rbda/contacts/soft.py | 6 +- src/jaxsim/rbda/contacts/visco_elastic.py | 56 +++++---- 8 files changed, 238 insertions(+), 170 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 1e8651814..889aa8740 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -136,7 +136,7 @@ def collidable_point_dynamics( link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, **kwargs, -) -> tuple[jtp.Matrix, dict[str, jtp.Array]]: +) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: r""" Compute the 6D force applied to each collidable point. @@ -163,89 +163,58 @@ def collidable_point_dynamics( Instead, the 6D forces are returned in the active representation. """ - # Build the soft contact model. + # Build the additional kwargs to pass to the computation of the contact forces. match model.contact_model: case contacts.SoftContacts(): - assert isinstance(model.contact_model, contacts.SoftContacts) - - # Compute the 6D force expressed in the inertial frame and applied to each - # collidable point, and the corresponding material deformation rate. - # Note that the material deformation rate is always returned in the mixed frame - # C[W] = (W_p_C, [W]). This is convenient for integration purpose. - W_f_Ci, (CW_ṁ,) = model.contact_model.compute_contact_forces( - model=model, data=data, **kwargs - ) - # Create the dictionary of auxiliary data. - # This contact model considers the material deformation as additional state - # of the ODE system. We need to pass its dynamics to the integrator. - aux_data = dict(m_dot=CW_ṁ) + kwargs_contact_model = kwargs case contacts.RigidContacts(): - assert isinstance(model.contact_model, contacts.RigidContacts) - # Compute the 6D force expressed in the inertial frame and applied to each - # collidable point. - W_f_Ci, _ = model.contact_model.compute_contact_forces( - model=model, - data=data, - link_forces=link_forces, - joint_force_references=joint_force_references, - **kwargs, + kwargs_contact_model = ( + dict( + link_forces=link_forces, + joint_force_references=joint_force_references, + ) + | kwargs ) - aux_data = dict() - case contacts.RelaxedRigidContacts(): - assert isinstance(model.contact_model, contacts.RelaxedRigidContacts) - # Compute the 6D force expressed in the inertial frame and applied to each - # collidable point. - W_f_Ci, _ = model.contact_model.compute_contact_forces( - model=model, - data=data, - link_forces=link_forces, - joint_force_references=joint_force_references, - **kwargs, + kwargs_contact_model = ( + dict( + link_forces=link_forces, + joint_force_references=joint_force_references, + ) + | kwargs ) - aux_data = dict() - case contacts.ViscoElasticContacts(): - assert isinstance(model.contact_model, contacts.ViscoElasticContacts) - - # It is not yet clear how to pass the time step to this stage. - # A possibility is to restrict the integrator to only forward Euler - # and store the Δt inside the model. - module = jaxsim.rbda.contacts.visco_elastic.step.__module__ - name = jaxsim.rbda.contacts.visco_elastic.step.__name__ - msg = "You need to use the custom '{}.{}' function with this contact model." - jaxsim.exceptions.raise_runtime_error_if( - condition=True, msg=msg.format(module, name) - ) - # Compute the 6D force expressed in the inertial frame and applied to each - # collidable point. - W_f_Ci, (W_f̿_Ci, m_tf) = model.contact_model.compute_contact_forces( - model=model, - data=data, - dt=None, # TODO - link_forces=link_forces, - joint_force_references=joint_force_references, - **kwargs, + kwargs_contact_model = ( + dict( + dt=model.time_step, + link_forces=link_forces, + joint_force_references=joint_force_references, + ) + | kwargs ) - aux_data = dict(W_f_avg2_C=W_f̿_Ci, m_tf=m_tf) - case _: - raise ValueError(f"Invalid contact model {model.contact_model}") + raise ValueError(f"Invalid contact model: {model.contact_model}") + + W_f_C, aux_data = model.contact_model.compute_contact_forces( + model=model, + data=data, + **kwargs_contact_model, + ) # Compute the transforms of the implicit frames `C[L] = (W_p_C, [L])` # associated to each collidable point. # In inertial-fixed representation, the computation of these transforms # is not necessary and the conversion below becomes a no-op. - W_H_Ci = ( + W_H_C = ( js.contact.transforms(model=model, data=data) if data.velocity_representation is not VelRepr.Inertial else jnp.zeros( @@ -261,7 +230,7 @@ def collidable_point_dynamics( transform=W_H_C, is_force=True, ) - )(W_f_Ci, W_H_Ci) + )(W_f_C, W_H_C) return f_Ci, aux_data diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 414d7aabf..10e26c8aa 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1770,6 +1770,7 @@ def body_to_other_representation( def link_contact_forces( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, + *, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, **kwargs, @@ -1822,48 +1823,16 @@ def link_contact_forces( joint_force_references=joint_force_references, ) - # Compute the 6D forces applied to each collidable point expressed in the - # inertial frame. - with ( - data.switch_velocity_representation(VelRepr.Inertial), - input_references.switch_velocity_representation(VelRepr.Inertial), - ): - W_f_C = js.contact.collidable_point_forces( - model=model, - data=data, - link_forces=input_references.link_forces(), - joint_force_references=input_references.joint_force_references(), - **kwargs, - ) - - # Construct the vector defining the parent link index of each collidable point. - # We use this vector to sum the 6D forces of all collidable points rigidly - # attached to the same link. - parent_link_index_of_collidable_points = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - ) - - # Create the mask that associate each collidable point to their parent link. - # We use this mask to sum the collidable points to the right link. - mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange( - model.number_of_links() - ) - - # Sum the forces of all collidable points rigidly attached to a body. - # Since the contact forces W_f_C are expressed in the world frame, - # we don't need any coordinate transformation. - W_f_L = mask.T @ W_f_C - - # Create a references object to store the link forces. - references = js.references.JaxSimModelReferences.build( - model=model, link_forces=W_f_L, velocity_representation=VelRepr.Inertial + # Compute the 6D forces applied to the links equivalent to the forces applied + # to the frames associated to the collidable points. + f_L, _ = model.contact_model.compute_link_contact_forces( + model=model, + data=data, + link_forces=input_references.link_forces(model=model, data=data), + joint_force_references=input_references.joint_force_references(), + **kwargs, ) - # Use the references object to convert the link forces to the velocity - # representation of data. - with references.switch_velocity_representation(data.velocity_representation): - f_L = references.link_forces(model=model, data=data) - return f_L @@ -1970,6 +1939,11 @@ def step( Returns: A tuple containing the new data of the model and the new state of the integrator. + + Note: + In order to reduce the occurrences of frame conversions performed internally, + it is recommended to use inertial-fixed velocity representation. This can be + particularly useful for automatically differentiated logic. """ # Extract the integrator kwargs. diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 5dc0da8bd..56229464a 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -131,7 +131,7 @@ def system_velocity_dynamics( # Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact # with the terrain. - W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float) + W_f_L_terrain = jnp.zeros_like(O_f_L).astype(float) # Initialize a dictionary of auxiliary data. # This dictionary is used to store additional data computed by the contact model. @@ -139,66 +139,59 @@ def system_velocity_dynamics( if len(model.kin_dyn_parameters.contact_parameters.body) > 0: - # Note: the following code should be kept in sync with the function - # `jaxsim.api.model.link_contact_forces`. We cannot merge them since - # here we need to get also aux_data. - - # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point - # along with contact-specific auxiliary states. with ( data.switch_velocity_representation(VelRepr.Inertial), references.switch_velocity_representation(VelRepr.Inertial), ): - W_f_Ci, aux_data = js.contact.collidable_point_dynamics( + + # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point + # along with contact-specific auxiliary states. + W_f_C, aux_data = js.contact.collidable_point_dynamics( model=model, data=data, link_forces=references.link_forces(model=model, data=data), joint_force_references=references.joint_force_references(model=model), ) - # Construct the vector defining the parent link index of each collidable point. - # We use this vector to sum the 6D forces of all collidable points rigidly - # attached to the same link. - parent_link_index_of_collidable_points = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - ) - - # Sum the forces of all collidable points rigidly attached to a body. - # Since the contact forces W_f_Ci are expressed in the world frame, - # we don't need any coordinate transformation. - mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange( - model.number_of_links() - ) - - W_f_Li_terrain = mask.T @ W_f_Ci + # Compute the 6D forces applied to the links equivalent to the forces applied + # to the frames associated to the collidable points. + W_f_L_terrain = model.contact_model.link_forces_from_contact_forces( + model=model, + data=data, + contact_forces=W_f_C, + ) # =========================== # Compute system acceleration # =========================== - # Compute the total link forces + # Compute the total link forces. with ( data.switch_velocity_representation(VelRepr.Inertial), references.switch_velocity_representation(VelRepr.Inertial), ): + + # Sum the contact forces just computed with the link forces applied by the user. references = references.apply_link_forces( model=model, data=data, - forces=W_f_Li_terrain, + forces=W_f_L_terrain, additive=True, ) - # Get the link forces in inertial representation + # Get the link forces in inertial-fixed representation. f_L_total = references.link_forces(model=model, data=data) - v̇_WB, s̈ = system_acceleration( + # Compute the system acceleration in inertial-fixed representation. + # This representation is useful for integration purpose. + W_v̇_WB, s̈ = system_acceleration( model=model, data=data, joint_force_references=joint_force_references, link_forces=f_L_total, ) - return v̇_WB, s̈, aux_data + return W_v̇_WB, s̈, aux_data def system_acceleration( @@ -400,7 +393,7 @@ def system_dynamics( pass case _: - raise ValueError(f"Invalid contact model {model.contact_model}") + raise ValueError(f"Invalid contact model: {model.contact_model}") # Extract the velocities. W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics( diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index e6892704a..517ecb483 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -2,7 +2,6 @@ import abc import functools -from typing import Any import jax import jax.numpy as jnp @@ -10,6 +9,7 @@ import jaxsim.api as js import jaxsim.terrain import jaxsim.typing as jtp +from jaxsim.api.common import ModelDataWithVelocityRepresentation from jaxsim.utils import JaxsimDataclass try: @@ -131,7 +131,7 @@ def compute_contact_forces( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, **kwargs, - ) -> tuple[jtp.Matrix, tuple[Any, ...]]: + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: """ Compute the contact forces. @@ -142,11 +142,145 @@ def compute_contact_forces( Returns: A tuple containing as first element the computed 6D contact force applied to the contact points and expressed in the world frame, and as second element - a tuple of optional additional information. + a dictionary of optional additional information. """ pass + def compute_link_contact_forces( + self, + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + **kwargs, + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: + """ + Compute the link contact forces. + + Args: + model: The robot model considered by the contact model. + data: The data of the considered model. + + Returns: + A tuple containing as first element the 6D contact force applied to the + links and expressed in the frame of the velocity representation of data, + and as second element a dictionary of optional additional information. + """ + + # Compute the contact forces expressed in the inertial frame. + # This function, contrarily to `compute_contact_forces`, already handles how + # the optional kwargs should be passed to the specific contact models. + W_f_C, aux_dict = js.contact.collidable_point_dynamics( + model=model, data=data, **kwargs + ) + + # Compute the 6D forces applied to the links equivalent to the forces applied + # to the frames associated to the collidable points. + with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): + + W_f_L = self.link_forces_from_contact_forces( + model=model, data=data, contact_forces=W_f_C + ) + + # Store the link forces in the references object for easy conversion. + references = js.references.JaxSimModelReferences.build( + model=model, + data=data, + link_forces=W_f_L, + velocity_representation=jaxsim.VelRepr.Inertial, + ) + + # Convert the link forces to the frame corresponding to the velocity + # representation of data. + with references.switch_velocity_representation(data.velocity_representation): + f_L = references.link_forces(model=model, data=data) + + return f_L, aux_dict + + @staticmethod + def link_forces_from_contact_forces( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + contact_forces: jtp.MatrixLike, + ) -> jtp.Matrix: + """ + Compute the link forces from the contact forces. + + Args: + model: The robot model considered by the contact model. + data: The data of the considered model. + contact_forces: The contact forces computed by the contact model. + + Returns: + The 6D contact forces applied to the links and expressed in the frame of + the velocity representation of data. + """ + + # Convert the contact forces to a JAX array. + f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze()) + + # Get the pose of the enabled collidable points. + W_H_C = js.contact.transforms(model=model, data=data) + + # Convert the contact forces to inertial-fixed representation. + W_f_C = jax.vmap( + lambda f_C, W_H_C: ( + ModelDataWithVelocityRepresentation.other_representation_to_inertial( + array=f_C, + other_representation=data.velocity_representation, + transform=W_H_C, + is_force=True, + ) + ) + )(f_C, W_H_C) + + # Get the object storing the contact parameters of the model. + contact_parameters = model.kin_dyn_parameters.contact_parameters + + # Extract the indices corresponding to the enabled collidable points. + indices_of_enabled_collidable_points = ( + contact_parameters.indices_of_enabled_collidable_points + ) + + # Construct the vector defining the parent link index of each collidable point. + # We use this vector to sum the 6D forces of all collidable points rigidly + # attached to the same link. + parent_link_index_of_collidable_points = jnp.array( + contact_parameters.body, dtype=int + )[indices_of_enabled_collidable_points] + + # Create the mask that associate each collidable point to their parent link. + # We use this mask to sum the collidable points to the right link. + mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange( + model.number_of_links() + ) + + # Sum the forces of all collidable points rigidly attached to a body. + # Since the contact forces W_f_C are expressed in the world frame, + # we don't need any coordinate transformation. + W_f_L = mask.T @ W_f_C + + # Compute the link transforms. + W_H_L = ( + js.model.forward_kinematics(model=model, data=data) + if data.velocity_representation is not jaxsim.VelRepr.Inertial + else jnp.zeros(shape=(model.number_of_links(), 4, 4)) + ) + + # Convert the inertial-fixed link forces to the velocity representation of data. + f_L = jax.vmap( + lambda W_f_L, W_H_L: ( + ModelDataWithVelocityRepresentation.inertial_to_other_representation( + array=W_f_L, + other_representation=data.velocity_representation, + transform=W_H_L, + is_force=True, + ) + ) + )(W_f_L, W_H_L) + + return f_L + @classmethod def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]: """ diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 25fb35b5e..df913dab9 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -238,7 +238,7 @@ def compute_contact_forces( *, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, - ) -> tuple[jtp.Matrix, tuple]: + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: """ Compute the contact forces. @@ -458,7 +458,7 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool: ), )(CW_fl_C, W_H_C) - return W_f_C, () + return W_f_C, {} @staticmethod def _regularizers( diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index bdbbb2937..99078ea6d 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -242,7 +242,7 @@ def compute_contact_forces( *, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, - ) -> tuple[jtp.Matrix, tuple]: + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: """ Compute the contact forces. @@ -402,7 +402,7 @@ def compute_contact_forces( ), )(CW_fl_C, W_H_C) - return W_f_C, () + return W_f_C, {} @staticmethod def _delassus_matrix( diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 4af693527..41d744453 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -423,7 +423,7 @@ def compute_contact_forces( self, model: js.model.JaxSimModel, data: js.data.JaxSimModelData, - ) -> tuple[jtp.Matrix, tuple[jtp.Matrix]]: + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: """ Compute the contact forces. @@ -433,7 +433,7 @@ def compute_contact_forces( Returns: A tuple containing as first element the computed contact forces, and as - second element the derivative of the material deformation. + second element a dictionary with derivative of the material deformation. """ # Initialize the model and data this contact model is operating on. @@ -460,4 +460,4 @@ def compute_contact_forces( ) )(W_p_C, W_ṗ_C, m) - return W_f, (ṁ,) + return W_f, dict(m_dot=ṁ) diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py index 25019fcfc..184a7bb4e 100644 --- a/src/jaxsim/rbda/contacts/visco_elastic.py +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -266,7 +266,7 @@ def compute_contact_forces( dt: jtp.FloatLike | None = None, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, - ) -> tuple[jtp.Matrix, tuple[jtp.Matrix, jtp.Matrix]]: + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: """ Compute the contact forces. @@ -291,7 +291,7 @@ def compute_contact_forces( Returns: A tuple containing as first element the computed 6D contact force applied to the contact point and expressed in the world frame, and as second element - a tuple of optional additional information. + a dictionary of optional additional information. """ # Initialize the model and data this contact model is operating on. @@ -347,7 +347,7 @@ def compute_contact_forces( lambda CW_fl: compute_forces_inertial_fixed_vmap(CW_fl, W_H_C) )(jnp.stack([CW_f̅l, CW_fl̿])) - return W_f̅_C, (W_f̿_C, m_tf) + return W_f̅_C, dict(W_f_avg2_C=W_f̿_C, m_tf=m_tf) @staticmethod @functools.partial(jax.jit, static_argnames=("max_squarings",)) @@ -973,7 +973,7 @@ def step( dt = dt if dt is not None else model.time_step # Compute the contact forces with the exponential integrator. - W_f̅_C, (W_f̿_C, m_tf) = model.contact_model.compute_contact_forces( + W_f̅_C, aux_data = model.contact_model.compute_contact_forces( model=model, data=data, dt=jnp.array(dt).astype(float), @@ -981,39 +981,31 @@ def step( joint_force_references=joint_force_references, ) + # Extract the final material deformation and the average of average forces + # from the dictionary containing auxiliary data. + m_tf = aux_data["m_tf"] + W_f̿_C = aux_data["W_f_avg2_C"] + # =============================== # Compute the link contact forces # =============================== - # Extract the indices corresponding to the enabled collidable points. - # The visco-elastic contact model computed only their contact forces. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) + # Get the link contact forces by summing the forces of contact points belonging + # to the same link. + W_f̅_L, W_f̿_L = jax.vmap( + lambda W_f_C: model.contact_model.link_forces_from_contact_forces( + model=model, data=data, contact_forces=W_f_C + ) + )(jnp.stack([W_f̅_C, W_f̿_C])) # Compute the link transforms. - W_H_L = js.model.forward_kinematics(model=model, data=data) - - # Construct the vector defining the parent link index of each collidable point. - # We use this vector to sum the 6D forces of all collidable points rigidly - # attached to the same link. - parent_link_index_of_collidable_points = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] - - # Create the mask that associate each collidable point to their parent link. - # We use this mask to sum the collidable points to the right link. - mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange( - model.number_of_links() + W_H_L = ( + js.model.forward_kinematics(model=model, data=data) + if data.velocity_representation is not jaxsim.VelRepr.Mixed + else jnp.zeros(shape=(model.number_of_links(), 4, 4)) ) - # Sum the forces of all collidable points rigidly attached to a body. - # Since the contact forces W_f_C are expressed in the world frame, - # we don't need any coordinate transformation. - W_f̅_L = mask.T @ W_f̅_C - W_f̿_L = mask.T @ W_f̿_C - - # For integration purpose, we need these average of averages expressed in + # For integration purpose, we need the average of average forces expressed in # mixed representation. LW_f̿_L = jax.vmap( lambda W_f_L, W_H_L: data.inertial_to_other_representation( @@ -1046,6 +1038,12 @@ def step( # be much more accurate than the one computed with the discrete soft contacts. with data_tf.mutable_context(): + # Extract the indices corresponding to the enabled collidable points. + # The visco-elastic contact model computed only their contact forces. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + data_tf.state.extended |= { "tangential_deformation": data_tf.state.extended["tangential_deformation"] .at[indices_of_enabled_collidable_points] From 271b76c1aeb25f979f75a3e3c7767b0200d54bf4 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 22 Oct 2024 11:01:47 +0200 Subject: [PATCH 03/10] Minor updates to contact models --- src/jaxsim/rbda/contacts/relaxed_rigid.py | 59 ++++++++++++++++++----- src/jaxsim/rbda/contacts/rigid.py | 31 +++++++++--- src/jaxsim/rbda/contacts/soft.py | 8 ++- src/jaxsim/rbda/contacts/visco_elastic.py | 58 +++++++++++++++------- 4 files changed, 117 insertions(+), 39 deletions(-) diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index df913dab9..ee58790d6 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -120,19 +120,44 @@ def default(name: str): return cls( time_constant=jnp.array( - time_constant or default("time_constant"), dtype=float + ( + time_constant + if time_constant is not None + else default("time_constant") + ), + dtype=float, ), damping_coefficient=jnp.array( - damping_coefficient or default("damping_coefficient"), dtype=float + ( + damping_coefficient + if damping_coefficient is not None + else default("damping_coefficient") + ), + dtype=float, + ), + d_min=jnp.array( + d_min if d_min is not None else default("d_min"), dtype=float + ), + d_max=jnp.array( + d_max if d_max is not None else default("d_max"), dtype=float + ), + width=jnp.array( + width if width is not None else default("width"), dtype=float + ), + midpoint=jnp.array( + midpoint if midpoint is not None else default("midpoint"), dtype=float ), - d_min=jnp.array(d_min or default("d_min"), dtype=float), - d_max=jnp.array(d_max or default("d_max"), dtype=float), - width=jnp.array(width or default("width"), dtype=float), - midpoint=jnp.array(midpoint or default("midpoint"), dtype=float), - power=jnp.array(power or default("power"), dtype=float), - stiffness=jnp.array(stiffness or default("stiffness"), dtype=float), - damping=jnp.array(damping or default("damping"), dtype=float), - mu=jnp.array(mu or default("mu"), dtype=float), + power=jnp.array( + power if power is not None else default("power"), dtype=float + ), + stiffness=jnp.array( + stiffness if stiffness is not None else default("stiffness"), + dtype=float, + ), + damping=jnp.array( + damping if damping is not None else default("damping"), dtype=float + ), + mu=jnp.array(mu if mu is not None else default("mu"), dtype=float), ) def valid(self) -> jtp.BoolLike: @@ -210,7 +235,9 @@ def build( # Create the solver options to set by combining the default solver options # with the user-provided solver options. - solver_options = default_solver_options | (solver_options or {}) + solver_options = default_solver_options | ( + solver_options if solver_options is not None else {} + ) # Make sure that the solver options are hashable. # We need to check this because the solver options are static. @@ -223,9 +250,15 @@ def build( return cls( parameters=( - parameters or cls.__dataclass_fields__["parameters"].default_factory() + parameters + if parameters is not None + else cls.__dataclass_fields__["parameters"].default_factory() + ), + terrain=( + terrain + if terrain is not None + else cls.__dataclass_fields__["terrain"].default_factory() ), - terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(), _solver_options_keys=tuple(solver_options.keys()), _solver_options_values=tuple(solver_options.values()), ) diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 99078ea6d..220d65722 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -66,9 +66,17 @@ def build( """Create a `RigidContactParams` instance""" return cls( - mu=mu or cls.__dataclass_fields__["mu"].default, - K=K or cls.__dataclass_fields__["K"].default, - D=D or cls.__dataclass_fields__["D"].default, + mu=jnp.array( + mu + if mu is not None + else cls.__dataclass_fields__["mu"].default_factory() + ).astype(float), + K=jnp.array( + K if K is not None else cls.__dataclass_fields__["K"].default_factory() + ).astype(float), + D=jnp.array( + D if D is not None else cls.__dataclass_fields__["D"].default_factory() + ).astype(float), ) def valid(self) -> jtp.BoolLike: @@ -147,7 +155,9 @@ def build( # Create the solver options to set by combining the default solver options # with the user-provided solver options. - solver_options = default_solver_options | (solver_options or {}) + solver_options = default_solver_options | ( + solver_options if solver_options is not None else {} + ) # Make sure that the solver options are hashable. # We need to check this because the solver options are static. @@ -160,12 +170,19 @@ def build( return cls( parameters=( - parameters or cls.__dataclass_fields__["parameters"].default_factory() + parameters + if parameters is not None + else cls.__dataclass_fields__["parameters"].default_factory() + ), + terrain=( + terrain + if terrain is not None + else cls.__dataclass_fields__["terrain"].default_factory() ), - terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(), regularization_delassus=float( regularization_delassus - or cls.__dataclass_fields__["regularization_delassus"].default + if regularization_delassus is not None + else cls.__dataclass_fields__["regularization_delassus"].default ), _solver_options_keys=tuple(solver_options.keys()), _solver_options_values=tuple(solver_options.values()), diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 41d744453..e726df379 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -237,9 +237,13 @@ def build( else cls.__dataclass_fields__["parameters"].default_factory() ) - return SoftContacts( + return cls( parameters=parameters, - terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(), + terrain=( + terrain + if terrain is not None + else cls.__dataclass_fields__["terrain"].default_factory() + ), ) @classmethod diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py index 184a7bb4e..1ffda6338 100644 --- a/src/jaxsim/rbda/contacts/visco_elastic.py +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -13,6 +13,7 @@ import jaxsim.exceptions import jaxsim.typing as jtp from jaxsim import logging +from jaxsim.api.common import ModelDataWithVelocityRepresentation from jaxsim.math import StandardGravity from jaxsim.terrain import FlatTerrain, Terrain @@ -235,11 +236,17 @@ def build( else cls.__dataclass_fields__["parameters"].default_factory() ) - return ViscoElasticContacts( + return cls( parameters=parameters, - terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(), + terrain=( + terrain + if terrain is not None + else cls.__dataclass_fields__["terrain"].default_factory() + ), max_squarings=int( - max_squarings or cls.__dataclass_fields__["max_squarings"].default + max_squarings + if max_squarings is not None + else cls.__dataclass_fields__["max_squarings"].default ), ) @@ -315,8 +322,8 @@ def compute_contact_forces( model=model, data=data, dt=jnp.array(dt).astype(float), - joint_force_references=joint_force_references, link_forces=link_forces, + joint_force_references=joint_force_references, indices_of_enabled_collidable_points=indices_of_enabled_collidable_points, max_squarings=self.max_squarings, ) @@ -334,11 +341,13 @@ def compute_contact_forces( # Vmapped transformation from mixed to inertial-fixed representation. compute_forces_inertial_fixed_vmap = jax.vmap( - lambda CW_fl_C, W_H_C: data.other_representation_to_inertial( - array=jnp.zeros(6).at[0:3].set(CW_fl_C), - other_representation=jaxsim.VelRepr.Mixed, - transform=W_H_C, - is_force=True, + lambda CW_fl_C, W_H_C: ( + ModelDataWithVelocityRepresentation.other_representation_to_inertial( + array=jnp.zeros(6).at[0:3].set(CW_fl_C), + other_representation=jaxsim.VelRepr.Mixed, + transform=W_H_C, + is_force=True, + ) ) ) @@ -407,8 +416,8 @@ def _compute_contact_forces_with_exponential_integration( A, b, A_sc, b_sc = ViscoElasticContacts._contact_points_dynamics( model=model, data=data, - joint_force_references=joint_force_references, link_forces=link_forces, + joint_force_references=joint_force_references, indices_of_enabled_collidable_points=indices, p_t0=p_t0, v_t0=v_t0, @@ -657,8 +666,8 @@ def _contact_points_dynamics( BW_v̇_free_WB, s̈_free = js.ode.system_acceleration( model=model, data=data, - joint_force_references=references.joint_force_references(model=model), link_forces=references.link_forces(model=model, data=data), + joint_force_references=references.joint_force_references(model=model), ) # Pack the free system acceleration in mixed representation. @@ -688,7 +697,20 @@ def _linearize_contact_model( parameters: ViscoElasticContactsParams, terrain: Terrain, ) -> tuple[jtp.Matrix, jtp.Vector]: - """""" + """ + Linearize the Hunt/Crossley contact model at the initial state. + + Args: + position: The position of the contact point. + velocity: The velocity of the contact point. + tangential_deformation: The tangential deformation of the contact point. + parameters: The parameters of the contact model. + terrain: The considered terrain. + + Returns: + A tuple containing the `A` matrix and the `b` vector of the linear system + corresponding to the contact dynamics linearized at the initial state. + """ # Initialize the state at which the model is linearized. p0 = jnp.array(position, dtype=float).squeeze() @@ -1008,11 +1030,13 @@ def step( # For integration purpose, we need the average of average forces expressed in # mixed representation. LW_f̿_L = jax.vmap( - lambda W_f_L, W_H_L: data.inertial_to_other_representation( - array=W_f_L, - other_representation=jaxsim.VelRepr.Mixed, - transform=W_H_L, - is_force=True, + lambda W_f_L, W_H_L: ( + ModelDataWithVelocityRepresentation.inertial_to_other_representation( + array=W_f_L, + other_representation=jaxsim.VelRepr.Mixed, + transform=W_H_L, + is_force=True, + ) ) )(W_f̿_L, W_H_L) From 4266df07c97cfbce687a5f8a6e7b4d547b1b03d2 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 22 Oct 2024 08:56:35 +0200 Subject: [PATCH 04/10] Update how extra arguments are passed to contact params builder methods --- src/jaxsim/api/contact.py | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 889aa8740..95c848737 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -369,11 +369,13 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: max_penetration=max_δ, number_of_active_collidable_points_steady_state=nc, damping_ratio=damping_ratio, - **dict( - p=model.contact_model.parameters.p, - q=model.contact_model.parameters.q, - ) - | kwargs, + **( + dict( + p=model.contact_model.parameters.p, + q=model.contact_model.parameters.q, + ) + | kwargs + ), ) case contacts.ViscoElasticContacts(): @@ -387,11 +389,13 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: max_penetration=max_δ, number_of_active_collidable_points_steady_state=nc, damping_ratio=damping_ratio, - **dict( - p=model.contact_model.parameters.p, - q=model.contact_model.parameters.q, - ) - | kwargs, + **( + dict( + p=model.contact_model.parameters.p, + q=model.contact_model.parameters.q, + ) + | kwargs + ), ) ) @@ -404,11 +408,13 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: parameters = contacts.RigidContactsParams.build( mu=static_friction_coefficient, - **dict( - K=K, - D=2 * jnp.sqrt(K), - ) - | kwargs, + **( + dict( + K=K, + D=2 * jnp.sqrt(K), + ) + | kwargs + ), ) case contacts.RelaxedRigidContacts(): From a9ccd15dd34dba3edbd7f20062e8cc628ff2a69f Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 22 Oct 2024 09:34:58 +0200 Subject: [PATCH 05/10] Fix dummy integration of material deformation of visco-elastic contacts --- src/jaxsim/api/ode.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 56229464a..ea7616f18 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -383,11 +383,9 @@ def system_dynamics( case contacts.ViscoElasticContacts(): - extended_ode_state["contacts_state"] = { - "tangential_deformation": jnp.zeros_like( - data.state.extended["tangential_deformation"] - ) - } + extended_ode_state["tangential_deformation"] = jnp.zeros_like( + data.state.extended["tangential_deformation"] + ) case contacts.RigidContacts() | contacts.RelaxedRigidContacts(): pass From 2a04195a53b7b8eed704837e1c86ceb02608527b Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 22 Oct 2024 09:45:12 +0200 Subject: [PATCH 06/10] Fix transform computation in visco-elastic contacts --- src/jaxsim/rbda/contacts/visco_elastic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py index 1ffda6338..14313abef 100644 --- a/src/jaxsim/rbda/contacts/visco_elastic.py +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -1023,7 +1023,7 @@ def step( # Compute the link transforms. W_H_L = ( js.model.forward_kinematics(model=model, data=data) - if data.velocity_representation is not jaxsim.VelRepr.Mixed + if data.velocity_representation is not jaxsim.VelRepr.Inertial else jnp.zeros(shape=(model.number_of_links(), 4, 4)) ) From 8052cd62a19afdc550fafe9669253095c4d30d32 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 22 Oct 2024 10:01:30 +0200 Subject: [PATCH 07/10] Update model step --- src/jaxsim/api/model.py | 97 ++++++++++++++++++++++++++++++++--------- 1 file changed, 77 insertions(+), 20 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 10e26c8aa..5255bac39 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1953,15 +1953,67 @@ def step( integrator_kwargs = kwargs.pop("integrator_kwargs", {}) integrator_kwargs = kwargs | integrator_kwargs - integrator_state = integrator_state if integrator_state is not None else dict() + # Initialize the integrator state. + integrator_state_t0 = integrator_state if integrator_state is not None else dict() # Initialize the time-related variables. state_t0 = data.state t0 = jnp.array(t0, dtype=float) dt = jnp.array(dt if dt is not None else model.time_step).astype(float) - # Rename the integrator state. - integrator_state_t0 = integrator_state + # The visco-elastic contacts operate at best with their own integrator. + # They can be used with Euler-like integrators, paying the price of ignoring + # some of the benefits of continuous-time integration on the system position. + # Furthermore, the requirement to know the Δt used by the integrator is not + # compatible with high-order integrators, that use advanced RK stages to evaluate + # the dynamics at intermediate times. + module = jaxsim.rbda.contacts.visco_elastic.step.__module__ + name = jaxsim.rbda.contacts.visco_elastic.step.__name__ + msg = "You need to use the custom '{}.{}' function with this contact model." + jaxsim.exceptions.raise_runtime_error_if( + condition=jnp.logical_and( + isinstance(model.contact_model, jaxsim.rbda.contacts.ViscoElasticContacts), + jnp.array( + [ + jnp.logical_not(jnp.allclose(dt, model.time_step)), + jnp.logical_not( + isinstance( + integrator, jaxsim.integrators.fixed_step.ForwardEuler + ) + ), + ] + ).any(), + ), + msg=msg.format(module, name), + ) + + # ================= + # Phase 1: pre-step + # ================= + + # TODO: some contact models here may want to perform a dynamic filtering of + # the enabled collidable points. + + # Build the references object. + # We assume that the link forces are expressed in the frame corresponding to the + # velocity representation of the data. + references = js.references.JaxSimModelReferences.build( + model=model, + data=data, + velocity_representation=data.velocity_representation, + link_forces=link_forces, + joint_force_references=joint_force_references, + ) + + # ============= + # Phase 2: step + # ============= + + # Prepare the references to pass. + with references.switch_velocity_representation(data.velocity_representation): + + f_L = references.link_forces(model=model, data=data) + τ_references = references.joint_force_references(model=model) # Step the dynamics forward. state_tf, integrator_state_tf = integrator.step( @@ -1971,7 +2023,7 @@ def step( params=integrator_state_t0, # Always inject the current (model, data) pair into the system dynamics # considered by the integrator, and include the input variables represented - # by the pair (joint_force_references, link_forces). + # by the pair (f_L, τ_references). # Note that the wrapper of the system dynamics will override (state_x0, t0) # inside the passed data even if it is not strictly needed. This logic is # necessary to re-use the jit-compiled step function of compatible pytrees @@ -1980,8 +2032,8 @@ def step( dict( model=model, data=data, - joint_force_references=joint_force_references, - link_forces=link_forces, + link_forces=f_L, + joint_force_references=τ_references, ) | integrator_kwargs ), @@ -1990,6 +2042,10 @@ def step( # Store the new state of the model. data_tf = data.replace(state=state_tf) + # ================== + # Phase 3: post-step + # ================== + # Post process the simulation state, if needed. match model.contact_model: @@ -2017,17 +2073,18 @@ def step( msg="Baumgarte stabilization is not supported with ForwardEuler integrators", ) + W_p_C = js.contact.collidable_point_positions(model, data_tf) + + # Compute the penetration depth of the collidable points. + δ, *_ = jax.vmap( + jaxsim.rbda.contacts.common.compute_penetration_data, + in_axes=(0, 0, None), + )(W_p_C, jnp.zeros_like(W_p_C), model.terrain) + with data_tf.switch_velocity_representation(VelRepr.Mixed): J_WC = js.contact.jacobian(model, data_tf) M = js.model.free_floating_mass_matrix(model, data_tf) - W_p_C = js.contact.collidable_point_positions(model, data_tf) - - # Compute the penetration depth of the collidable points. - δ, *_ = jax.vmap( - jaxsim.rbda.contacts.common.compute_penetration_data, - in_axes=(0, 0, None), - )(W_p_C, jnp.zeros_like(W_p_C), model.terrain) # Compute the impact velocity. # It may be discontinuous in case new contacts are made. @@ -2040,13 +2097,13 @@ def step( ) ) - # Reset the generalized velocity. - data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6]) - data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:]) + # Reset the generalized velocity. + data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6]) + data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:]) - # Restore the input velocity representation. - data_tf = data_tf.replace( - velocity_representation=data.velocity_representation, validate=False - ) + # Restore the input velocity representation. + data_tf = data_tf.replace( + velocity_representation=data.velocity_representation, validate=False + ) return data_tf, integrator_state_tf From 8f202fa7d69cdc78e7767345be4c43d22118f1af Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 22 Oct 2024 11:01:20 +0200 Subject: [PATCH 08/10] Always compute visco-elastic contact models on inertial-fixed data --- src/jaxsim/rbda/contacts/visco_elastic.py | 34 ++++++++++++++++++----- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py index 14313abef..490122f98 100644 --- a/src/jaxsim/rbda/contacts/visco_elastic.py +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -991,16 +991,31 @@ def step( assert isinstance(model.contact_model, ViscoElasticContacts) assert isinstance(data.contacts_params, ViscoElasticContactsParams) + # Compute the contact forces in inertial-fixed representation. + # TODO: understand what's wrong in other representations. + data_inertial_fixed = data.replace( + velocity_representation=jaxsim.VelRepr.Inertial, validate=False + ) + + # Create the references object. + references = js.references.JaxSimModelReferences.build( + model=model, + data=data, + link_forces=link_forces, + joint_force_references=joint_force_references, + velocity_representation=data.velocity_representation, + ) + # Initialize the time step. dt = dt if dt is not None else model.time_step # Compute the contact forces with the exponential integrator. W_f̅_C, aux_data = model.contact_model.compute_contact_forces( model=model, - data=data, + data=data_inertial_fixed, dt=jnp.array(dt).astype(float), - link_forces=link_forces, - joint_force_references=joint_force_references, + link_forces=references.link_forces(model=model, data=data), + joint_force_references=references.joint_force_references(model=model), ) # Extract the final material deformation and the average of average forces @@ -1016,7 +1031,7 @@ def step( # to the same link. W_f̅_L, W_f̿_L = jax.vmap( lambda W_f_C: model.contact_model.link_forces_from_contact_forces( - model=model, data=data, contact_forces=W_f_C + model=model, data=data_inertial_fixed, contact_forces=W_f_C ) )(jnp.stack([W_f̅_C, W_f̿_C])) @@ -1048,10 +1063,10 @@ def step( data_tf: js.data.JaxSimModelData = ( model.contact_model.integrate_data_with_average_contact_forces( model=model, - data=data, + data=data_inertial_fixed, dt=dt, - link_forces=link_forces, - joint_force_references=joint_force_references, + link_forces=references.link_forces(model=model, data=data), + joint_force_references=references.joint_force_references(model=model), average_link_contact_forces_inertial=W_f̅_L, average_of_average_link_contact_forces_mixed=LW_f̿_L, ) @@ -1074,4 +1089,9 @@ def step( .set(m_tf) } + # Restore the original velocity representation. + data_tf = data_tf.replace( + velocity_representation=data.velocity_representation, validate=False + ) + return data_tf, {} From 50ea62a6ab6ed686ba540b30cfd0da86996e26fb Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 23 Oct 2024 09:15:23 +0200 Subject: [PATCH 09/10] Unify common kwargs of contact models Co-authored-by: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> --- src/jaxsim/api/contact.py | 34 +++++++++++----------------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 95c848737..f230fd021 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -163,47 +163,35 @@ def collidable_point_dynamics( Instead, the 6D forces are returned in the active representation. """ + # Build the common kw arguments to pass to the computation of the contact forces. + common_kwargs = dict( + link_forces=link_forces, + joint_force_references=joint_force_references, + ) + # Build the additional kwargs to pass to the computation of the contact forces. match model.contact_model: case contacts.SoftContacts(): - kwargs_contact_model = kwargs + kwargs_contact_model = {} case contacts.RigidContacts(): - kwargs_contact_model = ( - dict( - link_forces=link_forces, - joint_force_references=joint_force_references, - ) - | kwargs - ) + kwargs_contact_model = common_kwargs | kwargs case contacts.RelaxedRigidContacts(): - kwargs_contact_model = ( - dict( - link_forces=link_forces, - joint_force_references=joint_force_references, - ) - | kwargs - ) + kwargs_contact_model = common_kwargs | kwargs case contacts.ViscoElasticContacts(): - kwargs_contact_model = ( - dict( - dt=model.time_step, - link_forces=link_forces, - joint_force_references=joint_force_references, - ) - | kwargs - ) + kwargs_contact_model = common_kwargs | dict(dt=model.time_step) | kwargs case _: raise ValueError(f"Invalid contact model: {model.contact_model}") + # Compute the contact forces with the active contact model. W_f_C, aux_data = model.contact_model.compute_contact_forces( model=model, data=data, From 103750e7c5e473a2990960c23e637f1435c63517 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 23 Oct 2024 08:58:27 +0200 Subject: [PATCH 10/10] Update condition with boolean operators Co-authored-by: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> --- src/jaxsim/api/model.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 5255bac39..044a2b4d3 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1971,18 +1971,12 @@ def step( name = jaxsim.rbda.contacts.visco_elastic.step.__name__ msg = "You need to use the custom '{}.{}' function with this contact model." jaxsim.exceptions.raise_runtime_error_if( - condition=jnp.logical_and( - isinstance(model.contact_model, jaxsim.rbda.contacts.ViscoElasticContacts), - jnp.array( - [ - jnp.logical_not(jnp.allclose(dt, model.time_step)), - jnp.logical_not( - isinstance( - integrator, jaxsim.integrators.fixed_step.ForwardEuler - ) - ), - ] - ).any(), + condition=( + isinstance(model.contact_model, jaxsim.rbda.contacts.ViscoElasticContacts) + & ( + ~jnp.allclose(dt, model.time_step) + | ~isinstance(integrator, jaxsim.integrators.fixed_step.ForwardEuler) + ) ), msg=msg.format(module, name), )