Skip to content

Commit

Permalink
merging relevant changes from VXL https://github.com/vxl/vxl/tree/mas…
Browse files Browse the repository at this point in the history
  • Loading branch information
tvercaut committed Jan 26, 2025
1 parent adfea91 commit c29766b
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 143 deletions.
199 changes: 98 additions & 101 deletions Source/lsmrBase.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
* limitations under the License.
*
*=========================================================================*/

#include "lsmrBase.h"

#include <algorithm>
Expand All @@ -41,7 +40,7 @@ lsmrBase::lsmrBase()
this->btol = 1e-6;
this->conlim = 1.0 / ( 10 * sqrt( this->eps ) );
this->itnlim = 10;
this->nout = NULL;
this->nout = nullptr;
this->istop = 0;
this->itn = 0;
this->normA = 0.0;
Expand All @@ -58,9 +57,7 @@ lsmrBase::lsmrBase()
}


lsmrBase::~lsmrBase()
{
}
lsmrBase::~lsmrBase() = default;


unsigned int
Expand Down Expand Up @@ -249,24 +246,24 @@ lsmrBase::Dnrm2( unsigned int n, const double *x ) const
for ( unsigned int i = 0; i < n; i++ )
{
if ( x[i] != 0.0 )
{
double dx = x[i];
const double absxi = std::abs(dx);

if ( magnitudeOfLargestElement < absxi )
{
// rescale the sum to the range of the new element
dx = magnitudeOfLargestElement / absxi;
sumOfSquaresScaled = sumOfSquaresScaled * (dx * dx) + 1.0;
magnitudeOfLargestElement = absxi;
}
else
{
// rescale the new element to the range of the sum
dx = absxi / magnitudeOfLargestElement;
sumOfSquaresScaled += dx * dx;
}
}
{
double dx = x[i];
const double absxi = std::abs(dx);

if ( magnitudeOfLargestElement < absxi )
{
// rescale the sum to the range of the new element
dx = magnitudeOfLargestElement / absxi;
sumOfSquaresScaled = sumOfSquaresScaled * (dx * dx) + 1.0;
magnitudeOfLargestElement = absxi;
}
else
{
// rescale the new element to the range of the sum
dx = absxi / magnitudeOfLargestElement;
sumOfSquaresScaled += dx * dx;
}
}
}

const double norm = magnitudeOfLargestElement * sqrt( sumOfSquaresScaled );
Expand All @@ -290,7 +287,7 @@ Solve( unsigned int m, unsigned int n, const double * b, double * x )

// Initialize.

unsigned int localVecs = std::min( localSize, std::min( m,n ) );
unsigned int const localVecs = std::min( localSize, std::min( m,n ) );

if( this->nout )
{
Expand All @@ -302,7 +299,7 @@ Solve( unsigned int m, unsigned int n, const double * b, double * x )
(*this->nout) << " localSize (no. of vectors for local reorthogonalization) = " << this->localSize << std::endl;
}

int pfreq = 20;
int const pfreq = 20;
int pcount = 0;
this->damped = ( this->damp > zero );

