Skip to content

Commit

Permalink
Merge pull request #96 from arcaneframework/dev/gg-add-timer-for-solving
Browse files Browse the repository at this point in the history
Add timers for linear system solving
  • Loading branch information
grospelliergilles authored Jan 23, 2024
2 parents 1b30bb0 + 0c46e60 commit 072d9d0
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 46 deletions.
95 changes: 55 additions & 40 deletions femutils/HypreDoFLinearSystem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <arcane/core/ServiceFactory.h>
#include <arcane/core/IParallelMng.h>
#include <arcane/core/ItemPrinter.h>
#include <arcane/core/Timer.h>

#include <arcane/accelerator/core/Runner.h>

Expand Down Expand Up @@ -243,6 +244,7 @@ solve()

// Récupère le communicateur MPI associé
IParallelMng* pm = m_dof_family->parallelMng();
ITimeStats* tstat = pm->timeStats();
Parallel::Communicator arcane_comm = pm->communicator();
MPI_Comm mpi_comm = MPI_COMM_WORLD;
if (arcane_comm.isValid())
Expand Down Expand Up @@ -386,18 +388,21 @@ solve()
}
}

/* GPU pointers; efficient in large chunks */
HYPRE_IJMatrixSetValues(ij_A,
nb_local_row,
rows_nb_column_data,
rows_index_span.data(),
columns_index_span.data(),
matrix_values.data());

HYPRE_IJMatrixAssemble(ij_A);
HYPRE_IJMatrixGetObject(ij_A, (void**)&parcsr_A);
Real m2 = platform::getRealTime();
info() << "Time to create matrix=" << (m2 - m1);
{
Timer::Action ta1(tstat, "HypreLinearSystemBuildMatrix");
/* GPU pointers; efficient in large chunks */
HYPRE_IJMatrixSetValues(ij_A,
nb_local_row,
rows_nb_column_data,
rows_index_span.data(),
columns_index_span.data(),
matrix_values.data());

HYPRE_IJMatrixAssemble(ij_A);
HYPRE_IJMatrixGetObject(ij_A, (void**)&parcsr_A);
Real m2 = platform::getRealTime();
info() << "Time to create matrix=" << (m2 - m1);
}

