diff --git a/docs/play_examples.py b/docs/play_examples.py index 24459d3..3890f7c 100644 --- a/docs/play_examples.py +++ b/docs/play_examples.py @@ -694,3 +694,24 @@ def fshard(x): jax.debug.visualize_array_sharding(y[:,0]) #print(y) + +################################# + +import jax +from jax import jit +from functools import partial + + +def inner_function(x, static_val): + # The behavior of `static_val` in JIT context depends on how it was passed + return x + static_val + +def outer_function(x, static_val): + return inner_function(x, static_val) + +# JIT-compile the outer function where the second argument is treated as static +compiled_function = jit(outer_function, static_argnums=(1,)) + +# Call the compiled function with a static argument +result = compiled_function(3, 10) +print(result) # Expected to print 13 diff --git a/feniax/intrinsic/dq_dynamic.py b/feniax/intrinsic/dq_dynamic.py index 94ae034..59d1b69 100644 --- a/feniax/intrinsic/dq_dynamic.py +++ b/feniax/intrinsic/dq_dynamic.py @@ -210,7 +210,7 @@ def dq_20g242(t, q, *args): # @jax.jit # @partial(jax.jit, static_argnames=['q']) def dq_20g21(t, q, *args): - """Gust response.""" + """Gust response, clamped model""" ( eta_0, diff --git a/feniax/intrinsic/dynamicFast.py b/feniax/intrinsic/dynamicFast.py index c513382..e3b525d 100644 --- a/feniax/intrinsic/dynamicFast.py +++ b/feniax/intrinsic/dynamicFast.py @@ -8,17 +8,8 @@ import feniax.intrinsic.dq_dynamic as dq_dynamic import feniax.systems.intrinsic_system as isys -@partial(jax.jit, static_argnames=["config"]) -def main_20g1( - q0, - config, - *args, - **kwargs, -): - """ - Dynamic response free vibrations - """ - +def _get_inputs(config, **kwargs): + kwargs_list = list(kwargs.keys()) if "Ka" in kwargs_list: Ka = kwargs.get("Ka") @@ -40,108 +31,15 @@ def main_20g1( alpha = kwargs.get("alpha") else: alpha = 1. - - config.system.build_states(config.fem.num_modes, config.fem.num_nodes) - q2_index = config.system.states["q2"] - q1_index = config.system.states["q1"] - eigenvals = jnp.load(config.fem.folder / config.fem.eig_names[0]) - eigenvecs = jnp.load(config.fem.folder / config.fem.eig_names[1]) - eigenvals = eigenvals[: config.fem.num_modes] - eigenvecs = eigenvecs[:, : config.fem.num_modes] - # solver_args = config.system.solver_settings - X = config.fem.X - ( - phi1, - psi1, - phi2, - phi1l, - phi1ml, - psi1l, - phi2l, - psi2l, - omega, - X_xdelta, - C0ab, - C06ab, - ) = adcommon._compute_modes(X, Ka, Ma, eigenvals, eigenvecs, config) - - gamma1 = couplings.f_gamma1(phi1, psi1) - gamma2 = couplings.f_gamma2(phi1ml, phi2l, psi2l, X_xdelta) - config.system.xloads.build_point_follower(config.fem.num_nodes, C06ab) - x_forceinterpol = config.system.xloads.x - y_forceinterpol = alpha * config.system.xloads.force_follower - states = config.system.states - eta0 = jnp.zeros(config.fem.num_modes) - dq_args = ( - eta0, - gamma1, - gamma2, - omega, - states, - ) - - states_puller, eqsolver = sollibs.factory( - config.system.solver_library, config.system.solver_function - ) - sol = eqsolver( - dq_dynamic.dq_20g1, - dq_args, - config.system.solver_settings, - q0=q0, - t0=config.system.t0, - t1=config.system.t1, - tn=config.system.tn, - dt=config.system.dt, - t=config.system.t, - ) - q = states_puller(sol) + input_dict = dict(Ka=Ka, Ma=Ma, eigenvals=eigenvals, + eigenvecs=eigenvecs, + alpha=alpha + ) + return input_dict - # q = _solve(dq_static.dq_10g11, t_loads, q0, dq_args, config.system.solver_settings) - # q = _solve(dq_static.dq_10g11, t_loads, q0, dq_args, config.system.solver_settings) - q1 = q[:, q1_index] - q2 = q[:, q2_index] - # X2, X3, ra, Cab = recover_staticfields(q, tn, X, q2_index, - # phi2l, psi2l, X_xdelta, C0ab, config.fem) - tn = len(q) - X1, X2, X3, ra, Cab = isys.recover_fields( - q1, q2, tn, X, phi1l, phi2l, psi2l, X_xdelta, C0ab, config - ) +def _get_inputs_aero(config, **kwargs): - return dict(phi1 = phi1, - psi1 = psi1, - phi2 = phi2, - phi1l = phi1l, - phi1ml = phi1ml, - psi1l = psi1l, - phi2l = phi2l, - psi2l = psi2l, - omega = omega, - X_xdelta = X_xdelta, - C0ab = C0ab, - C06ab = C06ab, - gamma1 = gamma1, - gamma2 = gamma2, - q = q, - X1 = X1, - X2 = X2, - X3 = X3, - ra = ra, - Cab = Cab - ) - - -@partial(jax.jit, static_argnames=["config"]) -def main_20g11( - q0, - config, - *args, - **kwargs, -): - """ - Dynamic response to Follower load - """ - kwargs_list = list(kwargs.keys()) if "Ka" in kwargs_list: Ka = kwargs.get("Ka") @@ -163,128 +61,6 @@ def main_20g11( alpha = kwargs.get("alpha") else: alpha = 1. - - config.system.build_states(config.fem.num_modes, config.fem.num_nodes) - q2_index = config.system.states["q2"] - q1_index = config.system.states["q1"] - eigenvals = jnp.load(config.fem.folder / config.fem.eig_names[0]) - eigenvecs = jnp.load(config.fem.folder / config.fem.eig_names[1]) - eigenvals = eigenvals[: config.fem.num_modes] - eigenvecs = eigenvecs[:, : config.fem.num_modes] - # solver_args = config.system.solver_settings - X = config.fem.X - ( - phi1, - psi1, - phi2, - phi1l, - phi1ml, - psi1l, - phi2l, - psi2l, - omega, - X_xdelta, - C0ab, - C06ab, - ) = adcommon._compute_modes(X, Ka, Ma, eigenvals, eigenvecs, config) - - gamma1 = couplings.f_gamma1(phi1, psi1) - gamma2 = couplings.f_gamma2(phi1ml, phi2l, psi2l, X_xdelta) - config.system.xloads.build_point_follower(config.fem.num_nodes, C06ab) - x_forceinterpol = config.system.xloads.x - y_forceinterpol = alpha * config.system.xloads.force_follower - states = config.system.states - eta0 = jnp.zeros(config.fem.num_modes) - dq_args = ( - eta0, - gamma1, - gamma2, - omega, - phi1l, - x_forceinterpol, - y_forceinterpol, - states, - ) - - states_puller, eqsolver = sollibs.factory( - config.system.solver_library, config.system.solver_function - ) - - sol = eqsolver( - dq_dynamic.dq_20g11, - dq_args, - config.system.solver_settings, - q0=q0, - t0=config.system.t0, - t1=config.system.t1, - tn=config.system.tn, - dt=config.system.dt, - t=config.system.t, - ) - q = states_puller(sol) - - # q = _solve(dq_static.dq_10g11, t_loads, q0, dq_args, config.system.solver_settings) - # q = _solve(dq_static.dq_10g11, t_loads, q0, dq_args, config.system.solver_settings) - q1 = q[:, q1_index] - q2 = q[:, q2_index] - # X2, X3, ra, Cab = recover_staticfields(q, tn, X, q2_index, - # phi2l, psi2l, X_xdelta, C0ab, config.fem) - tn = len(q) - X1, X2, X3, ra, Cab = isys.recover_fields( - q1, q2, tn, X, phi1l, phi2l, psi2l, X_xdelta, C0ab, config - ) - - return dict(phi1 = phi1, - psi1 = psi1, - phi2 = phi2, - phi1l = phi1l, - phi1ml = phi1ml, - psi1l = psi1l, - phi2l = phi2l, - psi2l = psi2l, - omega = omega, - X_xdelta = X_xdelta, - C0ab = C0ab, - C06ab = C06ab, - gamma1 = gamma1, - gamma2 = gamma2, - q = q, - X1 = X1, - X2 = X2, - X3 = X3, - ra = ra, - Cab = Cab - ) - - -@partial(jax.jit, static_argnames=["config"]) -def main_20g21( - q0, - config, - *args, - **kwargs, -): - """ - Gust response - """ - - kwargs_list = list(kwargs.keys()) - if "Ka" in kwargs_list: - Ka = kwargs.get("Ka") - else: - Ka = config.fem.Ka - if "Ma" in kwargs_list: - Ma = kwargs.get("Ma") - else: - Ma = config.fem.Ma - if "eigenvals" in kwargs_list: - eigenvals = kwargs.get("eigenvals") - else: - eigenvals = config.fem.eigenvals - if "eigenvecs" in kwargs_list: - eigenvecs = kwargs.get("eigenvecs") - else: - eigenvecs = config.fem.eigenvecs if "gust_intensity" in kwargs_list: gust_intensity = kwargs.get("gust_intensity") else: @@ -301,18 +77,18 @@ def main_20g21( rho_inf = kwargs.get("rho_inf") else: rho_inf = config.system.aero.rho_inf - - config.system.build_states(config.fem.num_modes, config.fem.num_nodes) - q2_index = jnp.array( - range(config.fem.num_modes, 2 * config.fem.num_modes) - ) # config.system.states['q2'] - q1_index = jnp.array(range(config.fem.num_modes)) # config.system.states['q1'] - # eigenvals, eigenvecs = scipy.linalg.eigh(Ka, Ma) - eigenvals = jnp.load(config.fem.folder / config.fem.eig_names[0]).T - eigenvecs = jnp.load(config.fem.folder / config.fem.eig_names[1]).T - reduced_eigenvals = eigenvals[: config.fem.num_modes] - reduced_eigenvecs = eigenvecs[:, : config.fem.num_modes] - # solver_args = config.system.solver_settings + + input_dict = dict(Ka=Ka, Ma=Ma, eigenvals=eigenvals, + eigenvecs=eigenvecs, + alpha=alpha, + gust_intensity=gust_intensity, + gust_length=gust_length, u_inf=u_inf, rho_inf=rho_inf + ) + return input_dict + +def _build_intrinsic(get_inputs, config, **kwargs): + + input_dict = get_inputs(config, **kwargs) X = config.fem.X ( phi1, @@ -326,30 +102,123 @@ def main_20g21( omega, X_xdelta, C0ab, - C06ab, - ) = adcommon._compute_modes(X, Ka, Ma, reduced_eigenvals, reduced_eigenvecs, config) - ################# + C06ab + ) = adcommon._compute_modes(X, + input_dict['Ka'], + input_dict['Ma'], + input_dict['eigenvals'], + input_dict['eigenvecs'], + config) gamma1 = couplings.f_gamma1(phi1, psi1) gamma2 = couplings.f_gamma2(phi1ml, phi2l, psi2l, X_xdelta) + output_dict = dict(phi1 = phi1, + psi1 = psi1, + phi2 = phi2, + phi1l = phi1l, + phi1ml = phi1ml, + psi1l = psi1l, + phi2l = phi2l, + psi2l = psi2l, + omega = omega, + X_xdelta = X_xdelta, + C0ab = C0ab, + C06ab = C06ab, + gamma1=gamma1, + gamma2 = gamma2, + ) + return output_dict, input_dict - ################# +def _build_solution(q, output_dict, config): + + X = config.fem.X + tn = len(q) + q1_index = config.system.states["q1"] + q2_index = config.system.states["q2"] + q1 = q[:, q1_index] + q2 = q[:, q2_index] + X1, X2, X3, ra, Cab = isys.recover_fields( + q1, + q2, + tn, + X, + output_dict['phi1l'], + output_dict['phi2l'], + output_dict['psi2l'], + output_dict['X_xdelta'], + output_dict['C0ab'], + config + ) + output_dict['q'] = q + output_dict['X1'] = X1 + output_dict['X2'] = X2 + output_dict['X3'] = X3 + output_dict['ra'] = ra + output_dict['Cab'] = Cab + +def _build_solutionRB(q, output_dict, config): + + X = config.fem.X + tn = config.system.tn #len(q) WARNING: needs to be static for the recover + dt = config.system.dt + q1_index = config.system.states["q1"] + q2_index = config.system.states["q2"] + q1 = q[:, q1_index] + q2 = q[:, q2_index] + X1, X2, X3, ra, Cab = isys.recover_fieldsRB( + q1, + q2, + tn, + dt, + X, + output_dict['phi1l'], + output_dict['phi2l'], + output_dict['psi2l'], + output_dict['X_xdelta'], + output_dict['C0ab'], + config + ) + output_dict['q'] = q + output_dict['X1'] = X1 + output_dict['X2'] = X2 + output_dict['X3'] = X3 + output_dict['ra'] = ra + output_dict['Cab'] = Cab + + +def _get_aero(u_inf, rho_inf, config): + + q_inf = 0.5 * rho_inf * u_inf ** 2 A0 = config.system.aero.A[0] A1 = config.system.aero.A[1] A2 = config.system.aero.A[2] - A3 = config.system.aero.A[3:] - D0 = config.system.aero.D[0] - D1 = config.system.aero.D[1] - D2 = config.system.aero.D[2] - D3 = config.system.aero.D[3:] - # u_inf = config.system.aero.u_inf - # rho_inf = config.system.aero.rho_inf - q_inf = 0.5 * rho_inf * u_inf**2 # config.system.aero.q_inf + A3 = config.system.aero.A[3:] c_ref = config.system.aero.c_ref + poles = config.system.aero.poles A0hat = q_inf * A0 A1hat = c_ref * rho_inf * u_inf / 4 * A1 A2hat = c_ref**2 * rho_inf / 8 * A2 A3hat = q_inf * A3 A2hatinv = jnp.linalg.inv(jnp.eye(len(A2hat)) - A2hat) + return (q_inf, + c_ref, + poles, + A0hat, + A1hat, + A2hatinv, + A3hat + ) + +def _get_gust(input_dict, q_inf, c_ref, config): + + u_inf = input_dict['u_inf'] + rho_inf = input_dict['rho_inf'] + gust_intensity = input_dict['gust_intensity'] + gust_length = input_dict['gust_length'] + + D0 = config.system.aero.D[0] + D1 = config.system.aero.D[1] + D2 = config.system.aero.D[2] + D3 = config.system.aero.D[3:] D0hat = q_inf * D0 D1hat = c_ref * rho_inf * u_inf / 4 * D1 D2hat = c_ref**2 * rho_inf / 8 * D2 @@ -387,16 +256,139 @@ def main_20g21( gust_dot, gust_ddot, ) - poles = config.system.aero.poles + return timegust, Q_wsum, Ql_wdot + +#@partial(jax.jit, static_argnames=["config"]) +def main_20g1( + q0, + config, + *args, + **kwargs, +): + """ + Dynamic response free vibrations + """ + + output, input_dict = _build_intrinsic(_get_inputs, config, **kwargs) + + config.system.build_states(config.fem.num_modes, config.fem.num_nodes) + states = config.system.states + eta0 = jnp.zeros(config.fem.num_modes) + dq_args = ( + eta0, + output['gamma1'], + output['gamma2'], + output['omega'], + states, + ) + + states_puller, eqsolver = sollibs.factory( + config.system.solver_library, config.system.solver_function + ) + + sol = eqsolver( + dq_dynamic.dq_20g1, + dq_args, + config.system.solver_settings, + q0=q0, + t0=config.system.t0, + t1=config.system.t1, + tn=config.system.tn, + dt=config.system.dt, + t=config.system.t, + ) + q = states_puller(sol) + _build_solution(q, output, config) + return output + +#@partial(jax.jit, static_argnames=["config"]) +def main_20g11( + q0, + config, + *args, + **kwargs, +): + """ + Dynamic response to Follower load + """ + + output, input_dict = _build_intrinsic(_get_inputs, config, **kwargs) + + config.system.build_states(config.fem.num_modes, config.fem.num_nodes) + config.system.xloads.build_point_follower(config.fem.num_nodes, output['C06ab']) + + x_forceinterpol = config.system.xloads.x + y_forceinterpol = input_dict['alpha'] * config.system.xloads.force_follower + states = config.system.states + eta0 = jnp.zeros(config.fem.num_modes) + dq_args = ( + eta0, + output['gamma1'], + output['gamma2'], + output['omega'], + output['phi1l'], + x_forceinterpol, + y_forceinterpol, + states, + ) + + states_puller, eqsolver = sollibs.factory( + config.system.solver_library, config.system.solver_function + ) + + sol = eqsolver( + dq_dynamic.dq_20g11, + dq_args, + config.system.solver_settings, + q0=q0, + t0=config.system.t0, + t1=config.system.t1, + tn=config.system.tn, + dt=config.system.dt, + t=config.system.t, + ) + q = states_puller(sol) + _build_solution(q, output, config) + return output + + +#@partial(jax.jit, static_argnames=["config"]) +def main_20g21( + q0, + config, + *args, + **kwargs, +): + """ + Gust response, clamped model + """ + + output, input_dict = _build_intrinsic(_get_inputs_aero, config, **kwargs) + config.system.build_states(config.fem.num_modes, config.fem.num_nodes) + + ################# + (q_inf, + c_ref, + poles, + A0hat, + A1hat, + A2hatinv, + A3hat + ) = _get_aero(input_dict['u_inf'], input_dict['rho_inf'], config) + + timegust, Q_wsum, Ql_wdot = _get_gust(input_dict, + q_inf, + c_ref, + config) num_poles = config.system.aero.num_poles num_modes = config.fem.num_modes states = config.system.states eta0 = jnp.zeros(num_modes) dq_args = ( eta0, - gamma1, - gamma2, - omega, + output['gamma1'], + output['gamma2'], + output['omega'], states, poles, num_modes, @@ -407,7 +399,7 @@ def main_20g21( A1hat, A2hatinv, A3hat, - u_inf, + input_dict['u_inf'], Q_wsum, Ql_wdot, ) @@ -430,36 +422,75 @@ def main_20g21( t=config.system.t, ) q = states_puller(sol) + _build_solution(q, output, config) + return output - # q = _solve(dq_static.dq_10g11, t_loads, q0, dq_args, config.system.solver_settings) - # q = _solve(dq_static.dq_10g11, t_loads, q0, dq_args, config.system.solver_settings) - q1 = q[:, q1_index] - q2 = q[:, q2_index] - tn = len(q) - # X2, X3, ra, Cab = isys.recover_staticfields(q2, tn, X, - # phi2l, psi2l, X_xdelta, C0ab, config.fem) - X1, X2, X3, ra, Cab = isys.recover_fields( - q1, q2, tn, X, phi1l, phi2l, psi2l, X_xdelta, C0ab, config +def main_20g546(q0, + config, + *args, + **kwargs + ): + """Gust response free flight, q0 obtained via integrator q1.""" + + output, input_dict = _build_intrinsic(_get_inputs_aero, config, **kwargs) + config.system.build_states(config.fem.num_modes, config.fem.num_nodes) + + ################# + (q_inf, + c_ref, + poles, + A0hat, + A1hat, + A2hatinv, + A3hat + ) = _get_aero(input_dict['u_inf'], input_dict['rho_inf'], config) + + timegust, Q_wsum, Ql_wdot = _get_gust(input_dict, + q_inf, + c_ref, + config) + num_poles = config.system.aero.num_poles + num_modes = config.fem.num_modes + states = config.system.states + eta0 = jnp.zeros(num_modes) + dq_args = ( + eta0, + output['gamma1'], + output['gamma2'], + output['omega'], + output['phi1l'], + states, + poles, + num_modes, + num_poles, + timegust, + c_ref, + A0hat, + A1hat, + A2hatinv, + A3hat, + input_dict['u_inf'], + Q_wsum, + Ql_wdot, + ) + + ################# + states_puller, eqsolver = sollibs.factory( + config.system.solver_library, config.system.solver_function ) - return dict(phi1 = phi1, - psi1 = psi1, - phi2 = phi2, - phi1l = phi1l, - phi1ml = phi1ml, - psi1l = psi1l, - phi2l = phi2l, - psi2l = psi2l, - omega = omega, - X_xdelta = X_xdelta, - C0ab = C0ab, - C06ab = C06ab, - gamma1 = gamma1, - gamma2 = gamma2, - q = q, - X1 = X1, - X2 = X2, - X3 = X3, - ra = ra, - Cab = Cab - ) + sol = eqsolver( + dq_dynamic.dq_20g546, + dq_args, + config.system.solver_settings, + q0=q0, + t0=config.system.t0, + t1=config.system.t1, + tn=config.system.tn, + dt=config.system.dt, + t=config.system.t, + ) + q = states_puller(sol) + _build_solutionRB(q, output, config) + return output + diff --git a/feniax/intrinsic/staticFast.py b/feniax/intrinsic/staticFast.py index 12e65ef..7a92857 100644 --- a/feniax/intrinsic/staticFast.py +++ b/feniax/intrinsic/staticFast.py @@ -11,30 +11,9 @@ newton = partial(jax.jit, static_argnames=["F", "sett"])(libdiffrax.newton) _solve = partial(jax.jit, static_argnames=["eqsolver", "dq", "sett"])(isys._staticSolve) - -@partial(jax.jit, static_argnames=["config"]) -def main_10g11( - q0, - config, - *args, - **kwargs, -): - """ - Static computation with follower forces - """ +def _get_inputs(config, **kwargs): kwargs_list = list(kwargs.keys()) - #print(config.fem.eigenvals) - # Ka = jax.lax.select("Ka" in kwargs_list, kwargs.get("Ka"), config.fem.Ka) - # Ma = jax.lax.select("Ma" in kwargs_list, kwargs.get("Ma"), config.fem.Ma) - # # eigenvals = config.fem.eigenvals #jax.lax.cond("eigenvals" in kwargs_list, - # # #lambda kwargs: kwargs.get("eigenvals"), - # # #lambda kwargs: config.fem.eigenvals, - # # #kwargs) - # # eigenvecs = config.fem.eigenvecs #jax.lax.select("eigenvecs" in kwargs_list, - # # #kwargs.get("eigenvecs"), config.fem.eigenvecs) - # # t_loads = kwargs.get("t_loads") #config.system.t #jax.lax.select("t_loads" in kwargs_list, - # # # kwargs.get("t_loads"), config.system.t) if "Ka" in kwargs_list: Ka = kwargs.get("Ka") else: @@ -55,11 +34,12 @@ def main_10g11( t_loads = kwargs.get("t_loads") else: t_loads = config.system.t - - tn = len(t_loads) - config.system.build_states(config.fem.num_modes, config.fem.num_nodes) - q2_index = config.system.states["q2"] - # solver_args = config.system.solver_settings + + return Ka, Ma, eigenvals, eigenvecs, t_loads + +def _build_intrinsic(config, **kwargs): + + Ka, Ma, eigenvals, eigenvecs, t_loads = _get_inputs(config, **kwargs) X = config.fem.X ( phi1, @@ -75,41 +55,104 @@ def main_10g11( C0ab, C06ab ) = adcommon._compute_modes(X, Ka, Ma, eigenvals, eigenvecs, config) - gamma2 = couplings.f_gamma2(phi1ml, phi2l, psi2l, X_xdelta) + output_dict = dict(phi1 = phi1, + psi1 = psi1, + phi2 = phi2, + phi1l = phi1l, + phi1ml = phi1ml, + psi1l = psi1l, + phi2l = phi2l, + psi2l = psi2l, + omega = omega, + X_xdelta = X_xdelta, + C0ab = C0ab, + C06ab = C06ab, + gamma2 = gamma2, + t_loads = t_loads + ) + return output_dict + +def _build_solution(q, output_dict, config): + + X = config.fem.X + t_loads = output_dict['t_loads'] + tn = len(t_loads) + q2_index = config.system.states["q2"] + q2 = q[:, q2_index] + X2, X3, ra, Cab = isys.recover_staticfields( + q2, + tn, + X, + output_dict['phi2l'], + output_dict['psi2l'], + output_dict['X_xdelta'], + output_dict['C0ab'], + config + ) + X1 = jnp.zeros_like(X2) + output_dict['q'] = q + output_dict['X1'] = X1 + output_dict['X2'] = X2 + output_dict['X3'] = X3 + output_dict['ra'] = ra + output_dict['Cab'] = Cab + +def _get_aeromatrices(config): + + A0 = config.system.aero.A[0] + C0 = config.system.aero.Q0_rigid + A0hat = config.system.aero.q_inf * A0 + C0hat = config.system.aero.q_inf * C0 + return A0hat, C0hat + +# @partial(jax.jit, static_argnames=["config"]) +def main_10g11( + q0, + config, + *args, + **kwargs, +): + """Structural static with follower point forces.""" + + output = _build_intrinsic(config, **kwargs) + + t_loads = output['t_loads'] + config.system.build_states(config.fem.num_modes, config.fem.num_nodes) eta0 = jnp.zeros(config.fem.num_modes) - config.system.xloads.build_point_follower(config.fem.num_nodes, C06ab) + config.system.xloads.build_point_follower(config.fem.num_nodes, output['C06ab']) x_forceinterpol = config.system.xloads.x y_forceinterpol = config.system.xloads.force_follower - dq_args = (eta0, gamma2, omega, phi1l, x_forceinterpol, y_forceinterpol) + dq_args = (eta0, output['gamma2'], output['omega'], output['phi1l'], + x_forceinterpol, y_forceinterpol) q = _solve( newton, dq_static.dq_10g11, t_loads, q0, dq_args, config.system.solver_settings ) - q2 = q[:, q2_index] - X2, X3, ra, Cab = isys.recover_staticfields( - q2, tn, X, phi2l, psi2l, X_xdelta, C0ab, config - ) - X1 = jnp.zeros_like(X2) - return dict(phi1 = phi1, - psi1 = psi1, - phi2 = phi2, - phi1l = phi1l, - phi1ml = phi1ml, - psi1l = psi1l, - phi2l = phi2l, - psi2l = psi2l, - omega = omega, - X_xdelta = X_xdelta, - C0ab = C0ab, - C06ab = C06ab, - gamma2 = gamma2, - q = q, - X1 = X1, - X2 = X2, - X3 = X3, - ra = ra, - Cab = Cab - ) + _build_solution(q, output, config) + return output +#@partial(jax.jit, static_argnames=["config"]) +def main_10g15( + q0, + config, + *args, + **kwargs, +): + """Manoeuvre under qalpha.""" + output = _build_intrinsic(config, **kwargs) + + t_loads = output['t_loads'] + config.system.build_states(config.fem.num_modes, config.fem.num_nodes) + eta0 = jnp.zeros(config.fem.num_modes) + A0hat, C0hat = _get_aeromatrices(config) + dq_args = (eta0, output['gamma2'], output['omega'], output['phi1l'], + config.system.xloads.x, config.system.aero.qalpha, + A0hat, C0hat) + + q = _solve( + newton, dq_static.dq_10g15, t_loads, q0, dq_args, config.system.solver_settings + ) + _build_solution(q, output, config) + return output diff --git a/feniax/systems/intrinsicFast.py b/feniax/systems/intrinsicFast.py index d844da4..18504f0 100644 --- a/feniax/systems/intrinsicFast.py +++ b/feniax/systems/intrinsicFast.py @@ -7,6 +7,7 @@ from feniax.systems.intrinsic_system import IntrinsicSystem import jax +from functools import partial class IntrinsicFastSystem(IntrinsicSystem, cls_name="Fast_intrinsic"): @@ -66,7 +67,8 @@ def set_system(self): label_sys = self.settings.label self.label = f"main_{label_sys}" print(f"***** Setting intrinsic static Fast system with label {self.label}") - self.main = getattr(staticFast, self.label) + self.main = partial(jax.jit, static_argnames=["config"])( + getattr(staticFast, self.label)) def build_solution(self, q, X2, X3, ra, Cab, *args, **kwargs): @@ -90,7 +92,8 @@ def set_system(self): label_sys = self.settings.label self.label = f"main_{label_sys}" print(f"***** Setting intrinsic Dynamic Fast system with label {self.label}") - self.main = getattr(dynamicFast, self.label) + self.main = partial(jax.jit, static_argnames=["config"])( + getattr(dynamicFast, self.label)) def build_solution(self, q, X1, X2, X3, ra, Cab, **kwargs): diff --git a/feniax/systems/intrinsic_system.py b/feniax/systems/intrinsic_system.py index cf97bc3..eac38a2 100644 --- a/feniax/systems/intrinsic_system.py +++ b/feniax/systems/intrinsic_system.py @@ -78,7 +78,6 @@ def recover_staticfields(q2, tn, X, phi2l, psi2l, X_xdelta, C0ab, config): return X2, X3, ra, Cab - class IntrinsicSystem(System, cls_name="intrinsic"): def __init__( self, diff --git a/tests/conftest.py b/tests/conftest.py index c435bb2..91fc3f4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,21 +10,26 @@ def pytest_addoption(parser): parser.addoption( "--runprivate", action="store_true", default=False, help="run proprietary tests" ) + parser.addoption( + "--runall", action="store_true", default=False, help="run all tests" + ) def pytest_configure(config): config.addinivalue_line("markers", "slow: mark test as slow to run") config.addinivalue_line("markers", "private: mark test as non-opensource") + config.addinivalue_line("markers", "all: run all tests") def pytest_collection_modifyitems(config, items): - if config.getoption("--runslow") and config.getoption("--runprivate"): - # --runslow given in cli: do not skip slow tests - return - skip_slow = pytest.mark.skip(reason="need --runslow option to run") - skip_private = pytest.mark.skip(reason="need --runprivate option to run") - for item in items: - if "slow" in item.keywords: - item.add_marker(skip_slow) - if "private" in item.keywords: - item.add_marker(skip_private) + if not config.getoption("--runslow") and not config.getoption("--runall"): + skip_slow = pytest.mark.skip(reason="need --runslow option to run") + for item in items: + if "slow" in item.keywords: + item.add_marker(skip_slow) + + if not config.getoption("--runprivate") and not config.getoption("--runall"): + skip_private = pytest.mark.skip(reason="need --runprivate option to run") + for item in items: + if "private" in item.keywords: + item.add_marker(skip_private) diff --git a/tests/intrinsic/aeroelastic_dynamic/test_BUGgustFast.py b/tests/intrinsic/aeroelastic_dynamic/test_BUGgustFast.py new file mode 100644 index 0000000..09be661 --- /dev/null +++ b/tests/intrinsic/aeroelastic_dynamic/test_BUGgustFast.py @@ -0,0 +1,308 @@ +import pathlib +import time +import jax.numpy as jnp +import numpy as np +import feniax.preprocessor.configuration as configuration # import Config, dump_to_yaml +from feniax.preprocessor.inputs import Inputs +from feniax.preprocessor import solution +import feniax.feniax_main +import pytest + +file_path = pathlib.Path(__file__).parent + +class TestBUGGustclamped: + + @pytest.fixture(scope="class") + def sol(self): + + label_dlm = "d1c7" + sol = "cao" + label_gaf = "Dd1c7F3Scao-50" + num_modes = 50 + c_ref = 3.0 + u_inf = 209.62786434059765 + rho_inf = 0.41275511341689247 + num_poles = 5 + Dhj_file = f"D{label_gaf}p{num_poles}" + Ahh_file = f"A{label_gaf}p{num_poles}" + Poles_file = f"Poles{label_gaf}p{num_poles}" + inp = Inputs() + inp.engine = "intrinsicmodal" + inp.fem.eig_type = "inputs" + + inp.fem.connectivity = dict(# FusWing=['RWing', + # 'LWing'], + FusBack=['FusTail', + 'VTP'], + FusFront=None, + RWing=None, + LWing=None, + FusTail=None, + VTP=['HTP', 'VTPTail'], + HTP=['RHTP', 'LHTP'], + VTPTail=None, + RHTP=None, + LHTP=None, + ) + inp.fem.grid = file_path / f"../../../examples/BUG/FEM/structuralGrid_{sol[:-1]}" + inp.fem.Ka_name = file_path / f"../../../examples/BUG/FEM/Ka_{sol[:-1]}.npy" + inp.fem.Ma_name = file_path / f"../../../examples/BUG/FEM/Ma_{sol[:-1]}.npy" + inp.fem.eig_names = [file_path / f"../../../examples/BUG/FEM/eigenvals_{sol}{num_modes}.npy", + file_path / f"../../../examples/BUG/FEM/eigenvecs_{sol}{num_modes}.npy"] + inp.driver.typeof = "intrinsic" + inp.driver.save_fem = False + inp.driver.sol_path = None + + inp.fem.num_modes = num_modes + + inp.simulation.typeof = "single" + inp.system.operationalmode = "fast" + inp.system.save = False + inp.system.solution = "dynamic" + inp.system.t1 = 1. + inp.system.tn = 1001 + inp.system.solver_library = "runge_kutta" + inp.system.solver_function = "ode" + inp.system.solver_settings = dict(solver_name="rk4") + inp.system.xloads.modalaero_forces = True + inp.system.aero.c_ref = c_ref + inp.system.aero.u_inf = u_inf + inp.system.aero.rho_inf = rho_inf + inp.system.aero.poles = file_path / f"../../../examples/BUG/AERO/{Poles_file}.npy" + inp.system.aero.A = file_path / f"../../../examples/BUG/AERO/{Ahh_file}.npy" + inp.system.aero.D = file_path / f"../../../examples/BUG/AERO/{Dhj_file}.npy" + inp.system.aero.gust_profile = "mc" + inp.system.aero.gust.intensity = 20 + inp.system.aero.gust.length = 150. + inp.system.aero.gust.step = 0.1 + inp.system.aero.gust.shift = 0. + inp.system.aero.gust.panels_dihedral = file_path / f"../../../examples/BUG/AERO/Dihedral_{label_dlm}.npy" + inp.system.aero.gust.collocation_points = file_path / f"../../../examples/BUG/AERO/Collocation_{label_dlm}.npy" + config = configuration.Config(inp) + obj_sol = feniax.feniax_main.main(input_obj=config) + return obj_sol + + @pytest.fixture + def data(self): + sol_path = file_path / "data/BUG/gust2_cao" + sol = solution.IntrinsicSolution(sol_path) + sol.load_container("Modes") + sol.load_container("Couplings") + sol.load_container("DynamicSystem", label="_sys1") + + return sol.data + + def test_phi1(self, sol, data): + + assert jnp.allclose(sol.modes.phi1, data.modes.phi1) + + def test_phi2(self, sol, data): + + assert jnp.allclose(sol.modes.phi2, data.modes.phi2) + + def test_psi1(self, sol, data): + + assert jnp.allclose(sol.modes.psi1, data.modes.psi1) + + def test_phi1l(self, sol, data): + + assert jnp.allclose(sol.modes.phi1l, data.modes.phi1l) + + def test_phi2l(self, sol, data): + + assert jnp.allclose(sol.modes.phi2l, data.modes.phi2l) + + def test_psi1l(self, sol, data): + + assert jnp.allclose(sol.modes.psi1l, data.modes.psi1l) + + def test_psi2l(self, sol, data): + + assert jnp.allclose(sol.modes.psi2l, data.modes.psi2l) + + def test_phi1ml(self, sol, data): + + assert jnp.allclose(sol.modes.phi1ml, data.modes.phi1ml) + + def test_gamma1(self, sol, data): + + assert jnp.allclose(sol.couplings.gamma1, + data.couplings.gamma1) + + def test_gamma2(self, sol, data): + + assert jnp.allclose(sol.couplings.gamma2, + data.couplings.gamma2) + + def test_qs(self, sol, data): + + assert jnp.allclose(sol.dynamicsystem_sys1.q, + data.dynamicsystem_sys1.q) + + def test_Xs(self, sol, data): + + assert jnp.allclose(sol.dynamicsystem_sys1.X2, + data.dynamicsystem_sys1.X2, + atol=1e-5) + assert jnp.allclose(sol.dynamicsystem_sys1.X3, + data.dynamicsystem_sys1.X3) + + def test_ra(self, sol, data): + + assert jnp.allclose(sol.dynamicsystem_sys1.ra, + data.dynamicsystem_sys1.ra) + + def test_Cab(self, sol, data): + + assert jnp.allclose(sol.dynamicsystem_sys1.Cab, + data.dynamicsystem_sys1.Cab) + +class TestBUGGustfree: + + @pytest.fixture(scope="class") + def sol(self): + + label_dlm = "d1c7" + sol = "eao" + label_gaf = "Dd1c7F3Seao-50" + num_modes = 50 + c_ref = 3.0 + u_inf = 209.62786434059765 + rho_inf = 0.41275511341689247 + num_poles = 5 + Dhj_file = f"D{label_gaf}p{num_poles}" + Ahh_file = f"A{label_gaf}p{num_poles}" + Poles_file = f"Poles{label_gaf}p{num_poles}" + inp = Inputs() + inp.engine = "intrinsicmodal" + inp.fem.eig_type = "inputs" + + inp.fem.connectivity = dict(# FusWing=['RWing', + # 'LWing'], + FusBack=['FusTail', + 'VTP'], + FusFront=None, + RWing=None, + LWing=None, + FusTail=None, + VTP=['HTP', 'VTPTail'], + HTP=['RHTP', 'LHTP'], + VTPTail=None, + RHTP=None, + LHTP=None, + ) + inp.fem.grid = file_path / f"../../../examples/BUG/FEM/structuralGrid_{sol[:-1]}" + inp.fem.Ka_name = file_path / f"../../../examples/BUG/FEM/Ka_{sol[:-1]}.npy" + inp.fem.Ma_name = file_path / f"../../../examples/BUG/FEM/Ma_{sol[:-1]}.npy" + inp.fem.eig_names = [file_path / f"../../../examples/BUG/FEM/eigenvals_{sol}{num_modes}.npy", + file_path / f"../../../examples/BUG/FEM/eigenvecs_{sol}{num_modes}.npy"] + inp.driver.typeof = "intrinsic" + inp.driver.save_fem = False + inp.driver.sol_path = None + + inp.fem.num_modes = num_modes + + inp.simulation.typeof = "single" + inp.system.operationalmode = "fast" + inp.system.save = False + inp.system.solution = "dynamic" + inp.system.bc1 = 'free' + inp.system.q0treatment = 1 + inp.system.t1 = 1. + inp.system.tn = 1001 + inp.system.solver_library = "runge_kutta" + inp.system.solver_function = "ode" + inp.system.solver_settings = dict(solver_name="rk4") + inp.system.xloads.modalaero_forces = True + inp.system.aero.c_ref = c_ref + inp.system.aero.u_inf = u_inf + inp.system.aero.rho_inf = rho_inf + inp.system.aero.poles = file_path / f"../../../examples/BUG/AERO/{Poles_file}.npy" + inp.system.aero.A = file_path / f"../../../examples/BUG/AERO/{Ahh_file}.npy" + inp.system.aero.D = file_path / f"../../../examples/BUG/AERO/{Dhj_file}.npy" + inp.system.aero.gust_profile = "mc" + inp.system.aero.gust.intensity = 20 + inp.system.aero.gust.length = 150. + inp.system.aero.gust.step = 0.1 + inp.system.aero.gust.shift = 0. + inp.system.aero.gust.panels_dihedral = file_path / f"../../../examples/BUG/AERO/Dihedral_{label_dlm}.npy" + inp.system.aero.gust.collocation_points = file_path / f"../../../examples/BUG/AERO/Collocation_{label_dlm}.npy" + config = configuration.Config(inp) + obj_sol = feniax.feniax_main.main(input_obj=config) + return obj_sol + + @pytest.fixture + def data(self): + sol_path = file_path / "data/BUG/gust2_eao" + sol = solution.IntrinsicSolution(sol_path) + sol.load_container("Modes") + sol.load_container("Couplings") + sol.load_container("DynamicSystem", label="_sys1") + + return sol.data + + def test_phi1(self, sol, data): + + assert jnp.allclose(sol.modes.phi1, data.modes.phi1) + + def test_phi2(self, sol, data): + + assert jnp.allclose(sol.modes.phi2, data.modes.phi2) + + def test_psi1(self, sol, data): + + assert jnp.allclose(sol.modes.psi1, data.modes.psi1) + + def test_phi1l(self, sol, data): + + assert jnp.allclose(sol.modes.phi1l, data.modes.phi1l) + + def test_phi2l(self, sol, data): + + assert jnp.allclose(sol.modes.phi2l, data.modes.phi2l) + + def test_psi1l(self, sol, data): + + assert jnp.allclose(sol.modes.psi1l, data.modes.psi1l) + + def test_psi2l(self, sol, data): + + assert jnp.allclose(sol.modes.psi2l, data.modes.psi2l) + + def test_phi1ml(self, sol, data): + + assert jnp.allclose(sol.modes.phi1ml, data.modes.phi1ml) + + def test_gamma1(self, sol, data): + + assert jnp.allclose(sol.couplings.gamma1, + data.couplings.gamma1) + + def test_gamma2(self, sol, data): + + assert jnp.allclose(sol.couplings.gamma2, + data.couplings.gamma2) + + def test_qs(self, sol, data): + + assert jnp.allclose(sol.dynamicsystem_sys1.q, + data.dynamicsystem_sys1.q) + + def test_Xs(self, sol, data): + + assert jnp.allclose(sol.dynamicsystem_sys1.X2, + data.dynamicsystem_sys1.X2, + atol=1e-5) + assert jnp.allclose(sol.dynamicsystem_sys1.X3, + data.dynamicsystem_sys1.X3) + + def test_ra(self, sol, data): + + assert jnp.allclose(sol.dynamicsystem_sys1.ra, + data.dynamicsystem_sys1.ra) + + def test_Cab(self, sol, data): + + assert jnp.allclose(sol.dynamicsystem_sys1.Cab, + data.dynamicsystem_sys1.Cab) +