Expand Down Expand Up @@ -400,15 +397,15 @@ Solve( unsigned int m, unsigned int n, const double * b, double * x )
if ( this->nout )
{
if ( damped )
{
(*this->nout) << " Itn x(1) norm rbar Abar'rbar"
" Compatible LS norm Abar cond Abar\n";
}
{
(*this->nout) << " Itn x(1) norm rbar Abar'rbar"
" Compatible LS norm Abar cond Abar\n";
}
else
{
(*this->nout) << " Itn x(1) norm r A'r "
" Compatible LS norm A cond A\n";
}
{
(*this->nout) << " Itn x(1) norm r A'r "
" Compatible LS norm A cond A\n";
}

test1 = one;
test2 = alpha / beta;
Expand All @@ -434,34 +431,34 @@ Solve( unsigned int m, unsigned int n, const double * b, double * x )

if ( beta > zero )
{
this->Scale( m, (one/beta), u );
if ( localOrtho ) {
if (localPointer+1 < localVecs) {
localPointer = localPointer + 1;
} else {
localPointer = 0;
localVQueueFull = true;
}
std::copy( v, v+n, localV+localPointer*n );
}
this->Scale( n, (- beta), v );
this->Aprod2( m, n, v, u ); // v = A'*u
if ( localOrtho ) {
unsigned int localOrthoLimit = localVQueueFull ? localVecs : localPointer+1;

for( unsigned int localOrthoCount =0; localOrthoCount<localOrthoLimit;
++localOrthoCount) {
double d = std::inner_product(v,v+n,localV+n*localOrthoCount,0.0);
daxpy( n, -d, localV+localOrthoCount*n, v );
}
}

alpha = this->Dnrm2( n, v );

if ( alpha > zero )
{
this->Scale( m, (one/beta), u );
if ( localOrtho ) {
if (localPointer+1 < localVecs) {
localPointer = localPointer + 1;
} else {
localPointer = 0;
localVQueueFull = true;
}
std::copy( v, v+n, localV+localPointer*n );
}
this->Scale( n, (- beta), v );
this->Aprod2( m, n, v, u ); // v = A'*u
if ( localOrtho ) {
unsigned int localOrthoLimit = localVQueueFull ? localVecs : localPointer+1;

for( unsigned int localOrthoCount =0; localOrthoCount<localOrthoLimit;
++localOrthoCount) {
double d = std::inner_product(v,v+n,localV+n*localOrthoCount,0.0);
daxpy( n, -d, localV+localOrthoCount*n, v );
}
}

alpha = this->Dnrm2( n, v );

if ( alpha > zero )
{
this->Scale( n, (one/alpha), v );
}
}
}

// At this point, beta = beta_{k+1}, alpha = alpha_{k+1}.
Expand All @@ -470,25 +467,25 @@ Solve( unsigned int m, unsigned int n, const double * b, double * x )
//----------------------------------------------------------------
// Construct rotation Qhat_{k,2k+1}.

double alphahat = this->D2Norm( alphabar, damp );
double chat = alphabar/alphahat;
double shat = damp/alphahat;
double const alphahat = this->D2Norm( alphabar, damp );
double const chat = alphabar/alphahat;
double const shat = damp/alphahat;

// Use a plane rotation (Q_i) to turn B_i to R_i.

double rhoold = rho;
double const rhoold = rho;
rho = D2Norm(alphahat, beta);
double c = alphahat/rho;
double s = beta/rho;
double thetanew = s*alpha;
double const c = alphahat/rho;
double const s = beta/rho;
double const thetanew = s*alpha;
alphabar = c*alpha;

// Use a plane rotation (Qbar_i) to turn R_i^T into R_i^bar.

double rhobarold = rhobar;
double zetaold = zeta;
double thetabar = sbar*rho;
double rhotemp = cbar*rho;
double const rhobarold = rhobar;
double const zetaold = zeta;
double const thetabar = sbar*rho;
double const rhotemp = cbar*rho;
rhobar = this->D2Norm(cbar*rho, thetanew);
cbar = cbar*rho/rhobar;
sbar = thetanew/rhobar;
Expand All @@ -506,20 +503,20 @@ Solve( unsigned int m, unsigned int n, const double * b, double * x )
// Estimate ||r||.

// Apply rotation Qhat_{k,2k+1}.
double betaacute = chat* betadd;
double betacheck = - shat* betadd;
double const betaacute = chat* betadd;
double const betacheck = - shat* betadd;

// Apply rotation Q_{k,k+1}.
double betahat = c*betaacute;
double const betahat = c*betaacute;
betadd = - s*betaacute;

// Apply rotation Qtilde_{k-1}.
// betad = betad_{k-1} here.

double thetatildeold = thetatilde;
double rhotildeold = this->D2Norm(rhodold, thetabar);
double ctildeold = rhodold/rhotildeold;
double stildeold = thetabar/rhotildeold;
double const thetatildeold = thetatilde;
double const rhotildeold = this->D2Norm(rhodold, thetabar);
double const ctildeold = rhodold/rhotildeold;
double const stildeold = thetabar/rhotildeold;
thetatilde = stildeold* rhobar;
rhodold = ctildeold* rhobar;
betad = - stildeold*betad + ctildeold*betahat;
Expand All @@ -528,7 +525,7 @@ Solve( unsigned int m, unsigned int n, const double * b, double * x )
// rhodold = rhod_k here.

tautildeold = (zetaold - thetatildeold*tautildeold)/rhotildeold;
double taud = (zeta - thetatilde*tautildeold)/rhodold;
double const taud = (zeta - thetatilde*tautildeold)/rhodold;
d = d + betacheck*betacheck;
this->normr = sqrt(d + Sqr(betad - taud) + Sqr(betadd));

Expand Down Expand Up @@ -557,9 +554,9 @@ Solve( unsigned int m, unsigned int n, const double * b, double * x )

test1 = this->normr / this->normb;
test2 = this->normAr/(this->normA*this->normr);
double test3 = one/this->condA;
double t1 = test1/(one + this->normA*this->normx/this->normb);
double rtol = this->btol + this->atol*this->normA*normx/this->normb;
double const test3 = one/this->condA;
double const t1 = test1/(one + this->normA*this->normx/this->normb);
double const rtol = this->btol + this->atol*this->normA*normx/this->normb;

// The following tests guard against extremely small values of
// atol, btol or ctol. (The user may have set any or all of
Expand Down Expand Up @@ -593,21 +590,21 @@ Solve( unsigned int m, unsigned int n, const double * b, double * x )
if ( this->istop!=0 ) prnt = true;

if ( prnt ) { // Print a line for this iteration
if ( pcount >= pfreq ) { // Print a heading first
pcount = 0;
if ( damped )
{
(*this->nout) << " Itn x(1) norm rbar Abar'rbar"
" Compatible LS norm Abar cond Abar\n";
} else {
(*this->nout) << " Itn x(1) norm r A'r "
" Compatible LS norm A cond A\n";
}
}
pcount = pcount + 1;
(*this->nout)
<< this->itn << ", " << x[0] << ", " <<this->normr << ", " << this->normAr << ", " << test1 << ", " << test2
<< ", " << this->normA << ", " << this->condA << std::endl;
if ( pcount >= pfreq ) { // Print a heading first
pcount = 0;
if ( damped )
{
(*this->nout) << " Itn x(1) norm rbar Abar'rbar"
" Compatible LS norm Abar cond Abar\n";
} else {
(*this->nout) << " Itn x(1) norm r A'r "
" Compatible LS norm A cond A\n";
}
}
pcount = pcount + 1;
(*this->nout)
<< this->itn << ", " << x[0] << ", " <<this->normr << ", " << this->normAr << ", " << test1 << ", " << test2
<< ", " << this->normA << ", " << this->condA << std::endl;
}
}