if (do_dump_matrix) {
String file_name = String("dumpA.") + String::fromNumber(my_rank) + ".txt";
Expand Down Expand Up @@ -449,36 +454,46 @@ solve()

HYPRE_Solver solver = nullptr;
HYPRE_Solver precond = nullptr;
/* setup AMG */
HYPRE_ParCSRPCGCreate(mpi_comm, &solver);

/* Set some parameters (See Reference Manual for more parameters) */
HYPRE_PCGSetMaxIter(solver, 1000); /* max iterations */
HYPRE_PCGSetTol(solver, 1e-7); /* conv. tolerance */
HYPRE_PCGSetTwoNorm(solver, 1); /* use the two norm as the stopping criteria */
HYPRE_PCGSetPrintLevel(solver, 2); /* print solve info */
HYPRE_PCGSetLogging(solver, 1); /* needed to get run info later */

hypreCheck("HYPRE_BoomerAMGCreate", HYPRE_BoomerAMGCreate(&precond));

HYPRE_BoomerAMGCreate(&precond);
HYPRE_BoomerAMGSetPrintLevel(precond, 1); /* print amg solution info */
HYPRE_BoomerAMGSetCoarsenType(precond, 6);
HYPRE_BoomerAMGSetOldDefault(precond);
HYPRE_BoomerAMGSetRelaxType(precond, 6); /* Sym G.S./Jacobi hybrid */
HYPRE_BoomerAMGSetNumSweeps(precond, 1);
HYPRE_BoomerAMGSetTol(precond, 0.0); /* conv. tolerance zero */
HYPRE_BoomerAMGSetMaxIter(precond, 1); /* do only one iteration! */

hypreCheck("HYPRE_ParCSRPCGSetPrecond",
HYPRE_ParCSRPCGSetPrecond(solver, HYPRE_BoomerAMGSolve, HYPRE_BoomerAMGSetup, precond));
hypreCheck("HYPRE_PCGSetup",
HYPRE_ParCSRPCGSetup(solver, parcsr_A, parvector_b, parvector_x));
{
Timer::Action ta1(tstat, "HypreSetPrecond");
/* setup AMG */
HYPRE_ParCSRPCGCreate(mpi_comm, &solver);

/* Set some parameters (See Reference Manual for more parameters) */
HYPRE_PCGSetMaxIter(solver, 1000); /* max iterations */
HYPRE_PCGSetTol(solver, 1e-7); /* conv. tolerance */
HYPRE_PCGSetTwoNorm(solver, 1); /* use the two norm as the stopping criteria */
HYPRE_PCGSetPrintLevel(solver, 2); /* print solve info */
HYPRE_PCGSetLogging(solver, 1); /* needed to get run info later */

hypreCheck("HYPRE_BoomerAMGCreate", HYPRE_BoomerAMGCreate(&precond));

HYPRE_BoomerAMGCreate(&precond);
HYPRE_BoomerAMGSetPrintLevel(precond, 1); /* print amg solution info */
HYPRE_BoomerAMGSetCoarsenType(precond, 6);
HYPRE_BoomerAMGSetOldDefault(precond);
HYPRE_BoomerAMGSetRelaxType(precond, 6); /* Sym G.S./Jacobi hybrid */
HYPRE_BoomerAMGSetNumSweeps(precond, 1);
HYPRE_BoomerAMGSetTol(precond, 0.0); /* conv. tolerance zero */
HYPRE_BoomerAMGSetMaxIter(precond, 1); /* do only one iteration! */

hypreCheck("HYPRE_ParCSRPCGSetPrecond",
HYPRE_ParCSRPCGSetPrecond(solver, HYPRE_BoomerAMGSolve, HYPRE_BoomerAMGSetup, precond));
}
Real a1 = platform::getRealTime();
hypreCheck("HYPRE_PCGSolve",
HYPRE_ParCSRPCGSolve(solver, parcsr_A, parvector_b, parvector_x));
{
Timer::Action ta1(tstat, "HypreSetup");
hypreCheck("HYPRE_PCGSetup",
HYPRE_ParCSRPCGSetup(solver, parcsr_A, parvector_b, parvector_x));
}

{
Timer::Action ta1(tstat, "HypreLinearSystemSolve");
hypreCheck("HYPRE_PCGSolve",
HYPRE_ParCSRPCGSolve(solver, parcsr_A, parvector_b, parvector_x));
}
Real b1 = platform::getRealTime();
info() << "Time to solve=" << (b1 - a1);
info() << "Time to setup and solve=" << (b1 - a1);

if (is_parallel) {
Int32 nb_wanted_row = m_parallel_rows_index.extent0();
Expand Down
19 changes: 13 additions & 6 deletions poisson/FemModule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,6 @@ _doStationarySolve()
_assembleLinearOperator();
}

// # T=linalg.solve(K,RHS)
_solve();

// Check results
Expand Down Expand Up @@ -2133,21 +2132,28 @@ _assembleCooGPUBilinearOperatorTRIA3()
void FemModule::
_solve()
{

Timer::Action timer_action(m_time_stats, "Solving");
ITimeStats* tstat = m_time_stats;
Timer::Action timer_action(tstat, "Solving");

std::chrono::_V2::system_clock::time_point solve_start;
if (m_register_time) {
solve_start = std::chrono::high_resolution_clock::now();
}

m_linear_system.solve();
{
Timer::Action ta1(tstat, "LinearSystemSolve");
// # T=linalg.solve(K,RHS)
m_linear_system.solve();
}

// Re-Apply boundary conditions because the solver has modified the value
// of u on all nodes
_applyDirichletBoundaryConditions();

{
Timer::Action ta1(tstat, "ApplyBoundaryConditions");
_applyDirichletBoundaryConditions();
}
{
Timer::Action ta1(tstat, "CopySolution");
VariableDoFReal& dof_u(m_linear_system.solutionVariable());
// Copy RHS DoF to Node u
auto node_dof(m_dofs_on_nodes.nodeDoFConnectivityView());
Expand All @@ -2160,6 +2166,7 @@ _solve()

//test
m_u.synchronize();

// def update_T(self,T):
// """Update u value on nodes after the FE resolution"""
// for i in range(0,len(self.mesh.nodes)):
Expand Down

0 comments on commit 072d9d0

Please sign in to comment.