Prev Next lu_solve.hpp Headings

Source: LuSolve
# ifndef CPPAD_LU_SOLVE_INCLUDED
# define CPPAD_LU_SOLVE_INCLUDED
 
# include <complex>
# include <vector>

# ifdef _MSC_VER
// Fix for problem with MSC standard math functions name space location
# include <cppad/cppad.hpp>  
# endif

# include <cppad/local/cppad_assert.hpp>
# include <cppad/check_simple_vector.hpp>
# include <cppad/check_numeric_type.hpp>
# include <cppad/lu_factor.hpp>
# include <cppad/lu_invert.hpp>

namespace CppAD { // BEGIN CppAD namespace

// LeqZero
template <typename Float>
inline bool LeqZero(const Float &x)
{	return x <= Float(0); }
inline bool LeqZero( const std::complex<double> &x )
{	return x == std::complex<double>(0); }
inline bool LeqZero( const std::complex<float> &x )
{	return x == std::complex<float>(0); }

// LuSolve
template <typename Float, typename FloatVector>
int LuSolve(
	size_t             n      ,
	size_t             m      , 
	const FloatVector &A      , 
	const FloatVector &B      , 
	FloatVector       &X      , 
	Float        &logdet      )
{	
	// check numeric type specifications
	CheckNumericType<Float>();

	// check simple vector class specifications
	CheckSimpleVector<Float, FloatVector>();

	size_t        p;       // index of pivot element (diagonal of L)
	int     signdet;       // sign of the determinant
	Float     pivot;       // pivot element

	// the value zero
	const Float zero(0);

	// pivot row and column order in the matrix
	std::vector<size_t> ip(n);
	std::vector<size_t> jp(n);

	// -------------------------------------------------------
	CPPAD_ASSERT_KNOWN(
		A.size() == n * n,
		"Error in LuSolve: A must have size equal to n * n"
	);
	CPPAD_ASSERT_KNOWN(
		B.size() == n * m,
		"Error in LuSolve: B must have size equal to n * m"
	);
	CPPAD_ASSERT_KNOWN(
		X.size() == n * m,
		"Error in LuSolve: X must have size equal to n * m"
	);
	// -------------------------------------------------------

	// copy A so that it does not change
	FloatVector Lu(A);

	// copy B so that it does not change
	X = B;

	// Lu factor the matrix A
	signdet = LuFactor(ip, jp, Lu);

	// compute the log of the determinant
	logdet  = Float(0);
	for(p = 0; p < n; p++)
	{	// pivot using the max absolute element
		pivot   = Lu[ ip[p] * n + jp[p] ];

		// check for determinant equal to zero
		if( pivot == zero )
		{	// abort the mission
			logdet = Float(0);
			return   0;
		}

		// update the determinant
		if( LeqZero ( pivot ) )
		{	logdet += log( - pivot );
			signdet = - signdet;
		}
		else	logdet += log( pivot );

	}

	// solve the linear equations
	LuInvert(ip, jp, Lu, X);

	// return the sign factor for the determinant
	return signdet;
}
} // END CppAD namespace 
# endif

Input File: omh/lu_solve_hpp.omh