Expand All @@ -622,9 +619,9 @@ TerminationPrintOut()
{
if ( this->nout ) {
(*this->nout) << " Exit LSMR. istop = " << this->istop << " ,itn = " << this->itn << std::endl
<< " Exit LSMR. normA = " << this->normA << " ,condA = " << this->condA << std::endl
<< " Exit LSMR. normb = " << this->normb << " ,normx = " << this->normx << std::endl
<< " Exit LSMR. normr = " << this->normr << " ,normAr = " << this->normAr << std::endl
<< " Exit LSMR. " << this->GetStoppingReasonMessage() << std::endl;
<< " Exit LSMR. normA = " << this->normA << " ,condA = " << this->condA << std::endl
<< " Exit LSMR. normb = " << this->normb << " ,normx = " << this->normx << std::endl
<< " Exit LSMR. normr = " << this->normr << " ,normAr = " << this->normAr << std::endl
<< " Exit LSMR. " << this->GetStoppingReasonMessage() << std::endl;
}
}
7 changes: 2 additions & 5 deletions Source/lsmrDense.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,15 @@
* limitations under the License.
*
*=========================================================================*/

#include "lsmrDense.h"

lsmrDense::lsmrDense()
{
this->A = 0;
this->A = nullptr;
}


lsmrDense::~lsmrDense()
{
}
lsmrDense::~lsmrDense() = default;


void
Expand Down
6 changes: 3 additions & 3 deletions Source/lsmrDense.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,23 @@ class lsmrDense : public lsmrBase
public:

lsmrDense();
virtual ~lsmrDense();
~lsmrDense() override;

/**
* computes y = y + A*x without altering x,
* where A is a matrix of dimensions A[m][n].
* The size of the vector x is n.
* The size of the vector y is m.
*/
void Aprod1(unsigned int m, unsigned int n, const double * x, double * y ) const;
void Aprod1(unsigned int m, unsigned int n, const double * x, double * y ) const override;

/**
* computes x = x + A'*y without altering y,
* where A is a matrix of dimensions A[m][n].
* The size of the vector x is n.
* The size of the vector y is m.
*/
void Aprod2(unsigned int m, unsigned int n, double * x, const double * y ) const;
void Aprod2(unsigned int m, unsigned int n, double * x, const double * y ) const override;

/** Set the matrix A of the equation to be solved A*x = b. */
void SetMatrix( double ** A );
Expand Down
Loading

0 comments on commit c29766b

Please sign in to comment.