diff --git a/demo/thermomechanical_dlrbnicsx_multigpu/dlrbnicsx_thermal_multigpu.py b/demo/thermomechanical_dlrbnicsx_multigpu/dlrbnicsx_thermal_multigpu.py index 40388c9..45f050d 100644 --- a/demo/thermomechanical_dlrbnicsx_multigpu/dlrbnicsx_thermal_multigpu.py +++ b/demo/thermomechanical_dlrbnicsx_multigpu/dlrbnicsx_thermal_multigpu.py @@ -714,7 +714,7 @@ def generate_ann_output_set(problem, reduced_problem, input_set, thermal_solution_file.write_function(thermal_projection_error_function_plot) - if thermal_cpu_group0_comm != MPI.COMM_NULL: + if thermal_gpu_group0_comm != MPI.COMM_NULL: print(f"Training time (Thermal): {thermal_elapsed_time}") if world_comm.rank == 0: